Skip to content

Commit 7f8af9b

Browse files
Ricardobrandonwillard
Ricardo
authored andcommitted
Deprecate remaining uses of Rebroadcast in favor of Unbroadcast
1 parent ac52d68 commit 7f8af9b

18 files changed

+337
-538
lines changed

aesara/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
147147
def get_scalar_constant_value(v):
148148
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
149149
150-
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
150+
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
151151
this function digs through them.
152152
153153
If ``aesara.sparse`` is also there, we will look over CSM `Op`.

aesara/compile/function/pfunc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def clone_inputs(i):
204204
err_sug = (
205205
"If the difference is related to the broadcast pattern,"
206206
" you can call the"
207-
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
208-
" function to remove broadcastable dimensions."
207+
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
208+
" function to mask broadcastable dimensions."
209209
)
210210

211211
raise TypeError(err_msg, err_sug)

aesara/ifelse.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
2424
from aesara.graph.op import _NoPythonOp
2525
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
26-
from aesara.tensor import basic
27-
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
26+
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
2827

2928

3029
__docformat__ = "restructedtext en"
@@ -451,7 +450,7 @@ def cond_make_inplace(fgraph, node):
451450
Shape,
452451
SpecifyShape,
453452
Reshape,
454-
basic.Rebroadcast,
453+
Unbroadcast,
455454
at.math.Dot,
456455
at.math.MaxAndArgmax,
457456
at.subtensor.Subtensor,

aesara/link/jax/dispatch.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
Eye,
3030
Join,
3131
MakeVector,
32-
Rebroadcast,
3332
ScalarFromTensor,
3433
TensorFromScalar,
3534
)
@@ -50,7 +49,7 @@
5049
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
5150
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
5251
from aesara.tensor.random.op import RandomVariable
53-
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
52+
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
5453
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
5554
from aesara.tensor.subtensor import (
5655
AdvancedIncSubtensor,
@@ -347,20 +346,12 @@ def specifyshape(x, *shape):
347346
return specifyshape
348347

349348

350-
@jax_funcify.register(Rebroadcast)
351-
def jax_funcify_Rebroadcast(op, **kwargs):
352-
op_axis = op.axis
353-
354-
def rebroadcast(x):
355-
for axis, value in op_axis.items():
356-
if value and x.shape[axis] != 1:
357-
raise ValueError(
358-
"Dimension %s in Rebroadcast's input was"
359-
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
360-
)
349+
@jax_funcify.register(Unbroadcast)
350+
def jax_funcify_Unbroadcast(op, **kwargs):
351+
def unbroadcast(x):
361352
return x
362353

363-
return rebroadcast
354+
return unbroadcast
364355

365356

366357
@jax_funcify.register(ViewOp)

aesara/link/numba/dispatch/tensor_basic.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
Eye,
1515
Join,
1616
MakeVector,
17-
Rebroadcast,
1817
ScalarFromTensor,
1918
TensorFromScalar,
2019
)
20+
from aesara.tensor.shape import Unbroadcast
2121

2222

2323
@numba_funcify.register(AllocEmpty)
@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
195195
return numba_basic.numba_njit(makevector_fn)
196196

197197

198-
@numba_funcify.register(Rebroadcast)
199-
def numba_funcify_Rebroadcast(op, **kwargs):
200-
# Make sure op_axis only has ints. This way we can avoid literal_unroll
201-
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
202-
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())
203-
198+
@numba_funcify.register(Unbroadcast)
199+
def numba_funcify_Unbroadcast(op, **kwargs):
204200
@numba_basic.numba_njit
205-
def rebroadcast(x):
206-
for axis, value in op_axis:
207-
if value and x.shape[axis] != 1:
208-
raise ValueError(
209-
("Dimension in Rebroadcast's input was supposed to be 1")
210-
)
201+
def unbroadcast(x):
211202
return x
212203

213-
return rebroadcast
204+
return unbroadcast
214205

215206

216207
@numba_funcify.register(TensorFromScalar)

aesara/scan/basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from aesara.tensor.basic import get_scalar_constant_value
1515
from aesara.tensor.exceptions import NotScalarConstantError
1616
from aesara.tensor.math import minimum
17-
from aesara.tensor.shape import shape_padleft
17+
from aesara.tensor.shape import shape_padleft, unbroadcast
1818
from aesara.tensor.type import TensorType, integer_dtypes
1919
from aesara.updates import OrderedUpdates
2020

@@ -751,7 +751,7 @@ def wrap_into_list(x):
751751
# defined in scan utils
752752
sit_sot_scan_inputs.append(
753753
expand_empty(
754-
at.unbroadcast(shape_padleft(actual_arg), 0),
754+
unbroadcast(shape_padleft(actual_arg), 0),
755755
actual_n_steps,
756756
)
757757
)
@@ -881,7 +881,7 @@ def wrap_into_list(x):
881881
# this will represent only a slice and it will have one
882882
# dimension less.
883883
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
884-
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)
884+
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
885885

886886
if not return_list and len(outputs) == 1:
887887
outputs = outputs[0]
@@ -1010,7 +1010,7 @@ def wrap_into_list(x):
10101010
sit_sot_inner_inputs.append(new_var)
10111011
sit_sot_scan_inputs.append(
10121012
expand_empty(
1013-
at.unbroadcast(shape_padleft(input.variable), 0),
1013+
unbroadcast(shape_padleft(input.variable), 0),
10141014
actual_n_steps,
10151015
)
10161016
)

0 commit comments

Comments
 (0)