Skip to content

Commit ffcf384

Browse files
delocktjruwase
andauthored
Abstract accelerator (step 1) (#2504)
* Establish building block of abstract accelerator * Change .*Tensor variable to @Property * [op builder] add op builder reflection to allow enumerate of builders in all_ops.py and builder_names.py * change @abstractproperty to @Property @AbstractMethod Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent c5f8585 commit ffcf384

File tree

7 files changed

+605
-0
lines changed

7 files changed

+605
-0
lines changed

deepspeed/accelerator/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .abstract_accelerator import DeepSpeedAccelerator
2+
from .real_accelerator import get_accelerator, set_accelerator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import abc
2+
from abc import ABC
3+
4+
5+
class DeepSpeedAccelerator(ABC):
6+
def __init__(self):
7+
self._name = None
8+
self._communication_backend_name = None
9+
10+
# Device APIs
11+
@abc.abstractmethod
12+
def device_name(self, device_index):
13+
...
14+
15+
@abc.abstractmethod
16+
def device(self, device_index):
17+
...
18+
19+
@abc.abstractmethod
20+
def set_device(self, device_index):
21+
...
22+
23+
@abc.abstractmethod
24+
def current_device(self):
25+
...
26+
27+
@abc.abstractmethod
28+
def current_device_name(self):
29+
...
30+
31+
@abc.abstractmethod
32+
def device_count(self):
33+
...
34+
35+
@abc.abstractmethod
36+
def synchronize(self, device_index=None):
37+
...
38+
39+
# RNG APIs
40+
@abc.abstractmethod
41+
def random(self):
42+
...
43+
44+
@abc.abstractmethod
45+
def set_rng_state(self, new_state, device_index=None):
46+
...
47+
48+
@abc.abstractmethod
49+
def get_rng_state(self, device_index=None):
50+
...
51+
52+
@abc.abstractmethod
53+
def manual_seed(self, seed):
54+
...
55+
56+
@abc.abstractmethod
57+
def manual_seed_all(self, seed):
58+
...
59+
60+
@abc.abstractmethod
61+
def initial_seed(self, seed):
62+
...
63+
64+
@abc.abstractmethod
65+
def default_generator(self, device_index):
66+
...
67+
68+
# Streams/Events
69+
@abc.abstractmethod
70+
def Stream(self, device=None, priority=0, **kwargs):
71+
...
72+
73+
@abc.abstractmethod
74+
def StreamContext(self, stream):
75+
...
76+
77+
@abc.abstractmethod
78+
def stream(self, stream):
79+
...
80+
81+
@abc.abstractmethod
82+
def current_stream(self, device_index=None):
83+
...
84+
85+
@abc.abstractmethod
86+
def default_stream(self, device_index=None):
87+
...
88+
89+
@abc.abstractmethod
90+
def Event(self, **kwargs):
91+
...
92+
93+
# Memory management
94+
@abc.abstractmethod
95+
def empty_cache(self):
96+
...
97+
98+
@abc.abstractmethod
99+
def memory_allocated(self, device_index=None):
100+
...
101+
102+
@abc.abstractmethod
103+
def max_memory_allocated(self, device_index=None):
104+
...
105+
106+
@abc.abstractmethod
107+
def reset_max_memory_allocated(self, device_index=None):
108+
...
109+
110+
@abc.abstractmethod
111+
def memory_cached(self, device_index=None):
112+
...
113+
114+
@abc.abstractmethod
115+
def max_memory_cached(self, device_index=None):
116+
...
117+
118+
@abc.abstractmethod
119+
def reset_max_memory_cached(self, device_index=None):
120+
...
121+
122+
@abc.abstractmethod
123+
def memory_stats(self, device_index=None):
124+
...
125+
126+
@abc.abstractmethod
127+
def reset_peak_memory_stats(self, device_index=None):
128+
...
129+
130+
@abc.abstractmethod
131+
def memory_reserved(self, device_index=None):
132+
...
133+
134+
@abc.abstractmethod
135+
def max_memory_reserved(self, device_index=None):
136+
...
137+
138+
@abc.abstractmethod
139+
def total_memory(self, device_index=None):
140+
...
141+
142+
# Data types
143+
@abc.abstractmethod
144+
def is_bf16_supported(self):
145+
...
146+
147+
@abc.abstractmethod
148+
def is_fp16_supported(self):
149+
...
150+
151+
# Misc
152+
@abc.abstractmethod
153+
def amp(self):
154+
...
155+
156+
@abc.abstractmethod
157+
def is_available(self):
158+
...
159+
160+
@abc.abstractmethod
161+
def range_push(self, msg):
162+
...
163+
164+
@abc.abstractmethod
165+
def range_pop(self):
166+
...
167+
168+
@abc.abstractmethod
169+
def lazy_call(self, callback):
170+
...
171+
172+
@abc.abstractmethod
173+
def communication_backend_name(self):
174+
...
175+
176+
# Tensor operations
177+
@property
178+
@abc.abstractmethod
179+
def BFloat16Tensor(self):
180+
...
181+
182+
@property
183+
@abc.abstractmethod
184+
def ByteTensor(self):
185+
...
186+
187+
@property
188+
@abc.abstractmethod
189+
def DoubleTensor(self):
190+
...
191+
192+
@property
193+
@abc.abstractmethod
194+
def FloatTensor(self):
195+
...
196+
197+
@property
198+
@abc.abstractmethod
199+
def HalfTensor(self):
200+
...
201+
202+
@property
203+
@abc.abstractmethod
204+
def IntTensor(self):
205+
...
206+
207+
@property
208+
@abc.abstractmethod
209+
def LongTensor(self):
210+
...
211+
212+
@abc.abstractmethod
213+
def pin_memory(self, tensor):
214+
...
215+
216+
@abc.abstractmethod
217+
def on_accelerator(self, tensor):
218+
...
219+
220+
@abc.abstractmethod
221+
def op_builder_dir(self):
222+
...
223+
224+
@abc.abstractmethod
225+
def create_op_builder(self, class_name):
226+
...
227+
228+
@abc.abstractmethod
229+
def build_extension(self):
230+
...

0 commit comments

Comments
 (0)