Skip to content

Commit 0a02b23

Browse files
authored
Implements head/tail based on iloc, and fixes bug in getitem. (#1057)
1 parent 1edf05d commit 0a02b23

File tree

4 files changed

+64
-12
lines changed

4 files changed

+64
-12
lines changed

mars/dataframe/indexing/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414

1515

1616
def _install():
17-
from .iloc import iloc
17+
from .iloc import iloc, head, tail
1818
from .loc import loc
1919
from .set_index import set_index
2020
from .getitem import dataframe_getitem, series_getitem
2121
from ..operands import DATAFRAME_TYPE, SERIES_TYPE
2222

23-
for cls in DATAFRAME_TYPE:
23+
for cls in DATAFRAME_TYPE + SERIES_TYPE:
2424
setattr(cls, 'iloc', property(iloc))
2525
setattr(cls, 'loc', property(loc))
26+
setattr(cls, 'head', head)
27+
setattr(cls, 'tail', tail)
28+
29+
for cls in DATAFRAME_TYPE:
2630
setattr(cls, 'set_index', set_index)
2731
setattr(cls, '__getitem__', dataframe_getitem)
2832
for cls in SERIES_TYPE:
29-
setattr(cls, 'iloc', property(iloc))
30-
setattr(cls, 'loc', property(loc))
3133
setattr(cls, '__getitem__', series_getitem)
3234

3335

mars/dataframe/indexing/getitem.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,11 @@ def tile_with_mask(cls, op):
276276
for idx, df_chunk in zip(out_chunk_indexes, df_chunks):
277277
mask_chunk = mask_chunks[df_chunk.index[0]]
278278
index_value = parse_index(out_df.index_value.to_pandas(), df_chunk)
279-
out_chunk = op.copy().reset_key().new_chunk([df_chunk, mask_chunk],
280-
shape=(np.nan, df_chunk.shape[1]), index=idx,
279+
out_chunk = op.copy().reset_key().new_chunk([df_chunk, mask_chunk], index=idx,
280+
shape=(np.nan, df_chunk.shape[1]),
281+
dtypes=df_chunk.dtypes,
281282
index_value=index_value,
282-
columns_value=df_chunk.columns_value,
283-
dtypes=df_chunk.dtypes)
283+
columns_value=df_chunk.columns_value)
284284
out_chunks.append(out_chunk)
285285

286286
else:
@@ -292,8 +292,10 @@ def tile_with_mask(cls, op):
292292
chunk_op = op.copy().reset_key()
293293
chunk_op._mask = op.mask.iloc[nsplits_acc[idx]:nsplits_acc[idx+1]]
294294
out_chunk = chunk_op.new_chunk([in_chunk], index=in_chunk.index,
295-
shape=(np.nan, in_chunk.shape[1]), dtypes=in_chunk.dtypes,
296-
index_value=in_df.index_value, columns_value=in_chunk.columns_value)
295+
shape=(np.nan, in_chunk.shape[1]),
296+
dtypes=in_chunk.dtypes,
297+
index_value=in_df.index_value,
298+
columns_value=in_chunk.columns_value)
297299
out_chunks.append(out_chunk)
298300

299301
nsplits = ((np.nan,) * in_df.chunk_shape[0], in_df.nsplits[1])
@@ -363,7 +365,7 @@ def execute(cls, ctx, op):
363365
mask = ctx[op.inputs[1].key]
364366
else:
365367
mask = op.mask
366-
ctx[op.outputs[0].key] = df[mask]
368+
ctx[op.outputs[0].key] = df[mask.reindex_like(ctx[op.inputs[0].key]).fillna(False)]
367369

368370

369371
_list_like_types = (list, np.ndarray, SERIES_TYPE, pd.Series, TENSOR_TYPE)
@@ -382,7 +384,9 @@ def dataframe_getitem(df, item):
382384
if col_name not in columns:
383385
raise KeyError('%s not in columns' % col_name)
384386
op = DataFrameIndex(col_names=item, object_type=ObjectType.dataframe)
385-
elif isinstance(item, _list_like_types) and astensor(item).dtype == np.bool:
387+
elif isinstance(item, _list_like_types):
388+
# NB: don't enforce the dtype of `item` to be `bool` since it may be unknown
389+
print('use mask: item = ', item)
386390
op = DataFrameIndex(mask=item, object_type=ObjectType.dataframe)
387391
else:
388392
if item not in columns:

mars/dataframe/indexing/iloc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,11 @@ def execute(cls, ctx, op):
432432

433433
def iloc(a):
434434
return DataFrameIloc(a)
435+
436+
437+
def head(a, n=5):
438+
return DataFrameIloc(a)[0:n]
439+
440+
441+
def tail(a, n=5):
442+
return DataFrameIloc(a)[-n:]

mars/dataframe/indexing/tests/test_indexing_execution.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,41 @@ def testSeriesGetitem(self):
456456
series5 = series[selected]
457457
pd.testing.assert_series_equal(
458458
self.executor.execute_dataframe(series5, concat=True)[0], data[selected])
459+
460+
def testHead(self):
461+
data = pd.DataFrame(np.random.rand(10, 5), columns=['c1', 'c2', 'c3', 'c4', 'c5'])
462+
df = md.DataFrame(data, chunk_size=2)
463+
464+
pd.testing.assert_frame_equal(
465+
self.executor.execute_dataframe(df.head(), concat=True)[0], data.head())
466+
pd.testing.assert_frame_equal(
467+
self.executor.execute_dataframe(df.head(3), concat=True)[0], data.head(3))
468+
pd.testing.assert_frame_equal(
469+
self.executor.execute_dataframe(df.head(-3), concat=True)[0], data.head(-3))
470+
pd.testing.assert_frame_equal(
471+
self.executor.execute_dataframe(df.head(8), concat=True)[0], data.head(8))
472+
pd.testing.assert_frame_equal(
473+
self.executor.execute_dataframe(df.head(-8), concat=True)[0], data.head(-8))
474+
pd.testing.assert_frame_equal(
475+
self.executor.execute_dataframe(df.head(13), concat=True)[0], data.head(13))
476+
pd.testing.assert_frame_equal(
477+
self.executor.execute_dataframe(df.head(-13), concat=True)[0], data.head(-13))
478+
479+
def testTail(self):
480+
data = pd.DataFrame(np.random.rand(10, 5), columns=['c1', 'c2', 'c3', 'c4', 'c5'])
481+
df = md.DataFrame(data, chunk_size=2)
482+
483+
pd.testing.assert_frame_equal(
484+
self.executor.execute_dataframe(df.tail(), concat=True)[0], data.tail())
485+
pd.testing.assert_frame_equal(
486+
self.executor.execute_dataframe(df.tail(3), concat=True)[0], data.tail(3))
487+
pd.testing.assert_frame_equal(
488+
self.executor.execute_dataframe(df.tail(-3), concat=True)[0], data.tail(-3))
489+
pd.testing.assert_frame_equal(
490+
self.executor.execute_dataframe(df.tail(8), concat=True)[0], data.tail(8))
491+
pd.testing.assert_frame_equal(
492+
self.executor.execute_dataframe(df.tail(-8), concat=True)[0], data.tail(-8))
493+
pd.testing.assert_frame_equal(
494+
self.executor.execute_dataframe(df.tail(13), concat=True)[0], data.tail(13))
495+
pd.testing.assert_frame_equal(
496+
self.executor.execute_dataframe(df.tail(-13), concat=True)[0], data.tail(-13))

0 commit comments

Comments
 (0)