Skip to content

Commit 57274d9

Browse files
committed
feat: GAP-1 — persist training history with metrics
- Add ModelTrainingRun table to capture every training run - Store timing, config snapshot, final metrics, epoch-by-epoch curves - Save history in model_run.py after model.fit() completes - Add 3 new endpoints: GET /api/model/{model_id}/training-runs GET /api/model/{model_id}/training-run/{run_id}/metrics POST /api/model/{model_id}/training-run/{run_id}/set-as-best - Alembic migration: 9e13d23b8149 - Register ModelTrainingRun in migrations/env.py
1 parent ef5b6c6 commit 57274d9

7 files changed

Lines changed: 485 additions & 11 deletions

File tree

tensormap-backend/app/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
validation_exception_handler,
2222
)
2323
from app.middleware import RequestLoggingMiddleware
24-
from app.routers import data_process, data_upload, deep_learning, health, project
24+
from app.routers import data_process, data_upload, deep_learning, health, project, training_run
2525
from app.shared.logging_config import get_logger
2626
from app.socketio_instance import sio
2727

@@ -66,6 +66,7 @@ async def lifespan(app: FastAPI):
6666
app.include_router(data_process.router, prefix=settings.api_base)
6767
app.include_router(deep_learning.router, prefix=settings.api_base)
6868
app.include_router(project.router, prefix=settings.api_base)
69+
app.include_router(training_run.router, prefix=settings.api_base)
6970

7071
# Wrap FastAPI with SocketIO so socket.io requests are handled,
7172
# and everything else passes through to FastAPI.

tensormap-backend/app/models/ml.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class ModelBasic(SQLModel, table=True):
4343

4444
project: Optional["Project"] = Relationship(back_populates="models")
4545
file: Optional["DataFile"] = Relationship(back_populates="model_basic")
46+
training_runs: list["ModelTrainingRun"] = Relationship(
47+
back_populates="model",
48+
sa_relationship_kwargs={"cascade": "all,delete"},
49+
)
4650
configs: list["ModelConfigs"] = Relationship(
4751
back_populates="model",
4852
sa_relationship_kwargs={"cascade": "all,delete"},
@@ -68,5 +72,6 @@ class ModelConfigs(SQLModel, table=True):
6872

6973
# Resolve forward references
7074
from app.models.data import DataFile # noqa: E402
75+
from app.models.training_run import ModelTrainingRun # noqa: E402
7176

7277
ModelBasic.model_rebuild()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from datetime import datetime
2+
from typing import Optional
3+
4+
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String, func
5+
from sqlmodel import Field, Relationship, SQLModel
6+
7+
8+
class ModelTrainingRun(SQLModel, table=True):
9+
"""Records every training run for a model, capturing metrics and config."""
10+
11+
__tablename__ = "model_training_run"
12+
13+
id: int | None = Field(default=None, primary_key=True)
14+
model_id: int = Field(
15+
sa_column=Column(ForeignKey("model_basic.id", ondelete="CASCADE"), index=True, nullable=False)
16+
)
17+
18+
# Timing
19+
started_at: datetime = Field(sa_column=Column(DateTime, nullable=False))
20+
completed_at: datetime | None = Field(default=None, sa_column=Column(DateTime, nullable=True))
21+
duration_seconds: float | None = Field(default=None, nullable=True)
22+
23+
# Config snapshot at time of training
24+
epochs_configured: int | None = Field(default=None, nullable=True)
25+
batch_size_configured: int | None = Field(default=None, nullable=True)
26+
training_split_configured: float | None = Field(default=None, nullable=True)
27+
optimizer: str | None = Field(default=None, max_length=50, nullable=True)
28+
loss_fn: str | None = Field(default=None, max_length=50, nullable=True)
29+
metric_name: str | None = Field(default=None, max_length=50, nullable=True)
30+
31+
# Final results
32+
final_train_loss: float | None = Field(default=None, nullable=True)
33+
final_train_metric: float | None = Field(default=None, nullable=True)
34+
final_val_loss: float | None = Field(default=None, nullable=True)
35+
final_val_metric: float | None = Field(default=None, nullable=True)
36+
37+
# Full epoch-by-epoch curves stored as JSON arrays
38+
epoch_losses: list | None = Field(default=None, sa_column=Column(JSON, nullable=True))
39+
epoch_metrics: list | None = Field(default=None, sa_column=Column(JSON, nullable=True))
40+
epoch_val_losses: list | None = Field(default=None, sa_column=Column(JSON, nullable=True))
41+
epoch_val_metrics: list | None = Field(default=None, sa_column=Column(JSON, nullable=True))
42+
43+
# Status: in_progress | success | failed | best
44+
status: str = Field(default="in_progress", sa_column=Column(String(20), nullable=False))
45+
error_message: str | None = Field(default=None, nullable=True)
46+
47+
created_on: datetime | None = Field(default=None, sa_column=Column(DateTime, server_default=func.now()))
48+
49+
model: Optional["ModelBasic"] = Relationship(back_populates="training_runs")
50+
51+
52+
from app.models.ml import ModelBasic # noqa: E402
53+
54+
ModelTrainingRun.model_rebuild()
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from sqlmodel import Session, select
3+
4+
from app.database import get_db
5+
from app.models.ml import ModelBasic
6+
from app.models.training_run import ModelTrainingRun
7+
from app.shared.logging_config import get_logger
8+
9+
logger = get_logger(__name__)
10+
11+
12+
logger = get_logger(__name__)
13+
router = APIRouter(tags=["Training Runs"])
14+
15+
16+
@router.get("/model/{model_id}/training-runs")
17+
def get_training_runs(
18+
model_id: int,
19+
offset: int = 0,
20+
limit: int = 50,
21+
db: Session = Depends(get_db),
22+
):
23+
"""List all training runs for a model, newest first. Limited to prevent large result sets."""
24+
model = db.get(ModelBasic, model_id)
25+
if not model:
26+
raise HTTPException(status_code=404, detail="Model not found")
27+
28+
runs = db.exec(
29+
select(ModelTrainingRun)
30+
.where(ModelTrainingRun.model_id == model_id)
31+
.order_by(ModelTrainingRun.created_on.desc())
32+
.offset(offset)
33+
.limit(limit)
34+
).all()
35+
36+
return {
37+
"model_id": model_id,
38+
"model_name": model.model_name,
39+
"returned_runs": len(runs),
40+
"offset": offset,
41+
"limit": limit,
42+
"runs": [
43+
{
44+
"id": r.id,
45+
"status": r.status,
46+
"started_at": r.started_at,
47+
"completed_at": r.completed_at,
48+
"duration_seconds": r.duration_seconds,
49+
"epochs_configured": r.epochs_configured,
50+
"batch_size_configured": r.batch_size_configured,
51+
"final_train_loss": r.final_train_loss,
52+
"final_train_metric": r.final_train_metric,
53+
"final_val_loss": r.final_val_loss,
54+
"final_val_metric": r.final_val_metric,
55+
"metric_name": r.metric_name,
56+
"error_message": r.error_message,
57+
}
58+
for r in runs
59+
],
60+
}
61+
62+
63+
@router.get("/model/{model_id}/training-run/{run_id}/metrics")
64+
def get_training_run_metrics(model_id: int, run_id: int, db: Session = Depends(get_db)):
65+
"""Get detailed epoch-by-epoch metrics for a specific training run."""
66+
run = db.exec(
67+
select(ModelTrainingRun).where(
68+
ModelTrainingRun.id == run_id,
69+
ModelTrainingRun.model_id == model_id,
70+
)
71+
).first()
72+
73+
if not run:
74+
raise HTTPException(status_code=404, detail="Training run not found")
75+
76+
return {
77+
"id": run.id,
78+
"model_id": run.model_id,
79+
"status": run.status,
80+
"started_at": run.started_at,
81+
"completed_at": run.completed_at,
82+
"duration_seconds": run.duration_seconds,
83+
"config": {
84+
"epochs": run.epochs_configured,
85+
"batch_size": run.batch_size_configured,
86+
"training_split": run.training_split_configured,
87+
"optimizer": run.optimizer,
88+
"loss_fn": run.loss_fn,
89+
"metric": run.metric_name,
90+
},
91+
"results": {
92+
"final_train_loss": run.final_train_loss,
93+
"final_train_metric": run.final_train_metric,
94+
"final_val_loss": run.final_val_loss,
95+
"final_val_metric": run.final_val_metric,
96+
},
97+
"curves": {
98+
"epoch_losses": run.epoch_losses,
99+
"epoch_metrics": run.epoch_metrics,
100+
"epoch_val_losses": run.epoch_val_losses,
101+
"epoch_val_metrics": run.epoch_val_metrics,
102+
},
103+
"error_message": run.error_message,
104+
}
105+
106+
107+
@router.post("/model/{model_id}/training-run/{run_id}/set-as-best")
108+
def set_as_best(model_id: int, run_id: int, db: Session = Depends(get_db)):
109+
"""Mark a training run as the best run for this model."""
110+
run = db.exec(
111+
select(ModelTrainingRun).where(
112+
ModelTrainingRun.id == run_id,
113+
ModelTrainingRun.model_id == model_id,
114+
)
115+
).first()
116+
117+
if not run:
118+
raise HTTPException(status_code=404, detail="Training run not found")
119+
if run.status not in ("success", "best"):
120+
raise HTTPException(status_code=400, detail="Can only mark successful runs as best")
121+
122+
# Clear previous best for this model
123+
all_runs = db.exec(select(ModelTrainingRun).where(ModelTrainingRun.model_id == model_id)).all()
124+
for r in all_runs:
125+
if r.status == "best":
126+
r.status = "success"
127+
db.add(r)
128+
129+
run.status = "best"
130+
db.add(run)
131+
db.commit()
132+
db.refresh(run)
133+
134+
return {"message": f"Run #{run_id} marked as best", "run_id": run_id}

0 commit comments

Comments
 (0)