|
43 | 43 | torch.ops.aten.relu_: torch.ops.aten.relu,
|
44 | 44 | # squeeze_ is expected to change tensor's shape. So replace with new value
|
45 | 45 | torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
|
| 46 | + torch.ops.aten.sqrt_: torch.ops.aten.sqrt, |
46 | 47 | torch.ops.aten.clamp_: torch.ops.aten.clamp,
|
47 | 48 | torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
|
48 | 49 | torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
|
@@ -112,7 +113,11 @@ def _aten_add(x, y, *, alpha=1):
|
112 | 113 |
|
113 | 114 | assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
114 | 115 | """
|
115 |
| - return x + y * alpha |
| 116 | + res = x + y * alpha |
| 117 | + if isinstance(x, float) or isinstance(y, float): |
| 118 | + new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) |
| 119 | + res = res.astype(new_dtype) |
| 120 | + return res |
116 | 121 |
|
117 | 122 |
|
118 | 123 | @op(torch.ops.aten.copy_, is_jax_function=False)
|
@@ -169,6 +174,16 @@ def _aten_cauchy_(x, median=0, sigma=1):
|
169 | 174 | return x.at[:].set(samples)
|
170 | 175 |
|
171 | 176 |
|
| 177 | +@op(torch.ops.aten.atleast_2d) |
| 178 | +def _aten_atleast_2d(inputs): |
| 179 | + return jnp.atleast_2d(inputs) |
| 180 | + |
| 181 | + |
| 182 | +@op(torch.ops.aten.atleast_1d) |
| 183 | +def _aten_atleast_1d(inputs): |
| 184 | + return jnp.atleast_1d(inputs) |
| 185 | + |
| 186 | + |
172 | 187 | # aten.complex
|
173 | 188 | @op(torch.ops.aten.complex)
|
174 | 189 | def _aten_complex(real, imag):
|
@@ -281,6 +296,10 @@ def _aten_mul(x, y):
|
281 | 296 | res = x * y
|
282 | 297 | if isinstance(x, float) or isinstance(y, float):
|
283 | 298 | res = res.astype(new_dtype)
|
| 299 | + else: |
| 300 | + if (not isinstance(x, int)) and (not isinstance(y, int)): |
| 301 | + if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64): |
| 302 | + res = res.astype(new_dtype) |
284 | 303 | return res
|
285 | 304 |
|
286 | 305 |
|
@@ -1284,6 +1303,9 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
|
1284 | 1303 |
|
1285 | 1304 | input_shape = input.shape
|
1286 | 1305 |
|
| 1306 | + if 0 in input_shape: |
| 1307 | + return input, input, input |
| 1308 | + |
1287 | 1309 | # Reshape for group-wise normalization
|
1288 | 1310 | reshaped_input = jnp.reshape(input, (1, N * group, -1))
|
1289 | 1311 |
|
|
0 commit comments