Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import logging
import typing
import warnings
from abc import ABC, abstractmethod
from collections import Counter
from pathlib import Path
Expand All @@ -22,9 +23,23 @@
from flair.embeddings import Embeddings
from flair.embeddings.base import load_embeddings
from flair.file_utils import Tqdm, load_torch_state
from flair.safetensors_utils import SafetensorsSerializer
from flair.training_utils import EmbeddingStorageMode, Result, store_embeddings
import importlib


def _load_state(model_path: Union[str, Path]) -> dict[str, Any]:
path = Path(model_path)
if SafetensorsSerializer.is_safetensors_model(path):
return SafetensorsSerializer.load(path)
warnings.warn(
"Loading model from pickle format. Pickle is deprecated due to security concerns. "
"Consider re-saving the model with model.save() to convert to safetensors format.",
FutureWarning,
stacklevel=3,
)
return load_torch_state(str(path))

log = logging.getLogger("flair")


Expand Down Expand Up @@ -273,21 +288,28 @@ def _fetch_model(model_identifier: str):
"""
return model_identifier

def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
def save(
self,
model_file: Union[str, Path],
checkpoint: bool = False,
use_safetensors: bool = True,
) -> None:
"""Saves the current model to the provided file.

Args:
model_file: The model file.
model_file: The model file path. For safetensors format, this will be a directory.
checkpoint: This parameter is currently unused.
use_safetensors: If True (default), save using safetensors format. If False, use pickle.
"""
model_state = self._get_state_dict()

# write out a "model card" if one is set
if self.model_card is not None:
model_state["model_card"] = self.model_card

# save model
torch.save(model_state, str(model_file), pickle_protocol=4)
if use_safetensors:
SafetensorsSerializer.save(model_state, model_file)
else:
torch.save(model_state, str(model_file), pickle_protocol=4)

@property
def license_info(self) -> str:
Expand Down Expand Up @@ -337,7 +359,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":

# if the model cannot be fetched, load as a file
try:
state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path))
state = model_path if isinstance(model_path, dict) else _load_state(model_path)
except Exception:
log.error("-" * 80)
log.error(
Expand Down Expand Up @@ -371,7 +393,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
# if this class is not abstract, fetch the model and load it
if not isinstance(model_path, dict):
model_file = cls._fetch_model(str(model_path))
state = load_torch_state(model_file)
state = _load_state(model_file)
else:
state = model_path

Expand Down
221 changes: 221 additions & 0 deletions flair/safetensors_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any, Union

import torch
from safetensors.torch import load_file as safetensors_load_file
from safetensors.torch import save_file as safetensors_save_file


def _is_tensor(value: Any) -> bool:
return isinstance(value, torch.Tensor)


def _flatten_dict(
d: dict[str, Any], parent_key: str = "", sep: str = "."
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
tensors: dict[str, torch.Tensor] = {}
metadata: dict[str, Any] = {}

for key, value in d.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else key

if _is_tensor(value):
tensors[new_key] = value
metadata[key] = {"__tensor_key__": new_key}
elif isinstance(value, dict):
nested_tensors, nested_metadata = _flatten_dict(value, new_key, sep)
tensors.update(nested_tensors)
metadata[key] = nested_metadata
elif isinstance(value, list):
processed_list, list_tensors = _process_list(value, new_key, sep)
tensors.update(list_tensors)
metadata[key] = processed_list
else:
metadata[key] = value

return tensors, metadata


def _process_list(
lst: list, parent_key: str, sep: str
) -> tuple[list, dict[str, torch.Tensor]]:
tensors: dict[str, torch.Tensor] = {}
processed: list = []

for i, item in enumerate(lst):
item_key = f"{parent_key}{sep}{i}"

if _is_tensor(item):
tensors[item_key] = item
processed.append({"__tensor_key__": item_key})
elif isinstance(item, dict):
nested_tensors, nested_metadata = _flatten_dict(item, item_key, sep)
tensors.update(nested_tensors)
processed.append(nested_metadata)
elif isinstance(item, list):
nested_list, nested_tensors = _process_list(item, item_key, sep)
tensors.update(nested_tensors)
processed.append(nested_list)
else:
processed.append(item)

return processed, tensors


def _unflatten_dict(
metadata: dict[str, Any], tensors: dict[str, torch.Tensor]
) -> dict[str, Any]:
result: dict[str, Any] = {}

for key, value in metadata.items():
if isinstance(value, dict):
if "__tensor_key__" in value:
tensor_key = value["__tensor_key__"]
result[key] = tensors[tensor_key]
else:
result[key] = _unflatten_dict(value, tensors)
elif isinstance(value, list):
result[key] = _unflatten_list(value, tensors)
else:
result[key] = value

return result


def _unflatten_list(lst: list, tensors: dict[str, torch.Tensor]) -> list:
result: list = []

for item in lst:
if isinstance(item, dict):
if "__tensor_key__" in item:
tensor_key = item["__tensor_key__"]
result.append(tensors[tensor_key])
else:
result.append(_unflatten_dict(item, tensors))
elif isinstance(item, list):
result.append(_unflatten_list(item, tensors))
else:
result.append(item)

return result


def separate_tensors_and_metadata(
state_dict: dict[str, Any],
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
return _flatten_dict(state_dict)


def combine_tensors_and_metadata(
tensors: dict[str, torch.Tensor], metadata: dict[str, Any]
) -> dict[str, Any]:
return _unflatten_dict(metadata, tensors)


def _json_serializer(obj: Any) -> Any:
if isinstance(obj, bytes):
return {"__bytes__": base64.b64encode(obj).decode("ascii")}
if isinstance(obj, BytesIO):
return {"__bytesio__": base64.b64encode(obj.getvalue()).decode("ascii")}
if isinstance(obj, torch.dtype):
return {"__torch_dtype__": str(obj)}
if isinstance(obj, torch.device):
return {"__torch_device__": str(obj)}
if isinstance(obj, type):
return {"__class__": f"{obj.__module__}.{obj.__name__}"}
if isinstance(obj, torch.nn.Module):
raise TypeError(
f"Cannot serialize torch.nn.Module '{type(obj).__name__}' to safetensors. "
"Models with custom decoders should use pickle format: model.save(path, use_safetensors=False)"
)
if type(obj).__name__ == "Dictionary" and hasattr(obj, "idx2item"):
return {
"__flair_dictionary__": True,
"idx2item": [item.decode("utf-8") for item in obj.idx2item],
"add_unk": obj.add_unk,
"multi_label": getattr(obj, "multi_label", False),
"span_labels": getattr(obj, "span_labels", False),
}
if hasattr(obj, "to_dict"):
return obj.to_dict()
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def _json_deserializer(obj: dict) -> Any:
if "__bytes__" in obj:
return base64.b64decode(obj["__bytes__"])
if "__bytesio__" in obj:
return BytesIO(base64.b64decode(obj["__bytesio__"]))
if "__flair_dictionary__" in obj:
from flair.data import Dictionary

d = Dictionary(add_unk=False)
for item in obj["idx2item"]:
d.add_item(item)
d.add_unk = obj.get("add_unk", True)
d.multi_label = obj.get("multi_label", False)
d.span_labels = obj.get("span_labels", False)
return d
if "__torch_dtype__" in obj:
dtype_str = obj["__torch_dtype__"]
dtype_map = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
"torch.bool": torch.bool,
}
return dtype_map.get(dtype_str, torch.float32)
if "__torch_device__" in obj:
return torch.device(obj["__torch_device__"])
return obj


class SafetensorsSerializer:
TENSORS_FILENAME = "model.safetensors"
METADATA_FILENAME = "model_metadata.json"

@classmethod
def save(cls, state_dict: dict[str, Any], path: Union[str, Path]) -> None:
path = Path(path)
path.mkdir(parents=True, exist_ok=True)

tensors, metadata = separate_tensors_and_metadata(state_dict)

tensors_path = path / cls.TENSORS_FILENAME
if tensors:
safetensors_save_file(tensors, str(tensors_path))
else:
safetensors_save_file({}, str(tensors_path))

metadata_path = path / cls.METADATA_FILENAME
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, default=_json_serializer)

@classmethod
def load(cls, path: Union[str, Path]) -> dict[str, Any]:
path = Path(path)

tensors_path = path / cls.TENSORS_FILENAME
tensors = safetensors_load_file(str(tensors_path))

metadata_path = path / cls.METADATA_FILENAME
with open(metadata_path, encoding="utf-8") as f:
metadata = json.load(f, object_hook=_json_deserializer)

return combine_tensors_and_metadata(tensors, metadata)

@classmethod
def is_safetensors_model(cls, path: Union[str, Path]) -> bool:
path = Path(path)
if path.is_dir():
return (path / cls.TENSORS_FILENAME).exists() and (path / cls.METADATA_FILENAME).exists()
return False
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pptree>=3.1
python-dateutil>=2.8.2
pytorch_revgrad>=0.2.0
regex>=2022.1.18
safetensors>=0.4.0
scikit-learn>=1.0.2
segtok>=1.5.11
sqlitedict>=2.0.0
Expand Down
45 changes: 45 additions & 0 deletions tests/test_safetensors_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import torch

from flair.safetensors_utils import SafetensorsSerializer


@pytest.mark.parametrize(
"state_dict",
[
{"weight": torch.randn(3, 3)},
{"a": torch.randn(2), "b": {"nested": torch.randn(4)}},
{"config": "value", "tensor": torch.randn(5)},
],
)
def test_save_load_roundtrip(tmp_path, state_dict):
"""Verify state dict keys are preserved through save/load cycle."""
SafetensorsSerializer.save(state_dict, tmp_path / "model")
loaded = SafetensorsSerializer.load(tmp_path / "model")
assert loaded.keys() == state_dict.keys()


@pytest.mark.parametrize(
"state_dict",
[
{"weight": torch.randn(3, 3)},
{"nested": {"deep": {"tensor": torch.randn(2)}}},
],
)
def test_tensors_preserved(tmp_path, state_dict):
"""Verify tensor values remain numerically identical after serialization."""
SafetensorsSerializer.save(state_dict, tmp_path / "model")
loaded = SafetensorsSerializer.load(tmp_path / "model")
original_tensor = list(state_dict.values())[0]
loaded_tensor = list(loaded.values())[0]
while isinstance(original_tensor, dict):
original_tensor = list(original_tensor.values())[0]
loaded_tensor = list(loaded_tensor.values())[0]
assert torch.allclose(original_tensor, loaded_tensor)


def test_is_safetensors_model(tmp_path):
"""Verify safetensors format detection distinguishes valid models from missing paths."""
assert not SafetensorsSerializer.is_safetensors_model(tmp_path / "nonexistent")
SafetensorsSerializer.save({"w": torch.randn(2)}, tmp_path / "model")
assert SafetensorsSerializer.is_safetensors_model(tmp_path / "model")