Skip to content

Commit 3315415

Browse files
authored
Merge pull request #10 from philnach/fix/pytorch-inf-bfloat16-memory
Switch pytorch_inf from float32 to auto to fix OOM on 16GB machines
2 parents dea3b07 + acfa7d6 commit 3315415

4 files changed

Lines changed: 6 additions & 6 deletions

File tree

scenarios/macos/mac_pytorch_inf/mac_pytorch_inf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class MacPytorchInf(core.app_scenario.Scenario):
1515

1616
module = __module__.split('.')[-1]
17-
prep_version = "5"
17+
prep_version = "6"
1818
resources = module + "_resources"
1919

2020

scenarios/macos/mac_pytorch_inf/mac_pytorch_inf_resources/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def setup_model(model_name, device):
185185
print("Downloading model...")
186186
model = AutoModelForCausalLM.from_pretrained(
187187
model_name,
188-
dtype=torch.float16 if device == 'cuda' else torch.float32,
188+
torch_dtype="auto",
189189
device_map="auto" if device == 'cuda' else None
190190
)
191191
# model.resize_token_embeddings(len(tokenizer))
@@ -371,7 +371,7 @@ def main():
371371
print("Loading model...")
372372
model = AutoModelForCausalLM.from_pretrained(
373373
model_name,
374-
dtype=torch.float16 if device == 'cuda' else torch.float32,
374+
torch_dtype="auto",
375375
device_map="auto" if device == 'cuda' else None
376376
)
377377
model.resize_token_embeddings(len(tokenizer))

scenarios/windows/pytorch_inf/pytorch_inf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class PytorchInf(core.app_scenario.Scenario):
1515

1616
module = __module__.split('.')[-1]
17-
prep_version = "9"
17+
prep_version = "10"
1818
# prep_scenarios = [(module, prep_version)]
1919
resources = module + "_resources"
2020

scenarios/windows/pytorch_inf/pytorch_inf_resources/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def setup_model(model_name, device):
188188
print("Downloading model...")
189189
model = AutoModelForCausalLM.from_pretrained(
190190
model_name,
191-
dtype=torch.float16 if device == 'cuda' else torch.float32,
191+
torch_dtype="auto",
192192
device_map="auto" if device == 'cuda' else None
193193
)
194194
# model.resize_token_embeddings(len(tokenizer))
@@ -374,7 +374,7 @@ def main():
374374
print("Loading model...")
375375
model = AutoModelForCausalLM.from_pretrained(
376376
model_name,
377-
dtype=torch.float16 if device == 'cuda' else torch.float32,
377+
torch_dtype="auto",
378378
device_map="auto" if device == 'cuda' else None
379379
)
380380
model.resize_token_embeddings(len(tokenizer))

0 commit comments

Comments
 (0)