Skip to content

Commit e82869c

Browse files
committed
feat: change assert to raise
1 parent b7d7b83 commit e82869c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ class FileMeta(TypedDict):
4040

4141
def _dt_validate(value: Any) -> torch.dtype:
4242
if isinstance(value, str):
43-
assert value.startswith("torch."), f"dtype {value} should start with torch."
43+
if not value.startswith("torch."):
44+
raise ValueError(f"dtype {value} should start with torch.")
4445
try:
4546
value = getattr(torch, value.split(".")[1])
4647
except AttributeError as e:
4748
raise ValueError(f"unknown dtype: {value}") from e
48-
assert isinstance(value, torch.dtype), f"dtype {value} should be torch.dtype, got {type(value)}"
49+
if not isinstance(value, torch.dtype):
50+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
4951
return value
5052

5153

@@ -60,7 +62,8 @@ def _dt_validate(value: Any) -> torch.dtype:
6062
def _size_validate(value: Any) -> torch.Size:
6163
if isinstance(value, list | tuple):
6264
return torch.Size(value)
63-
assert isinstance(value, torch.Size), f"size {value} should be torch.Size, got {type(value)}"
65+
if not isinstance(value, torch.Size):
66+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
6467
return value
6568

6669

0 commit comments

Comments
 (0)