Skip to content

Conversation

sghng
Copy link

@sghng sghng commented Oct 14, 2025

In Python 3.14, there's a change in the signature of _Pickler._batch_setitems.

It's changed to:

# pickle.py

def _batch_setitems(self, items, obj):
    # Helper to batch up SETITEMS sequences; proto >= 1 only
    save = self.save
    write = self.write

To accomodate this, in dill, we have this compatibility code:

if sys.hexversion < 0x30E00A1:
    pickler._batch_setitems(iter(source.items()))
else:
    pickler._batch_setitems(iter(source.items()), obj=obj)

Thus, the datasets package will emit this error

│ /Users/sghuang/mamba/envs/ds/lib/python3.14/site-packages/dill/_dill.py:1262 in save_module_dict │
│                                                                                                  │
│   1259 │   │   if is_dill(pickler, child=False) and pickler._session:                            │
│   1260 │   │   │   # we only care about session the first pass thru                              │
│   1261 │   │   │   pickler._first_pass = False                                                   │
│ ❱ 1262 │   │   StockPickler.save_dict(pickler, obj)                                              │
│   1263 │   │   logger.trace(pickler, "# D2")                                                     │
│   1264 │   return                                                                                │
│   1265                                                                                           │
│                                                                                                  │
│ /Users/sghuang/mamba/envs/ds/lib/python3.14/pickle.py:1133 in save_dict                          │
│                                                                                                  │
│   1130 │   │   print(f"Line number: {inspect.getsourcelines(method)[1]}")                        │
│   1131 │   │   print(f"Full path: {inspect.getmodule(method)}")                                  │
│   1132 │   │   print(f"Class: {method.__qualname__}")                                            │
│ ❱ 1133 │   │   self._batch_setitems(obj.items(), obj)                                            │
│   1134 │                                                                                         │
│   1135 │   dispatch[dict] = save_dict                                                            │
│   1136                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Pickler._batch_setitems() takes 2 positional arguments but 3 were given
[NOTE] when serializing datasets.table.InMemoryTable state
[NOTE] when serializing datasets.table.InMemoryTable object

To fix it, we update the signature of the _batch_setitems method defined in utils/_dill.py.

This fix should be backward compatible, since the compatibility is handled by dill.

This should close #7813.

Similar to joblib/joblib#1658.

Related to uqfoundation/dill#724.

@Qubitium
Copy link

Qubitium commented Oct 15, 2025

@sghng There is a regression with python 3.13.8 when lm-eval is calling datasets load

self = <test_llama3_2.TestLlama3_2 testMethod=test_llama3_2>

    def test_llama3_2(self):
>       self.quant_lm_eval()

tests/models/test_llama3_2.py:35: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/models/model_test.py:773: in quant_lm_eval
    self.model, _ = self.quantModel(self.NATIVE_MODEL_ID, batch_size=self.QUANT_BATCH_SIZE, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tests/models/model_test.py:588: in quantModel
    reuse_candidates = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tests/models/model_test.py:250: in perform_post_quant_validation
    arc_records[backend] = self.run_arc_challenge_eval(model, backend, trust_remote_code=trust_remote_code)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tests/models/model_test.py:217: in run_arc_challenge_eval
    task_results = self.lm_eval(
tests/models/model_test.py:753: in lm_eval
    raise e
tests/models/model_test.py:699: in lm_eval
    results = GPTQModel.eval(
gptqmodel/models/auto.py:474: in eval
    results = simple_evaluate(
../lm-evaluation-harness/lm_eval/utils.py:456: in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
../lm-evaluation-harness/lm_eval/evaluator.py:283: in simple_evaluate
    task_dict = get_task_dict(
../lm-evaluation-harness/lm_eval/tasks/__init__.py:635: in get_task_dict
    task_name_from_string_dict = task_manager.load_task_or_group(
../lm-evaluation-harness/lm_eval/tasks/__init__.py:426: in load_task_or_group
    collections.ChainMap(
../lm-evaluation-harness/lm_eval/tasks/__init__.py:428: in <lambda>
    lambda task: self._load_individual_task_or_group(task),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../lm-evaluation-harness/lm_eval/tasks/__init__.py:326: in _load_individual_task_or_group
    return _load_task(task_config, task=name_or_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../lm-evaluation-harness/lm_eval/tasks/__init__.py:286: in _load_task
    task_object = ConfigurableTask(config=config)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../lm-evaluation-harness/lm_eval/api/task.py:865: in __init__
    self.download(self.config.dataset_kwargs)
../lm-evaluation-harness/lm_eval/api/task.py:997: in download
    self.dataset = datasets.load_dataset(
../datasets/src/datasets/load.py:1397: in load_dataset
    builder_instance = load_dataset_builder(
../datasets/src/datasets/load.py:1185: in load_dataset_builder
    builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
../datasets/src/datasets/builder.py:612: in _use_legacy_cache_dir_if_possible
    self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../datasets/src/datasets/builder.py:485: in _check_legacy_cache2
    config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../datasets/src/datasets/fingerprint.py:188: in hash
    return cls.hash_bytes(dumps(value))
                          ^^^^^^^^^^^^
../datasets/src/datasets/utils/_dill.py:120: in dumps
    dump(obj, file)
../datasets/src/datasets/utils/_dill.py:114: in dump
    Pickler(file, recurse=True).dump(obj)
../vm313t/lib/python3.13t/site-packages/dill/_dill.py:428: in dump
    StockPickler.dump(self, obj)
/usr/lib/python3.13/pickle.py:484: in dump
    self.save(obj)
../datasets/src/datasets/utils/_dill.py:70: in save
    dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)
../vm313t/lib/python3.13t/site-packages/dill/_dill.py:422: in save
    StockPickler.save(self, obj, save_persistent_id)
/usr/lib/python3.13/pickle.py:558: in save
    f(self, obj)  # Call unbound method with explicit self
    ^^^^^^^^^^^^
../vm313t/lib/python3.13t/site-packages/dill/_dill.py:1262: in save_module_dict
    StockPickler.save_dict(pickler, obj)
_ _ _ _ _ _ _ _ _`` _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <datasets.utils._dill.Pickler object at 0x4eaf175e390>
obj = {'data_files': {'train': ['hf://datasets/allenai/ai2_arc@210d026faf9955653af8916fad021475a3f00453/ARC-Challenge/train-...//datasets/allenai/ai2_arc@210d026faf9955653af8916fad021475a3f00453/ARC-Challenge/validation-00000-of-00001.parquet']}}

    def save_dict(self, obj):
        if self.bin:
            self.write(EMPTY_DICT)
        else:   # proto 0 -- can't use EMPTY_DICT
            self.write(MARK + DICT)
    
        self.memoize(obj)
>       self._batch_setitems(obj.items())
E       TypeError: Pickler._batch_setitems() missing 1 required positional argument: 'obj'

/usr/lib/python3.13/pickle.py:990: TypeError

Python 3.13.8

(vm313t) root@gpu-base:~/datasets# pip show dill datasets transformers lm-eval
Name: dill
Version: 0.4.0
Summary: serialize all of Python
Home-page: https://github.com/uqfoundation/dill
Author: Mike McKerns
Author-email: mmckerns@uqfoundation.org
License: BSD-3-Clause
Location: /root/vm313t/lib/python3.13t/site-packages
Requires: 
Required-by: datasets, evaluate, GPTQModel, lm_eval, multiprocess
---
Name: datasets
Version: 4.2.1.dev0 <-- this PR
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: thomas@huggingface.co
License: Apache 2.0
Location: /root/vm313t/lib/python3.13t/site-packages
Editable project location: /root/datasets
Requires: dill, filelock, fsspec, httpx, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyyaml, requests, tqdm, xxhash
Required-by: evaluate, GPTQModel, lm_eval
---
Name: transformers
Version: 4.57.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /root/vm313t/lib/python3.13t/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: GPTQModel, lm_eval, peft, tokenicer
---
Name: lm_eval
Version: 0.4.9.1
Summary: A framework for evaluating language models
Home-page: https://github.com/EleutherAI/lm-evaluation-harness
Author: 
Author-email: EleutherAI <contact@eleuther.ai>
License: MIT
Location: /root/vm313t/lib/python3.13t/site-packages
Editable project location: /root/lm-evaluation-harness
Requires: accelerate, datasets, dill, evaluate, jsonlines, more_itertools, numexpr, peft, pybind11, pytablewriter, rouge-score, sacrebleu, scikit-learn, sqlitedict, torch, tqdm-multiprocess, transformers, word2number, zstandard
Required-by: 

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can use *args as in the Joblib PR to fix the issue in 3.13 and keep backward compatibility

@sghng
Copy link
Author

sghng commented Oct 15, 2025

@lhoestq Using *args works well. In fact we might as well pass through **kwargs as well to make the intent clearer.

It also appears to me that there could some other changes.

def _batch_setitems(self, items, obj):
    if self._legacy_no_dict_keys_sorting:
        return super()._batch_setitems(items)

Python dictionaries are insertion ordered official since Python 3.7, which came to end-of-life in mid 2023 already. I tried installing datasets for Python 3.8 and it resolves to version 3.1, so I think it's safe to drop this in latest version.

    dill.Pickler._batch_setitems(self, items)

This line should simply be return super()._batch_setitems(items), since we're already extending from dill.Pickler. (not having return in the original version is probably a bug)

@sghng sghng requested a review from lhoestq October 15, 2025 19:20
@sghng sghng changed the title fix: pass obj to _batch_setitems() for py3.14 fix: better args passthrough for _batch_setitems() Oct 15, 2025
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Caching does not work when using python3.14

4 participants