Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Tests

on:
pull_request_target:
branches: [main]

jobs:
backend-unit-tests:
name: Backend Unit Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }}
ref: ${{ github.event.pull_request.head.sha || '' }}

- uses: astral-sh/setup-uv@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
working-directory: tensormap-backend
run: uv sync --frozen --extra dev

- name: Run backend unit tests
working-directory: tensormap-backend
run: uv run pytest tests/ -v --tb=short

integration-tests:
name: Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }}
ref: ${{ github.event.pull_request.head.sha || '' }}

- uses: astral-sh/setup-uv@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
working-directory: tensormap-backend
run: uv sync --frozen --extra dev

- name: Run integration tests
working-directory: tensormap-backend
run: uv run pytest ../tests/integration/ -v --tb=short
26 changes: 26 additions & 0 deletions tensormap-backend/app/routers/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from app.services.data_process import (
add_target_service,
augment_image_service,
delete_one_target_by_id_service,
get_all_targets_service,
get_column_stats_service,
Expand Down Expand Up @@ -95,3 +96,28 @@ def preprocess(file_id: uuid_pkg.UUID, request: PreprocessRequest, db: Session =
logger.debug("Preprocessing file_id=%s with %d transformations", file_id, len(request.transformations))
body, status_code = preprocess_data(db, file_id=file_id, transformations=request.transformations)
return JSONResponse(status_code=status_code, content=body)


@router.post("/data/augment/image/{file_id}")
def augment_image(
file_id: uuid_pkg.UUID,
technique: str = Query(
"flip_horizontal",
pattern="^(flip_horizontal|flip_vertical|rotate_90|brightness|zoom|gaussian_noise|random_crop)$",
),
db: Session = Depends(get_db),
):
"""Apply image augmentation techniques to generate synthetic variants.

Supported techniques:
- flip_horizontal: Mirror along vertical axis
- flip_vertical: Mirror along horizontal axis
- rotate_90: Rotate by 90 degrees
- brightness: Increase brightness by 20%
- zoom: Zoom to 90% then resize
- gaussian_noise: Add Gaussian noise
- random_crop: Crop to 85% then resize
"""
logger.debug("Applying %s augmentation to file_id=%s", technique, file_id)
body, status_code = augment_image_service(db, file_id=file_id, technique=technique)
return JSONResponse(status_code=status_code, content=body)
55 changes: 55 additions & 0 deletions tensormap-backend/app/routers/deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
from app.database import get_db
from app.schemas.deep_learning import ModelNameRequest, ModelSaveRequest, ModelValidateRequest, TrainingConfigRequest
from app.services.deep_learning import (
compare_runs_service,
delete_model_service,
export_model_service,
get_available_model_list,
get_code_service,
get_model_graph_service,
interpret_model_service,
model_save_service,
model_validate_service,
run_code_service,
tune_hyperparameters_service,
update_training_config_service,
)
from app.shared.logging_config import get_logger
Expand Down Expand Up @@ -111,3 +115,54 @@ def get_model_list(
"""Return a paginated list of saved model names, optionally filtered by project."""
body, status_code = get_available_model_list(db, project_id=project_id, offset=offset, limit=limit)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/interpret/{model_name}")
def interpret_model(
model_name: str,
file_id: uuid_pkg.UUID | None = Query(None),
project_id: uuid_pkg.UUID | None = Query(None),
db: Session = Depends(get_db),
):
"""Generate interpretability analysis for a trained model."""
logger.debug("Interpreting model %s", model_name)
body, status_code = interpret_model_service(db, model_name=model_name, file_id=file_id, project_id=project_id)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/export/{model_name}")
def export_model(
model_name: str,
format: str = Query("savedmodel", pattern="^(savedmodel|tflite|onnx)$"),
project_id: uuid_pkg.UUID | None = Query(None),
db: Session = Depends(get_db),
):
"""Export a trained model in the specified format."""
logger.debug("Exporting model %s as %s", model_name, format)
body, status_code = export_model_service(db, model_name=model_name, export_format=format, project_id=project_id)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/compare")
def compare_runs(
project_id: uuid_pkg.UUID | None = Query(None),
limit: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
):
"""Compare metrics across multiple training runs for a project."""
logger.debug("Comparing runs for project %s", project_id)
body, status_code = compare_runs_service(db, project_id=project_id, limit=limit)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/tune/{model_name}")
def tune_hyperparameters(
model_name: str,
file_id: uuid_pkg.UUID | None = Query(None),
project_id: uuid_pkg.UUID | None = Query(None),
db: Session = Depends(get_db),
):
"""Perform hyperparameter tuning for a model via grid search."""
logger.debug("Tuning hyperparameters for %s", model_name)
body, status_code = tune_hyperparameters_service(db, model_name=model_name, file_id=file_id, project_id=project_id)
return JSONResponse(status_code=status_code, content=body)
84 changes: 84 additions & 0 deletions tensormap-backend/app/services/data_process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import uuid as uuid_pkg
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -370,3 +371,86 @@ def preprocess_data(db: Session, file_id: uuid_pkg.UUID, transformations: list)
except Exception as e:
logger.exception("Error preprocessing data: %s", str(e))
return _resp(500, False, f"Error preprocessing data: {e}")


def augment_image_service(
db: Session,
file_id: uuid_pkg.UUID,
technique: str = "flip_horizontal",
) -> tuple:
"""Apply image augmentation techniques to generate synthetic variants.

Supported techniques:
- flip_horizontal: Mirror image along vertical axis
- flip_vertical: Mirror image along horizontal axis
- rotate_90: Rotate image by 90 degrees
- brightness: Adjust brightness by 20%
- zoom: Zoom to 90% then resize
- gaussian_noise: Add Gaussian noise
- random_crop: Crop to 85% then resize
"""
file_record = db.get(DataFile, file_id)
if not file_record:
return _resp(404, False, "Image file not found")

file_path = file_record.file_path
if not file_path or not os.path.exists(file_path):
return _resp(404, False, "Image file not found on disk")

settings = get_settings()
output_dir = os.path.join(settings.UPLOAD_DIRECTORY, f"augmented_{file_id}")
os.makedirs(output_dir, exist_ok=True)

supported_formats = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
_, ext = os.path.splitext(file_path)
if ext.lower() not in supported_formats:
return _resp(400, False, f"Unsupported image format: {ext}")

try:
from PIL import Image, ImageEnhance

original = Image.open(file_path)

if technique == "flip_horizontal":
augmented = original.transpose(Image.FLIP_LEFT_RIGHT)
elif technique == "flip_vertical":
augmented = original.transpose(Image.FLIP_TOP_BOTTOM)
elif technique == "rotate_90":
augmented = original.rotate(90, expand=True)
elif technique == "brightness":
enhancer = ImageEnhance.Brightness(original)
augmented = enhancer.enhance(1.2)
elif technique == "zoom":
width, height = original.size
new_width = int(width * 0.9)
new_height = int(height * 0.9)
left = (width - new_width) // 2
top = (height - new_height) // 2
cropped = original.crop((left, top, left + new_width, top + new_height))
augmented = cropped.resize((width, height), Image.LANCZOS)
elif technique == "gaussian_noise":
np_img = np.array(original.convert("RGB")).astype(np.float32) / 255.0
noise = np.random.normal(0, 0.05, np_img.shape)
np_img = np.clip(np_img + noise, 0, 1)
augmented = Image.fromarray((np_img * 255).astype(np.uint8))
elif technique == "random_crop":
width, height = original.size
crop_size = int(min(width, height) * 0.85)
left = np.random.randint(0, width - crop_size + 1)
top = np.random.randint(0, height - crop_size + 1)
cropped = original.crop((left, top, left + crop_size, top + crop_size))
augmented = cropped.resize((width, height), Image.LANCZOS)
else:
return _resp(400, False, f"Unknown technique: {technique}")

output_path = os.path.join(output_dir, f"augmented_0{ext}")
augmented.save(output_path)

logger.info("Applied %s augmentation to file %s", technique, file_id)
return _resp(200, True, f"Generated augmented image using {technique}", {"output_path": output_path})

except ImportError:
return _resp(500, False, "Pillow not installed")
except Exception as e:
logger.exception("Error augmenting image: %s", str(e))
return _resp(500, False, f"Error augmenting image: {e}")
Loading
Loading