Skip to content

Commit 34c1e29

Browse files
authored
enable autoround cases on XPU (#38167)
* enable autoround cases on XPU Signed-off-by: Matrix Yao <[email protected]> * fix style Signed-off-by: Matrix Yao <[email protected]> --------- Signed-off-by: Matrix Yao <[email protected]>
1 parent 0f77ca7 commit 34c1e29

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

src/transformers/testing_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,6 +3013,11 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
30133013
"cpu": 0,
30143014
"default": 0,
30153015
}
3016+
BACKEND_SYNCHRONIZE = {
3017+
"cuda": torch.cuda.synchronize,
3018+
"cpu": None,
3019+
"default": None,
3020+
}
30163021
BACKEND_TORCH_ACCELERATOR_MODULE = {
30173022
"cuda": torch.cuda,
30183023
"cpu": None,
@@ -3025,6 +3030,7 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
30253030
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
30263031
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
30273032
BACKEND_MEMORY_ALLOCATED = {"default": 0}
3033+
BACKEND_SYNCHRONIZE = {"default": None}
30283034
BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
30293035

30303036

@@ -3052,6 +3058,7 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
30523058
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
30533059
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
30543060
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
3061+
BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
30553062
BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
30563063

30573064

@@ -3085,6 +3092,10 @@ def backend_memory_allocated(device: str):
30853092
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
30863093

30873094

3095+
def backend_synchronize(device: str):
3096+
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
3097+
3098+
30883099
def backend_torch_accelerator_module(device: str):
30893100
return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
30903101

tests/quantization/autoround/test_auto_round.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
1919
from transformers.testing_utils import (
20+
backend_empty_cache,
21+
backend_synchronize,
2022
require_accelerate,
2123
require_auto_round,
2224
require_intel_extension_for_pytorch,
25+
require_torch_accelerator,
2326
require_torch_gpu,
24-
require_torch_multi_gpu,
27+
require_torch_multi_accelerator,
2528
slow,
2629
torch_device,
2730
)
@@ -33,7 +36,7 @@
3336

3437

3538
@slow
36-
@require_torch_gpu
39+
@require_torch_accelerator
3740
@require_auto_round
3841
@require_accelerate
3942
class AutoRoundTest(unittest.TestCase):
@@ -50,24 +53,27 @@ class AutoRoundTest(unittest.TestCase):
5053
EXPECTED_OUTPUTS.add(
5154
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
5255
)
56+
EXPECTED_OUTPUTS.add(
57+
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She has also climbed mountains and explored caves"
58+
)
5359

54-
device_map = "cuda"
60+
device_map = torch_device
5561

5662
# called only once for all test in this class
5763
@classmethod
5864
def setUpClass(cls):
5965
"""
6066
Setup quantized model
6167
"""
62-
torch.cuda.synchronize()
68+
backend_synchronize(torch_device)
6369
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
6470
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
6571
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
6672
)
6773

6874
def tearDown(self):
6975
gc.collect()
70-
torch.cuda.empty_cache()
76+
backend_empty_cache(torch_device)
7177
gc.collect()
7278

7379
def test_quantized_model(self):
@@ -128,14 +134,15 @@ def test_save_pretrained(self):
128134
)
129135

130136
quantized_model.save_pretrained(tmpdirname)
131-
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")
137+
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
132138

133139
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
134140

135141
output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
136-
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
142+
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
143+
self.assertIn(output_tokens, self.EXPECTED_OUTPUTS)
137144

138-
@require_torch_multi_gpu
145+
@require_torch_multi_accelerator
139146
def test_quantized_model_multi_gpu(self):
140147
"""
141148
Simple test that checks if the quantized model is working properly with multiple GPUs
@@ -159,7 +166,7 @@ def test_convert_from_gptq(self):
159166
quantization_config = AutoRoundConfig()
160167

161168
model = AutoModelForCausalLM.from_pretrained(
162-
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
169+
model_name, device_map=torch_device, quantization_config=quantization_config, torch_dtype="auto"
163170
)
164171
tokenizer = AutoTokenizer.from_pretrained(model_name)
165172

@@ -185,6 +192,7 @@ def test_convert_from_awq_cpu(self):
185192
inputs = tokenizer(text, return_tensors="pt").to(model.device)
186193
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
187194

195+
@require_torch_gpu
188196
def test_mixed_bits(self):
189197
"""
190198
Simple test that checks if auto-round work properly with mixed bits
@@ -203,7 +211,9 @@ def test_mixed_bits(self):
203211
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
204212
with tempfile.TemporaryDirectory() as tmpdirname:
205213
autoround.quantize_and_save(output_dir=tmpdirname)
206-
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
214+
model = AutoModelForCausalLM.from_pretrained(
215+
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
216+
)
207217
text = "There is a girl who likes adventure,"
208218
inputs = tokenizer(text, return_tensors="pt").to(model.device)
209219
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])

0 commit comments

Comments
 (0)