Skip to content

Commit adc32b6

Browse files
authored
Merge pull request #54 from dingraha/override_method
Use the new `override_method` for optional methods in `omjlcomps`
2 parents 45dcfff + b5a73d0 commit adc32b6

3 files changed

Lines changed: 60 additions & 63 deletions

File tree

python/omjlcomps/__init__.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class JuliaExplicitComp(om.ExplicitComponent):
9494
jlcomp : subtype of `OpenMDAOCore.AbstractExplicitComp`
9595
A Julia struct that subtypes `OpenMDAOCore.AbstractExplicitComp`.
9696
Used by `JuliaExplicitComp` to call Julia functions that mimic methods required by an OpenMDAO `ExplicitComponent` (e.g., `OpenMDAOCore.setup`, `OpenMDAOCore.compute!`, `OpenMDAOCore.compute_partials!`, etc.).
97+
noisy_julia_domain_error: bool
98+
If `True`, a `DomainError` thrown in the Julia code will be printed before being re-raised as an OpenMDAO `AnalysisError`.
99+
Otherwise only the `AnalysisError` will be thrown (and perhaps silenced by OpenMDAO, depending on its configuration).
97100
"""
98101

99102
def initialize(self):
@@ -117,9 +120,10 @@ def compute_partials(self, inputs, partials):
117120
jl.OpenMDAOCore.compute_partials_b(self._jlcomp, inputs_dict, partials_dict)
118121
except JuliaError as e:
119122
if jl.isa(e.exception, jl.DomainError):
123+
msg = f"caught Julia DomainError in {self}.compute_partials:\n{e}"
120124
if self.options['noisy_julia_domain_error']:
121-
print(f"caught Julia DomainError in {self}.compute_partials:\n{e}")
122-
raise AnalysisError(f"caught Julia DomainError in {self}.compute_partials:\n{e}")
125+
print(msg)
126+
raise AnalysisError(msg)
123127
else:
124128
raise e from None
125129

@@ -131,10 +135,7 @@ def compute_partials(self, inputs, partials):
131135
wrt_rel = wrt_abs.split(".")[-1]
132136
partials[of_obs, wrt_abs] = _only(partials_dict[of_rel, wrt_rel])
133137

134-
# https://www.ianlewis.org/en/dynamically-adding-method-classes-or-class-instanc
135-
self.compute_partials = MethodType(compute_partials, self)
136-
# Hmm...
137-
self._has_compute_partials = True
138+
self.override_method("compute_partials", compute_partials)
138139

139140
if jl.OpenMDAOCore.has_compute_jacvec_product(self._jlcomp):
140141
def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
@@ -146,9 +147,10 @@ def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
146147
jl.OpenMDAOCore.compute_jacvec_product_b(self._jlcomp, inputs_dict, d_inputs_dict, d_outputs_dict, mode)
147148
except JuliaError as e:
148149
if jl.isa(e.exception, jl.DomainError):
150+
msg = f"caught Julia DomainError in {self}.compute_jacvec_product:\n{e}"
149151
if self.options['noisy_julia_domain_error']:
150-
print(f"caught Julia DomainError in {self}.compute_jacvec_product:\n{e}")
151-
raise AnalysisError(f"caught Julia DomainError in {self}.compute_jacvec_product:\n{e}")
152+
print(msg)
153+
raise AnalysisError(msg)
152154
else:
153155
raise e from None
154156

@@ -165,10 +167,7 @@ def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
165167
else:
166168
raise ValueError(f"unknown mode = {mode} in {self}.compute_jacvec_product")
167169

168-
169-
self.compute_jacvec_product = MethodType(compute_jacvec_product, self)
170-
# https://github.com/OpenMDAO/OpenMDAO/pull/2802
171-
self.matrix_free = True
170+
self.override_method("compute_jacvec_product", compute_jacvec_product)
172171

173172
def setup_partials(self):
174173
_setup_partials_common(self)
@@ -182,8 +181,9 @@ def compute(self, inputs, outputs):
182181
jl.OpenMDAOCore.compute_b(self._jlcomp, inputs_dict, outputs_dict)
183182
except JuliaError as e:
184183
if jl.isa(e.exception, jl.DomainError):
184+
msg = f"caught Julia DomainError in {self}.compute:\n{e}"
185185
if self.options['noisy_julia_domain_error']:
186-
print(f"caught Julia DomainError in {self}.compute:\n{e}")
186+
print(msg)
187187
raise AnalysisError(f"caught Julia DomainError in {self}.compute:\n{e}")
188188
else:
189189
raise e from None
@@ -203,6 +203,9 @@ class JuliaImplicitComp(om.ImplicitComponent):
203203
jlcomp : subtype of `OpenMDAOCore.AbstractImplicitComp`
204204
A Julia struct that subtypes `OpenMDAOCore.AbstractImplicitComp`.
205205
Used by `JuliaImplicitComp` to call Julia functions that mimic methods required by an OpenMDAO `ImplicitComponent` (e.g., `OpenMDAOCore.setup`, `OpenMDAOCore.apply_nonlinear!`, `OpenMDAOCore.linearize!`, etc.).
206+
noisy_julia_domain_error: bool
207+
If `True`, a `DomainError` thrown in the Julia code will be printed before being re-raised as an OpenMDAO `AnalysisError`.
208+
Otherwise only the `AnalysisError` will be thrown (and perhaps silenced by OpenMDAO, depending on its configuration).
206209
"""
207210

208211
def initialize(self):
@@ -211,29 +214,6 @@ def initialize(self):
211214
def setup(self):
212215
_setup_common(self)
213216

214-
if jl.OpenMDAOCore.has_apply_nonlinear(self._jlcomp):
215-
def apply_nonlinear(self, inputs, outputs, residuals):
216-
inputs_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in inputs.items()})
217-
outputs_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in outputs.items()})
218-
residuals_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in residuals.items()})
219-
220-
try:
221-
jl.OpenMDAOCore.apply_nonlinear_b(self._jlcomp, inputs_dict, outputs_dict, residuals_dict)
222-
except JuliaError as e:
223-
if jl.isa(e.exception, jl.DomainError):
224-
if self.options['noisy_julia_domain_error']:
225-
print(f"caught Julia DomainError in {self}.apply_nonlinear:\n{e}")
226-
raise AnalysisError(f"caught Julia DomainError in {self}.apply_nonlinear:\n{e}")
227-
else:
228-
raise e from None
229-
230-
# Handle scalar entries in residuals, which aren't passed by reference when constructing residuals_dict.
231-
for k in list(residuals.keys()):
232-
if not isinstance(residuals[k], np.ndarray):
233-
residuals[k] = _only(residuals_dict[k])
234-
235-
self.apply_nonlinear = MethodType(apply_nonlinear, self)
236-
237217
if jl.OpenMDAOCore.has_solve_nonlinear(self._jlcomp):
238218
def solve_nonlinear(self, inputs, outputs):
239219
inputs_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in inputs.items()})
@@ -243,9 +223,10 @@ def solve_nonlinear(self, inputs, outputs):
243223
jl.OpenMDAOCore.solve_nonlinear_b(self._jlcomp, inputs_dict, outputs_dict)
244224
except JuliaError as e:
245225
if jl.isa(e.exception, jl.DomainError):
226+
msg = f"caught Julia DomainError in {self}.solve_nonlinear:\n{e}"
246227
if self.options['noisy_julia_domain_error']:
247-
print(f"caught Julia DomainError in {self}.solve_nonlinear:\n{e}")
248-
raise AnalysisError(f"caught Julia DomainError in {self}.solve_nonlinear:\n{e}")
228+
print(msg)
229+
raise AnalysisError(msg)
249230
else:
250231
raise e from None
251232

@@ -254,9 +235,7 @@ def solve_nonlinear(self, inputs, outputs):
254235
if not isinstance(outputs[k], np.ndarray):
255236
outputs[k] = _only(outputs_dict[k])
256237

257-
self.solve_nonlinear = MethodType(solve_nonlinear, self)
258-
# https://github.com/OpenMDAO/OpenMDAO/pull/2802
259-
self._has_solve_nl = True
238+
self.override_method("solve_nonlinear", solve_nonlinear)
260239

261240
if jl.OpenMDAOCore.has_linearize(self._jlcomp):
262241
def linearize(self, inputs, outputs, partials):
@@ -274,9 +253,10 @@ def linearize(self, inputs, outputs, partials):
274253
jl.OpenMDAOCore.linearize_b(self._jlcomp, inputs_dict, outputs_dict, partials_dict)
275254
except JuliaError as e:
276255
if jl.isa(e.exception, jl.DomainError):
256+
msg = f"caught Julia DomainError in {self}.linearize:\n{e}"
277257
if self.options['noisy_julia_domain_error']:
278-
print(f"caught Julia DomainError in {self}.linearize:\n{e}")
279-
raise AnalysisError(f"caught Julia DomainError in {self}.linearize:\n{e}")
258+
print(msg)
259+
raise AnalysisError(msg)
280260
else:
281261
raise e from None
282262

@@ -288,8 +268,7 @@ def linearize(self, inputs, outputs, partials):
288268
wrt_rel = wrt_abs.split(".")[-1]
289269
partials[of_obs, wrt_abs] = _only(partials_dict[of_rel, wrt_rel])
290270

291-
self.linearize = MethodType(linearize, self)
292-
self._has_linearize = True
271+
self.override_method("linearize", linearize)
293272

294273
if jl.OpenMDAOCore.has_apply_linear(self._jlcomp):
295274
def apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
@@ -304,9 +283,10 @@ def apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
304283
d_inputs_dict, d_outputs_dict, d_residuals_dict, mode)
305284
except JuliaError as e:
306285
if jl.isa(e.exception, jl.DomainError):
286+
msg = f"caught Julia DomainError in {self}.apply_linear:\n{e}"
307287
if self.options['noisy_julia_domain_error']:
308-
print(f"caught Julia DomainError in {self}.apply_linear:\n{e}")
309-
raise AnalysisError(f"caught Julia DomainError in {self}.apply_linear:\n{e}")
288+
print(msg)
289+
raise AnalysisError(msg)
310290
else:
311291
raise e from None
312292

@@ -327,9 +307,7 @@ def apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
327307
else:
328308
raise ValueError(f"unknown mode = {mode} in {self}.apply_linear")
329309

330-
self.apply_linear = MethodType(apply_linear, self)
331-
# https://github.com/OpenMDAO/OpenMDAO/pull/2802
332-
self.matrix_free = True
310+
self.override_method("apply_linear", apply_linear)
333311

334312
if jl.OpenMDAOCore.has_solve_linear(self._jlcomp):
335313
def solve_linear(self, d_outputs, d_residuals, mode):
@@ -340,9 +318,10 @@ def solve_linear(self, d_outputs, d_residuals, mode):
340318
jl.OpenMDAOCore.solve_linear_b(self._jlcomp, d_outputs_dict, d_residuals_dict, mode)
341319
except JuliaError as e:
342320
if jl.isa(e.exception, jl.DomainError):
321+
msg = f"caught Julia DomainError in {self}.solve_linear:\n{e}"
343322
if self.options['noisy_julia_domain_error']:
344-
print(f"caught Julia DomainError in {self}.solve_linear:\n{e}")
345-
raise AnalysisError(f"caught Julia DomainError in {self}.solve_linear:\n{e}")
323+
print(msg)
324+
raise AnalysisError(msg)
346325
else:
347326
raise e from None
348327

@@ -359,7 +338,7 @@ def solve_linear(self, d_outputs, d_residuals, mode):
359338
else:
360339
raise ValueError(f"unknown mode = {mode} in {self}.solve_linear")
361340

362-
self.solve_linear = MethodType(solve_linear, self)
341+
self.override_method("solve_linear", solve_linear)
363342

364343
def setup_partials(self):
365344
_setup_partials_common(self)
@@ -377,9 +356,10 @@ def guess_nonlinear(self, inputs, outputs, residuals):
377356
jl.OpenMDAOCore.guess_nonlinear_b(self._jlcomp, inputs_dict, outputs_dict, residuals_dict)
378357
except JuliaError as e:
379358
if jl.isa(e.exception, jl.DomainError):
359+
msg = f"caught Julia DomainError in {self}.guess_nonlinear:\n{e}"
380360
if self.options['noisy_julia_domain_error']:
381-
print(f"caught Julia DomainError in {self}.guess_nonlinear:\n{e}")
382-
raise AnalysisError(f"caught Julia DomainError in {self}.guess_nonlinear:\n{e}")
361+
print(msg)
362+
raise AnalysisError(msg)
383363
else:
384364
raise e from None
385365

@@ -388,9 +368,28 @@ def guess_nonlinear(self, inputs, outputs, residuals):
388368
if not isinstance(outputs[k], np.ndarray):
389369
outputs[k] = _only(outputs_dict[k])
390370

391-
self.guess_nonlinear = MethodType(guess_nonlinear, self)
392-
# Hmm...
393-
self._has_guess = True
371+
self.override_method("guess_nonlinear", guess_nonlinear)
372+
373+
def apply_nonlinear(self, inputs, outputs, residuals):
374+
inputs_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in inputs.items()})
375+
outputs_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in outputs.items()})
376+
residuals_dict = juliacall.convert(jl.Dict, {k: np.atleast_1d(v) for k, v in residuals.items()})
377+
378+
try:
379+
jl.OpenMDAOCore.apply_nonlinear_b(self._jlcomp, inputs_dict, outputs_dict, residuals_dict)
380+
except JuliaError as e:
381+
if jl.isa(e.exception, jl.DomainError):
382+
msg = f"caught Julia DomainError in {self}.apply_nonlinear:\n{e}"
383+
if self.options['noisy_julia_domain_error']:
384+
print(msg)
385+
raise AnalysisError(msg)
386+
else:
387+
raise e from None
388+
389+
# Handle scalar entries in residuals, which aren't passed by reference when constructing residuals_dict.
390+
for k in list(residuals.keys()):
391+
if not isinstance(residuals[k], np.ndarray):
392+
residuals[k] = _only(residuals_dict[k])
394393

395394

396395
def to_jlsymstrdict(d):

python/omjlcomps/test/test_explicit_ad_shape_by_conn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import os
33
import unittest
44

5-
import aviary.api as av
6-
75
import numpy as np
86
from numpy.random import rand
97

python/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ description = "Create OpenMDAO Components using the Julia programming language"
44
readme = "README.md"
55
keywords = ["openmdao_component"]
66
license = {text = "MIT"}
7-
version = "0.2.6"
7+
version = "0.2.7"
88

99
dependencies = [
10-
"openmdao~=3.36",
10+
"openmdao~=3.42",
1111
"juliapkg~=0.1.10",
1212
"juliacall~=0.9.13",
1313
]
1414

1515
[project.optional-dependencies]
16-
test = ["om-aviary"]
16+
test = ["aviary @ git+https://github.com/OpenMDAO/Aviary#egg=main"]
1717

1818
[project.entry-points.openmdao_component]
1919
juliaexplicitcomp = "omjlcomps:JuliaExplicitComp"

0 commit comments

Comments
 (0)