Skip to content

Commit f6d5726

Browse files
committed
Make it possible to pass shape from cli when generating python code
1 parent ad26d33 commit f6d5726

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/gotranx/cli/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ..schemes import Scheme, get_scheme
99
from ..codegen import PythonFormat, CFormat
10+
from ..codegen.base import Shape
1011
from . import gotran2c, gotran2py, gotran2julia
1112
from . import utils
1213

@@ -286,6 +287,12 @@ def ode2py(
286287
"-b",
287288
help="Backend for the generated code",
288289
),
290+
shape: Shape = typer.Option(
291+
Shape.dynamic,
292+
"--shape",
293+
"-S",
294+
help="Shape of the output arrays",
295+
),
289296
):
290297
if fname is None:
291298
return typer.echo("No file specified")
@@ -295,6 +302,7 @@ def ode2py(
295302
delta = config_data.get("delta", delta)
296303
stiff_states = config_data.get("stiff_states", stiff_states)
297304
scheme = config_data.get("scheme", scheme)
305+
shape = Shape(config_data.get("shape", shape))
298306
scheme = utils.validate_scheme(scheme)
299307
py_config = config_data.get("python", {})
300308
format = PythonFormat(py_config.get("format", format))
@@ -310,6 +318,7 @@ def ode2py(
310318
delta=delta,
311319
format=format,
312320
backend=backend,
321+
shape=shape,
313322
)
314323

315324

src/gotranx/cli/gotran2py.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def main(
118118
delta: float = 1e-8,
119119
suffix: str = ".py",
120120
backend: Backend = Backend.numpy,
121+
shape: Shape = Shape.dynamic,
121122
) -> None:
122123
loglevel = logging.DEBUG if verbose else logging.INFO
123124
structlog.configure(
@@ -134,6 +135,7 @@ def main(
134135
stiff_states=stiff_states,
135136
delta=delta,
136137
backend=backend,
138+
shape=shape,
137139
)
138140
out = fname if outname is None else Path(outname)
139141
out_name = out.with_suffix(suffix=suffix)

0 commit comments

Comments
 (0)