We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4c6e25d commit 08bcb94Copy full SHA for 08bcb94
app/utils.py
@@ -53,7 +53,12 @@ def load_model(version: str = "v2"):
53
download_from_s3(s3, f"{S3_BASE_PATH}/{version}/quantized_model.pth", local_quant_path)
54
55
logger.info("⚙️ Loading full quantized DistilBERT model...")
56
- model = torch.load(local_quant_path, map_location="cpu")
+ # initialize model with same architecture
57
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
58
+ # load weights
59
+ state_dict = torch.load(local_quant_path, map_location="cpu")
60
+ model.load_state_dict(state_dict)
61
+
62
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
63
64
else:
0 commit comments