diff --git a/src/arraymancer/tensor/private/p_shapeshifting.nim b/src/arraymancer/tensor/private/p_shapeshifting.nim index a9dc6ccdd..40aebb3af 100644 --- a/src/arraymancer/tensor/private/p_shapeshifting.nim +++ b/src/arraymancer/tensor/private/p_shapeshifting.nim @@ -54,6 +54,16 @@ proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|MetadataArray, result: v else: reshape_with_copy(t, new_shape, result) +proc reshapeImplWithContig*(t : AnyTensor, new_shape: varargs[int]|MetadataArray, result: var AnyTensor, layout: OrderType) {.noSideEffect.}= + when compileOption("boundChecks"): + when new_shape is MetadataArray: + check_reshape(t, new_shape) + else: + check_reshape(t, new_shape.toMetadataArray) + + reshapeImpl(t.asContiguous(layout, force=true), new_shape, result) + + proc broadcastImpl*(t: var AnyTensor, shape: varargs[int]|MetadataArray) {.noSideEffect.}= when compileOption("boundChecks"): assert t.rank == shape.len diff --git a/src/arraymancer/tensor/shapeshifting.nim b/src/arraymancer/tensor/shapeshifting.nim index 6a933e361..e7b94a313 100644 --- a/src/arraymancer/tensor/shapeshifting.nim +++ b/src/arraymancer/tensor/shapeshifting.nim @@ -36,6 +36,39 @@ proc transpose*(t: Tensor): Tensor {.noInit,noSideEffect,inline.} = result.offset = t.offset result.storage = t.storage +proc transpose*(t : Tensor, axes: seq[int]) : Tensor {.noInit,inline.} = + ## Transpose a Tensor using a passed permutation of axes. + ## + ## Data is not copied or modified, only metadata is modified. + assert axes.len == t.rank + + let + n = axes.len + + var + perm = newSeqWith(t.rank, 0) + mrep = newSeqWith(t.rank, -1) + new_shape = t.shape + new_strides = t.strides + + for i in 0 ..< n: + var axis = axes[i] + if axis < 0: + axis += t.rank + assert axis >= 0 and axis < t.rank, "Out of bounds axis for the Tensor" + assert mrep[axis] == -1, "Axes can not be repeated" + mrep[axis] = i + perm[i] = axis + + for i in 0 ..< n: + new_shape[i] = t.shape[perm[i]] + new_strides[i] = t.strides[perm[i]] + + result.shape = new_shape + result.strides = new_strides + result.offset = t.offset + result.storage = t.storage + proc asContiguous*[T](t: Tensor[T], layout: OrderType = rowMajor, force: bool = false): Tensor[T] {.noInit.} = ## Transform a tensor with general striding to a Tensor with contiguous layout. ## @@ -69,6 +102,19 @@ proc reshape*(t: Tensor, new_shape: varargs[int]): Tensor {.noInit.} = ## - a tensor with the same data but reshaped. reshapeImpl(t, new_shape, result) +proc reshape*(t: Tensor, new_shape: varargs[int], layout: OrderType): Tensor {.noInit.} = + ## Reshape a tensor. If possible no data copy is done and the returned tensor + ## shares data with the input. If input is not contiguous, this is not possible + ## and a copy will be made. + ## + ## Input: + ## - a tensor + ## - a new shape. Number of elements must be the same + ## - a memory layout to use when reshaping the data + ## Returns: + ## - a tensor with the same data but reshaped. + reshapeImplWithContig(t, new_shape, result, layout) + proc reshape*(t: Tensor, new_shape: MetadataArray): Tensor {.noInit.} = ## Reshape a tensor. If possible no data copy is done and the returned tensor ## shares data with the input. If input is not contiguous, this is not possible @@ -81,6 +127,19 @@ proc reshape*(t: Tensor, new_shape: MetadataArray): Tensor {.noInit.} = ## - a tensor with the same data but reshaped. reshapeImpl(t, new_shape, result) +proc reshape*(t: Tensor, new_shape: MetadataArray, layout: OrderType): Tensor {.noInit.} = + ## Reshape a tensor. If possible no data copy is done and the returned tensor + ## shares data with the input. If input is not contiguous, this is not possible + ## and a copy will be made. + ## + ## Input: + ## - a tensor + ## - a new shape. Number of elements must be the same + ## - a memory layout to use when reshaping the data + ## Returns: + ## - a tensor with the same data but reshaped. + reshapeImplWithContig(t, new_shape, result, layout) + proc broadcast*[T](t: Tensor[T], shape: varargs[int]): Tensor[T] {.noInit,noSideEffect.}= ## Explicitly broadcast a tensor to the specified shape. ## diff --git a/tests/tensor/test_shapeshifting.nim b/tests/tensor/test_shapeshifting.nim index 3c932d4a0..2ec72c891 100644 --- a/tests/tensor/test_shapeshifting.nim +++ b/tests/tensor/test_shapeshifting.nim @@ -64,6 +64,33 @@ testSuite "Shapeshifting": check: a == [[1,2], [3,4]].toTensor() + test "Reshape with explicit order": + let a = toSeq(1..12).toTensor().reshape(3, 2, 2).asContiguous(rowMajor, force = true) + let b = toSeq(1..12).toTensor().reshape(3, 2, 2).asContiguous(colMajor, force = true) + check: a == b + # Default behavior is respecting memory layouts when reshaping + check: a.reshape(6, 2) != b.reshape(6, 2) + + # Explicit ordering will reshape using the same memory layout + check: a.reshape(6, 2, colMajor) == b.reshape(6, 2) + check: a.reshape(6, 2) == b.reshape(6, 2, rowMajor) + + test "Transpose with explicit permutation": + let a = toSeq(1..6).toTensor().reshape(1, 2, 3) + let b = a.transpose(@[0, 2, 1]) + let c = a.transpose(@[2, 0, 1]) + # Check different permutations other than a full transpose + + let expected_b = @[1, 4, 2, 5, 3, 6].toTensor().reshape(1, 3, 2) + check: b == expected_b + check: b.shape == [1, 3, 2] + check: b.strides == [6, 1, 3] + + let expected_c = @[1, 4, 2, 5, 3, 6].toTensor().reshape(3, 1, 2) + check: c == expected_c + check: c.shape == [3, 1, 2] + check: c.strides == [1, 6, 3] + test "Unsafe reshape": block: let a = toSeq(1..4).toTensor()