Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the support for max_unpool1d, max_unpool2d, and max_unpool3d #7524 #8084

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@
"nn.functional.max_pool1d",
"nn.functional.max_pool2d",
"nn.functional.max_pool3d",
"nn.functional.max_unpool1d",
"nn.functional.max_unpool2d",
"nn.functional.max_unpool3d",
"nn.functional.multi_head_attention_forward",
"nn.functional.multi_margin_loss",
"nn.functional.multilabel_margin_loss",
Expand Down
57 changes: 57 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4026,3 +4026,60 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
else:
s = None
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)


@op(torch.ops.aten.max_unpool3d)
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size

# Initialize an output array of zeros with the provided output_size
output = jnp.zeros(output_size, dtype=input.dtype)

# Use numpy.ndindex to iterate over all indices of the input tensor
for idx in np.ndindex(input.shape):
max_index = indices[idx]

# Get the spatial dimensions of the output
spatial_dims = output_size[2:] # (D, H, W)

# Unravel the flat index to multi-dimensional index
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)

# Combine batch, channel, and spatial indices
full_idx = idx[:2] + unpooled_spatial_idx

output = output.at[full_idx].set(input[idx])

return output

@op(torch.ops.aten.max_unpool2d)
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size

# Initialize the output array with zeros
output = jnp.zeros(output_size, dtype=input.dtype)

# Use numpy.ndindex to iterate over all indices of the input tensor
for idx in np.ndindex(input.shape):
max_index = indices[idx]

# Get the spatial dimensions of the output (H, W)
spatial_dims = output_size[2:]

# Unravel the flat index to multi-dimensional index for 2D
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)

# Combine batch, channel, and spatial indices
full_idx = idx[:2] + unpooled_spatial_idx

# Set the value in the output array at the corresponding location
output = output.at[full_idx].set(input[idx])

return output

Loading