|
37 | 37 | # Load Nari model and config |
38 | 38 | print("Loading Nari model...") |
39 | 39 | try: |
40 | | - if device.type == "cpu": |
41 | | - # CPU performs better on float32 |
42 | | - print(f"Using device: {device}, attempting to load model with float32") |
43 | | - model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32", device=device) |
44 | | - elif device.type == "mps": |
45 | | - # MPS (Apple Silicon) prefers float32 due to poor float16 support |
46 | | - print(f"Using device: {device}, attempting to load model with float32") |
47 | | - model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32", device=device) |
48 | | - elif device.type == "cuda": |
49 | | - # CUDA (NVIDIA) benefits from float16 |
50 | | - print(f"Using device: {device}, attempting to load model with float16") |
51 | | - model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device) |
52 | | - else: |
53 | | - # Fallback |
54 | | - print(f"Unknown device type '{device.type}', defaulting to float16") |
55 | | - model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device) |
| 40 | + dtype_map = { |
| 41 | + "cpu": "float32", |
| 42 | + "mps": "float32", # Apple M series – better with float32 |
| 43 | + "cuda": "float16", # NVIDIA – better with float16 |
| 44 | + } |
| 45 | + |
| 46 | + dtype = dtype_map.get(device.type, "float16") |
| 47 | + print(f"Using device: {device}, attempting to load model with {dtype}") |
| 48 | + model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype=dtype, device=device) |
56 | 49 | except Exception as e: |
57 | 50 | print(f"Error loading Nari model: {e}") |
58 | 51 | raise |
|
0 commit comments