|
21 | 21 |
|
22 | 22 | logger = get_logger() |
23 | 23 |
|
24 | | -datasets_4 = version.parse(datasets.__version__) >= version.parse('4.0') |
25 | 24 | _pair_keys = ['messages', 'images', 'videos', 'audios', 'tools', 'objects'] |
26 | 25 |
|
27 | 26 |
|
@@ -55,6 +54,7 @@ def __init__(self, |
55 | 54 | self.traceback_limit = traceback_limit |
56 | 55 | self._traceback_counter = 0 |
57 | 56 | self.dataset_sample = dataset_sample |
| 57 | + self.datasets_4 = version.parse(datasets.__version__) >= version.parse('4.0') |
58 | 58 | if not isinstance(random_state, np.random.RandomState): |
59 | 59 | random_state = np.random.RandomState(random_state) |
60 | 60 | self.random_state = random_state |
@@ -244,17 +244,16 @@ def remove_useless_columns(dataset: DATASET_TYPE) -> DATASET_TYPE: |
244 | 244 | dataset = dataset.select_columns(k_list) |
245 | 245 | return dataset |
246 | 246 |
|
247 | | - @staticmethod |
248 | 247 | @contextmanager |
249 | | - def _patch_arrow_writer(): |
| 248 | + def _patch_arrow_writer(self): |
250 | 249 | # fix AI-ModelScope/ms_agent_for_agentfabric:all |
251 | 250 | from datasets.arrow_writer import ArrowWriter |
252 | 251 |
|
253 | | - def _new_init(self, schema=None, features=None, *args, **kwargs): |
| 252 | + def _new_init(_self, schema=None, features=None, *args, **kwargs): |
254 | 253 |
|
255 | 254 | if features is not None: |
256 | 255 |
|
257 | | - if datasets_4: |
| 256 | + if self.datasets_4: |
258 | 257 | from datasets.features import Json, List |
259 | 258 | messages_feature = List(Json()) |
260 | 259 | for key in ['messages', 'rejected_messages', 'positive_messages', 'negative_messages']: |
@@ -283,7 +282,7 @@ def _new_init(self, schema=None, features=None, *args, **kwargs): |
283 | 282 | 'bbox_type': Value(dtype='string'), |
284 | 283 | 'image_id': Sequence(feature=Value(dtype='int64'), length=-1), |
285 | 284 | } |
286 | | - ArrowWriter.__origin_init__(self, schema, features, *args, **kwargs) |
| 285 | + ArrowWriter.__origin_init__(_self, schema, features, *args, **kwargs) |
287 | 286 |
|
288 | 287 | ArrowWriter.__origin_init__ = ArrowWriter.__init__ |
289 | 288 | ArrowWriter.__init__ = _new_init |
|
0 commit comments