Skip to content

Commit 5a8bf3f

Browse files
mmhabtjruwasemrwyattii
authored
Implement some APIs of HPU accelerator (#4935)
Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Michael Wyatt <[email protected]>
1 parent 7d51139 commit 5a8bf3f

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

accelerator/hpu_accelerator.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -108,28 +108,28 @@ def reset_max_memory_allocated(self, device_index=None):
108108
return self.hpu.reset_max_memory_allocated()
109109

110110
def memory_cached(self, device_index=None):
111-
return 0
111+
return self.hpu.memory_cached(device_index)
112112

113113
def max_memory_cached(self, device_index=None):
114-
return 0
114+
return self.hpu.max_memory_cached(device_index)
115115

116116
def reset_max_memory_cached(self, device_index=None):
117-
return 0
117+
return None
118118

119119
def memory_stats(self, device_index=None):
120-
return {}
120+
return self.hpu.memory_stats(device_index)
121121

122122
def reset_peak_memory_stats(self, device_index=None):
123-
self.hpu.reset_peak_memory_stats()
123+
self.hpu.reset_peak_memory_stats(device_index)
124124

125125
def memory_reserved(self, device_index=None):
126-
return 0
126+
return self.hpu.memory_reserved(device_index)
127127

128128
def max_memory_reserved(self, device_index=None):
129-
return 0
129+
return self.hpu.max_memory_reserved(device_index)
130130

131131
def total_memory(self, device_index=None):
132-
return 0
132+
return self.memory_stats(device_index)['Limit']
133133

134134
def available_memory(self, device_index=None):
135135
return self.total_memory(device_index) - self.memory_allocated(device_index)
@@ -186,31 +186,31 @@ def replay_graph(self, graph):
186186
# Tensor operations
187187
@property
188188
def BFloat16Tensor(self):
189-
return torch.hpu.BFloat16Tensor
189+
return self.hpu.BFloat16Tensor
190190

191191
@property
192192
def ByteTensor(self):
193-
return torch.hpu.ByteTensor
193+
return self.hpu.ByteTensor
194194

195195
@property
196196
def DoubleTensor(self):
197-
return torch.hpu.DoubleTensor
197+
return self.hpu.DoubleTensor
198198

199199
@property
200200
def FloatTensor(self):
201-
return torch.hpu.FloatTensor
201+
return self.hpu.FloatTensor
202202

203203
@property
204204
def HalfTensor(self):
205-
return torch.hpu.HalfTensor
205+
return self.hpu.HalfTensor
206206

207207
@property
208208
def IntTensor(self):
209-
return torch.hpu.IntTensor
209+
return self.hpu.IntTensor
210210

211211
@property
212212
def LongTensor(self):
213-
return torch.hpu.LongTensor
213+
return self.hpu.LongTensor
214214

215215
def pin_memory(self, tensor, align_bytes=1):
216216
return tensor.pin_memory(self.device())

0 commit comments

Comments
 (0)