1212from collections import defaultdict
1313from datetime import timedelta
1414from functools import lru_cache
15- from typing import TYPE_CHECKING , Any , BinaryIO , NamedTuple
15+ from typing import TYPE_CHECKING , Annotated , Any , BinaryIO , NamedTuple
1616
1717import numpy as np
1818import requests
1919import torch
2020import torch .distributed as dist
2121import zmq
2222from loguru import logger
23- from pydantic import BaseModel , ConfigDict
23+ from pydantic import BaseModel , ConfigDict , PlainSerializer , PlainValidator , WithJsonSchema
2424from safetensors .torch import safe_open
2525from torch .multiprocessing .reductions import reduce_tensor
2626
@@ -38,16 +38,47 @@ class FileMeta(TypedDict):
3838 tp_concat_dim : int
3939
4040
41- class ParameterMeta (BaseModel ):
42- # now all classes are changed to pydantic BaseModel
43- # it will directly report validation errors for unknown types
44- # like torch.dtype, torch.Size, so we need this configuration
45- # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
46- model_config = ConfigDict (arbitrary_types_allowed = True )
41+ def _dt_validate (value : Any ) -> torch .dtype :
42+ if isinstance (value , str ):
43+ if not value .startswith ("torch." ):
44+ raise ValueError (f"dtype { value } should start with torch." )
45+ try :
46+ value = getattr (torch , value .split ("." )[1 ])
47+ except AttributeError as e :
48+ raise ValueError (f"unknown dtype: { value } " ) from e
49+ if not isinstance (value , torch .dtype ):
50+ raise TypeError (f"dtype { value } should be torch.dtype, got { type (value )} " )
51+ return value
52+
53+
54+ _TorchDtype = Annotated [
55+ torch .dtype ,
56+ PlainValidator (_dt_validate ),
57+ PlainSerializer (lambda x : str (x ), return_type = str ),
58+ WithJsonSchema ({"type" : "string" }, mode = "serialization" ),
59+ ]
60+
4761
62+ def _size_validate (value : Any ) -> torch .Size :
63+ if isinstance (value , list | tuple ):
64+ return torch .Size (value )
65+ if not isinstance (value , torch .Size ):
66+ raise TypeError (f"size { value } should be torch.Size, got { type (value )} " )
67+ return value
68+
69+
70+ _TorchSize = Annotated [
71+ torch .Size ,
72+ PlainValidator (_size_validate ),
73+ PlainSerializer (lambda x : tuple (x ), return_type = tuple ),
74+ WithJsonSchema ({"type" : "array" , "items" : {"type" : "integer" }}, mode = "serialization" ),
75+ ]
76+
77+
78+ class ParameterMeta (BaseModel ):
4879 name : str
49- dtype : torch . dtype
50- shape : torch . Size
80+ dtype : _TorchDtype
81+ shape : _TorchSize
5182
5283
5384class BucketRange (NamedTuple ):
0 commit comments