Skip to content

Commit 41d6bb7

Browse files
authored
Merge pull request (Refactoring of VLM code structure)
Refactoring of VLM code for centralized config and training structure
2 parents cb023ea + 42334fc commit 41d6bb7

File tree

22 files changed

+209
-58
lines changed

22 files changed

+209
-58
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
data/**/*.mp4
2+
data/**/*.json
3+
data/**/*.txt
4+
data/**/*.csv
25
models/**/*.pth
36
!**/.gitkeep
47

configs/vlm/hardware/v100.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
training:
2+
# batch_size: 8
3+
# gradient_accumulation_steps: 4 # Effective batch = 32
4+
5+
# model:
6+
# torch_dtype: "float16"
7+
8+
# accelerate:
9+
# use_accelerate: true
10+
# mixed_precision: "fp16" # V100 supports fp16, NOT bf16
11+
# gradient_checkpointing: true
12+
13+
# peft:
14+
# use_peft: true
15+
# peft_method: "lora"
16+
# r: 8
17+
# alpha: 16
18+
# dropout: 0.1
19+
20+
# quantization:
21+
# load_in_4bit: true
22+
# bnb_4bit_quant_type: "nf4"
23+
# bnb_4bit_compute_dtype: "float16"
24+
25+
# device: "cuda"

configs/vlm/serve.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Configurations for serving the app
2+
3+
model:
4+
model_id: "Qwen/Qwen2.5-VL"
5+
6+
inference:
7+
batch_size: 4
8+
max_batch_wait_ms: 1000
9+
10+
server:
11+
host: "0.0.0.0"
12+
port: 8000

configs/vlm/train.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
training:
2+
batch_size: 8
3+
learning_rate: 5e-5
4+
5+
model:
6+
model_id: "Qwen/Qwen2.5-VL"
7+
8+
# data:
9+
10+
logging:
11+
log_level: "INFO"
12+
13+
# checkpoint:
14+
# save_steps: 1000
15+
# save_total_limit: 3
16+
# output_dir: "experiments"
17+
18+
peft:
19+
use_peft: false

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"pillow>=11.3.0",
1515
"pydantic>=2.12.2",
1616
"pydantic-settings>=2.12.0",
17+
"pyyaml>=6.0.3",
1718
"qwen-vl-utils>=0.0.14",
1819
"ruff>=0.13.2",
1920
"threaded-videocapture>=1.0.1",

scripts/perception/.gitkeep

Whitespace-only changes.

scripts/perception/train_videomae.example

Whitespace-only changes.

src/iris/config/__init__.py

Whitespace-only changes.

src/iris/perception/README.md

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/iris/server/app.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
1010
from PIL import Image
1111

12+
from iris.server.config import ServerConfig
1213
from iris.server.dependencies import get_server_state
13-
from iris.vlm.inference.model_loader import load_model_and_processor
1414
from iris.vlm.inference.queue.jobs import SingleFrameJob
1515
from iris.vlm.inference.queue.queue import InferenceQueue
16+
from iris.vlm.models import load_model_and_processor
1617

1718
logging.basicConfig(level=logging.INFO)
1819
logger = logging.getLogger(__name__)
1920

21+
config = ServerConfig()
22+
2023

2124
@asynccontextmanager
2225
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@@ -25,10 +28,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
2528
state = get_server_state()
2629

2730
logger.info("Loading model...")
28-
state.model, state.processor = load_model_and_processor("smolvlm2")
31+
state.model, state.processor = load_model_and_processor(config.model_key)
2932

3033
logger.info("Starting inference queue...")
31-
state.queue = InferenceQueue(max_queue_size=10, num_workers=1)
34+
state.queue = InferenceQueue(
35+
max_queue_size=config.max_queue_size, num_workers=config.num_workers
36+
)
3237
await state.queue.start()
3338

3439
state.model_loaded = True

0 commit comments

Comments
 (0)