4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import unittest
7
8
import urllib .error
8
9
from unittest import TestCase
9
10
from unittest .mock import MagicMock
10
-
11
- from torchft .checkpointing import CheckpointServer
11
+ from io import BytesIO
12
+ import torch
13
+ from typing import Tuple
14
+ from checkpointing import CheckpointServer , TensorMetadata , write_state_dict , read_state_dict
12
15
13
16
14
17
class TestCheckpointing (TestCase ):
@@ -33,3 +36,67 @@ def test_checkpoint_server(self) -> None:
33
36
CheckpointServer .load_from_address (addr )
34
37
35
38
server .shutdown ()
39
+
40
+ def setUp (self ):
41
+ self .file = BytesIO ()
42
+
43
+ def test_scalar_tensor (self ):
44
+ tensor = torch .tensor (42 , dtype = torch .int32 )
45
+ state_dict = {'scalar' : tensor }
46
+ write_state_dict (state_dict , self .file )
47
+ self .file .seek (0 )
48
+
49
+ result = read_state_dict (self .file )
50
+ self .assertTrue (torch .equal (result ['scalar' ], tensor ))
51
+
52
+ def test_strided_tensor (self ):
53
+ base_tensor = torch .arange (16 , dtype = torch .float32 ).reshape (4 , 4 )
54
+ strided_tensor = base_tensor [::2 , ::2 ]
55
+ state_dict = {'strided' : strided_tensor }
56
+ write_state_dict (state_dict , self .file )
57
+ self .file .seek (0 )
58
+
59
+ result = read_state_dict (self .file )
60
+ self .assertTrue (torch .equal (result ['strided' ], strided_tensor ))
61
+
62
+ def test_tensor_with_offset (self ):
63
+ base_tensor = torch .arange (10 , dtype = torch .float64 )
64
+ offset_tensor = base_tensor [2 :]
65
+ state_dict = {'offset' : offset_tensor }
66
+ write_state_dict (state_dict , self .file )
67
+ self .file .seek (0 )
68
+
69
+ result = read_state_dict (self .file )
70
+ self .assertTrue (torch .equal (result ['offset' ], offset_tensor ))
71
+
72
+ def test_nested_tensors (self ):
73
+ tensor1 = torch .tensor ([1 , 2 , 3 ], dtype = torch .int32 )
74
+ tensor2 = torch .tensor ([[1.5 , 2.5 ], [3.5 , 4.5 ]], dtype = torch .float64 )
75
+ state_dict = {'nested' : {'tensor1' : tensor1 , 'tensor2' : tensor2 }}
76
+ write_state_dict (state_dict , self .file )
77
+ self .file .seek (0 )
78
+
79
+ result = read_state_dict (self .file )
80
+ self .assertTrue (torch .equal (result ['nested' ]['tensor1' ], tensor1 ))
81
+ self .assertTrue (torch .equal (result ['nested' ]['tensor2' ], tensor2 ))
82
+
83
+ def test_various_data_types (self ):
84
+ tensor_float32 = torch .tensor ([1.0 , 2.0 , 3.0 ], dtype = torch .float32 )
85
+ tensor_int16 = torch .tensor ([1 , 2 , 3 ], dtype = torch .int16 )
86
+ tensor_bool = torch .tensor ([True , False , True ], dtype = torch .bool )
87
+ state_dict = {
88
+ 'float32' : tensor_float32 ,
89
+ 'int16' : tensor_int16 ,
90
+ 'bool' : tensor_bool ,
91
+ }
92
+ write_state_dict (state_dict , self .file )
93
+ self .file .seek (0 )
94
+
95
+ result = read_state_dict (self .file )
96
+ self .assertTrue (torch .equal (result ['float32' ], tensor_float32 ))
97
+ self .assertTrue (torch .equal (result ['int16' ], tensor_int16 ))
98
+ self .assertTrue (torch .equal (result ['bool' ], tensor_bool ))
99
+
100
+
101
+ if __name__ == '__main__' :
102
+ unittest .main ()
0 commit comments