Skip to content

Commit 18981d2

Browse files
authored
[BACKPORT] Fix the error raised when inferring dtype in DataFrame.transform (#1424) (#1427)
1 parent e9c744b commit 18981d2

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

mars/dataframe/base/tests/test_base_execution.py

+17
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,23 @@ def rename_fn(f, new_name):
581581
result = self.executor.execute_dataframe(r, concat=True)[0]
582582
expected = s_raw.transform(['cumsum', lambda x: x + 1])
583583
pd.testing.assert_frame_equal(result, expected)
584+
585+
# test transform on string dtype
586+
df_raw = pd.DataFrame({'col1': ['str'] * 10, 'col2': ['string'] * 10})
587+
df = from_pandas_df(df_raw, chunk_size=3)
588+
589+
with self.assertRaises(TypeError):
590+
df['col1'].transform(lambda x: x + '_suffix')
591+
592+
r = df.transform(lambda x: x + '_suffix')
593+
result = self.executor.execute_dataframe(r, concat=True)[0]
594+
expected = df_raw.transform(lambda x: x + '_suffix')
595+
pd.testing.assert_frame_equal(result, expected)
596+
597+
r = df['col2'].transform(lambda x: x + '_suffix', dtype=np.dtype('str'))
598+
result = self.executor.execute_dataframe(r, concat=True)[0]
599+
expected = df_raw['col2'].transform(lambda x: x + '_suffix')
600+
pd.testing.assert_series_equal(result, expected)
584601
finally:
585602
options.chunk_store_limit = old_chunk_store_limit
586603

mars/dataframe/base/transform.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,35 @@ def tile(cls, op: "TransformOperand"):
171171
def _infer_df_func_returns(self, in_dtypes, dtypes):
172172
if self.object_type == ObjectType.dataframe:
173173
empty_df = build_empty_df(in_dtypes, index=pd.RangeIndex(2))
174-
with np.errstate(all='ignore'):
175-
if self.call_agg:
176-
infer_df = empty_df.agg(self._func, axis=self._axis, *self.args, **self.kwds)
177-
else:
178-
infer_df = empty_df.transform(self._func, axis=self._axis, *self.args, **self.kwds)
174+
try:
175+
with np.errstate(all='ignore'):
176+
if self.call_agg:
177+
infer_df = empty_df.agg(self._func, axis=self._axis, *self.args, **self.kwds)
178+
else:
179+
infer_df = empty_df.transform(self._func, axis=self._axis, *self.args, **self.kwds)
180+
except: # noqa: E722
181+
infer_df = None
179182
else:
180183
empty_df = build_empty_series(in_dtypes[1], index=pd.RangeIndex(2), name=in_dtypes[0])
181-
with np.errstate(all='ignore'):
182-
if self.call_agg:
183-
infer_df = empty_df.agg(self._func, args=self.args, **self.kwds)
184-
else:
185-
infer_df = empty_df.transform(self._func, convert_dtype=self.convert_dtype,
186-
args=self.args, **self.kwds)
184+
try:
185+
with np.errstate(all='ignore'):
186+
if self.call_agg:
187+
infer_df = empty_df.agg(self._func, args=self.args, **self.kwds)
188+
else:
189+
infer_df = empty_df.transform(self._func, convert_dtype=self.convert_dtype,
190+
args=self.args, **self.kwds)
191+
except: # noqa: E722
192+
infer_df = None
193+
194+
if infer_df is None and dtypes is None:
195+
raise TypeError('Failed to infer dtype, please specify dtypes as arguments.')
196+
197+
if infer_df is None:
198+
is_df = self.object_type == ObjectType.dataframe
199+
else:
200+
is_df = isinstance(infer_df, pd.DataFrame)
187201

188-
if isinstance(infer_df, pd.DataFrame):
202+
if is_df:
189203
new_dtypes = dtypes or infer_df.dtypes
190204
self._object_type = ObjectType.dataframe
191205
else:

0 commit comments

Comments
 (0)