Skip to content

Commit

Permalink
Merge pull request #90 from osyris-project/revert_vector_slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
nvaytet authored Mar 9, 2022
2 parents 9c22f90 + 0640fae commit 2a0b761
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/loading_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"center = data['sink']['position'][0]\n",
"center = data['sink']['position'][0:1]\n",
"dx = 200 * osyris.units('au')"
]
},
Expand Down Expand Up @@ -353,7 +353,7 @@
"osyris.map(data[\"hydro\"][\"density\"],\n",
" {\"data\": data[\"sink\"][\"position\"], \"mode\": \"scatter\", \"c\": \"white\",\n",
" \"s\": 20. * osyris.units(\"au\"), \"alpha\": 0.7},\n",
" norm='log', direction=\"z\", origin=center)"
" norm='log', direction=\"z\", origin=center[0])"
]
}
],
Expand Down
11 changes: 4 additions & 7 deletions src/osyris/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ def __init__(self, values=0, unit=None, parent=None, name=""):
# return self._array

def __getitem__(self, slice_):
out = self.__class__(values=self._array[slice_],
unit=self._unit,
parent=self._parent,
name=self._name)
if isinstance(slice_, int) and self.ndim > 1:
return out.reshape(1, len(out))
return out
return self.__class__(values=self._array[slice_],
unit=self._unit,
parent=self._parent,
name=self._name)

def __len__(self):
if self._array.shape:
Expand Down
4 changes: 2 additions & 2 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_slicing():

def test_slicing_vector():
a = osyris.Array(values=np.arange(12.).reshape(4, 3), unit='m')
assert all(np.ravel(a[2] == osyris.Array(values=[[6., 7., 8.]], unit='m')))
assert a[2].shape == (1, 3)
assert all(np.ravel(a[2:3] == osyris.Array(values=[[6., 7., 8.]], unit='m')))
assert a[2:3].shape == (1, 3)
assert all(
np.ravel(a[:2] == osyris.Array(values=[[0., 1., 2.], [3., 4., 5.]], unit='m')))

0 comments on commit 2a0b761

Please sign in to comment.