Skip to content

Commit a2f62c4

Browse files
committed
Refactor and fix _flatten_single_param_structs_into_groups
1 parent 6df3a76 commit a2f62c4

1 file changed

Lines changed: 89 additions & 56 deletions

File tree

src/styx/ir/optimize.py

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import typing
1+
from dataclasses import dataclass
22
from typing import Generator, Callable, Iterable, NamedTuple
33

44
import styx.ir.core as ir
@@ -111,10 +111,7 @@ def _param_parent_location(param: ir.Param) -> _ParentLocation | None:
111111
return None
112112

113113

114-
T = typing.TypeVar("T")
115-
116-
117-
def _join_optionals(a: T | None, b: T | None, join: Callable[[T, T], T]) -> T | None:
114+
def _join_optionals[T](a: T | None, b: T | None, join: Callable[[T, T], T]) -> T | None:
118115
if a is None:
119116
if b is None:
120117
return None
@@ -136,69 +133,105 @@ def _merge_docs(a: ir.Documentation, b: ir.Documentation) -> ir.Documentation:
136133

137134
def _flatten_single_param_structs_into_groups(app: ir.App) -> ir.App:
138135
"""
139-
If a subcommand has a single param it should be merged into the parent.
140-
If both are nullable this does not always work.
141-
142-
Quite tricky - Really hope to never have to touch this again.
136+
Flatten structs that contain only a single parameter into their parent group.
143137
"""
144138

145-
needs_rerun = True
146-
while needs_rerun:
147-
needs_rerun = False
139+
@dataclass
140+
class TokenLocation:
141+
"""Location of a param within the tree."""
142+
143+
parent_struct: ir.Param[ir.Param.Struct]
144+
group_idx: int
145+
cmdarg_idx: int
146+
token_idx: int
147+
148+
@property
149+
def group(self) -> ir.ConditionalGroup:
150+
return self.parent_struct.body.groups[self.group_idx]
151+
152+
@property
153+
def cmdarg(self) -> ir.CmdArg:
154+
return self.group.cargs[self.cmdarg_idx]
155+
156+
def find_token_location(target: ir.Param) -> TokenLocation | None:
157+
"""Find where a param lives in its parent's token list."""
158+
parent = target.parent
159+
if parent is None or not isinstance(parent.body, ir.Param.Struct):
160+
return None
161+
162+
for group_idx, group in enumerate(parent.body.groups):
163+
for cmdarg_idx, cmdarg in enumerate(group.cargs):
164+
for token_idx, token in enumerate(cmdarg.tokens):
165+
if token is target:
166+
return TokenLocation(parent, group_idx, cmdarg_idx, token_idx)
167+
return None
168+
169+
def get_single_param(struct_body: ir.Param.Struct) -> ir.Param | None:
170+
"""Get the single param if struct has exactly one, else None."""
171+
params = list(struct_body.iter_params_shallow())
172+
return params[0] if len(params) == 1 else None
173+
174+
def find_flattening_candidate() -> tuple[ir.Param[ir.Param.Struct], ir.Param, TokenLocation] | None:
175+
"""Find a struct that can be safely flattened."""
148176
for struct in app.command.iter_structs_deep():
177+
if struct.parent is None:
178+
continue
149179
if isinstance(struct.parent.body, ir.Param.StructUnion):
150180
continue
151-
if struct.list_:
181+
if struct.list_ is not None:
182+
continue
183+
if struct.base.outputs:
152184
continue
153-
if _count(struct.body.iter_params_shallow()) != 1:
185+
186+
single_param = get_single_param(struct.body)
187+
if single_param is None:
154188
continue
155-
single_param = struct.body.iter_params_shallow().__next__()
189+
156190
if struct.nullable and single_param.nullable:
157191
continue
158192

159-
location = _param_parent_location(struct)
160-
assert location is not None
161-
162-
if struct.nullable:
163-
# merge all groups and use a single group with the param now nullable
164-
165-
single_param.nullable = True
166-
single_param.default_value = ir.Param.SetToNone
167-
new_cargs: list[ir.CmdArg] = []
168-
for g in struct.body.groups:
169-
for cmdarg in g.cargs:
170-
new_cargs.append(cmdarg)
171-
struct.body.groups = [ir.ConditionalGroup(cargs=new_cargs)]
172-
# todo: handle joins?
173-
174-
single_param.base.docs = _merge_docs(struct.base.docs, single_param.base.docs)
175-
176-
# replace the cmdarg containing the struct with all cmdargs from the struct
177-
# get all structs cmdargs (after flattening if nullable)
178-
struct_cmdargs = []
179-
for g in struct.body.groups:
180-
struct_cmdargs.extend(g.cargs)
181-
182-
# build new cmdargs list: before + struct's cmdargs + after
183-
new_cargs = (
184-
location.group.cargs[: location.cmdarg_idx] # cmdargs before the one with struct
185-
+ struct_cmdargs # all cmdargs from inside the struct
186-
+ location.group.cargs[(location.cmdarg_idx + 1) :] # cmdargs after
187-
)
188-
189-
# replace the group with the new cmdargs
190-
new_groups = (
191-
location.parent.body.groups[: location.group_idx]
192-
+ [ir.ConditionalGroup(cargs=new_cargs, join=location.group.join)] # preserve join
193-
+ location.parent.body.groups[(location.group_idx + 1) :]
194-
)
195-
196-
location.parent.body.groups = new_groups
197-
198-
app.command.setup_parent_references()
199-
needs_rerun = True
193+
location = find_token_location(struct)
194+
if location is None:
195+
continue
196+
197+
return struct, single_param, location
198+
199+
return None
200+
201+
while True:
202+
candidate = find_flattening_candidate()
203+
if candidate is None:
200204
break
201205

206+
struct, single_param, location = candidate
207+
208+
# Transfer nullability if needed
209+
if struct.nullable and not single_param.nullable:
210+
single_param.nullable = True
211+
single_param.default_value = ir.Param.SetToNone
212+
213+
# Merge docs
214+
single_param.base.docs = _merge_docs(struct.base.docs, single_param.base.docs)
215+
216+
# Collect all tokens from the struct's cmdargs (flattening the struct's internal structure)
217+
replacement_tokens: list[ir.Param | str] = []
218+
for group in struct.body.groups:
219+
for cmdarg in group.cargs:
220+
replacement_tokens.extend(cmdarg.tokens)
221+
222+
# Splice at the TOKEN level: replace the struct param with its internal tokens
223+
old_tokens = location.cmdarg.tokens
224+
new_tokens = (
225+
old_tokens[: location.token_idx] # tokens before the struct
226+
+ replacement_tokens # struct's internal tokens
227+
+ old_tokens[location.token_idx + 1 :] # tokens after the struct
228+
)
229+
230+
# Update the cmdarg's tokens in place
231+
location.cmdarg.tokens = new_tokens
232+
233+
app.command.setup_parent_references()
234+
202235
return app
203236

204237

0 commit comments

Comments
 (0)