Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .example.env
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ MODEL_CHANGE_PASSWORD=sugarai2024

DEFAULT_MODEL=Qwen/Qwen2-1.5B-Instruct

#For local development
DEV_MODE=1 #0 for default model
DEV_MODEL_NAME=HuggingFaceTB/SmolLM-135M-Instruct
PROD_MODEL_NAME=Qwen/Qwen2-1.5B-Instruct

DOC_PATHS=["./docs/Pygame Documentation.pdf", "./docs/Python GTK+3 Documentation.pdf", "./docs/Sugar Toolkit Documentation.pdf"]

PORT=8000
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ The FastAPI server provides endpoints to interact with Sugar-AI.
```sh
pip install -r requirements.txt
```
## Local Development (DEV_MODE)

By default, Sugar-AI loads large language models intended for production use.
These models may require significant memory and can cause startup failures
on low-memory contributor machines.

To improve the local development experience, Sugar-AI provides a development
mode that uses a lightweight, CPU-friendly model.

### Enable DEV_MODE

```bash
DEV_MODE=1 python main.py
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're supposed to close this code block you started. It's missing the three backticks.


### Run the server

Expand Down
46 changes: 32 additions & 14 deletions app/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Optional, List
import app.prompts as prompts
from app.config import settings
import logging
logger = logging.getLogger("sugar-ai")

def format_docs(docs):
"""Return document content separated by newlines"""
Expand All @@ -34,25 +38,39 @@ def extract_answer_from_output(outputs):

class RAGAgent:
"""Retrieval-Augmented Generation agent for Sugar-AI"""

def __init__(self, model: str = "google/gemma-3-27b-it", quantize: bool = True):
# disable quantization if CUDA is not available
self.use_quant = quantize and torch.cuda.is_available()
self.model_name = model

if self.use_quant:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

def __init__(self, model: Optional[str] = None, quantize: bool = True):
# 1) Determine model name with clear precedence:
# explicit argument > DEV_MODEL_NAME (if DEV_MODE) > PROD_MODEL_NAME > DEFAULT_MODEL
if model:
self.model_name = model
logger.info("Using explicit model argument: %s", self.model_name)
else:
if getattr(settings, "DEV_MODE", False):
# prefer DEV_MODEL_NAME, then fallback to DEFAULT_MODEL
self.model_name = getattr(settings, "DEV_MODEL_NAME", settings.DEFAULT_MODEL)
logger.info("DEV_MODE active: using lightweight model %s", self.model_name)
else:
# production: prefer PROD_MODEL_NAME, else DEFAULT_MODEL
self.model_name = getattr(settings, "PROD_MODEL_NAME", settings.DEFAULT_MODEL)
logger.info("Using production model %s", self.model_name)

# 2) Compute quantization/device choices. Keep quantization off in DEV_MODE by default.
self.use_quant = quantize and torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False)
device = 0 if torch.cuda.is_available() and not getattr(settings, "DEV_MODE", False) else -1
dtype = torch.float16 if device == 0 else torch.float32

if self.use_quant:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
self.model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto"
Expand All @@ -75,11 +93,11 @@ def __init__(self, model: str = "google/gemma-3-27b-it", quantize: bool = True):
else:
self.model = pipeline(
"text-generation",
model=model,
model=self.model_name,
max_new_tokens=1024,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
torch_dtype=dtype, # Use the dynamic dtype
device=device, # Use the dynamic device
)

self.simplify_model = self.model
Expand All @@ -97,7 +115,7 @@ def set_model(self, model: str) -> None:
self.model_name = model
self.model = pipeline(
"text-generation",
model=model,
model=self.model_name,
max_length=1024,
truncation=True,
torch_dtype=torch.float16
Expand Down
25 changes: 14 additions & 11 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
Configuration settings for Sugar-AI.
"""
import os
import json
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from typing import Dict, List, Any, Optional
from dotenv import load_dotenv

load_dotenv()

class Settings(BaseSettings):
"""Application settings loaded from environment variables"""
API_KEYS: Dict[str, Dict[str, Any]] = json.loads(os.getenv("API_KEYS", "{}"))
MODEL_CHANGE_PASSWORD: str = os.getenv("MODEL_CHANGE_PASSWORD", "")
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "Qwen/Qwen2-1.5B-Instruct")
DOC_PATHS: List[str] = json.loads(os.getenv("DOC_PATHS", '["./docs/Pygame Documentation.pdf", "./docs/Python GTK+3 Documentation.pdf", "./docs/Sugar Toolkit Documentation.pdf"]'))
MAX_DAILY_REQUESTS: int = int(os.getenv("MAX_DAILY_REQUESTS", 100))

# Dev mode (THIS MUST EXIST)
DEV_MODE: bool = os.getenv("DEV_MODE", "0") == "1"
DEV_MODEL_NAME: str | None = None
PROD_MODEL_NAME: str | None = None
DEFAULT_MODEL: str | None = None

API_KEYS: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
MODEL_CHANGE_PASSWORD: str = ""
DOC_PATHS: List[str] = Field(default_factory=list)
MAX_DAILY_REQUESTS: int = 100

# OAuth
github_client_id: Optional[str] = None
github_client_secret: Optional[str] = None
Expand All @@ -34,4 +37,4 @@ class Config:
env_file = ".env"
extra = "allow" # this allows extra attribute if we have any

settings = Settings()
settings = Settings()
5 changes: 2 additions & 3 deletions app/routes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ class PromptedLLMRequest(BaseModel):
# setup logging
logger = logging.getLogger("sugar-ai")

# load ai agent and document paths
agent = RAGAgent(model=settings.DEFAULT_MODEL)
agent.retriever = agent.setup_vectorstore(settings.DOC_PATHS)
# Initialize the agent
agent = None

# user quotas tracking
user_quotas: Dict[str, Dict] = {}
Expand Down
19 changes: 18 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import os

from app import create_app
from app.ai import RAGAgent
from app.database import get_db
from app.auth import sync_env_keys_to_db
from app.config import settings
from app.routes import api

# setup logging
logger = logging.getLogger("sugar-ai")
Expand All @@ -38,7 +40,22 @@ async def startup_event():
"""Initialize data on app startup"""
db = next(get_db())
sync_env_keys_to_db(db)
logger.info(f"Starting Sugar-AI with model: {settings.DEFAULT_MODEL}")
if settings.DEV_MODE:
active_model = settings.DEV_MODEL_NAME
logger.info(f"DEV_MODE active. Loading lightweight model: {active_model}")
else:
active_model = settings.PROD_MODEL_NAME
logger.info(f"PRODUCTION mode. Loading full model: {active_model}")

initialized_agent = RAGAgent(model=active_model)
initialized_agent.retriever = initialized_agent.setup_vectorstore(settings.DOC_PATHS)

# Inject this instance into the API module
# This updates the 'agent = None' in api.py to be the real loaded model
api.agent = initialized_agent

app.state.agent = initialized_agent


if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
Expand Down