Skip to content

Commit 9e1c380

Browse files
committed
better docstrings
1 parent 3714aa3 commit 9e1c380

13 files changed

Lines changed: 381 additions & 39 deletions

src/meow/arrays.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,43 @@ def _assert_dtype(arr: np.ndarray, dtype: str) -> np.ndarray:
131131

132132

133133
def Dim(ndim: int, *, coerce: bool = True) -> AfterValidator: # noqa: N802
134-
"""Validator to ensure the array has a specific number of dimensions."""
134+
"""Validator to ensure the array has a specific number of dimensions.
135+
136+
Args:
137+
ndim: the required number of dimensions.
138+
coerce: if True, reshape the array to match; if False, raise on mismatch.
139+
140+
Returns:
141+
A pydantic AfterValidator that checks or coerces the array dimensions.
142+
"""
135143
f = _coerce_dim if coerce else _assert_dim
136144
return AfterValidator(partial(f, ndim=ndim))
137145

138146

139147
def DType(dtype: str, *, coerce: bool = True) -> AfterValidator: # noqa: N802
140-
"""Validator to ensure the array has a specific dtype."""
148+
"""Validator to ensure the array has a specific dtype.
149+
150+
Args:
151+
dtype: the required data type as a string (e.g. "float64").
152+
coerce: if True, cast the array to the dtype; if False, raise on mismatch.
153+
154+
Returns:
155+
A pydantic AfterValidator that checks or coerces the array dtype.
156+
"""
141157
f = _coerce_dtype if coerce else _assert_dtype
142158
return AfterValidator(partial(f, dtype=dtype))
143159

144160

145161
def Shape(*shape: int, coerce: bool = True) -> AfterValidator: # noqa: N802
146-
"""Validator to ensure the array has a specific shape."""
162+
"""Validator to ensure the array has a specific shape.
163+
164+
Args:
165+
*shape: the required dimensions of the array.
166+
coerce: if True, reshape the array to match; if False, raise on mismatch.
167+
168+
Returns:
169+
A pydantic AfterValidator that checks or coerces the array shape.
170+
"""
147171
f = _coerce_shape if coerce else _assert_shape
148172
return AfterValidator(partial(f, shape=shape))
149173

src/meow/base_model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,14 @@ def _visualize(self, **_: Any) -> None:
181181

182182

183183
def cache(prop: Callable) -> Callable:
184-
"""Decorator to cache the result of a property method."""
184+
"""Decorator to cache the result of a property method.
185+
186+
Args:
187+
prop: the property method whose result should be cached.
188+
189+
Returns:
190+
A wrapped callable that caches and returns the computed value.
191+
"""
185192
prop_name = getattr(prop, "__name__", "")
186193
if not prop_name:
187194
return prop
@@ -201,7 +208,14 @@ def getter(self): # noqa: ANN001,ANN202
201208

202209

203210
def cached_property(method: Callable): # noqa: ANN201
204-
"""Decorator to cache the result of a property method."""
211+
"""Decorator to cache the result of a property method.
212+
213+
Args:
214+
method: the method to wrap as a cached property.
215+
216+
Returns:
217+
A property descriptor that caches the method's return value.
218+
"""
205219
return property(cache(method))
206220

207221

src/meow/cell.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,17 @@ def create_cells(
122122
Ls: Annotated[NDArray, Dim(1), DType("float64")],
123123
z_min: float = 0.0,
124124
) -> list[Cell]:
125-
"""Create multiple `Cell` objects with a `Mesh` and a collection of cell lengths."""
125+
"""Create multiple `Cell` objects with a `Mesh` and a collection of cell lengths.
126+
127+
Args:
128+
structures: the 3D structures shared by all cells.
129+
mesh: a single mesh or a list of meshes, one per cell.
130+
Ls: a 1D array of cell lengths.
131+
z_min: the starting z-coordinate for the first cell.
132+
133+
Returns:
134+
A list of Cell objects spanning the given lengths.
135+
"""
126136
Ls = np.asarray(Ls, float)
127137
if Ls.ndim != 1:
128138
msg = f"Ls should be 1D. Got shape: {Ls.shape}."

src/meow/eme/cascade.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,20 @@ def compute_s_matrix_sax(
2929
propagations_fn: Callable = compute_propagation_s_matrices,
3030
**_: Any,
3131
) -> sax.SDenseMM:
32-
"""Calculate the S-matrix for given sets of modes."""
32+
"""Calculate the S-matrix for given sets of modes.
33+
34+
Args:
35+
modes: Modal basis for each cell in the stack.
36+
cells: Cells from which to derive propagation lengths. Either cells
37+
or cell_lengths must be provided.
38+
cell_lengths: Optional explicit propagation lengths per cell.
39+
sax_backend: SAX backend used for circuit evaluation.
40+
interfaces_fn: Callable that computes interface S-matrices.
41+
propagations_fn: Callable that computes propagation S-matrices.
42+
43+
Returns:
44+
A tuple ``(S, port_map)`` in SAX dense multimode format.
45+
"""
3346
propagations = propagations_fn(modes, cells, cell_lengths=cell_lengths)
3447
interfaces = interfaces_fn(modes)
3548
net = _get_netlist(propagations, interfaces)
@@ -94,7 +107,15 @@ def _get_netlist(
94107

95108

96109
def downselect_s(S: sax.SDenseMM, ports: list[str]) -> sax.SDenseMM:
97-
"""Downselect the S-matrix to the given ports."""
110+
"""Downselect the S-matrix to the given ports.
111+
112+
Args:
113+
S: A tuple ``(S_matrix, port_map)`` in SAX dense multimode format.
114+
ports: Port names to keep.
115+
116+
Returns:
117+
A new ``(S_matrix, port_map)`` tuple containing only the selected ports.
118+
"""
98119
S_matrix, port_map = S
99120
idxs = [port_map[port] for port in ports]
100121
S_matrix = S_matrix[idxs, :][:, idxs]

src/meow/eme/interface.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,20 @@ def compute_interface_s_matrices(
221221
The same sharp edges apply here as for the single-interface solve:
222222
orthonormalization, inner-product choice, TSVD cutoff, and passivity method
223223
must all be interpreted consistently across the full stack.
224+
225+
Args:
226+
modes: Ordered list of modal bases across the stack.
227+
inner_product: Inner-product callable for forming overlap matrices.
228+
conjugate: Whether to use the conjugated formulation. Inferred from
229+
inner_product if None.
230+
tsvd_rcond: Relative singular-value cutoff for the TSVD solve.
231+
passivity_method: Method for enforcing S-matrix passivity.
232+
enforce_reciprocity: Whether to symmetrize each interface S-matrix.
233+
ignore_warnings: Whether to suppress numerical warnings.
234+
235+
Returns:
236+
Dict mapping interface names (e.g. ``"i_0_1"``) to SAX dense
237+
multimode S-matrix tuples.
224238
"""
225239
return {
226240
f"i_{i}_{i + 1}": compute_interface_s_matrix(

src/meow/eme/propagation.py

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def compute_propagation_s_matrix(modes: Modes, cell_length: float) -> sax.SDictM
2323
Each mode acquires a phase ``exp(2j * pi * neff / wl * cell_length)`` while
2424
propagating through the cell. Backward propagation is mirrored by the
2525
bidirectional port mapping in the returned SAX dictionary.
26+
27+
Args:
28+
modes: Modal basis for the cell.
29+
cell_length: Physical length of the cell.
30+
31+
Returns:
32+
SAX dictionary mapping port pairs to complex transmission values.
2633
"""
2734
s_dict = {
2835
(f"left@{i}", f"right@{i}"): jnp.exp(
@@ -67,7 +74,15 @@ def compute_propagation_s_matrices(
6774

6875

6976
def select_ports(S: sax.SDenseMM, ports: list[str]) -> sax.SDenseMM:
70-
"""Keep a subset of ports from an S-matrix."""
77+
"""Keep a subset of ports from an S-matrix.
78+
79+
Args:
80+
S: A tuple ``(s_matrix, port_map)`` in SAX dense multimode format.
81+
ports: Port names to retain.
82+
83+
Returns:
84+
A new ``(s_matrix, port_map)`` tuple with only the selected ports.
85+
"""
7186
s, pm = S
7287
idxs = jnp.array([pm[port] for port in ports], dtype=jnp.int32)
7388
s = s[idxs, :][:, idxs]
@@ -112,7 +127,16 @@ def pi_pairs(
112127
interfaces: dict[str, sax.SDenseMM],
113128
sax_backend: sax.Backend,
114129
) -> list[sax.STypeMM]:
115-
"""Return propagation-interface pairs for a full stack."""
130+
"""Return propagation-interface pairs for a full stack.
131+
132+
Args:
133+
propagations: Propagation S-matrices keyed by cell index.
134+
interfaces: Interface S-matrices keyed by adjacent cell pairs.
135+
sax_backend: SAX backend used for cascading.
136+
137+
Returns:
138+
List of cascaded propagation-interface S-matrices, one per cell.
139+
"""
116140
pairs: list[sax.STypeMM] = []
117141
for i in range(len(propagations)):
118142
propagation = propagations[f"p_{i}"]
@@ -128,7 +152,16 @@ def pi_pairs(
128152
def l2r_matrices(
129153
pairs: list[sax.STypeMM], identity: sax.SDenseMM, sax_backend: sax.Backend
130154
) -> list[sax.STypeMM]:
131-
"""Return cumulative left-to-right S-matrices."""
155+
"""Return cumulative left-to-right S-matrices.
156+
157+
Args:
158+
pairs: Propagation-interface pair S-matrices from ``pi_pairs``.
159+
identity: Identity-like S-matrix used as the initial accumulator.
160+
sax_backend: SAX backend used for cascading.
161+
162+
Returns:
163+
List of cumulative S-matrices from the left boundary up to each cell.
164+
"""
132165
matrices: list[sax.STypeMM] = [identity]
133166
for pair in pairs[:-1]:
134167
matrices.append(_connect_two(matrices[-1], pair, sax_backend))
@@ -138,7 +171,15 @@ def l2r_matrices(
138171
def r2l_matrices(
139172
pairs: list[sax.STypeMM], sax_backend: sax.Backend
140173
) -> list[sax.STypeMM]:
141-
"""Return cumulative right-to-left S-matrices."""
174+
"""Return cumulative right-to-left S-matrices.
175+
176+
Args:
177+
pairs: Propagation-interface pair S-matrices from ``pi_pairs``.
178+
sax_backend: SAX backend used for cascading.
179+
180+
Returns:
181+
List of cumulative S-matrices from each cell to the right boundary.
182+
"""
142183
matrices = [pairs[-1]]
143184
for pair in pairs[-2::-1]:
144185
matrices.append(_connect_two(pair, matrices[-1], sax_backend))
@@ -150,7 +191,15 @@ def split_square_matrix(
150191
) -> tuple[
151192
tuple[ComplexArray2D, ComplexArray2D], tuple[ComplexArray2D, ComplexArray2D]
152193
]:
153-
"""Split a square matrix into its four block submatrices."""
194+
"""Split a square matrix into its four block submatrices.
195+
196+
Args:
197+
matrix: Square matrix to split.
198+
idx: Row/column index at which to split.
199+
200+
Returns:
201+
Nested tuple ``((top_left, top_right), (bottom_left, bottom_right))``.
202+
"""
154203
if matrix.shape[0] != matrix.shape[1]:
155204
msg = "Matrix has to be square."
156205
raise ValueError(msg)
@@ -167,7 +216,18 @@ def compute_mode_amplitudes(
167216
excitation_l: ComplexArray1D,
168217
excitation_r: ComplexArray1D,
169218
) -> tuple[ComplexArray1D, ComplexArray1D]:
170-
"""Solve for the forward and backward modal amplitudes in one cell."""
219+
"""Solve for the forward and backward modal amplitudes in one cell.
220+
221+
Args:
222+
u: Cumulative left-to-right S-matrix (dense array).
223+
v: Cumulative right-to-left S-matrix (dense array).
224+
m: Number of right-side (forward) modes.
225+
excitation_l: Left boundary excitation vector.
226+
excitation_r: Right boundary excitation vector.
227+
228+
Returns:
229+
Tuple ``(forward, backward)`` of complex amplitude vectors.
230+
"""
171231
n = u.shape[0] - m
172232
_, [u21, u22] = split_square_matrix(u, n)
173233
[v11, v12], _ = split_square_matrix(v, m)
@@ -185,7 +245,18 @@ def propagate(
185245
excitation_l: ComplexArray1D,
186246
excitation_r: ComplexArray1D,
187247
) -> tuple[list[ComplexArray1D], list[ComplexArray1D]]:
188-
"""Propagate boundary excitations through cumulative S-matrices."""
248+
"""Propagate boundary excitations through cumulative S-matrices.
249+
250+
Args:
251+
l2rs: Cumulative left-to-right S-matrices for each cell.
252+
r2ls: Cumulative right-to-left S-matrices for each cell.
253+
excitation_l: Left boundary excitation vector.
254+
excitation_r: Right boundary excitation vector.
255+
256+
Returns:
257+
Tuple ``(forwards, backwards)`` where each element is a list of
258+
complex amplitude vectors, one per cell.
259+
"""
189260
forwards = []
190261
backwards = []
191262
for l2r, r2l in zip(l2rs, r2ls, strict=False):
@@ -212,7 +283,20 @@ def plot_fields(
212283
y: float,
213284
z: FloatArray1D,
214285
) -> tuple[ComplexArray2D, FloatArray1D]:
215-
"""Reconstruct an ``Ex(x, z)`` field slice from propagated modal amplitudes."""
286+
"""Reconstruct an ``Ex(x, z)`` field slice from propagated modal amplitudes.
287+
288+
Args:
289+
modes: Mode sets for each cell.
290+
cells: Cells defining the geometry and mesh.
291+
forwards: Forward amplitude vectors per cell.
292+
backwards: Backward amplitude vectors per cell.
293+
y: Transverse y-coordinate at which to sample the field.
294+
z: Global z-grid on which to reconstruct the field.
295+
296+
Returns:
297+
Tuple ``(field, x)`` where ``field`` is the complex ``Ex(z, x)``
298+
array and ``x`` is the transverse sampling grid.
299+
"""
216300
mesh_y = cells[0].mesh.y
217301
mesh_x = cells[0].mesh.x
218302
mesh_x = mesh_x[:-1] + np.diff(mesh_x) / 2
@@ -275,6 +359,14 @@ def track_modes(
275359
Unmatched modes are appended after the matched subset in their original
276360
order. This function does not change the physics of the interface solve; it
277361
only makes the local basis more continuous for reconstruction and plotting.
362+
363+
Args:
364+
modes: Ordered list of modal bases across the stack.
365+
inner_product_fn: Inner-product callable used to compute overlaps
366+
between neighboring mode sets.
367+
368+
Returns:
369+
Reordered and phase-aligned list of modal bases.
278370
"""
279371
if not modes:
280372
return []

src/meow/fde/lumerical.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,18 @@ def compute_modes_lumerical(
2727
post_process: Callable = post_process_modes,
2828
sim: Sim | None = None,
2929
) -> list[Mode]:
30-
"""Compute ``Modes` for a given ``FdeSpec` (Lumerical backend)."""
30+
"""Compute ``Modes`` for a given ``CrossSection`` (Lumerical backend).
31+
32+
Args:
33+
cs: the cross-section to solve modes for.
34+
num_modes: number of modes to compute.
35+
unit: unit scaling factor (default 1e-6 for micrometres).
36+
post_process: callable applied to the raw mode list before returning.
37+
sim: optional Lumerical simulation object; uses the global one if None.
38+
39+
Returns:
40+
The computed and post-processed list of modes.
41+
"""
3142
from lumapi import LumApiError # type: ignore[reportMissingImports]
3243

3344
sim = get_sim(sim=sim)
@@ -122,7 +133,14 @@ def create_lumerical_geometries(
122133
env: Environment,
123134
unit: float,
124135
) -> None:
125-
"""Create Lumerical geometries from a list of structures."""
136+
"""Create Lumerical geometries from a list of structures.
137+
138+
Args:
139+
sim: the Lumerical simulation object.
140+
structures: 3-D structures to add to the simulation.
141+
env: the environment containing material/wavelength info.
142+
unit: unit scaling factor.
143+
"""
126144
sim = get_sim(sim=sim)
127145
sim.switchtolayout()
128146
sim.deleteall()
@@ -131,7 +149,15 @@ def create_lumerical_geometries(
131149

132150

133151
def get_sim(**kwargs: Any) -> Sim:
134-
"""Get the Lumerical simulation object."""
152+
"""Get the Lumerical simulation object.
153+
154+
Args:
155+
**kwargs: keyword arguments; pass ``sim`` to set and return a specific
156+
simulation object.
157+
158+
Returns:
159+
The active Lumerical simulation object.
160+
"""
135161
global _sim # noqa: PLW0603
136162
sim = kwargs.get("sim", None)
137163
if sim is not None:

0 commit comments

Comments
 (0)