1212from collections import defaultdict
1313from datetime import timedelta
1414from functools import lru_cache
15- from typing import TYPE_CHECKING , Any , BinaryIO , NamedTuple
15+ from typing import TYPE_CHECKING , Annotated , Any , BinaryIO , NamedTuple
1616
1717import numpy as np
1818import requests
1919import torch
2020import torch .distributed as dist
2121import zmq
2222from loguru import logger
23- from pydantic import BaseModel , ConfigDict
23+ from pydantic import BaseModel , ConfigDict , PlainSerializer , PlainValidator , WithJsonSchema
2424from safetensors .torch import safe_open
2525from torch .multiprocessing .reductions import reduce_tensor
2626
@@ -38,16 +38,44 @@ class FileMeta(TypedDict):
3838 tp_concat_dim : int
3939
4040
41- class ParameterMeta (BaseModel ):
42- # now all classes are changed to pydantic BaseModel
43- # it will directly report validation errors for unknown types
44- # like torch.dtype, torch.Size, so we need this configuration
45- # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
46- model_config = ConfigDict (arbitrary_types_allowed = True )
41+ def _dt_validate (value : Any ) -> torch .dtype :
42+ if isinstance (value , str ):
43+ assert value .startswith ("torch." ), f"dtype { value } should start with torch."
44+ try :
45+ value = getattr (torch , value .split ("." )[1 ])
46+ except AttributeError as e :
47+ raise ValueError (f"unknown dtype: { value } " ) from e
48+ assert isinstance (value , torch .dtype ), f"dtype { value } should be torch.dtype, got { type (value )} "
49+ return value
50+
51+
52+ _TorchDtype = Annotated [
53+ torch .dtype ,
54+ PlainValidator (_dt_validate ),
55+ PlainSerializer (lambda x : str (x ), return_type = str ),
56+ WithJsonSchema ({"type" : "string" }, mode = "serialization" ),
57+ ]
58+
4759
60+ def _size_validate (value : Any ) -> torch .Size :
61+ if isinstance (value , list | tuple ):
62+ return torch .Size (value )
63+ assert isinstance (value , torch .Size ), f"size { value } should be torch.Size, got { type (value )} "
64+ return value
65+
66+
67+ _TorchSize = Annotated [
68+ torch .Size ,
69+ PlainValidator (_size_validate ),
70+ PlainSerializer (lambda x : tuple (x ), return_type = tuple ),
71+ WithJsonSchema ({"type" : "array" , "items" : {"type" : "integer" }}, mode = "serialization" ),
72+ ]
73+
74+
75+ class ParameterMeta (BaseModel ):
4876 name : str
49- dtype : torch . dtype
50- shape : torch . Size
77+ dtype : _TorchDtype
78+ shape : _TorchSize
5179
5280
5381class BucketRange (NamedTuple ):
0 commit comments