55"""
66
77import ast
8+ import operator
89
910import numpy as np
1011
1112from nkigym .codegen .analysis import _OpCall
13+ from nkigym .ops .activation import NKIActivation
14+ from nkigym .ops .activation_1d import NKIActivation1D
1215from nkigym .ops .base import NKIOp
1316
17+ _BINOP_FNS : dict [type , object ] = {
18+ ast .Add : operator .add ,
19+ ast .Sub : operator .sub ,
20+ ast .Mult : operator .mul ,
21+ ast .Div : operator .truediv ,
22+ }
23+
1424
1525def find_func_def (source : str ) -> ast .FunctionDef :
1626 """Find the first FunctionDef in parsed source.
@@ -41,10 +51,28 @@ def _is_nkigym_call(call: ast.Call) -> bool:
4151 return isinstance (func , ast .Attribute ) and isinstance (func .value , ast .Name ) and func .value .id == "nkigym"
4252
4353
54+ def _eval_binop (node : ast .BinOp ) -> object :
55+ """Evaluate a binary operation on constant operands.
56+
57+ Args:
58+ node: AST BinOp node.
59+
60+ Returns:
61+ Result of the binary operation.
62+ """
63+ left = _eval_expr (node .left )
64+ right = _eval_expr (node .right )
65+ op_fn = _BINOP_FNS .get (type (node .op ))
66+ if op_fn is None :
67+ raise ValueError (f"Unsupported binary op: { ast .dump (node )} " )
68+ return op_fn (left , right )
69+
70+
4471def _eval_expr (node : ast .expr ) -> object :
4572 """Evaluate an AST expression to a Python object.
4673
47- Resolves ``np.X`` attribute accesses and literal constants.
74+ Resolves ``np.X`` attribute accesses, literal constants,
75+ binary operations, and unary negation.
4876
4977 Args:
5078 node: AST expression node.
@@ -53,11 +81,14 @@ def _eval_expr(node: ast.expr) -> object:
5381 The resolved Python object.
5482 """
5583 result = None
56- if isinstance (node , ast .Attribute ) and isinstance (node .value , ast .Name ):
57- if node .value .id == "np" :
58- result = getattr (np , node .attr )
84+ if isinstance (node , ast .Attribute ) and isinstance (node .value , ast .Name ) and node .value .id == "np" :
85+ result = getattr (np , node .attr )
5986 elif isinstance (node , ast .Constant ):
6087 result = node .value
88+ elif isinstance (node , ast .BinOp ):
89+ result = _eval_binop (node )
90+ elif isinstance (node , ast .UnaryOp ) and isinstance (node .op , ast .USub ):
91+ result = - _eval_expr (node .operand )
6192 if result is None :
6293 raise ValueError (f"Unsupported kwarg expression: { ast .dump (node )} " )
6394 return result
@@ -77,6 +108,45 @@ def _arg_name(node: ast.expr) -> str:
77108 return node .id
78109
79110
111+ def _maybe_reclassify_activation (op : _OpCall , output_axes_map : dict [str , tuple [str , ...]]) -> _OpCall :
112+ """Reclassify NKIActivation to NKIActivation1D if input is 1D.
113+
114+ Args:
115+ op: Parsed op call to check.
116+ output_axes_map: Maps variable name to output axes of its producer op.
117+
118+ Returns:
119+ Original or reclassified op call.
120+ """
121+ is_1d = (
122+ op .stmt_type is NKIActivation
123+ and op .input_vars [0 ] in output_axes_map
124+ and len (output_axes_map [op .input_vars [0 ]]) == 1
125+ )
126+ return op ._replace (stmt_type = NKIActivation1D ) if is_1d else op
127+
128+
129+ def _resolve_op_variants (op_calls : list [_OpCall ]) -> list [_OpCall ]:
130+ """Post-parse pass to reclassify ops based on producer output shapes.
131+
132+ Traces the SSA chain to determine operand dimensionality and
133+ reclassifies NKIActivation to NKIActivation1D when input is 1D.
134+
135+ Args:
136+ op_calls: Parsed op calls from the function body.
137+
138+ Returns:
139+ Op calls with reclassified types where appropriate.
140+ """
141+ output_axes_map : dict [str , tuple [str , ...]] = {}
142+ result : list [_OpCall ] = []
143+ for op in op_calls :
144+ resolved = _maybe_reclassify_activation (op , output_axes_map )
145+ output_axes_map [resolved .output_var ] = getattr (resolved .stmt_type , "OUTPUT_AXES" , ())
146+ result .append (resolved )
147+ return result
148+
149+
80150def parse_body (func_def : ast .FunctionDef ) -> list [_OpCall ]:
81151 """Parse function body into a list of _OpCall.
82152
@@ -91,7 +161,7 @@ def parse_body(func_def: ast.FunctionDef) -> list[_OpCall]:
91161 for node in func_def .body :
92162 if not _try_parse_node (node , op_calls , counter ):
93163 raise ValueError (f"Unsupported statement: { ast .dump (node )} " )
94- return op_calls
164+ return _resolve_op_variants ( op_calls )
95165
96166
97167def _try_parse_node (node : ast .stmt , op_calls : list [_OpCall ], counter : list [int ]) -> bool :
@@ -157,6 +227,28 @@ def _try_parse_return(node: ast.Return, op_calls: list[_OpCall], counter: list[i
157227 return result
158228
159229
230+ def _disambiguate_op (op_name : str , call : ast .Call ) -> str :
231+ """Disambiguate user function name to internal op registry key.
232+
233+ - ``activation`` with ``reduce_op`` kwarg → ``activation_reduce``
234+ - ``tensor_scalar`` with < 2 positional args → ``tensor_scalar_const``
235+
236+ Args:
237+ op_name: User-facing function name from AST.
238+ call: AST Call node with keyword arguments.
239+
240+ Returns:
241+ Internal op registry key.
242+ """
243+ kwarg_names = {kw .arg for kw in call .keywords }
244+ result = op_name
245+ if op_name == "activation" and "reduce_op" in kwarg_names :
246+ result = "activation_reduce"
247+ elif op_name == "tensor_scalar" and len (call .args ) < 2 :
248+ result = "tensor_scalar_const"
249+ return result
250+
251+
160252def _flatten_call (call : ast .Call , output : str , op_calls : list [_OpCall ], counter : list [int ]) -> None :
161253 """Flatten a nkigym call (possibly nested) into _OpCall entries.
162254
@@ -167,7 +259,7 @@ def _flatten_call(call: ast.Call, output: str, op_calls: list[_OpCall], counter:
167259 counter: Mutable counter for intermediate variable names.
168260 """
169261 assert isinstance (call .func , ast .Attribute )
170- op_name = call .func .attr
262+ op_name = _disambiguate_op ( call .func .attr , call )
171263 registry = NKIOp .all_ops ()
172264 if op_name not in registry :
173265 raise ValueError (f"Unknown op: { op_name !r} " )
0 commit comments