Skip to content

Commit 3c8b7eb

Browse files
committed
concatenate for pytorch RGG
1 parent ecc53ff commit 3c8b7eb

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

forge/test/random/rgg/shapes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def concatenate_inputs(calculation_context: ShapeCalculationContext) -> List[Ten
9696
test_context: RandomizerTestContext = calculation_context.test_context
9797
rng_shape = test_context.rng_shape
9898
forward_kwargs = calculation_context.forward_kwargs
99-
axis = forward_kwargs["axis"]
99+
axis_column = "axis" if "axis" in forward_kwargs else "dim"
100+
axis = forward_kwargs[axis_column]
100101

101102
if axis >= len(output_shape) or axis < 0:
102103
axis %= len(output_shape)
@@ -234,7 +235,8 @@ def concatenate_adjust(node: RandomizerNode, test_context: RandomizerTestContext
234235

235236
input_num = node.input_num
236237
output_shape = node.output_shape
237-
axis = node.forward_kwargs["axis"]
238+
axis_column = "axis" if "axis" in node.forward_kwargs else "dim"
239+
axis = node.forward_kwargs[axis_column]
238240

239241
if not -len(output_shape) <= axis < len(output_shape):
240242
axis = None # must be recalculated
@@ -267,7 +269,7 @@ def concatenate_adjust(node: RandomizerNode, test_context: RandomizerTestContext
267269
# Axis 0 is not supported
268270
continue
269271
if input_num_range.operands_min <= mid_size:
270-
node.forward_kwargs["axis"] = axis
272+
node.forward_kwargs[axis_column] = axis
271273
node.input_num = rng_shape.randint(
272274
input_num_range.operands_min, min(mid_size, input_num_range.operands_max)
273275
)

0 commit comments

Comments
 (0)