Skip to content

Commit 4127313

Browse files
Merge pull request #521 from RocketPy-Team/bug/function-2d-discretize
BUG: Invalid Arguments on Two Dimensional Discretize (HOTFIX).
2 parents fa3d9a7 + 223d598 commit 4127313

File tree

3 files changed

+59
-13
lines changed

3 files changed

+59
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ straightforward as possible.
4040

4141
### Fixed
4242

43-
-
43+
- BUG: Invalid Arguments on Two Dimensional Discretize. [#521](https://github.com/RocketPy-Team/RocketPy/pull/521)
4444

4545
## [v1.1.4] - 2023-12-07
4646

rocketpy/mathutils/function.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def set_source(self, source):
189189
self : Function
190190
Returns the Function instance.
191191
"""
192-
_ = self._check_user_input(
192+
*_, interpolation, extrapolation = self._check_user_input(
193193
source,
194194
self.__inputs__,
195195
self.__outputs__,
@@ -277,10 +277,10 @@ def source_function(_):
277277
self.source = source
278278
# Update extrapolation method
279279
if self.__extrapolation__ is None:
280-
self.set_extrapolation()
280+
self.set_extrapolation(extrapolation)
281281
# Set default interpolation for point source if it hasn't
282282
if self.__interpolation__ is None:
283-
self.set_interpolation()
283+
self.set_interpolation(interpolation)
284284
else:
285285
# Updates interpolation coefficients
286286
self.set_interpolation(self.__interpolation__)
@@ -560,14 +560,12 @@ def set_discrete(
560560
# Create nodes to evaluate function
561561
xs = np.linspace(lower[0], upper[0], sam[0])
562562
ys = np.linspace(lower[1], upper[1], sam[1])
563-
xs, ys = np.meshgrid(xs, ys)
564-
xs, ys = xs.flatten(), ys.flatten()
565-
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
563+
xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size)
566564
# Evaluate function at all mesh nodes and convert it to matrix
567-
zs = np.array(self.get_value(mesh))
568-
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
565+
zs = np.array(self.get_value(xs, ys))
569566
self.__interpolation__ = "shepard"
570567
self.__extrapolation__ = "natural"
568+
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
571569
return self
572570

573571
def set_discrete_based_on_model(
@@ -664,11 +662,8 @@ def set_discrete_based_on_model(
664662
# Create nodes to evaluate function
665663
xs = model_function.source[:, 0]
666664
ys = model_function.source[:, 1]
667-
xs, ys = np.meshgrid(xs, ys)
668-
xs, ys = xs.flatten(), ys.flatten()
669-
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
670665
# Evaluate function at all mesh nodes and convert it to matrix
671-
zs = np.array(self.get_value(mesh))
666+
zs = np.array(self.get_value(xs, ys))
672667
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
673668

674669
interp = (
@@ -2860,6 +2855,8 @@ def _check_user_input(
28602855

28612856
# check source for data type
28622857
# if list or ndarray, check for dimensions, interpolation and extrapolation
2858+
if isinstance(source, Function):
2859+
source = source.get_source()
28632860
if isinstance(source, (list, np.ndarray, str, Path)):
28642861
# Deal with csv or txt
28652862
if isinstance(source, (str, Path)):

tests/test_function.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,55 @@ def test_multivariable_function_plot(mock_show):
372372
assert func.plot() == None
373373

374374

375+
def test_set_discrete_2d():
376+
"""Tests the set_discrete method of the Function for
377+
two dimensional domains.
378+
"""
379+
func = Function(lambda x, y: x**2 + y**2)
380+
discretized_func = func.set_discrete([-5, -7], [8, 10], [50, 100])
381+
382+
assert isinstance(discretized_func, Function)
383+
assert isinstance(func, Function)
384+
assert discretized_func.source.shape == (50 * 100, 3)
385+
assert np.isclose(discretized_func.source[0, 0], -5)
386+
assert np.isclose(discretized_func.source[0, 1], -7)
387+
assert np.isclose(discretized_func.source[-1, 0], 8)
388+
assert np.isclose(discretized_func.source[-1, 1], 10)
389+
390+
391+
def test_set_discrete_2d_simplified():
392+
"""Tests the set_discrete method of the Function for
393+
two dimensional domains with simplified inputs.
394+
"""
395+
source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
396+
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
397+
discretized_func = func.set_discrete(-1, 1, 10)
398+
399+
assert isinstance(discretized_func, Function)
400+
assert isinstance(func, Function)
401+
assert discretized_func.source.shape == (100, 3)
402+
assert np.isclose(discretized_func.source[0, 0], -1)
403+
assert np.isclose(discretized_func.source[0, 1], -1)
404+
assert np.isclose(discretized_func.source[-1, 0], 1)
405+
assert np.isclose(discretized_func.source[-1, 1], 1)
406+
407+
408+
def test_set_discrete_based_on_2d_model(func_2d_from_csv):
409+
"""Tests the set_discrete_based_on_model method with a 2d model
410+
Function.
411+
"""
412+
func = Function(lambda x, y: x**2 + y**2)
413+
discretized_func = func.set_discrete_based_on_model(func_2d_from_csv)
414+
415+
assert isinstance(discretized_func, Function)
416+
assert isinstance(func, Function)
417+
assert np.array_equal(
418+
discretized_func.source[:, :2], func_2d_from_csv.source[:, :2]
419+
)
420+
assert discretized_func.__interpolation__ == func_2d_from_csv.__interpolation__
421+
assert discretized_func.__extrapolation__ == func_2d_from_csv.__extrapolation__
422+
423+
375424
@pytest.mark.parametrize(
376425
"x,y,z_expected",
377426
[

0 commit comments

Comments
 (0)