|
14 | 14 | from aesara.tensor.basic import get_scalar_constant_value
|
15 | 15 | from aesara.tensor.exceptions import NotScalarConstantError
|
16 | 16 | from aesara.tensor.math import minimum
|
17 |
| -from aesara.tensor.shape import shape_padleft |
| 17 | +from aesara.tensor.shape import shape_padleft, unbroadcast |
18 | 18 | from aesara.tensor.type import TensorType, integer_dtypes
|
19 | 19 | from aesara.updates import OrderedUpdates
|
20 | 20 |
|
@@ -751,7 +751,7 @@ def wrap_into_list(x):
|
751 | 751 | # defined in scan utils
|
752 | 752 | sit_sot_scan_inputs.append(
|
753 | 753 | expand_empty(
|
754 |
| - at.unbroadcast(shape_padleft(actual_arg), 0), |
| 754 | + unbroadcast(shape_padleft(actual_arg), 0), |
755 | 755 | actual_n_steps,
|
756 | 756 | )
|
757 | 757 | )
|
@@ -881,7 +881,7 @@ def wrap_into_list(x):
|
881 | 881 | # this will represent only a slice and it will have one
|
882 | 882 | # dimension less.
|
883 | 883 | if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
|
884 |
| - outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0) |
| 884 | + outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) |
885 | 885 |
|
886 | 886 | if not return_list and len(outputs) == 1:
|
887 | 887 | outputs = outputs[0]
|
@@ -1010,7 +1010,7 @@ def wrap_into_list(x):
|
1010 | 1010 | sit_sot_inner_inputs.append(new_var)
|
1011 | 1011 | sit_sot_scan_inputs.append(
|
1012 | 1012 | expand_empty(
|
1013 |
| - at.unbroadcast(shape_padleft(input.variable), 0), |
| 1013 | + unbroadcast(shape_padleft(input.variable), 0), |
1014 | 1014 | actual_n_steps,
|
1015 | 1015 | )
|
1016 | 1016 | )
|
|
0 commit comments