99from langchain_community .document_loaders import PyMuPDFLoader , TextLoader
1010from langchain_core .runnables import RunnablePassthrough
1111from langchain_core .prompts import ChatPromptTemplate
12+ from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
1213from typing import Optional , List
1314import app .prompts as prompts
15+ from app .config import settings
16+ import logging
17+ logger = logging .getLogger ("sugar-ai" )
1418
1519def format_docs (docs ):
1620 """Return document content separated by newlines"""
@@ -34,25 +38,39 @@ def extract_answer_from_output(outputs):
3438
3539class RAGAgent :
3640 """Retrieval-Augmented Generation agent for Sugar-AI"""
37-
38- def __init__ (self , model : str = "google/gemma-3-27b-it" , quantize : bool = True ):
39- # disable quantization if CUDA is not available
40- self .use_quant = quantize and torch .cuda .is_available ()
41- self .model_name = model
42-
43- if self .use_quant :
44- from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
41+
42+ def __init__ (self , model : Optional [str ] = None , quantize : bool = True ):
43+ # 1) Determine model name with clear precedence:
44+ # explicit argument > DEV_MODEL_NAME (if DEV_MODE) > PROD_MODEL_NAME > DEFAULT_MODEL
45+ if model :
46+ self .model_name = model
47+ logger .info ("Using explicit model argument: %s" , self .model_name )
48+ else :
49+ if getattr (settings , "DEV_MODE" , False ):
50+ # prefer DEV_MODEL_NAME, then fallback to DEFAULT_MODEL
51+ self .model_name = getattr (settings , "DEV_MODEL_NAME" , settings .DEFAULT_MODEL )
52+ logger .info ("DEV_MODE active: using lightweight model %s" , self .model_name )
53+ else :
54+ # production: prefer PROD_MODEL_NAME, else DEFAULT_MODEL
55+ self .model_name = getattr (settings , "PROD_MODEL_NAME" , settings .DEFAULT_MODEL )
56+ logger .info ("Using production model %s" , self .model_name )
57+
58+ # 2) Compute quantization/device choices. Keep quantization off in DEV_MODE by default.
59+ self .use_quant = quantize and torch .cuda .is_available () and not getattr (settings , "DEV_MODE" , False )
60+ device = 0 if torch .cuda .is_available () and not getattr (settings , "DEV_MODE" , False ) else - 1
61+ dtype = torch .float16 if device == 0 else torch .float32
4562
63+ if self .use_quant :
4664 bnb_config = BitsAndBytesConfig (
4765 load_in_4bit = True ,
4866 bnb_4bit_compute_dtype = torch .float16 ,
4967 bnb_4bit_use_double_quant = True ,
5068 bnb_4bit_quant_type = "nf4"
5169 )
5270
53- tokenizer = AutoTokenizer .from_pretrained (model )
71+ tokenizer = AutoTokenizer .from_pretrained (self . model_name )
5472 model_obj = AutoModelForCausalLM .from_pretrained (
55- model ,
73+ self . model_name ,
5674 quantization_config = bnb_config ,
5775 torch_dtype = torch .float16 ,
5876 device_map = "auto"
@@ -75,11 +93,11 @@ def __init__(self, model: str = "google/gemma-3-27b-it", quantize: bool = True):
7593 else :
7694 self .model = pipeline (
7795 "text-generation" ,
78- model = model ,
96+ model = self . model_name ,
7997 max_new_tokens = 1024 ,
8098 truncation = True ,
81- torch_dtype = torch . float16 ,
82- device = 0 if torch . cuda . is_available () else - 1 ,
99+ torch_dtype = dtype , # Use the dynamic dtype
100+ device = device , # Use the dynamic device
83101 )
84102
85103 self .simplify_model = self .model
@@ -97,7 +115,7 @@ def set_model(self, model: str) -> None:
97115 self .model_name = model
98116 self .model = pipeline (
99117 "text-generation" ,
100- model = model ,
118+ model = self . model_name ,
101119 max_length = 1024 ,
102120 truncation = True ,
103121 torch_dtype = torch .float16
0 commit comments