Skip to content

Commit 0e0f19a

Browse files
committed
Refactor model loading to use a dtype map for better device compatibility
1 parent b9f4e80 commit 0e0f19a

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

app.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,15 @@
3737
# Load Nari model and config
3838
print("Loading Nari model...")
3939
try:
40-
if device.type == "cpu":
41-
# CPU performs better on float32
42-
print(f"Using device: {device}, attempting to load model with float32")
43-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32", device=device)
44-
elif device.type == "mps":
45-
# MPS (Apple Silicon) prefers float32 due to poor float16 support
46-
print(f"Using device: {device}, attempting to load model with float32")
47-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32", device=device)
48-
elif device.type == "cuda":
49-
# CUDA (NVIDIA) benefits from float16
50-
print(f"Using device: {device}, attempting to load model with float16")
51-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device)
52-
else:
53-
# Fallback
54-
print(f"Unknown device type '{device.type}', defaulting to float16")
55-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device)
40+
dtype_map = {
41+
"cpu": "float32",
42+
"mps": "float32", # Apple M series – better with float32
43+
"cuda": "float16", # NVIDIA – better with float16
44+
}
45+
46+
dtype = dtype_map.get(device.type, "float16")
47+
print(f"Using device: {device}, attempting to load model with {dtype}")
48+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype=dtype, device=device)
5649
except Exception as e:
5750
print(f"Error loading Nari model: {e}")
5851
raise

0 commit comments

Comments
 (0)