Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions lmms_eval/caching/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import hashlib
import os
import pickle

import dill

from lmms_eval.loggers.utils import _handle_non_serializable
from lmms_eval.loggers.utils import _handle_non_serializable, is_serializable
from lmms_eval.utils import eval_logger

MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -43,17 +44,18 @@ def save_to_cache(file_name, obj):
serializable_obj = []

for item in obj:
sub_serializable_obj = []
for subitem in item:
if hasattr(subitem, "arguments"): # we need to handle the arguments specially since doc_to_visual is callable method and not serializable
serializable_arguments = tuple(arg if not callable(arg) else None for arg in subitem.arguments)
subitem.arguments = serializable_arguments
sub_serializable_obj.append(_handle_non_serializable(subitem))
serializable_obj.append(sub_serializable_obj)

eval_logger.debug(f"Saving {file_path} to cache...")
with open(file_path, "wb") as file:
file.write(dill.dumps(serializable_obj))
try:
with open(file_path, "wb") as file:
file.write(dill.dumps(serializable_obj))
except (pickle.PickleError, dill.PicklingError, TypeError, AttributeError):
with open(file_path, "wb") as file:
file.write(dill.dumps([[subitem if is_serializable(subitem) else _handle_non_serializable(subitem) for subitem in item] for item in obj]))


# NOTE the "key" param is to allow for flexibility
Expand Down
9 changes: 9 additions & 0 deletions lmms_eval/loggers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
import re
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -32,6 +33,14 @@ def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
return result, removed


def is_serializable(o: Any) -> bool:
try:
pickle.dumps(o)
return True
except (pickle.PickleError, TypeError, AttributeError):
return False


def _handle_non_serializable(o: Any) -> Union[int, str, list]:
"""Handle non-serializable objects by converting them to serializable types.

Expand Down