Skip to content

Commit 08bcb94

Browse files
committed
quantized model is state_dict not the actual model. make corrections
1 parent 4c6e25d commit 08bcb94

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

app/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def load_model(version: str = "v2"):
5353
download_from_s3(s3, f"{S3_BASE_PATH}/{version}/quantized_model.pth", local_quant_path)
5454

5555
logger.info("⚙️ Loading full quantized DistilBERT model...")
56-
model = torch.load(local_quant_path, map_location="cpu")
56+
# initialize model with same architecture
57+
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
58+
# load weights
59+
state_dict = torch.load(local_quant_path, map_location="cpu")
60+
model.load_state_dict(state_dict)
61+
5762
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
5863

5964
else:

0 commit comments

Comments
 (0)