Skip to content

Commit 1edf05d

Browse files
wjsi继盛hekaisheng
authored
Fix consistency between tensor metadata and real outputs (#1085)
Co-authored-by: 继盛 <[email protected]> Co-authored-by: hekaisheng <[email protected]>
1 parent e5c11e4 commit 1edf05d

File tree

21 files changed

+162
-79
lines changed

21 files changed

+162
-79
lines changed

mars/dataframe/align.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def _gen_dataframe_chunks(splits, out_shape, left_or_right, df):
579579
chunk_kw = {
580580
'index_value': chunk.index_value if splits[0].isdummy() else None,
581581
'columns_value': chunk.columns_value if splits[1].isdummy() else None,
582+
'dtypes': chunk.dtypes if splits[1].isdummy() else None
582583
}
583584
align_op = DataFrameIndexAlign(
584585
stage=OperandStage.map, index_min_max=index_min_max,

mars/dataframe/arithmetic/core.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ def _tile_scalar(cls, op):
195195
out_chunks.append(out_chunk)
196196

197197
new_op = op.copy()
198+
out = op.outputs[0]
198199
if isinstance(df, SERIES_TYPE):
199-
return new_op.new_seriess(op.inputs, df.shape, nsplits=tileable.nsplits, dtype=df.dtype,
200+
return new_op.new_seriess(op.inputs, df.shape, nsplits=tileable.nsplits, dtype=out.dtype,
200201
index_value=df.index_value, name=df.name, chunks=out_chunks)
201202
else:
202-
return new_op.new_dataframes(op.inputs, df.shape, nsplits=tileable.nsplits, dtypes=df.dtypes,
203+
return new_op.new_dataframes(op.inputs, df.shape, nsplits=tileable.nsplits, dtypes=out.dtypes,
203204
index_value=df.index_value, columns_value=df.columns_value,
204205
chunks=out_chunks)
205206

@@ -233,11 +234,12 @@ def _tile_with_tensor(cls, op):
233234
out_chunks.append(out_chunk)
234235

235236
new_op = op.copy()
237+
out = op.outputs[0]
236238
if isinstance(other, SERIES_TYPE):
237-
return new_op.new_seriess(op.inputs, other.shape, nsplits=other.nsplits, dtype=other.dtype,
238-
index_value=other.index_value, name=other.name, chunks=out_chunks)
239+
return new_op.new_seriess(op.inputs, other.shape, nsplits=other.nsplits, dtype=out.dtype,
240+
index_value=other.index_value, chunks=out_chunks)
239241
else:
240-
return new_op.new_dataframes(op.inputs, other.shape, nsplits=other.nsplits, dtypes=other.dtypes,
242+
return new_op.new_dataframes(op.inputs, other.shape, nsplits=other.nsplits, dtypes=out.dtypes,
241243
index_value=other.index_value, columns_value=other.columns_value,
242244
chunks=out_chunks)
243245

@@ -294,8 +296,17 @@ def _operator(self):
294296
def _calc_properties(cls, x1, x2=None, axis='columns'):
295297
if isinstance(x1, (DATAFRAME_TYPE, DATAFRAME_CHUNK_TYPE)) \
296298
and (x2 is None or np.isscalar(x2) or isinstance(x2, TENSOR_TYPE)):
297-
# FIXME infer the dtypes of result df properly
298-
return {'shape': x1.shape, 'dtypes': x1.dtypes,
299+
if x2 is None:
300+
dtypes = x1.dtypes
301+
elif np.isscalar(x2):
302+
dtypes = infer_dtypes(x1.dtypes, pd.Series(np.array(x2).dtype), cls._operator)
303+
elif x1.dtypes is not None and isinstance(x2, TENSOR_TYPE):
304+
dtypes = pd.Series(
305+
[infer_dtype(dt, x2.dtype, cls._operator) for dt in x1.dtypes],
306+
index=x1.dtypes.index)
307+
else:
308+
dtypes = x1.dtypes
309+
return {'shape': x1.shape, 'dtypes': dtypes,
299310
'columns_value': x1.columns_value, 'index_value': x1.index_value}
300311

301312
if isinstance(x1, (SERIES_TYPE, SERIES_CHUNK_TYPE)) \
@@ -310,7 +321,9 @@ def _calc_properties(cls, x1, x2=None, axis='columns'):
310321

311322
if x1.columns_value is not None and x2.columns_value is not None and \
312323
x1.columns_value.key == x2.columns_value.key:
313-
dtypes = x1.dtypes
324+
dtypes = pd.Series([infer_dtype(dt1, dt2, cls._operator) for dt1, dt2
325+
in zip(x1.dtypes, x2.dtypes)],
326+
index=x1.dtypes.index)
314327
columns = copy.copy(x1.columns_value)
315328
columns.value.should_be_monotonic = False
316329
column_shape = len(dtypes)
@@ -342,11 +355,12 @@ def _calc_properties(cls, x1, x2=None, axis='columns'):
342355
column_shape, dtypes, columns = np.nan, None, None
343356
if x1.columns_value is not None and x1.index_value is not None:
344357
if x1.columns_value.key == x2.index_value.key:
345-
dtypes = x1.dtypes
358+
dtypes = pd.Series([infer_dtype(dt, x2.dtype, cls._operator) for dt in x1.dtypes],
359+
index=x1.dtypes.index)
346360
columns = copy.copy(x1.columns_value)
347361
columns.value.should_be_monotonic = False
348362
column_shape = len(dtypes)
349-
else:
363+
else: # pragma: no cover
350364
dtypes = x1.dtypes # FIXME
351365
columns = infer_index_value(x1.columns_value, x2.index_value)
352366
columns.value.should_be_monotonic = True
@@ -359,10 +373,16 @@ def _calc_properties(cls, x1, x2=None, axis='columns'):
359373
index_shape, index = np.nan, None
360374
if x1.index_value is not None and x1.index_value is not None:
361375
if x1.index_value.key == x2.index_value.key:
362-
index = copy.copy(x1.columns_value)
376+
dtypes = pd.Series([infer_dtype(dt, x2.dtype, cls._operator) for dt in x1.dtypes],
377+
index=x1.dtypes.index)
378+
index = copy.copy(x1.index_value)
363379
index.value.should_be_monotonic = False
364380
index_shape = x1.shape[0]
365381
else:
382+
if x1.dtypes is not None:
383+
dtypes = pd.Series(
384+
[infer_dtype(dt, x2.dtype, cls._operator) for dt in x1.dtypes],
385+
index=x1.dtypes.index)
366386
index = infer_index_value(x1.index_value, x2.index_value)
367387
index.value.should_be_monotonic = True
368388
index_shape = np.nan

mars/dataframe/groupby/apply.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def tile(cls, op):
107107
columns_value=out_df.columns_value, index_value=out_df.index_value))
108108
else:
109109
chunks.append(new_op.new_chunk(
110-
inp_chunks, index=c.index, shape=(np.nan,), dtype=out_df.dtype,
110+
inp_chunks, name=out_df.name, index=c.index, shape=(np.nan,), dtype=out_df.dtype,
111111
index_value=out_df.index_value))
112112

113113
new_op = op.copy().reset_key()
@@ -129,7 +129,7 @@ def _infer_df_func_returns(self, in_object_type, in_dtypes, dtypes, index):
129129
if in_object_type == ObjectType.dataframe:
130130
empty_df = build_empty_df(in_dtypes, index=pd.RangeIndex(2))
131131
else:
132-
empty_df = build_empty_series(in_dtypes, index=pd.RangeIndex(2))
132+
empty_df = build_empty_series(in_dtypes[1], index=pd.RangeIndex(2), name=in_dtypes[0])
133133

134134
with np.errstate(all='ignore'):
135135
if self.is_transform:
@@ -148,10 +148,10 @@ def _infer_df_func_returns(self, in_object_type, in_dtypes, dtypes, index):
148148
new_dtypes = new_dtypes or infer_df.dtypes
149149
elif isinstance(infer_df, pd.Series):
150150
object_type = object_type or ObjectType.series
151-
new_dtypes = new_dtypes or infer_df.dtype
151+
new_dtypes = new_dtypes or (infer_df.name, infer_df.dtype)
152152
else:
153153
object_type = ObjectType.series
154-
new_dtypes = pd.Series(infer_df).dtype
154+
new_dtypes = (None, pd.Series(infer_df).dtype)
155155
except: # noqa: E722 # nosec
156156
pass
157157

@@ -164,7 +164,8 @@ def __call__(self, groupby, dtypes=None, index=None):
164164
in_df = groupby.inputs[0]
165165
in_dtypes = getattr(in_df, 'dtypes', None)
166166
if in_dtypes is None:
167-
in_dtypes = in_df.dtype
167+
in_dtypes = (in_df.name, in_df.dtype)
168+
168169
dtypes, index_value = self._infer_df_func_returns(
169170
in_df.op.object_type, in_dtypes, dtypes, index)
170171
for arg, desc in zip((self._object_type, dtypes, index_value),
@@ -178,8 +179,10 @@ def __call__(self, groupby, dtypes=None, index=None):
178179
return self.new_dataframe([groupby], shape=new_shape, dtypes=dtypes,
179180
index_value=index_value, columns_value=in_df.columns_value)
180181
else:
182+
name, dtype = dtypes
181183
new_shape = in_df.shape if self.is_transform else (np.nan,)
182-
return self.new_series([groupby], shape=new_shape, dtype=dtypes, index_value=index_value)
184+
return self.new_series([groupby], name=name, shape=new_shape, dtype=dtype,
185+
index_value=index_value)
183186

184187

185188
class GroupByApply(GroupByApplyTransform):

mars/dataframe/indexing/getitem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,10 @@ def tile_with_mask(cls, op):
275275
out_chunks = []
276276
for idx, df_chunk in zip(out_chunk_indexes, df_chunks):
277277
mask_chunk = mask_chunks[df_chunk.index[0]]
278+
index_value = parse_index(out_df.index_value.to_pandas(), df_chunk)
278279
out_chunk = op.copy().reset_key().new_chunk([df_chunk, mask_chunk],
279280
shape=(np.nan, df_chunk.shape[1]), index=idx,
280-
index_value=df_chunk.index_value,
281+
index_value=index_value,
281282
columns_value=df_chunk.columns_value,
282283
dtypes=df_chunk.dtypes)
283284
out_chunks.append(out_chunk)

mars/dataframe/indexing/index_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def set_chunk_index_info(cls,
475475
assert index_info.input_axis == 0, \
476476
'bool indexing on axis columns cannot be tensor'
477477

478-
index_value = parse_index(chunk_input.index_value.to_pandas(),
478+
index_value = parse_index(pd.Index([], chunk_input.index_value.to_pandas().dtype),
479479
chunk_input, index, store_data=False)
480480

481481
info = ChunkIndexAxisInfo(output_axis_index=output_axis_index,

mars/dataframe/merge/concat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _call_series(self, objs):
269269
else:
270270
index_value = parse_index(index)
271271
return self.new_series(objs, shape=(row_length,), dtype=objs[0].dtype,
272-
index_value=index_value)
272+
index_value=index_value, name=objs[0].name)
273273
else:
274274
self._object_type = ObjectType.dataframe
275275
col_length = 0

mars/dataframe/reduction/aggregation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _gen_map_chunks(cls, op, in_df, out_df, stage_infos: List[_stage_info],
285285
columns_value=chunk.columns_value)
286286
elif op.object_type == ObjectType.series:
287287
agg_chunk = map_op.new_chunk([chunk], shape=(out_df.shape[0], 1), index=new_index,
288-
index_value=out_df.index_value)
288+
index_value=out_df.index_value, name=out_df.name)
289289
else: # scalar target
290290
agg_chunk = map_op.new_chunk([chunk], shape=(1, 1), index=new_index)
291291
agg_chunks[agg_chunk.index] = agg_chunk
@@ -299,10 +299,11 @@ def _tile_single_chunk(cls, op: "DataFrameAggregate"):
299299
chunk_op = op.copy().reset_key()
300300
if op.object_type == ObjectType.dataframe:
301301
chunk = chunk_op.new_chunk(in_df.chunks, index=(0, 0), shape=out_df.shape,
302-
index_value=out_df.index_value, columns_value=out_df.columns_value)
302+
index_value=out_df.index_value, columns_value=out_df.columns_value,
303+
dtypes=out_df.dtypes)
303304
else:
304305
chunk = chunk_op.new_chunk(in_df.chunks, index=(0,), shape=out_df.shape,
305-
index_value=out_df.index_value)
306+
index_value=out_df.index_value, name=out_df.name)
306307

307308
tileable_op = op.copy().reset_key()
308309
kw = out_df.params.copy()

mars/dataframe/reduction/core.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...operands import OperandStage
2222
from ...utils import lazy_import
2323
from ...serialize import BoolField, AnyField, DataTypeField, Int32Field
24-
from ..utils import parse_index, build_empty_df, validate_axis
24+
from ..utils import parse_index, build_empty_df, build_empty_series, validate_axis
2525
from ..operands import DataFrameOperandMixin, DataFrameOperand, ObjectType, DATAFRAME_TYPE
2626
from ..merge import DataFrameConcat
2727

@@ -113,20 +113,26 @@ class DataFrameReductionMixin(DataFrameOperandMixin):
113113
@classmethod
114114
def _tile_one_chunk(cls, op):
115115
df = op.outputs[0]
116-
params = df.params
117116

118117
chk = op.inputs[0].chunks[0]
119118
chunk_params = {k: v for k, v in chk.params.items()
120119
if k in df.params}
121120
chunk_params['shape'] = df.shape
122121
chunk_params['index'] = chk.index
122+
if op.object_type == ObjectType.series:
123+
chunk_params.update(dict(dtype=df.dtype, index_value=df.index_value))
124+
elif op.object_type == ObjectType.dataframe:
125+
chunk_params.update(dict(dtypes=df.dtypes, index_value=df.index_value,
126+
columns_value=df.columns_value))
127+
else:
128+
chunk_params.update(dict(dtype=df.dtype))
123129
new_chunk_op = op.copy().reset_key()
124130
chunk = new_chunk_op.new_chunk(op.inputs[0].chunks, kws=[chunk_params])
125131

126132
new_op = op.copy()
127133
nsplits = tuple((s,) for s in chunk.shape)
128-
params['chunks'] = [chunk]
129-
params['nsplits'] = nsplits
134+
params = df.params.copy()
135+
params.update(dict(chunks=[chunk], nsplits=nsplits))
130136
return new_op.new_tileables(op.inputs, kws=[params])
131137

132138
@classmethod
@@ -402,30 +408,45 @@ def execute(cls, ctx, op):
402408
def _call_dataframe(self, df):
403409
axis = getattr(self, 'axis', None) or 0
404410
level = getattr(self, 'level', None)
411+
skipna = getattr(self, 'skipna', None)
405412
numeric_only = getattr(self, 'numeric_only', None)
406413
self._axis = axis = validate_axis(axis, df)
407414
# TODO: enable specify level if we support groupby
408415
if level is not None:
409416
raise NotImplementedError('Not support specify level now')
410417

411418
empty_df = build_empty_df(df.dtypes)
412-
reduced_df = getattr(empty_df, getattr(self, '_func_name'))(axis=axis, level=level,
413-
numeric_only=numeric_only)
419+
func_name = getattr(self, '_func_name')
420+
if func_name == 'count':
421+
reduced_df = getattr(empty_df, func_name)(axis=axis, level=level, numeric_only=numeric_only)
422+
else:
423+
reduced_df = getattr(empty_df, func_name)(axis=axis, level=level, skipna=skipna,
424+
numeric_only=numeric_only)
414425
reduced_shape = (df.shape[0],) if axis == 1 else reduced_df.shape
415426
return self.new_series([df], shape=reduced_shape, dtype=reduced_df.dtype,
416-
index_value=parse_index(reduced_df.index))
427+
index_value=parse_index(reduced_df.index, store_data=axis == 0))
417428

418429
def _call_series(self, series):
419430
level = getattr(self, 'level', None)
420431
axis = getattr(self, 'axis', None)
432+
skipna = getattr(self, 'skipna', None)
433+
numeric_only = getattr(self, 'numeric_only', None)
421434
if axis == 'index':
422435
axis = 0
423436
self._axis = axis
424437
# TODO: enable specify level if we support groupby
425438
if level is not None:
426439
raise NotImplementedError('Not support specified level now')
427440

428-
return self.new_scalar([series], dtype=series.dtype)
441+
empty_series = build_empty_series(series.dtype)
442+
func_name = getattr(self, '_func_name')
443+
if func_name == 'count':
444+
reduced_series = empty_series.count(level=level)
445+
else:
446+
reduced_series = getattr(empty_series, func_name)(axis=axis, level=level, skipna=skipna,
447+
numeric_only=numeric_only)
448+
449+
return self.new_scalar([series], dtype=np.array(reduced_series).dtype)
429450

430451
def __call__(self, a):
431452
if isinstance(a, DATAFRAME_TYPE):
@@ -438,7 +459,7 @@ class DataFrameCumReductionMixin(DataFrameOperandMixin):
438459
@classmethod
439460
def _tile_one_chunk(cls, op):
440461
df = op.outputs[0]
441-
params = df.params
462+
params = df.params.copy()
442463

443464
chk = op.inputs[0].chunks[0]
444465
chunk_params = {k: v for k, v in chk.params.items()
@@ -525,7 +546,7 @@ def _tile_series(cls, op):
525546
new_op = op.copy().reset_key()
526547
return new_op.new_tileables(op.inputs, shape=in_series.shape, nsplits=in_series.nsplits,
527548
chunks=output_chunks, dtype=series.dtype,
528-
index_value=series.index_value)
549+
index_value=series.index_value, name=series.name)
529550

530551
@classmethod
531552
def tile(cls, op):

0 commit comments

Comments
 (0)