@@ -23,6 +23,13 @@ def compute_propagation_s_matrix(modes: Modes, cell_length: float) -> sax.SDictM
2323 Each mode acquires a phase ``exp(2j * pi * neff / wl * cell_length)`` while
2424 propagating through the cell. Backward propagation is mirrored by the
2525 bidirectional port mapping in the returned SAX dictionary.
26+
27+ Args:
28+ modes: Modal basis for the cell.
29+ cell_length: Physical length of the cell.
30+
31+ Returns:
32+ SAX dictionary mapping port pairs to complex transmission values.
2633 """
2734 s_dict = {
2835 (f"left@{ i } " , f"right@{ i } " ): jnp .exp (
@@ -67,7 +74,15 @@ def compute_propagation_s_matrices(
6774
6875
6976def select_ports (S : sax .SDenseMM , ports : list [str ]) -> sax .SDenseMM :
70- """Keep a subset of ports from an S-matrix."""
77+ """Keep a subset of ports from an S-matrix.
78+
79+ Args:
80+ S: A tuple ``(s_matrix, port_map)`` in SAX dense multimode format.
81+ ports: Port names to retain.
82+
83+ Returns:
84+ A new ``(s_matrix, port_map)`` tuple with only the selected ports.
85+ """
7186 s , pm = S
7287 idxs = jnp .array ([pm [port ] for port in ports ], dtype = jnp .int32 )
7388 s = s [idxs , :][:, idxs ]
@@ -112,7 +127,16 @@ def pi_pairs(
112127 interfaces : dict [str , sax .SDenseMM ],
113128 sax_backend : sax .Backend ,
114129) -> list [sax .STypeMM ]:
115- """Return propagation-interface pairs for a full stack."""
130+ """Return propagation-interface pairs for a full stack.
131+
132+ Args:
133+ propagations: Propagation S-matrices keyed by cell index.
134+ interfaces: Interface S-matrices keyed by adjacent cell pairs.
135+ sax_backend: SAX backend used for cascading.
136+
137+ Returns:
138+ List of cascaded propagation-interface S-matrices, one per cell.
139+ """
116140 pairs : list [sax .STypeMM ] = []
117141 for i in range (len (propagations )):
118142 propagation = propagations [f"p_{ i } " ]
@@ -128,7 +152,16 @@ def pi_pairs(
128152def l2r_matrices (
129153 pairs : list [sax .STypeMM ], identity : sax .SDenseMM , sax_backend : sax .Backend
130154) -> list [sax .STypeMM ]:
131- """Return cumulative left-to-right S-matrices."""
155+ """Return cumulative left-to-right S-matrices.
156+
157+ Args:
158+ pairs: Propagation-interface pair S-matrices from ``pi_pairs``.
159+ identity: Identity-like S-matrix used as the initial accumulator.
160+ sax_backend: SAX backend used for cascading.
161+
162+ Returns:
163+ List of cumulative S-matrices from the left boundary up to each cell.
164+ """
132165 matrices : list [sax .STypeMM ] = [identity ]
133166 for pair in pairs [:- 1 ]:
134167 matrices .append (_connect_two (matrices [- 1 ], pair , sax_backend ))
@@ -138,7 +171,15 @@ def l2r_matrices(
138171def r2l_matrices (
139172 pairs : list [sax .STypeMM ], sax_backend : sax .Backend
140173) -> list [sax .STypeMM ]:
141- """Return cumulative right-to-left S-matrices."""
174+ """Return cumulative right-to-left S-matrices.
175+
176+ Args:
177+ pairs: Propagation-interface pair S-matrices from ``pi_pairs``.
178+ sax_backend: SAX backend used for cascading.
179+
180+ Returns:
181+ List of cumulative S-matrices from each cell to the right boundary.
182+ """
142183 matrices = [pairs [- 1 ]]
143184 for pair in pairs [- 2 ::- 1 ]:
144185 matrices .append (_connect_two (pair , matrices [- 1 ], sax_backend ))
@@ -150,7 +191,15 @@ def split_square_matrix(
150191) -> tuple [
151192 tuple [ComplexArray2D , ComplexArray2D ], tuple [ComplexArray2D , ComplexArray2D ]
152193]:
153- """Split a square matrix into its four block submatrices."""
194+ """Split a square matrix into its four block submatrices.
195+
196+ Args:
197+ matrix: Square matrix to split.
198+ idx: Row/column index at which to split.
199+
200+ Returns:
201+ Nested tuple ``((top_left, top_right), (bottom_left, bottom_right))``.
202+ """
154203 if matrix .shape [0 ] != matrix .shape [1 ]:
155204 msg = "Matrix has to be square."
156205 raise ValueError (msg )
@@ -167,7 +216,18 @@ def compute_mode_amplitudes(
167216 excitation_l : ComplexArray1D ,
168217 excitation_r : ComplexArray1D ,
169218) -> tuple [ComplexArray1D , ComplexArray1D ]:
170- """Solve for the forward and backward modal amplitudes in one cell."""
219+ """Solve for the forward and backward modal amplitudes in one cell.
220+
221+ Args:
222+ u: Cumulative left-to-right S-matrix (dense array).
223+ v: Cumulative right-to-left S-matrix (dense array).
224+ m: Number of right-side (forward) modes.
225+ excitation_l: Left boundary excitation vector.
226+ excitation_r: Right boundary excitation vector.
227+
228+ Returns:
229+ Tuple ``(forward, backward)`` of complex amplitude vectors.
230+ """
171231 n = u .shape [0 ] - m
172232 _ , [u21 , u22 ] = split_square_matrix (u , n )
173233 [v11 , v12 ], _ = split_square_matrix (v , m )
@@ -185,7 +245,18 @@ def propagate(
185245 excitation_l : ComplexArray1D ,
186246 excitation_r : ComplexArray1D ,
187247) -> tuple [list [ComplexArray1D ], list [ComplexArray1D ]]:
188- """Propagate boundary excitations through cumulative S-matrices."""
248+ """Propagate boundary excitations through cumulative S-matrices.
249+
250+ Args:
251+ l2rs: Cumulative left-to-right S-matrices for each cell.
252+ r2ls: Cumulative right-to-left S-matrices for each cell.
253+ excitation_l: Left boundary excitation vector.
254+ excitation_r: Right boundary excitation vector.
255+
256+ Returns:
257+ Tuple ``(forwards, backwards)`` where each element is a list of
258+ complex amplitude vectors, one per cell.
259+ """
189260 forwards = []
190261 backwards = []
191262 for l2r , r2l in zip (l2rs , r2ls , strict = False ):
@@ -212,7 +283,20 @@ def plot_fields(
212283 y : float ,
213284 z : FloatArray1D ,
214285) -> tuple [ComplexArray2D , FloatArray1D ]:
215- """Reconstruct an ``Ex(x, z)`` field slice from propagated modal amplitudes."""
286+ """Reconstruct an ``Ex(x, z)`` field slice from propagated modal amplitudes.
287+
288+ Args:
289+ modes: Mode sets for each cell.
290+ cells: Cells defining the geometry and mesh.
291+ forwards: Forward amplitude vectors per cell.
292+ backwards: Backward amplitude vectors per cell.
293+ y: Transverse y-coordinate at which to sample the field.
294+ z: Global z-grid on which to reconstruct the field.
295+
296+ Returns:
297+ Tuple ``(field, x)`` where ``field`` is the complex ``Ex(z, x)``
298+ array and ``x`` is the transverse sampling grid.
299+ """
216300 mesh_y = cells [0 ].mesh .y
217301 mesh_x = cells [0 ].mesh .x
218302 mesh_x = mesh_x [:- 1 ] + np .diff (mesh_x ) / 2
@@ -275,6 +359,14 @@ def track_modes(
275359 Unmatched modes are appended after the matched subset in their original
276360 order. This function does not change the physics of the interface solve; it
277361 only makes the local basis more continuous for reconstruction and plotting.
362+
363+ Args:
364+ modes: Ordered list of modal bases across the stack.
365+ inner_product_fn: Inner-product callable used to compute overlaps
366+ between neighboring mode sets.
367+
368+ Returns:
369+ Reordered and phase-aligned list of modal bases.
278370 """
279371 if not modes :
280372 return []
0 commit comments