Skip to content

Commit 0cc2dd9

Browse files
committed
load correct model
1 parent 1f8433c commit 0cc2dd9

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

app/utils.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def load_model(version: str = "v2"):
2121
"""
2222
Loads a model version from S3.
2323
- v1: full Hugging Face model (config + tokenizer + weights)
24-
- v2: quantized PyTorch model (.pth)
24+
- v2: quantized PyTorch model (entire model saved)
2525
"""
2626
logger.info(f"🔍 Loading model version: {version}")
2727
s3 = boto3.client("s3")
@@ -30,7 +30,6 @@ def load_model(version: str = "v2"):
3030
os.makedirs(local_model_path, exist_ok=True)
3131

3232
if version == "v1":
33-
# Download all required files from S3 if not already cached
3433
files = [
3534
"config.json",
3635
"model.safetensors",
@@ -48,19 +47,13 @@ def load_model(version: str = "v2"):
4847
model = DistilBertForSequenceClassification.from_pretrained(local_model_path)
4948

5049
elif version == "v2":
51-
# v2: quantized model (state dict only)
52-
local_quant_path = os.path.join(local_model_path, "quantized_model.pth")
53-
50+
local_quant_path = os.path.join(local_model_path, "quantized_model_full.pth")
5451
if not os.path.exists(local_quant_path):
55-
logger.info("📥 Downloading quantized model from S3...")
56-
download_from_s3(s3, f"{S3_BASE_PATH}/{version}/quantized_model.pth", local_quant_path)
57-
58-
logger.info("⚙️ Loading DistilBERT config and applying quantized weights...")
59-
config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2)
60-
model = DistilBertForSequenceClassification(config)
61-
state_dict = torch.load(local_quant_path, map_location="cpu")
62-
model.load_state_dict(state_dict)
52+
logger.info("📥 Downloading full quantized model from S3...")
53+
download_from_s3(s3, f"{S3_BASE_PATH}/{version}/quantized_model_full.pth", local_quant_path)
6354

55+
logger.info("⚙️ Loading full quantized DistilBERT model...")
56+
model = torch.load(local_quant_path, map_location="cpu")
6457
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
6558

6659
else:
@@ -72,6 +65,7 @@ def load_model(version: str = "v2"):
7265

7366

7467

68+
7569
def predict_text(model, tokenizer, text: str):
7670
"""Run inference with automatic handling for both model types."""
7771
try:

0 commit comments

Comments
 (0)