Skip to content

Commit 9bf2dc4

Browse files
committed
Merge main into interpretability-clean
2 parents 580e0bf + 30da92f commit 9bf2dc4

6 files changed

Lines changed: 223 additions & 0 deletions

File tree

tensormap-backend/app/routers/data_process.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from app.services.data_process import (
1313
add_target_service,
14+
augment_image_service,
1415
delete_one_target_by_id_service,
1516
get_all_targets_service,
1617
get_column_stats_service,
@@ -95,3 +96,28 @@ def preprocess(file_id: uuid_pkg.UUID, request: PreprocessRequest, db: Session =
9596
logger.debug("Preprocessing file_id=%s with %d transformations", file_id, len(request.transformations))
9697
body, status_code = preprocess_data(db, file_id=file_id, transformations=request.transformations)
9798
return JSONResponse(status_code=status_code, content=body)
99+
100+
101+
@router.post("/data/augment/image/{file_id}")
102+
def augment_image(
103+
file_id: uuid_pkg.UUID,
104+
technique: str = Query(
105+
"flip_horizontal",
106+
pattern="^(flip_horizontal|flip_vertical|rotate_90|brightness|zoom|gaussian_noise|random_crop)$",
107+
),
108+
db: Session = Depends(get_db),
109+
):
110+
"""Apply image augmentation techniques to generate synthetic variants.
111+
112+
Supported techniques:
113+
- flip_horizontal: Mirror along vertical axis
114+
- flip_vertical: Mirror along horizontal axis
115+
- rotate_90: Rotate by 90 degrees
116+
- brightness: Increase brightness by 20%
117+
- zoom: Zoom to 90% then resize
118+
- gaussian_noise: Add Gaussian noise
119+
- random_crop: Crop to 85% then resize
120+
"""
121+
logger.debug("Applying %s augmentation to file_id=%s", technique, file_id)
122+
body, status_code = augment_image_service(db, file_id=file_id, technique=technique)
123+
return JSONResponse(status_code=status_code, content=body)

tensormap-backend/app/routers/deep_learning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from app.schemas.deep_learning import ModelNameRequest, ModelSaveRequest, ModelValidateRequest, TrainingConfigRequest
1111
from app.services.deep_learning import (
1212
delete_model_service,
13+
export_model_service,
1314
get_available_model_list,
1415
get_code_service,
1516
get_model_graph_service,
@@ -125,3 +126,16 @@ def interpret_model(
125126
logger.debug("Interpreting model %s", model_name)
126127
body, status_code = interpret_model_service(db, model_name=model_name, file_id=file_id, project_id=project_id)
127128
return JSONResponse(status_code=status_code, content=body)
129+
130+
131+
@router.get("/model/export/{model_name}")
132+
def export_model(
133+
model_name: str,
134+
format: str = Query("savedmodel", pattern="^(savedmodel|tflite|onnx)$"),
135+
project_id: uuid_pkg.UUID | None = Query(None),
136+
db: Session = Depends(get_db),
137+
):
138+
"""Export a trained model in the specified format."""
139+
logger.debug("Exporting model %s as %s", model_name, format)
140+
body, status_code = export_model_service(db, model_name=model_name, export_format=format, project_id=project_id)
141+
return JSONResponse(status_code=status_code, content=body)

tensormap-backend/app/services/data_process.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import uuid as uuid_pkg
23
from collections.abc import Callable
34
from typing import Any
@@ -370,3 +371,86 @@ def preprocess_data(db: Session, file_id: uuid_pkg.UUID, transformations: list)
370371
except Exception as e:
371372
logger.exception("Error preprocessing data: %s", str(e))
372373
return _resp(500, False, f"Error preprocessing data: {e}")
374+
375+
376+
def augment_image_service(
377+
db: Session,
378+
file_id: uuid_pkg.UUID,
379+
technique: str = "flip_horizontal",
380+
) -> tuple:
381+
"""Apply image augmentation techniques to generate synthetic variants.
382+
383+
Supported techniques:
384+
- flip_horizontal: Mirror image along vertical axis
385+
- flip_vertical: Mirror image along horizontal axis
386+
- rotate_90: Rotate image by 90 degrees
387+
- brightness: Adjust brightness by 20%
388+
- zoom: Zoom to 90% then resize
389+
- gaussian_noise: Add Gaussian noise
390+
- random_crop: Crop to 85% then resize
391+
"""
392+
file_record = db.get(DataFile, file_id)
393+
if not file_record:
394+
return _resp(404, False, "Image file not found")
395+
396+
file_path = file_record.file_path
397+
if not file_path or not os.path.exists(file_path):
398+
return _resp(404, False, "Image file not found on disk")
399+
400+
settings = get_settings()
401+
output_dir = os.path.join(settings.UPLOAD_DIRECTORY, f"augmented_{file_id}")
402+
os.makedirs(output_dir, exist_ok=True)
403+
404+
supported_formats = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
405+
_, ext = os.path.splitext(file_path)
406+
if ext.lower() not in supported_formats:
407+
return _resp(400, False, f"Unsupported image format: {ext}")
408+
409+
try:
410+
from PIL import Image, ImageEnhance
411+
412+
original = Image.open(file_path)
413+
414+
if technique == "flip_horizontal":
415+
augmented = original.transpose(Image.FLIP_LEFT_RIGHT)
416+
elif technique == "flip_vertical":
417+
augmented = original.transpose(Image.FLIP_TOP_BOTTOM)
418+
elif technique == "rotate_90":
419+
augmented = original.rotate(90, expand=True)
420+
elif technique == "brightness":
421+
enhancer = ImageEnhance.Brightness(original)
422+
augmented = enhancer.enhance(1.2)
423+
elif technique == "zoom":
424+
width, height = original.size
425+
new_width = int(width * 0.9)
426+
new_height = int(height * 0.9)
427+
left = (width - new_width) // 2
428+
top = (height - new_height) // 2
429+
cropped = original.crop((left, top, left + new_width, top + new_height))
430+
augmented = cropped.resize((width, height), Image.LANCZOS)
431+
elif technique == "gaussian_noise":
432+
np_img = np.array(original.convert("RGB")).astype(np.float32) / 255.0
433+
noise = np.random.normal(0, 0.05, np_img.shape)
434+
np_img = np.clip(np_img + noise, 0, 1)
435+
augmented = Image.fromarray((np_img * 255).astype(np.uint8))
436+
elif technique == "random_crop":
437+
width, height = original.size
438+
crop_size = int(min(width, height) * 0.85)
439+
left = np.random.randint(0, width - crop_size + 1)
440+
top = np.random.randint(0, height - crop_size + 1)
441+
cropped = original.crop((left, top, left + crop_size, top + crop_size))
442+
augmented = cropped.resize((width, height), Image.LANCZOS)
443+
else:
444+
return _resp(400, False, f"Unknown technique: {technique}")
445+
446+
output_path = os.path.join(output_dir, f"augmented_0{ext}")
447+
augmented.save(output_path)
448+
449+
logger.info("Applied %s augmentation to file %s", technique, file_id)
450+
return _resp(200, True, f"Generated augmented image using {technique}", {"output_path": output_path})
451+
452+
except ImportError:
453+
return _resp(500, False, "Pillow not installed")
454+
except Exception as e:
455+
logger.exception("Error augmenting image: %s", str(e))
456+
return _resp(500, False, f"Error augmenting image: {e}")

tensormap-backend/app/services/deep_learning.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,63 @@ def interpret_model_service(
601601
except Exception as e:
602602
logger.exception("Interpretability failed: %s", str(e))
603603
return {"success": False, "message": f"Interpretability failed: {e}", "data": None}, 500
604+
605+
606+
def export_model_service(
607+
db: Session, model_name: str, export_format: str = "savedmodel", project_id: uuid_pkg.UUID | None = None
608+
) -> tuple:
609+
"""Export a trained model to various formats.
610+
611+
Supports: savedmodel (TensorFlow), tflite (TensorFlow Lite), onnx (ONNX).
612+
"""
613+
from app.models import ModelBasic
614+
615+
stmt = select(ModelBasic).where(ModelBasic.model_name == model_name)
616+
if project_id is not None:
617+
stmt = stmt.where(ModelBasic.project_id == project_id)
618+
model = db.exec(stmt).first()
619+
620+
if not model:
621+
return {"success": False, "message": f"Model '{model_name}' not found", "data": None}, 404
622+
623+
model_path = os.path.join(MODEL_GENERATION_LOCATION, model_name + MODEL_GENERATION_TYPE)
624+
if not os.path.exists(model_path):
625+
return {"success": False, "message": "Model file not found", "data": None}, 404
626+
627+
try:
628+
loaded_model = tf.keras.models.load_model(model_path)
629+
except Exception as e:
630+
logger.error("Failed to load model: %s", e)
631+
return {"success": False, "message": f"Could not load model: {e}", "data": None}, 400
632+
633+
export_dir = os.path.join(MODEL_GENERATION_LOCATION, f"{model_name}_export", export_format)
634+
os.makedirs(export_dir, exist_ok=True)
635+
636+
try:
637+
if export_format == "savedmodel":
638+
saved_path = os.path.join(export_dir, model_name)
639+
loaded_model.save(saved_path)
640+
return {"success": True, "message": f"Exported to {saved_path}", "data": {"path": saved_path}}, 200
641+
642+
if export_format == "tflite":
643+
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)
644+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
645+
tflite_model = converter.convert()
646+
tflite_path = os.path.join(export_dir, f"{model_name}.tflite")
647+
with open(tflite_path, "wb") as f:
648+
f.write(tflite_model)
649+
return {"success": True, "message": f"Exported to {tflite_path}", "data": {"path": tflite_path}}, 200
650+
651+
if export_format == "onnx":
652+
try:
653+
import tf2onnx
654+
except ImportError:
655+
return {"success": False, "message": "ONNX not available", "data": None}, 501
656+
onnx_path = os.path.join(export_dir, f"{model_name}.onnx")
657+
tf2onnx.convert.from_keras(loaded_model, output_path=onnx_path)
658+
return {"success": True, "message": f"Exported to {onnx_path}", "data": {"path": onnx_path}}, 200
659+
660+
return {"success": False, "message": f"Unsupported: {export_format}"}, 400
661+
except Exception as e:
662+
logger.error("Export failed: %s", e)
663+
return {"success": False, "message": f"Export failed: {e}"}, 500
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Unit tests for model export service."""
2+
3+
import sys
4+
from unittest.mock import MagicMock
5+
6+
sys.modules.setdefault("tensorflow", MagicMock())
7+
sys.modules.setdefault("flatten_json", MagicMock())
8+
9+
10+
class TestExportModelService:
11+
def test_import(self):
12+
from app.services.deep_learning import export_model_service
13+
14+
assert callable(export_model_service)
15+
16+
def test_router_import(self):
17+
from app.routers.deep_learning import export_model
18+
19+
assert callable(export_model)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Unit tests for image augmentation service."""
2+
3+
import sys
4+
from unittest.mock import MagicMock
5+
6+
sys.modules.setdefault("tensorflow", MagicMock())
7+
sys.modules.setdefault("flatten_json", MagicMock())
8+
sys.modules.setdefault("pandas", MagicMock())
9+
10+
11+
class TestImageAugmentation:
12+
def test_import(self):
13+
from app.services.data_process import augment_image_service
14+
15+
assert callable(augment_image_service)
16+
17+
def test_router_import(self):
18+
from app.routers.data_process import augment_image
19+
20+
assert callable(augment_image)

0 commit comments

Comments
 (0)