Skip to content

Commit 7c73228

Browse files
committed
ENH: add vis args to classes .draw
1 parent 170e89c commit 7c73228

File tree

11 files changed

+88
-31
lines changed

11 files changed

+88
-31
lines changed

rocketpy/motors/hybrid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def add_tank(self, tank, position):
624624
)
625625
reset_funcified_methods(self)
626626

627-
def draw(self, *, filename=None):
627+
def draw(self, vis_args=None, *, filename=None):
628628
"""Draws a representation of the HybridMotor.
629629
630630
Parameters
@@ -639,7 +639,7 @@ def draw(self, *, filename=None):
639639
-------
640640
None
641641
"""
642-
self.plots.draw(filename=filename)
642+
self.plots.draw(vis_args=vis_args, filename=filename)
643643

644644
def to_dict(self, **kwargs):
645645
data = super().to_dict(**kwargs)

rocketpy/motors/liquid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def add_tank(self, tank, position):
480480
self.positioned_tanks.append({"tank": tank, "position": position})
481481
reset_funcified_methods(self)
482482

483-
def draw(self, *, filename=None):
483+
def draw(self, vis_args=None, *, filename=None):
484484
"""Draw a representation of the LiquidMotor.
485485
486486
Parameters
@@ -495,7 +495,7 @@ def draw(self, *, filename=None):
495495
-------
496496
None
497497
"""
498-
self.plots.draw(filename=filename)
498+
self.plots.draw(vis_args=vis_args, filename=filename)
499499

500500
def to_dict(self, **kwargs):
501501
data = super().to_dict(**kwargs)

rocketpy/motors/solid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def propellant_I_13(self):
748748
def propellant_I_23(self):
749749
return 0
750750

751-
def draw(self, *, filename=None):
751+
def draw(self, vis_args=None, *, filename=None):
752752
"""Draw a representation of the SolidMotor.
753753
754754
Parameters
@@ -763,7 +763,7 @@ def draw(self, *, filename=None):
763763
-------
764764
None
765765
"""
766-
self.plots.draw(filename=filename)
766+
self.plots.draw(vis_args=vis_args, filename=filename)
767767

768768
def to_dict(self, **kwargs):
769769
data = super().to_dict(**kwargs)

rocketpy/motors/tank.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def underfill_height_exception(param_name, param):
591591
def _discretize_fluid_inputs(self):
592592
"""Uniformly discretizes the parameter of inputs of fluid data ."""
593593

594-
def draw(self, *, filename=None):
594+
def draw(self, vis_args=None, *, filename=None):
595595
"""Draws the tank geometry.
596596
597597
Parameters
@@ -606,7 +606,7 @@ def draw(self, *, filename=None):
606606
-------
607607
None
608608
"""
609-
self.plots.draw(filename=filename)
609+
self.plots.draw(vis_args=vis_args, filename=filename)
610610

611611
def info(self):
612612
"""Prints out a summary of the tank properties."""

rocketpy/plots/aero_surface_plots.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, aero_surface):
2525
self.aero_surface = aero_surface
2626

2727
@abstractmethod
28-
def draw(self, *, filename=None):
28+
def draw(self, vis_args=None, *, filename=None):
2929
pass
3030

3131
def lift(self):
@@ -54,7 +54,7 @@ class _NoseConePlots(_AeroSurfacePlots):
5454
"""Class that contains all nosecone plots. This class inherits from the
5555
_AeroSurfacePlots class."""
5656

57-
def draw(self, *, filename=None):
57+
def draw(self, vis_args=None, *, filename=None):
5858
"""Draw the nosecone shape along with some important information,
5959
including the center line and the center of pressure position.
6060
@@ -73,14 +73,17 @@ def draw(self, *, filename=None):
7373
# Create the vectors X and Y with the points of the curve
7474
nosecone_x, nosecone_y = self.aero_surface.shape_vec
7575

76+
if vis_args is None:
77+
vis_args = {}
78+
7679
# Figure creation and set up
7780
_, ax = plt.subplots()
7881
ax.set_xlim(-0.05, self.aero_surface.length * 1.02) # Horizontal size
7982
ax.set_ylim(
8083
-self.aero_surface.base_radius * 1.05, self.aero_surface.base_radius * 1.05
8184
) # Vertical size
8285
ax.set_aspect("equal") # Makes the graduation be the same on both axis
83-
ax.set_facecolor("#EEEEEE") # Background color
86+
ax.set_facecolor(vis_args.get("background", "#EEEEEE")) # Background color
8487
ax.grid(True, linestyle="--", linewidth=0.5)
8588

8689
cp_plot = (self.aero_surface.cpz, 0)
@@ -140,7 +143,7 @@ class _FinsPlots(_AeroSurfacePlots):
140143
_AeroSurfacePlots class."""
141144

142145
@abstractmethod
143-
def draw(self, *, filename=None):
146+
def draw(self, vis_args=None, *, filename=None):
144147
pass
145148

146149
def airfoil(self):
@@ -201,7 +204,7 @@ class _TrapezoidalFinsPlots(_FinsPlots):
201204
"""Class that contains all trapezoidal fin plots."""
202205

203206
# pylint: disable=too-many-statements
204-
def draw(self, *, filename=None):
207+
def draw(self, vis_args=None, *, filename=None):
205208
"""Draw the fin shape along with some important information, including
206209
the center line, the quarter line and the center of pressure position.
207210
@@ -291,10 +294,14 @@ def draw(self, *, filename=None):
291294
label="Mean Aerodynamic Chord",
292295
)
293296

297+
if vis_args is None:
298+
vis_args = {}
299+
294300
# Plotting
295301
fig = plt.figure(figsize=(7, 4))
296302
with plt.style.context("bmh"):
297303
ax = fig.add_subplot(111)
304+
fig.patch.set_facecolor(vis_args.get("background", "#EEEEEE"))
298305

299306
# Fin
300307
ax.add_line(leading_edge)
@@ -330,7 +337,7 @@ class _EllipticalFinsPlots(_FinsPlots):
330337
"""Class that contains all elliptical fin plots."""
331338

332339
# pylint: disable=too-many-statements
333-
def draw(self, *, filename=None):
340+
def draw(self, vis_args=None, *, filename=None):
334341
"""Draw the fin shape along with some important information.
335342
These being: the center line and the center of pressure position.
336343
@@ -383,10 +390,14 @@ def draw(self, *, filename=None):
383390
# Center of pressure
384391
cp_point = [self.aero_surface.cpz, self.aero_surface.Yma]
385392

393+
if vis_args is None:
394+
vis_args = {}
395+
386396
# Plotting
387397
fig = plt.figure(figsize=(7, 4))
388398
with plt.style.context("bmh"):
389399
ax = fig.add_subplot(111)
400+
fig.patch.set_facecolor(vis_args.get("background", "#EEEEEE"))
390401
ax.add_patch(ellipse)
391402
ax.add_line(yma_line)
392403
ax.add_line(center_line)
@@ -409,7 +420,7 @@ class _FreeFormFinsPlots(_FinsPlots):
409420
"""Class that contains all free form fin plots."""
410421

411422
# pylint: disable=too-many-statements
412-
def draw(self, *, filename=None):
423+
def draw(self, vis_args=None, *, filename=None):
413424
"""Draw the fin shape along with some important information.
414425
These being: the center line and the center of pressure position.
415426
@@ -443,10 +454,14 @@ def draw(self, *, filename=None):
443454
label="Mean Aerodynamic Chord",
444455
)
445456

457+
if vis_args is None:
458+
vis_args = {}
459+
446460
# Plotting
447461
fig = plt.figure(figsize=(7, 4))
448462
with plt.style.context("bmh"):
449463
ax = fig.add_subplot(111)
464+
fig.patch.set_facecolor(vis_args.get("background", "#EEEEEE"))
450465

451466
# Fin
452467
ax.scatter(
@@ -483,7 +498,7 @@ def draw(self, *, filename=None):
483498
class _TailPlots(_AeroSurfacePlots):
484499
"""Class that contains all tail plots."""
485500

486-
def draw(self, *, filename=None):
501+
def draw(self, vis_args=None, *, filename=None):
487502
# This will de done in the future
488503
pass
489504

@@ -498,7 +513,7 @@ def drag_coefficient_curve(self):
498513
else:
499514
return self.aero_surface.drag_coefficient.plot()
500515

501-
def draw(self, *, filename=None):
516+
def draw(self, vis_args=None, *, filename=None):
502517
raise NotImplementedError
503518

504519
def all(self):
@@ -514,12 +529,12 @@ def all(self):
514529
class _GenericSurfacePlots(_AeroSurfacePlots):
515530
"""Class that contains all generic surface plots."""
516531

517-
def draw(self, *, filename=None):
532+
def draw(self, vis_args=None, *, filename=None):
518533
pass
519534

520535

521536
class _LinearGenericSurfacePlots(_AeroSurfacePlots):
522537
"""Class that contains all linear generic surface plots."""
523538

524-
def draw(self, *, filename=None):
539+
def draw(self, vis_args=None, *, filename=None):
525540
pass

rocketpy/plots/hybrid_motor_plots.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def Kn(self, lower_limit=None, upper_limit=None, *, filename=None):
144144
lower=lower_limit, upper=upper_limit, filename=filename
145145
)
146146

147-
def draw(self, *, filename=None):
147+
def draw(self, vis_args=None, *, filename=None):
148148
"""Draw a representation of the HybridMotor.
149149
150150
Parameters
@@ -159,7 +159,12 @@ def draw(self, *, filename=None):
159159
-------
160160
None
161161
"""
162-
_, ax = plt.subplots(figsize=(8, 6), facecolor="#EEEEEE")
162+
if vis_args is None:
163+
vis_args = {}
164+
165+
_, ax = plt.subplots(
166+
figsize=(8, 6), facecolor=vis_args.get("background", "#EEEEEE")
167+
)
163168

164169
tanks_and_centers = self._generate_positioned_tanks(csys=self.motor._csys)
165170
nozzle = self._generate_nozzle(

rocketpy/plots/liquid_motor_plots.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class _LiquidMotorPlots(_MotorPlots):
1414
1515
"""
1616

17-
def draw(self, *, filename=None):
17+
def draw(self, vis_args=None, *, filename=None):
1818
"""Draw a representation of the LiquidMotor.
1919
2020
Parameters
@@ -29,7 +29,12 @@ def draw(self, *, filename=None):
2929
-------
3030
None
3131
"""
32-
_, ax = plt.subplots(figsize=(8, 6), facecolor="#EEEEEE")
32+
if vis_args is None:
33+
vis_args = {}
34+
35+
_, ax = plt.subplots(
36+
figsize=(8, 6), facecolor=vis_args.get("background", "#EEEEEE")
37+
)
3338

3439
tanks_and_centers = self._generate_positioned_tanks(csys=self.motor._csys)
3540
nozzle = self._generate_nozzle(

rocketpy/plots/solid_motor_plots.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,28 @@ def Kn(self, lower_limit=None, upper_limit=None):
113113

114114
self.motor.Kn.plot(lower=lower_limit, upper=upper_limit)
115115

116-
def draw(self, *, filename=None):
116+
def draw(self, vis_args=None, *, filename=None):
117117
"""Draw a representation of the SolidMotor.
118118
119119
Parameters
120120
----------
121+
vis_args : dict, optional
122+
Determines the visual aspects when drawing the solid motor. If
123+
``None``, default values are used. Default values are:
124+
125+
{
126+
"background": "#EEEEEE",
127+
"tail": "black",
128+
"nose": "black",
129+
"body": "black",
130+
"fins": "black",
131+
"motor": "black",
132+
"buttons": "black",
133+
"line_width": 2.0,
134+
}
135+
136+
A full list of color names can be found at:
137+
https://matplotlib.org/stable/gallery/color/named_colors
121138
filename : str | None, optional
122139
The path the plot should be saved to. By default None, in which case
123140
the plot will be shown instead of saved. Supported file endings are:
@@ -128,7 +145,19 @@ def draw(self, *, filename=None):
128145
-------
129146
None
130147
"""
131-
_, ax = plt.subplots(figsize=(8, 6), facecolor="#EEEEEE")
148+
if vis_args is None:
149+
vis_args = {
150+
"background": "#EEEEEE",
151+
"tail": "black",
152+
"nose": "black",
153+
"body": "black",
154+
"fins": "black",
155+
"motor": "black",
156+
"buttons": "black",
157+
"line_width": 2.0,
158+
}
159+
160+
_, ax = plt.subplots(figsize=(8, 6), facecolor=vis_args["background"])
132161

133162
nozzle = self._generate_nozzle(
134163
translate=(self.motor.nozzle_position, 0), csys=self.motor._csys

rocketpy/plots/tank_plots.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _generate_tank(self, translate=(0, 0), csys=1):
7272
# Don't set any plot config here. Use the draw methods for that
7373
return tank
7474

75-
def draw(self, *, filename=None):
75+
def draw(self, vis_args=None, *, filename=None):
7676
"""Draws the tank geometry.
7777
7878
Parameters
@@ -87,7 +87,10 @@ def draw(self, *, filename=None):
8787
-------
8888
None
8989
"""
90-
_, ax = plt.subplots(facecolor="#EEEEEE")
90+
if vis_args is None:
91+
vis_args = {}
92+
93+
_, ax = plt.subplots(facecolor=vis_args.get("background", "#EEEEEE"))
9194

9295
ax.add_patch(self._generate_tank())
9396

rocketpy/rocket/aero_surface/fins/fins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def to_dict(self, **kwargs):
468468

469469
return data
470470

471-
def draw(self, *, filename=None):
471+
def draw(self, vis_args=None, *, filename=None):
472472
"""Draw the fin shape along with some important information, including
473473
the center line, the quarter line and the center of pressure position.
474474
@@ -484,4 +484,4 @@ def draw(self, *, filename=None):
484484
-------
485485
None
486486
"""
487-
self.plots.draw(filename=filename)
487+
self.plots.draw(vis_args=vis_args, filename=filename)

0 commit comments

Comments
 (0)