From 3db3350b83dcfac827f43358e62a221a29520e18 Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:24:52 +0100 Subject: [PATCH 1/8] add safe tensors as dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 39cf750c66..ebd93792c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ pptree>=3.1 python-dateutil>=2.8.2 pytorch_revgrad>=0.2.0 regex>=2022.1.18 +safetensors>=0.4.0 scikit-learn>=1.0.2 segtok>=1.5.11 sqlitedict>=2.0.0 From 52bf96bc6818121beed9b23328a35b28e0c289ea Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:35:31 +0100 Subject: [PATCH 2/8] add utils for safe tensor --- flair/safetensors_utils.py | 188 +++++++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 flair/safetensors_utils.py diff --git a/flair/safetensors_utils.py b/flair/safetensors_utils.py new file mode 100644 index 0000000000..b1f5c73c0f --- /dev/null +++ b/flair/safetensors_utils.py @@ -0,0 +1,188 @@ +import json +from pathlib import Path +from typing import Any, Union + +import torch +from safetensors.torch import load_file as safetensors_load_file +from safetensors.torch import save_file as safetensors_save_file + + +def _is_tensor(value: Any) -> bool: + return isinstance(value, torch.Tensor) + + +def _flatten_dict( + d: dict[str, Any], parent_key: str = "", sep: str = "." +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + tensors: dict[str, torch.Tensor] = {} + metadata: dict[str, Any] = {} + + for key, value in d.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + + if _is_tensor(value): + tensors[new_key] = value + metadata[key] = {"__tensor_key__": new_key} + elif isinstance(value, dict): + nested_tensors, nested_metadata = _flatten_dict(value, new_key, sep) + tensors.update(nested_tensors) + metadata[key] = nested_metadata + elif isinstance(value, list): + processed_list, list_tensors = _process_list(value, new_key, sep) + tensors.update(list_tensors) + metadata[key] = processed_list + else: + metadata[key] = value + + return tensors, metadata + + +def _process_list( + lst: list, parent_key: str, sep: str +) -> tuple[list, dict[str, torch.Tensor]]: + tensors: dict[str, torch.Tensor] = {} + processed: list = [] + + for i, item in enumerate(lst): + item_key = f"{parent_key}{sep}{i}" + + if _is_tensor(item): + tensors[item_key] = item + processed.append({"__tensor_key__": item_key}) + elif isinstance(item, dict): + nested_tensors, nested_metadata = _flatten_dict(item, item_key, sep) + tensors.update(nested_tensors) + processed.append(nested_metadata) + elif isinstance(item, list): + nested_list, nested_tensors = _process_list(item, item_key, sep) + tensors.update(nested_tensors) + processed.append(nested_list) + else: + processed.append(item) + + return processed, tensors + + +def _unflatten_dict( + metadata: dict[str, Any], tensors: dict[str, torch.Tensor] +) -> dict[str, Any]: + result: dict[str, Any] = {} + + for key, value in metadata.items(): + if isinstance(value, dict): + if "__tensor_key__" in value: + tensor_key = value["__tensor_key__"] + result[key] = tensors[tensor_key] + else: + result[key] = _unflatten_dict(value, tensors) + elif isinstance(value, list): + result[key] = _unflatten_list(value, tensors) + else: + result[key] = value + + return result + + +def _unflatten_list(lst: list, tensors: dict[str, torch.Tensor]) -> list: + result: list = [] + + for item in lst: + if isinstance(item, dict): + if "__tensor_key__" in item: + tensor_key = item["__tensor_key__"] + result.append(tensors[tensor_key]) + else: + result.append(_unflatten_dict(item, tensors)) + elif isinstance(item, list): + result.append(_unflatten_list(item, tensors)) + else: + result.append(item) + + return result + + +def separate_tensors_and_metadata( + state_dict: dict[str, Any], +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + return _flatten_dict(state_dict) + + +def combine_tensors_and_metadata( + tensors: dict[str, torch.Tensor], metadata: dict[str, Any] +) -> dict[str, Any]: + return _unflatten_dict(metadata, tensors) + + +def _json_serializer(obj: Any) -> Any: + if isinstance(obj, torch.dtype): + return {"__torch_dtype__": str(obj)} + if isinstance(obj, torch.device): + return {"__torch_device__": str(obj)} + if isinstance(obj, type): + return {"__class__": f"{obj.__module__}.{obj.__name__}"} + if hasattr(obj, "to_dict"): + return obj.to_dict() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def _json_deserializer(obj: dict) -> Any: + if "__torch_dtype__" in obj: + dtype_str = obj["__torch_dtype__"] + dtype_map = { + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.int16": torch.int16, + "torch.int8": torch.int8, + "torch.uint8": torch.uint8, + "torch.bool": torch.bool, + } + return dtype_map.get(dtype_str, torch.float32) + if "__torch_device__" in obj: + return torch.device(obj["__torch_device__"]) + return obj + + +class SafetensorsSerializer: + TENSORS_FILENAME = "model.safetensors" + METADATA_FILENAME = "model_metadata.json" + + @classmethod + def save(cls, state_dict: dict[str, Any], path: Union[str, Path]) -> None: + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + tensors, metadata = separate_tensors_and_metadata(state_dict) + + tensors_path = path / cls.TENSORS_FILENAME + if tensors: + safetensors_save_file(tensors, str(tensors_path)) + else: + safetensors_save_file({}, str(tensors_path)) + + metadata_path = path / cls.METADATA_FILENAME + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, default=_json_serializer) + + @classmethod + def load(cls, path: Union[str, Path]) -> dict[str, Any]: + path = Path(path) + + tensors_path = path / cls.TENSORS_FILENAME + tensors = safetensors_load_file(str(tensors_path)) + + metadata_path = path / cls.METADATA_FILENAME + with open(metadata_path, encoding="utf-8") as f: + metadata = json.load(f, object_hook=_json_deserializer) + + return combine_tensors_and_metadata(tensors, metadata) + + @classmethod + def is_safetensors_model(cls, path: Union[str, Path]) -> bool: + path = Path(path) + if path.is_dir(): + return (path / cls.TENSORS_FILENAME).exists() and (path / cls.METADATA_FILENAME).exists() + return False \ No newline at end of file From 9ca0b55ce6f29764c8a0a7c1f2349a12d023415e Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:37:35 +0100 Subject: [PATCH 3/8] add flag and enable SafetensorsSerializer usage in save method --- flair/nn/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flair/nn/model.py b/flair/nn/model.py index 1ca2731d48..d0f14fcd15 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -22,6 +22,7 @@ from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state +from flair.safetensors_utils import SafetensorsSerializer from flair.training_utils import EmbeddingStorageMode, Result, store_embeddings import importlib From f69a311d86c27da0c4f4bd504855158f00140c30 Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:44:22 +0100 Subject: [PATCH 4/8] deprection warnings --- flair/nn/model.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index d0f14fcd15..414de7daf0 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -2,6 +2,7 @@ import itertools import logging import typing +import warnings from abc import ABC, abstractmethod from collections import Counter from pathlib import Path @@ -26,6 +27,13 @@ from flair.training_utils import EmbeddingStorageMode, Result, store_embeddings import importlib + +def _load_state(model_path: Union[str, Path]) -> dict[str, Any]: + path = Path(model_path) + if SafetensorsSerializer.is_safetensors_model(path): + return SafetensorsSerializer.load(path) + return load_torch_state(str(path)) + log = logging.getLogger("flair") @@ -274,21 +282,28 @@ def _fetch_model(model_identifier: str): """ return model_identifier - def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + def save( + self, + model_file: Union[str, Path], + checkpoint: bool = False, + use_safetensors: bool = True, + ) -> None: """Saves the current model to the provided file. Args: - model_file: The model file. + model_file: The model file path. For safetensors format, this will be a directory. checkpoint: This parameter is currently unused. + use_safetensors: If True (default), save using safetensors format. If False, use pickle. """ model_state = self._get_state_dict() - # write out a "model card" if one is set if self.model_card is not None: model_state["model_card"] = self.model_card - # save model - torch.save(model_state, str(model_file), pickle_protocol=4) + if use_safetensors: + SafetensorsSerializer.save(model_state, model_file) + else: + torch.save(model_state, str(model_file), pickle_protocol=4) @property def license_info(self) -> str: @@ -338,7 +353,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": # if the model cannot be fetched, load as a file try: - state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path)) + state = model_path if isinstance(model_path, dict) else _load_state(model_path) except Exception: log.error("-" * 80) log.error( @@ -372,7 +387,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": # if this class is not abstract, fetch the model and load it if not isinstance(model_path, dict): model_file = cls._fetch_model(str(model_path)) - state = load_torch_state(model_file) + state = _load_state(model_file) else: state = model_path From 00188d0694e127ea35ca91dc8ac5588499cb8df5 Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:48:46 +0100 Subject: [PATCH 5/8] bytes deserialization + warnign --- flair/nn/model.py | 6 ++++++ flair/safetensors_utils.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/flair/nn/model.py b/flair/nn/model.py index 414de7daf0..a4eca58128 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -32,6 +32,12 @@ def _load_state(model_path: Union[str, Path]) -> dict[str, Any]: path = Path(model_path) if SafetensorsSerializer.is_safetensors_model(path): return SafetensorsSerializer.load(path) + warnings.warn( + "Loading model from pickle format. Pickle is deprecated due to security concerns. " + "Consider re-saving the model with model.save() to convert to safetensors format.", + FutureWarning, + stacklevel=3, + ) return load_torch_state(str(path)) log = logging.getLogger("flair") diff --git a/flair/safetensors_utils.py b/flair/safetensors_utils.py index b1f5c73c0f..80b1c826ef 100644 --- a/flair/safetensors_utils.py +++ b/flair/safetensors_utils.py @@ -1,4 +1,6 @@ +import base64 import json +from io import BytesIO from pathlib import Path from typing import Any, Union @@ -114,6 +116,10 @@ def combine_tensors_and_metadata( def _json_serializer(obj: Any) -> Any: + if isinstance(obj, bytes): + return {"__bytes__": base64.b64encode(obj).decode("ascii")} + if isinstance(obj, BytesIO): + return {"__bytesio__": base64.b64encode(obj.getvalue()).decode("ascii")} if isinstance(obj, torch.dtype): return {"__torch_dtype__": str(obj)} if isinstance(obj, torch.device): @@ -126,6 +132,10 @@ def _json_serializer(obj: Any) -> Any: def _json_deserializer(obj: dict) -> Any: + if "__bytes__" in obj: + return base64.b64decode(obj["__bytes__"]) + if "__bytesio__" in obj: + return BytesIO(base64.b64decode(obj["__bytesio__"])) if "__torch_dtype__" in obj: dtype_str = obj["__torch_dtype__"] dtype_map = { From 186922abb1833515dc379109905a6a4d6dc053c2 Mon Sep 17 00:00:00 2001 From: antonallote Date: Fri, 9 Jan 2026 15:50:48 +0100 Subject: [PATCH 6/8] dict support --- flair/safetensors_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/flair/safetensors_utils.py b/flair/safetensors_utils.py index 80b1c826ef..dc006d0d43 100644 --- a/flair/safetensors_utils.py +++ b/flair/safetensors_utils.py @@ -126,6 +126,19 @@ def _json_serializer(obj: Any) -> Any: return {"__torch_device__": str(obj)} if isinstance(obj, type): return {"__class__": f"{obj.__module__}.{obj.__name__}"} + if isinstance(obj, torch.nn.Module): + raise TypeError( + f"Cannot serialize torch.nn.Module '{type(obj).__name__}' to safetensors. " + "Models with custom decoders should use pickle format: model.save(path, use_safetensors=False)" + ) + if type(obj).__name__ == "Dictionary" and hasattr(obj, "idx2item"): + return { + "__flair_dictionary__": True, + "idx2item": [item.decode("utf-8") for item in obj.idx2item], + "add_unk": obj.add_unk, + "multi_label": getattr(obj, "multi_label", False), + "span_labels": getattr(obj, "span_labels", False), + } if hasattr(obj, "to_dict"): return obj.to_dict() raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") @@ -136,6 +149,16 @@ def _json_deserializer(obj: dict) -> Any: return base64.b64decode(obj["__bytes__"]) if "__bytesio__" in obj: return BytesIO(base64.b64decode(obj["__bytesio__"])) + if "__flair_dictionary__" in obj: + from flair.data import Dictionary + + d = Dictionary(add_unk=False) + for item in obj["idx2item"]: + d.add_item(item) + d.add_unk = obj.get("add_unk", True) + d.multi_label = obj.get("multi_label", False) + d.span_labels = obj.get("span_labels", False) + return d if "__torch_dtype__" in obj: dtype_str = obj["__torch_dtype__"] dtype_map = { From 887c384616bf2e1e8d2bd3f359769a614c3e2ae4 Mon Sep 17 00:00:00 2001 From: antonallote Date: Wed, 4 Feb 2026 20:46:16 +0100 Subject: [PATCH 7/8] unit testds --- tests/test_safetensors_utils.py | 43 +++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/test_safetensors_utils.py diff --git a/tests/test_safetensors_utils.py b/tests/test_safetensors_utils.py new file mode 100644 index 0000000000..1a3319fea7 --- /dev/null +++ b/tests/test_safetensors_utils.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from flair.safetensors_utils import SafetensorsSerializer + + +@pytest.mark.parametrize( + "state_dict", + [ + {"weight": torch.randn(3, 3)}, + {"a": torch.randn(2), "b": {"nested": torch.randn(4)}}, + {"config": "value", "tensor": torch.randn(5)}, + ], +) +def test_save_load_roundtrip(tmp_path, state_dict): + """Verify state dict keys are preserved through save/load cycle.""" + SafetensorsSerializer.save(state_dict, tmp_path / "model") + loaded = SafetensorsSerializer.load(tmp_path / "model") + assert loaded.keys() == state_dict.keys() + + +@pytest.mark.parametrize( + "state_dict", + [ + {"weight": torch.randn(3, 3)}, + {"nested": {"deep": {"tensor": torch.randn(2)}}}, + ], +) +def test_tensors_preserved(tmp_path, state_dict): + SafetensorsSerializer.save(state_dict, tmp_path / "model") + loaded = SafetensorsSerializer.load(tmp_path / "model") + original_tensor = list(state_dict.values())[0] + loaded_tensor = list(loaded.values())[0] + while isinstance(original_tensor, dict): + original_tensor = list(original_tensor.values())[0] + loaded_tensor = list(loaded_tensor.values())[0] + assert torch.allclose(original_tensor, loaded_tensor) + + +def test_is_safetensors_model(tmp_path): + assert not SafetensorsSerializer.is_safetensors_model(tmp_path / "nonexistent") + SafetensorsSerializer.save({"w": torch.randn(2)}, tmp_path / "model") + assert SafetensorsSerializer.is_safetensors_model(tmp_path / "model") From ef0a33276c9982c5643c694001f7597162b0d980 Mon Sep 17 00:00:00 2001 From: antonallote Date: Wed, 4 Feb 2026 20:47:23 +0100 Subject: [PATCH 8/8] docstrings --- tests/test_safetensors_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_safetensors_utils.py b/tests/test_safetensors_utils.py index 1a3319fea7..86b24e9e2c 100644 --- a/tests/test_safetensors_utils.py +++ b/tests/test_safetensors_utils.py @@ -27,6 +27,7 @@ def test_save_load_roundtrip(tmp_path, state_dict): ], ) def test_tensors_preserved(tmp_path, state_dict): + """Verify tensor values remain numerically identical after serialization.""" SafetensorsSerializer.save(state_dict, tmp_path / "model") loaded = SafetensorsSerializer.load(tmp_path / "model") original_tensor = list(state_dict.values())[0] @@ -38,6 +39,7 @@ def test_tensors_preserved(tmp_path, state_dict): def test_is_safetensors_model(tmp_path): + """Verify safetensors format detection distinguishes valid models from missing paths.""" assert not SafetensorsSerializer.is_safetensors_model(tmp_path / "nonexistent") SafetensorsSerializer.save({"w": torch.randn(2)}, tmp_path / "model") assert SafetensorsSerializer.is_safetensors_model(tmp_path / "model")