Skip to content

Commit 30da92f

Browse files
committed
Merge branch 'main' into pr312-export-clean
2 parents 28dd4dd + a2220ac commit 30da92f

3 files changed

Lines changed: 130 additions & 0 deletions

File tree

tensormap-backend/app/routers/data_process.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from app.services.data_process import (
1313
add_target_service,
14+
augment_image_service,
1415
delete_one_target_by_id_service,
1516
get_all_targets_service,
1617
get_column_stats_service,
@@ -95,3 +96,28 @@ def preprocess(file_id: uuid_pkg.UUID, request: PreprocessRequest, db: Session =
9596
logger.debug("Preprocessing file_id=%s with %d transformations", file_id, len(request.transformations))
9697
body, status_code = preprocess_data(db, file_id=file_id, transformations=request.transformations)
9798
return JSONResponse(status_code=status_code, content=body)
99+
100+
101+
@router.post("/data/augment/image/{file_id}")
102+
def augment_image(
103+
file_id: uuid_pkg.UUID,
104+
technique: str = Query(
105+
"flip_horizontal",
106+
pattern="^(flip_horizontal|flip_vertical|rotate_90|brightness|zoom|gaussian_noise|random_crop)$",
107+
),
108+
db: Session = Depends(get_db),
109+
):
110+
"""Apply image augmentation techniques to generate synthetic variants.
111+
112+
Supported techniques:
113+
- flip_horizontal: Mirror along vertical axis
114+
- flip_vertical: Mirror along horizontal axis
115+
- rotate_90: Rotate by 90 degrees
116+
- brightness: Increase brightness by 20%
117+
- zoom: Zoom to 90% then resize
118+
- gaussian_noise: Add Gaussian noise
119+
- random_crop: Crop to 85% then resize
120+
"""
121+
logger.debug("Applying %s augmentation to file_id=%s", technique, file_id)
122+
body, status_code = augment_image_service(db, file_id=file_id, technique=technique)
123+
return JSONResponse(status_code=status_code, content=body)

tensormap-backend/app/services/data_process.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import uuid as uuid_pkg
23
from collections.abc import Callable
34
from typing import Any
@@ -370,3 +371,86 @@ def preprocess_data(db: Session, file_id: uuid_pkg.UUID, transformations: list)
370371
except Exception as e:
371372
logger.exception("Error preprocessing data: %s", str(e))
372373
return _resp(500, False, f"Error preprocessing data: {e}")
374+
375+
376+
def augment_image_service(
377+
db: Session,
378+
file_id: uuid_pkg.UUID,
379+
technique: str = "flip_horizontal",
380+
) -> tuple:
381+
"""Apply image augmentation techniques to generate synthetic variants.
382+
383+
Supported techniques:
384+
- flip_horizontal: Mirror image along vertical axis
385+
- flip_vertical: Mirror image along horizontal axis
386+
- rotate_90: Rotate image by 90 degrees
387+
- brightness: Adjust brightness by 20%
388+
- zoom: Zoom to 90% then resize
389+
- gaussian_noise: Add Gaussian noise
390+
- random_crop: Crop to 85% then resize
391+
"""
392+
file_record = db.get(DataFile, file_id)
393+
if not file_record:
394+
return _resp(404, False, "Image file not found")
395+
396+
file_path = file_record.file_path
397+
if not file_path or not os.path.exists(file_path):
398+
return _resp(404, False, "Image file not found on disk")
399+
400+
settings = get_settings()
401+
output_dir = os.path.join(settings.UPLOAD_DIRECTORY, f"augmented_{file_id}")
402+
os.makedirs(output_dir, exist_ok=True)
403+
404+
supported_formats = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
405+
_, ext = os.path.splitext(file_path)
406+
if ext.lower() not in supported_formats:
407+
return _resp(400, False, f"Unsupported image format: {ext}")
408+
409+
try:
410+
from PIL import Image, ImageEnhance
411+
412+
original = Image.open(file_path)
413+
414+
if technique == "flip_horizontal":
415+
augmented = original.transpose(Image.FLIP_LEFT_RIGHT)
416+
elif technique == "flip_vertical":
417+
augmented = original.transpose(Image.FLIP_TOP_BOTTOM)
418+
elif technique == "rotate_90":
419+
augmented = original.rotate(90, expand=True)
420+
elif technique == "brightness":
421+
enhancer = ImageEnhance.Brightness(original)
422+
augmented = enhancer.enhance(1.2)
423+
elif technique == "zoom":
424+
width, height = original.size
425+
new_width = int(width * 0.9)
426+
new_height = int(height * 0.9)
427+
left = (width - new_width) // 2
428+
top = (height - new_height) // 2
429+
cropped = original.crop((left, top, left + new_width, top + new_height))
430+
augmented = cropped.resize((width, height), Image.LANCZOS)
431+
elif technique == "gaussian_noise":
432+
np_img = np.array(original.convert("RGB")).astype(np.float32) / 255.0
433+
noise = np.random.normal(0, 0.05, np_img.shape)
434+
np_img = np.clip(np_img + noise, 0, 1)
435+
augmented = Image.fromarray((np_img * 255).astype(np.uint8))
436+
elif technique == "random_crop":
437+
width, height = original.size
438+
crop_size = int(min(width, height) * 0.85)
439+
left = np.random.randint(0, width - crop_size + 1)
440+
top = np.random.randint(0, height - crop_size + 1)
441+
cropped = original.crop((left, top, left + crop_size, top + crop_size))
442+
augmented = cropped.resize((width, height), Image.LANCZOS)
443+
else:
444+
return _resp(400, False, f"Unknown technique: {technique}")
445+
446+
output_path = os.path.join(output_dir, f"augmented_0{ext}")
447+
augmented.save(output_path)
448+
449+
logger.info("Applied %s augmentation to file %s", technique, file_id)
450+
return _resp(200, True, f"Generated augmented image using {technique}", {"output_path": output_path})
451+
452+
except ImportError:
453+
return _resp(500, False, "Pillow not installed")
454+
except Exception as e:
455+
logger.exception("Error augmenting image: %s", str(e))
456+
return _resp(500, False, f"Error augmenting image: {e}")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Unit tests for image augmentation service."""
2+
3+
import sys
4+
from unittest.mock import MagicMock
5+
6+
sys.modules.setdefault("tensorflow", MagicMock())
7+
sys.modules.setdefault("flatten_json", MagicMock())
8+
sys.modules.setdefault("pandas", MagicMock())
9+
10+
11+
class TestImageAugmentation:
12+
def test_import(self):
13+
from app.services.data_process import augment_image_service
14+
15+
assert callable(augment_image_service)
16+
17+
def test_router_import(self):
18+
from app.routers.data_process import augment_image
19+
20+
assert callable(augment_image)

0 commit comments

Comments
 (0)