diff --git a/torchbenchmark/models/nanogpt/model.py b/torchbenchmark/models/nanogpt/model.py
index 3f01aa6e1b..d8d87c3758 100644
--- a/torchbenchmark/models/nanogpt/model.py
+++ b/torchbenchmark/models/nanogpt/model.py
@@ -349,7 +349,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
)
# Create AdamW optimizer and use the fused version if it is available
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
- use_fused = fused_available and device_type == "cuda"
+ use_fused = fused_available and device_type in ["cuda", "xpu"]
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(
optim_groups, lr=learning_rate, betas=betas, **extra_args
diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py
index 16c4840256..15c4bd62e6 100644
--- a/torchbenchmark/util/extra_args.py
+++ b/torchbenchmark/util/extra_args.py
@@ -37,9 +37,9 @@ def check_precision(
if precision == "bypass":
return True
if precision == "fp16":
- return model.device == "cuda" and hasattr(model, "enable_fp16")
+ return model.device in ["cuda", "xpu"] and hasattr(model, "enable_fp16")
if precision == "tf32":
- return model.device == "cuda"
+ return model.device in ["cuda", "xpu"]
if precision == "amp":
return True
if precision == "fx_int8":
@@ -47,9 +47,9 @@ def check_precision(
if precision == "bf16":
return True
if precision == "amp_fp16":
- if model.test == "eval" and model.device == "cuda":
+ if model.test == "eval" and model.device in ["cuda", "xpu"]:
return True
- if model.test == "train" and model.device == "cuda":
+ if model.test == "train" and model.device in ["cuda", "xpu"]:
return hasattr(model, "enable_amp") or is_staged_train_test(model)
if precision == "amp_bf16":
if model.test == "eval" and model.device == "cpu":
@@ -87,13 +87,13 @@ def get_precision_default(model: "torchbenchmark.util.model.BenchmarkModel") ->
if (
hasattr(model, "DEFAULT_EVAL_CUDA_PRECISION")
and model.test == "eval"
- and model.device == "cuda"
+ and model.device in ["cuda", "xpu"]
):
return model.DEFAULT_EVAL_CUDA_PRECISION
if (
hasattr(model, "DEFAULT_TRAIN_CUDA_PRECISION")
and model.test == "train"
- and model.device == "cuda"
+ and model.device in ["cuda", "xpu"]
):
return model.DEFAULT_TRAIN_CUDA_PRECISION
if hasattr(model, "DEFAULT_PRECISION"):