Skip to content

Commit b7d7b83

Browse files
committed
feat: make ParameterMeta JSON serializable
1 parent 6062db9 commit b7d7b83

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

checkpoint_engine/ps.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from collections import defaultdict
1313
from datetime import timedelta
1414
from functools import lru_cache
15-
from typing import TYPE_CHECKING, Any, BinaryIO, NamedTuple
15+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
1616

1717
import numpy as np
1818
import requests
1919
import torch
2020
import torch.distributed as dist
2121
import zmq
2222
from loguru import logger
23-
from pydantic import BaseModel, ConfigDict
23+
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema
2424
from safetensors.torch import safe_open
2525
from torch.multiprocessing.reductions import reduce_tensor
2626

@@ -38,16 +38,44 @@ class FileMeta(TypedDict):
3838
tp_concat_dim: int
3939

4040

41-
class ParameterMeta(BaseModel):
42-
# now all classes are changed to pydantic BaseModel
43-
# it will directly report validation errors for unknown types
44-
# like torch.dtype, torch.Size, so we need this configuration
45-
# see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
46-
model_config = ConfigDict(arbitrary_types_allowed=True)
41+
def _dt_validate(value: Any) -> torch.dtype:
42+
if isinstance(value, str):
43+
assert value.startswith("torch."), f"dtype {value} should start with torch."
44+
try:
45+
value = getattr(torch, value.split(".")[1])
46+
except AttributeError as e:
47+
raise ValueError(f"unknown dtype: {value}") from e
48+
assert isinstance(value, torch.dtype), f"dtype {value} should be torch.dtype, got {type(value)}"
49+
return value
50+
51+
52+
_TorchDtype = Annotated[
53+
torch.dtype,
54+
PlainValidator(_dt_validate),
55+
PlainSerializer(lambda x: str(x), return_type=str),
56+
WithJsonSchema({"type": "string"}, mode="serialization"),
57+
]
58+
4759

60+
def _size_validate(value: Any) -> torch.Size:
61+
if isinstance(value, list | tuple):
62+
return torch.Size(value)
63+
assert isinstance(value, torch.Size), f"size {value} should be torch.Size, got {type(value)}"
64+
return value
65+
66+
67+
_TorchSize = Annotated[
68+
torch.Size,
69+
PlainValidator(_size_validate),
70+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
71+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
72+
]
73+
74+
75+
class ParameterMeta(BaseModel):
4876
name: str
49-
dtype: torch.dtype
50-
shape: torch.Size
77+
dtype: _TorchDtype
78+
shape: _TorchSize
5179

5280

5381
class BucketRange(NamedTuple):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ requires-python = ">=3.10"
88
dependencies = [
99
"torch>=2.5.0",
1010
"fastapi",
11-
"pydantic",
11+
"pydantic>=2.0.0",
1212
"safetensors",
1313
"pyzmq",
1414
"uvicorn",
1515
"loguru",
1616
"numpy",
17+
"requests",
1718
]
1819

1920
[project.optional-dependencies]

0 commit comments

Comments
 (0)