@@ -96,7 +96,8 @@ def concatenate_inputs(calculation_context: ShapeCalculationContext) -> List[Ten
9696 test_context : RandomizerTestContext = calculation_context .test_context
9797 rng_shape = test_context .rng_shape
9898 forward_kwargs = calculation_context .forward_kwargs
99- axis = forward_kwargs ["axis" ]
99+ axis_column = "axis" if "axis" in forward_kwargs else "dim"
100+ axis = forward_kwargs [axis_column ]
100101
101102 if axis >= len (output_shape ) or axis < 0 :
102103 axis %= len (output_shape )
@@ -234,7 +235,8 @@ def concatenate_adjust(node: RandomizerNode, test_context: RandomizerTestContext
234235
235236 input_num = node .input_num
236237 output_shape = node .output_shape
237- axis = node .forward_kwargs ["axis" ]
238+ axis_column = "axis" if "axis" in node .forward_kwargs else "dim"
239+ axis = node .forward_kwargs [axis_column ]
238240
239241 if not - len (output_shape ) <= axis < len (output_shape ):
240242 axis = None # must be recalculated
@@ -267,7 +269,7 @@ def concatenate_adjust(node: RandomizerNode, test_context: RandomizerTestContext
267269 # Axis 0 is not supported
268270 continue
269271 if input_num_range .operands_min <= mid_size :
270- node .forward_kwargs ["axis" ] = axis
272+ node .forward_kwargs [axis_column ] = axis
271273 node .input_num = rng_shape .randint (
272274 input_num_range .operands_min , min (mid_size , input_num_range .operands_max )
273275 )
0 commit comments