Skip to content

Commit 694f496

Browse files
committed
Require same dimension in global array as local data
1 parent 560970b commit 694f496

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

heat/core/dndarray.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,12 @@ def __init__(
8888

8989
# check for inconsistencies between local and global arrays
9090
assert str(array.device) == device.torch_device
91-
assert self.ndim >= array.ndim, (
91+
assert self.ndim == array.ndim, (
9292
f"Local dimension {array.ndim} exceeds global dimension {self.ndim}!"
9393
)
94-
if self.ndim == array.ndim:
95-
assert all([gshape[i] >= array.shape[i] for i in range(self.ndim)]), (
96-
f"Local shape {array.shape} is larger than global shape {gshape}"
97-
)
94+
assert all([gshape[i] >= array.shape[i] for i in range(self.ndim)]), (
95+
f"Local shape {array.shape} is larger than global shape {gshape}"
96+
)
9897
assert dtype == types.canonical_heat_type(array.dtype), (
9998
f"Local datatype {array.dtype} is incompatible with global datatype {dtype}"
10099
)

0 commit comments

Comments
 (0)