@@ -21,7 +21,7 @@ def load_model(version: str = "v2"):
2121 """
2222 Loads a model version from S3.
2323 - v1: full Hugging Face model (config + tokenizer + weights)
24- - v2: quantized PyTorch model (.pth )
24+ - v2: quantized PyTorch model (entire model saved )
2525 """
2626 logger .info (f"🔍 Loading model version: { version } " )
2727 s3 = boto3 .client ("s3" )
@@ -30,7 +30,6 @@ def load_model(version: str = "v2"):
3030 os .makedirs (local_model_path , exist_ok = True )
3131
3232 if version == "v1" :
33- # Download all required files from S3 if not already cached
3433 files = [
3534 "config.json" ,
3635 "model.safetensors" ,
@@ -48,19 +47,13 @@ def load_model(version: str = "v2"):
4847 model = DistilBertForSequenceClassification .from_pretrained (local_model_path )
4948
5049 elif version == "v2" :
51- # v2: quantized model (state dict only)
52- local_quant_path = os .path .join (local_model_path , "quantized_model.pth" )
53-
50+ local_quant_path = os .path .join (local_model_path , "quantized_model_full.pth" )
5451 if not os .path .exists (local_quant_path ):
55- logger .info ("📥 Downloading quantized model from S3..." )
56- download_from_s3 (s3 , f"{ S3_BASE_PATH } /{ version } /quantized_model.pth" , local_quant_path )
57-
58- logger .info ("⚙️ Loading DistilBERT config and applying quantized weights..." )
59- config = DistilBertConfig .from_pretrained ("distilbert-base-uncased" , num_labels = 2 )
60- model = DistilBertForSequenceClassification (config )
61- state_dict = torch .load (local_quant_path , map_location = "cpu" )
62- model .load_state_dict (state_dict )
52+ logger .info ("📥 Downloading full quantized model from S3..." )
53+ download_from_s3 (s3 , f"{ S3_BASE_PATH } /{ version } /quantized_model_full.pth" , local_quant_path )
6354
55+ logger .info ("⚙️ Loading full quantized DistilBERT model..." )
56+ model = torch .load (local_quant_path , map_location = "cpu" )
6457 tokenizer = DistilBertTokenizer .from_pretrained ("distilbert-base-uncased" )
6558
6659 else :
@@ -72,6 +65,7 @@ def load_model(version: str = "v2"):
7265
7366
7467
68+
7569def predict_text (model , tokenizer , text : str ):
7670 """Run inference with automatic handling for both model types."""
7771 try :
0 commit comments