Skip to content

Commit 2aa5e9e

Browse files
fix: resolve AI model double-initialization and formatting
Previously, the RAGAgent was initializing twice (once on module import and once during startup), which unnecessarily doubled RAM consumption. This commit fixes the issue by implementing lazy-loading: - Moved the model instantiation strictly into FastAPI's startup_event. - Injected the singleton instance into the API routes only after the application fully boots. - Resolved formatting inconsistencies to align with the codebase standards.
1 parent 8d5ead8 commit 2aa5e9e

File tree

6 files changed

+86
-29
lines changed

6 files changed

+86
-29
lines changed

.example.env

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ MODEL_CHANGE_PASSWORD=sugarai2024
55

66
DEFAULT_MODEL=Qwen/Qwen2-1.5B-Instruct
77

8+
#For local development
9+
DEV_MODE=1 #0 for default model
10+
DEV_MODEL_NAME=HuggingFaceTB/SmolLM-135M-Instruct
11+
PROD_MODEL_NAME=Qwen/Qwen2-1.5B-Instruct
12+
813
DOC_PATHS=["./docs/Pygame Documentation.pdf", "./docs/Python GTK+3 Documentation.pdf", "./docs/Sugar Toolkit Documentation.pdf"]
914

1015
PORT=8000

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ The FastAPI server provides endpoints to interact with Sugar-AI.
3939
```sh
4040
pip install -r requirements.txt
4141
```
42+
## Local Development (DEV_MODE)
43+
44+
By default, Sugar-AI loads large language models intended for production use.
45+
These models may require significant memory and can cause startup failures
46+
on low-memory contributor machines.
47+
48+
To improve the local development experience, Sugar-AI provides a development
49+
mode that uses a lightweight, CPU-friendly model.
50+
51+
### Enable DEV_MODE
52+
53+
```bash
54+
DEV_MODE=1 python main.py
55+
```
56+
4257

4358
### Run the server
4459

app/ai.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
1010
from langchain_core.runnables import RunnablePassthrough
1111
from langchain_core.prompts import ChatPromptTemplate
12+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
1213
from typing import Optional, List
1314
import app.prompts as prompts
15+
from app.config import settings
16+
import logging
17+
logger = logging.getLogger("sugar-ai")
1418

1519
def format_docs(docs):
1620
"""Return document content separated by newlines"""
@@ -34,25 +38,39 @@ def extract_answer_from_output(outputs):
3438

3539
class RAGAgent:
3640
"""Retrieval-Augmented Generation agent for Sugar-AI"""
37-
38-
def __init__(self, model: str = "google/gemma-3-27b-it", quantize: bool = True):
39-
# disable quantization if CUDA is not available
40-
self.use_quant = quantize and torch.cuda.is_available()
41-
self.model_name = model
42-
43-
if self.use_quant:
44-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
41+
42+
def __init__(self, model: Optional[str] = None, quantize: bool = True):
43+
# 1) Determine model name with clear precedence:
44+
# explicit argument > DEV_MODEL_NAME (if DEV_MODE) > PROD_MODEL_NAME > DEFAULT_MODEL
45+
if model:
46+
self.model_name = model
47+
logger.info("Using explicit model argument: %s", self.model_name)
48+
else:
49+
if getattr(settings, "DEV_MODE", False):
50+
# prefer DEV_MODEL_NAME, then fallback to DEFAULT_MODEL
51+
self.model_name = getattr(settings, "DEV_MODEL_NAME", settings.DEFAULT_MODEL)
52+
logger.info("DEV_MODE active: using lightweight model %s", self.model_name)
53+
else:
54+
# production: prefer PROD_MODEL_NAME, else DEFAULT_MODEL
55+
self.model_name = getattr(settings, "PROD_MODEL_NAME", settings.DEFAULT_MODEL)
56+
logger.info("Using production model %s", self.model_name)
57+
58+
# 2) Compute quantization/device choices. Keep quantization off in DEV_MODE by default.
59+
self.use_quant = quantize and torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False)
60+
device = 0 if torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False) else -1
61+
dtype = torch.float16 if device == 0 else torch.float32
4562

63+
if self.use_quant:
4664
bnb_config = BitsAndBytesConfig(
4765
load_in_4bit=True,
4866
bnb_4bit_compute_dtype=torch.float16,
4967
bnb_4bit_use_double_quant=True,
5068
bnb_4bit_quant_type="nf4"
5169
)
5270

53-
tokenizer = AutoTokenizer.from_pretrained(model)
71+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
5472
model_obj = AutoModelForCausalLM.from_pretrained(
55-
model,
73+
self.model_name,
5674
quantization_config=bnb_config,
5775
torch_dtype=torch.float16,
5876
device_map="auto"
@@ -75,11 +93,11 @@ def __init__(self, model: str = "google/gemma-3-27b-it", quantize: bool = True):
7593
else:
7694
self.model = pipeline(
7795
"text-generation",
78-
model=model,
96+
model=self.model_name,
7997
max_new_tokens=1024,
8098
truncation=True,
81-
torch_dtype=torch.float16,
82-
device=0 if torch.cuda.is_available() else -1,
99+
torch_dtype=dtype, # Use the dynamic dtype
100+
device=device, # Use the dynamic device
83101
)
84102

85103
self.simplify_model = self.model
@@ -97,7 +115,7 @@ def set_model(self, model: str) -> None:
97115
self.model_name = model
98116
self.model = pipeline(
99117
"text-generation",
100-
model=model,
118+
model=self.model_name,
101119
max_length=1024,
102120
truncation=True,
103121
torch_dtype=torch.float16

app/config.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22
Configuration settings for Sugar-AI.
33
"""
44
import os
5-
import json
6-
from pydantic_settings import BaseSettings
5+
from pydantic_settings import BaseSettings, SettingsConfigDict
6+
from pydantic import Field
77
from typing import Dict, List, Any, Optional
8-
from dotenv import load_dotenv
9-
10-
load_dotenv()
118

129
class Settings(BaseSettings):
1310
"""Application settings loaded from environment variables"""
14-
API_KEYS: Dict[str, Dict[str, Any]] = json.loads(os.getenv("API_KEYS", "{}"))
15-
MODEL_CHANGE_PASSWORD: str = os.getenv("MODEL_CHANGE_PASSWORD", "")
16-
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "Qwen/Qwen2-1.5B-Instruct")
17-
DOC_PATHS: List[str] = json.loads(os.getenv("DOC_PATHS", '["./docs/Pygame Documentation.pdf", "./docs/Python GTK+3 Documentation.pdf", "./docs/Sugar Toolkit Documentation.pdf"]'))
18-
MAX_DAILY_REQUESTS: int = int(os.getenv("MAX_DAILY_REQUESTS", 100))
11+
12+
# Dev mode (THIS MUST EXIST)
13+
DEV_MODE: bool = os.getenv("DEV_MODE", "0") == "1"
14+
DEV_MODEL_NAME: str | None = None
15+
PROD_MODEL_NAME: str | None = None
16+
DEFAULT_MODEL: str | None = None
1917

18+
API_KEYS: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
19+
MODEL_CHANGE_PASSWORD: str = ""
20+
DOC_PATHS: List[str] = Field(default_factory=list)
21+
MAX_DAILY_REQUESTS: int = 100
22+
2023
# OAuth
2124
github_client_id: Optional[str] = None
2225
github_client_secret: Optional[str] = None
@@ -34,4 +37,4 @@ class Config:
3437
env_file = ".env"
3538
extra = "allow" # this allows extra attribute if we have any
3639

37-
settings = Settings()
40+
settings = Settings()

app/routes/api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ class PromptedLLMRequest(BaseModel):
3838
# setup logging
3939
logger = logging.getLogger("sugar-ai")
4040

41-
# load ai agent and document paths
42-
agent = RAGAgent(model=settings.DEFAULT_MODEL)
43-
agent.retriever = agent.setup_vectorstore(settings.DOC_PATHS)
41+
# Initialize the agent
42+
agent = None
4443

4544
# user quotas tracking
4645
user_quotas: Dict[str, Dict] = {}

main.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
import os
2525

2626
from app import create_app
27+
from app.ai import RAGAgent
2728
from app.database import get_db
2829
from app.auth import sync_env_keys_to_db
2930
from app.config import settings
31+
from app.routes import api
3032

3133
# setup logging
3234
logger = logging.getLogger("sugar-ai")
@@ -38,7 +40,22 @@ async def startup_event():
3840
"""Initialize data on app startup"""
3941
db = next(get_db())
4042
sync_env_keys_to_db(db)
41-
logger.info(f"Starting Sugar-AI with model: {settings.DEFAULT_MODEL}")
43+
if settings.DEV_MODE:
44+
active_model = settings.DEV_MODEL_NAME
45+
logger.info(f"DEV_MODE active. Loading lightweight model: {active_model}")
46+
else:
47+
active_model = settings.PROD_MODEL_NAME
48+
logger.info(f"PRODUCTION mode. Loading full model: {active_model}")
49+
50+
initialized_agent = RAGAgent(model=active_model)
51+
initialized_agent.retriever = initialized_agent.setup_vectorstore(settings.DOC_PATHS)
52+
53+
# Inject this instance into the API module
54+
# This updates the 'agent = None' in api.py to be the real loaded model
55+
api.agent = initialized_agent
56+
57+
app.state.agent = initialized_agent
58+
4259

4360
if __name__ == "__main__":
4461
port = int(os.getenv("PORT", 8000))

0 commit comments

Comments
 (0)