Skip to content

Commit efb8907

Browse files
feat: make ParameterMeta JSON serializable (#9)
1 parent 6062db9 commit efb8907

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

checkpoint_engine/ps.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from collections import defaultdict
1313
from datetime import timedelta
1414
from functools import lru_cache
15-
from typing import TYPE_CHECKING, Any, BinaryIO, NamedTuple
15+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
1616

1717
import numpy as np
1818
import requests
1919
import torch
2020
import torch.distributed as dist
2121
import zmq
2222
from loguru import logger
23-
from pydantic import BaseModel, ConfigDict
23+
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema
2424
from safetensors.torch import safe_open
2525
from torch.multiprocessing.reductions import reduce_tensor
2626

@@ -38,16 +38,47 @@ class FileMeta(TypedDict):
3838
tp_concat_dim: int
3939

4040

41-
class ParameterMeta(BaseModel):
42-
# now all classes are changed to pydantic BaseModel
43-
# it will directly report validation errors for unknown types
44-
# like torch.dtype, torch.Size, so we need this configuration
45-
# see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
46-
model_config = ConfigDict(arbitrary_types_allowed=True)
41+
def _dt_validate(value: Any) -> torch.dtype:
42+
if isinstance(value, str):
43+
if not value.startswith("torch."):
44+
raise ValueError(f"dtype {value} should start with torch.")
45+
try:
46+
value = getattr(torch, value.split(".")[1])
47+
except AttributeError as e:
48+
raise ValueError(f"unknown dtype: {value}") from e
49+
if not isinstance(value, torch.dtype):
50+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
51+
return value
52+
53+
54+
_TorchDtype = Annotated[
55+
torch.dtype,
56+
PlainValidator(_dt_validate),
57+
PlainSerializer(lambda x: str(x), return_type=str),
58+
WithJsonSchema({"type": "string"}, mode="serialization"),
59+
]
60+
4761

62+
def _size_validate(value: Any) -> torch.Size:
63+
if isinstance(value, list | tuple):
64+
return torch.Size(value)
65+
if not isinstance(value, torch.Size):
66+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
67+
return value
68+
69+
70+
_TorchSize = Annotated[
71+
torch.Size,
72+
PlainValidator(_size_validate),
73+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
74+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
75+
]
76+
77+
78+
class ParameterMeta(BaseModel):
4879
name: str
49-
dtype: torch.dtype
50-
shape: torch.Size
80+
dtype: _TorchDtype
81+
shape: _TorchSize
5182

5283

5384
class BucketRange(NamedTuple):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ requires-python = ">=3.10"
88
dependencies = [
99
"torch>=2.5.0",
1010
"fastapi",
11-
"pydantic",
11+
"pydantic>=2.0.0",
1212
"safetensors",
1313
"pyzmq",
1414
"uvicorn",
1515
"loguru",
1616
"numpy",
17+
"requests",
1718
]
1819

1920
[project.optional-dependencies]

0 commit comments

Comments
 (0)