@@ -131,55 +131,56 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
131131
132132Those operations are available for all tensor kinds: ` Int ` , ` Float ` , and ` Bool ` .
133133
134- | Burn | PyTorch Equivalent |
135- | ---------------------------------------------| ---------------------------------------------------------------------------|
136- | ` Tensor::cat(tensors, dim) ` | ` torch.cat(tensors, dim) ` |
137- | ` Tensor::empty(shape, device) ` | ` torch.empty(shape, device=device) ` |
138- | ` Tensor::from_primitive(primitive) ` | N/A |
139- | ` Tensor::stack(tensors, dim) ` | ` torch.stack(tensors, dim) ` |
140- | ` tensor.all() ` | ` tensor.all() ` |
141- | ` tensor.all_dim(dim) ` | ` tensor.all(dim) ` |
142- | ` tensor.any() ` | ` tensor.any() ` |
143- | ` tensor.any_dim(dim) ` | ` tensor.any(dim) ` |
144- | ` tensor.chunk(num_chunks, dim) ` | ` tensor.chunk(num_chunks, dim) ` |
145- | ` tensor.split(split_size, dim) ` | ` tensor.split(split_size, dim) ` |
146- | ` tensor.split_with_sizes(split_sizes, dim) ` | ` tensor.split([split_sizes], dim) ` |
147- | ` tensor.device() ` | ` tensor.device ` |
148- | ` tensor.dtype() ` | ` tensor.dtype ` |
149- | ` tensor.dims() ` | ` tensor.size() ` |
150- | ` tensor.equal(other) ` | ` x == y ` |
151- | ` tensor.expand(shape) ` | ` tensor.expand(shape) ` |
152- | ` tensor.flatten(start_dim, end_dim) ` | ` tensor.flatten(start_dim, end_dim) ` |
153- | ` tensor.flip(axes) ` | ` tensor.flip(axes) ` |
154- | ` tensor.into_data() ` | N/A |
155- | ` tensor.into_primitive() ` | N/A |
156- | ` tensor.into_scalar() ` | ` tensor.item() ` |
157- | ` tensor.narrow(dim, start, length) ` | ` tensor.narrow(dim, start, length) ` |
158- | ` tensor.not_equal(other) ` | ` x != y ` |
159- | ` tensor.permute(axes) ` | ` tensor.permute(axes) ` |
160- | ` tensor.movedim(src, dst) ` | ` tensor.movedim(src, dst) ` |
161- | ` tensor.repeat_dim(dim, times) ` | ` tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())]) ` |
162- | ` tensor.repeat(sizes) ` | ` tensor.repeat(sizes) ` |
163- | ` tensor.reshape(shape) ` | ` tensor.view(shape) ` |
164- | ` tensor.roll(shfts, dims) ` | ` tensor.roll(shifts, dims) ` |
165- | ` tensor.roll_dim(shift, dim) ` | ` tensor.roll([shift], [dim]) ` |
166- | ` tensor.select(dim, indices) ` | ` tensor.index_select(dim, indices) ` |
167- | ` tensor.select_assign(dim, indices, values) ` | N/A |
168- | ` tensor.shape() ` | ` tensor.shape ` |
169- | ` tensor.slice(s![range;step]) ` | ` tensor[(*ranges,)] ` or ` tensor[start:end:step] ` |
170- | ` tensor.slice_assign(ranges, values) ` | ` tensor[(*ranges,)] = values ` |
171- | ` tensor.slice_fill(ranges, value) ` | ` tensor[(*ranges,)] = value ` |
172- | ` tensor.slice_dim(dim, range) ` | N/A |
173- | ` tensor.squeeze(dim) ` | ` tensor.squeeze(dim) ` |
174- | ` tensor.swap_dims(dim1, dim2) ` | ` tensor.transpose(dim1, dim2) ` |
175- | ` tensor.take(dim, indices) ` | ` numpy.take(tensor, indices, dim) ` |
176- | ` tensor.to_data() ` | N/A |
177- | ` tensor.to_device(device) ` | ` tensor.to(device) ` |
178- | ` tensor.transpose() ` | ` tensor.T ` |
179- | ` tensor.t() ` | ` tensor.T ` |
180- | ` tensor.unsqueeze() ` | ` tensor.unsqueeze(0) ` |
181- | ` tensor.unsqueeze_dim(dim) ` | ` tensor.unsqueeze(dim) ` |
182- | ` tensor.unsqueeze_dims(dims) ` | N/A |
134+ | Burn | PyTorch Equivalent |
135+ | ----------------------------------------------| ---------------------------------------------------------------------------|
136+ | ` Tensor::cat(tensors, dim) ` | ` torch.cat(tensors, dim) ` |
137+ | ` Tensor::empty(shape, device) ` | ` torch.empty(shape, device=device) ` |
138+ | ` Tensor::from_primitive(primitive) ` | N/A |
139+ | ` Tensor::stack(tensors, dim) ` | ` torch.stack(tensors, dim) ` |
140+ | ` tensor.all() ` | ` tensor.all() ` |
141+ | ` tensor.all_dim(dim) ` | ` tensor.all(dim) ` |
142+ | ` tensor.any() ` | ` tensor.any() ` |
143+ | ` tensor.any_dim(dim) ` | ` tensor.any(dim) ` |
144+ | ` tensor.chunk(num_chunks, dim) ` | ` tensor.chunk(num_chunks, dim) ` |
145+ | ` tensor.split(split_size, dim) ` | ` tensor.split(split_size, dim) ` |
146+ | ` tensor.split_with_sizes(split_sizes, dim) ` | ` tensor.split([split_sizes], dim) ` |
147+ | ` tensor.device() ` | ` tensor.device ` |
148+ | ` tensor.dtype() ` | ` tensor.dtype ` |
149+ | ` tensor.dims() ` | ` tensor.size() ` |
150+ | ` tensor.equal(other) ` | ` x == y ` |
151+ | ` tensor.expand(shape) ` | ` tensor.expand(shape) ` |
152+ | ` tensor.flatten(start_dim, end_dim) ` | ` tensor.flatten(start_dim, end_dim) ` |
153+ | ` tensor.flip(axes) ` | ` tensor.flip(axes) ` |
154+ | ` tensor.into_data() ` | N/A |
155+ | ` tensor.into_primitive() ` | N/A |
156+ | ` tensor.into_scalar() ` | ` tensor.item() ` |
157+ | ` tensor.narrow(dim, start, length) ` | ` tensor.narrow(dim, start, length) ` |
158+ | ` tensor.not_equal(other) ` | ` x != y ` |
159+ | ` tensor.permute(axes) ` | ` tensor.permute(axes) ` |
160+ | ` tensor.movedim(src, dst) ` | ` tensor.movedim(src, dst) ` |
161+ | ` tensor.repeat_dim(dim, times) ` | ` tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())]) ` |
162+ | ` tensor.repeat(sizes) ` | ` tensor.repeat(sizes) ` |
163+ | ` tensor.reshape(shape) ` | ` tensor.view(shape) ` |
164+ | ` tensor.roll(shfts, dims) ` | ` tensor.roll(shifts, dims) ` |
165+ | ` tensor.roll_dim(shift, dim) ` | ` tensor.roll([shift], [dim]) ` |
166+ | ` tensor.select(dim, indices) ` | ` tensor.index_select(dim, indices) ` |
167+ | ` tensor.select_assign(dim, indices, values) ` | N/A |
168+ | ` tensor.shape() ` | ` tensor.shape ` |
169+ | ` tensor.slice(s![range;step]) ` | ` tensor[(*ranges,)] ` or ` tensor[start:end:step] ` |
170+ | ` tensor.slice_assign(ranges, values) ` | ` tensor[(*ranges,)] = values ` |
171+ | ` tensor.slice_fill(ranges, value) ` | ` tensor[(*ranges,)] = value ` |
172+ | ` tensor.slice_dim(dim, range) ` | N/A |
173+ | ` tensor.squeeze(dim) ` | ` tensor.squeeze(dim) ` |
174+ | ` tensor.swap_dims(dim1, dim2) ` | ` tensor.transpose(dim1, dim2) ` |
175+ | ` tensor.take(dim, indices) ` | ` numpy.take(tensor, indices, dim) ` |
176+ | ` tensor.to_data() ` | N/A |
177+ | ` tensor.to_device(device) ` | ` tensor.to(device) ` |
178+ | ` tensor.transpose() ` | ` tensor.T ` |
179+ | ` tensor.t() ` | ` tensor.T ` |
180+ | ` tensor.unfold(dim, size, step) ` | ` tensor.unfold(dim, size, step) ` |
181+ | ` tensor.unsqueeze() ` | ` tensor.unsqueeze(0) ` |
182+ | ` tensor.unsqueeze_dim(dim) ` | ` tensor.unsqueeze(dim) ` |
183+ | ` tensor.unsqueeze_dims(dims) ` | N/A |
183184
184185### Numeric Operations
185186
0 commit comments