|
2 | 2 | import sys |
3 | 3 | import base64 |
4 | 4 | import urllib.request |
| 5 | +import boto3 |
| 6 | +from botocore.config import Config |
5 | 7 |
|
6 | 8 | import cv2 |
7 | 9 | import numpy as np |
8 | 10 | import torch |
9 | 11 | import runpod |
10 | 12 |
|
| 13 | +# Initialize S3 client for R2 (Lazy) |
| 14 | +s3_client = None |
| 15 | + |
| 16 | +def get_s3_client(): |
| 17 | + global s3_client |
| 18 | + if s3_client is None: |
| 19 | + access_key = os.environ.get("R2_ACCESS_KEY_ID") |
| 20 | + secret_key = os.environ.get("R2_SECRET_ACCESS_KEY") |
| 21 | + endpoint = os.environ.get("R2_ENDPOINT") |
| 22 | + if all([access_key, secret_key, endpoint]): |
| 23 | + s3_client = boto3.client( |
| 24 | + "s3", |
| 25 | + endpoint_url=endpoint, |
| 26 | + aws_access_key_id=access_key, |
| 27 | + aws_secret_access_key=secret_key, |
| 28 | + config=Config(signature_version="s3v4"), |
| 29 | + region_name="auto", |
| 30 | + ) |
| 31 | + return s3_client |
| 32 | + |
| 33 | +def upload_to_r2(buffer, key, content_type="image/jpeg"): |
| 34 | + client = get_s3_client() |
| 35 | + if client is None: |
| 36 | + return None |
| 37 | + |
| 38 | + bucket = os.environ.get("R2_BUCKET_NAME", "stockgen-ai") |
| 39 | + try: |
| 40 | + client.put_object( |
| 41 | + Bucket=bucket, |
| 42 | + Key=key, |
| 43 | + Body=buffer, |
| 44 | + ContentType=content_type, |
| 45 | + CacheControl="public, max-age=31536000" |
| 46 | + ) |
| 47 | + custom_domain = os.environ.get("R2_CUSTOM_DOMAIN") |
| 48 | + if custom_domain: |
| 49 | + return f"{custom_domain.rstrip('/')}/{key}" |
| 50 | + return f"{os.environ.get('R2_ENDPOINT').rstrip('/')}/{bucket}/{key}" |
| 51 | + except Exception as e: |
| 52 | + print(f"❌ R2 Upload Failed: {str(e)}") |
| 53 | + return None |
| 54 | + |
11 | 55 | # Patch basicsr compatibility with newer PyTorch/torchvision |
12 | 56 | import torchvision.transforms.functional as F_tv |
13 | 57 | import types |
@@ -240,13 +284,24 @@ def handler(job): |
240 | 284 | if not success: |
241 | 285 | return {"error": "Failed to encode output image"} |
242 | 286 |
|
| 287 | + # Direct-to-R2 Upload Logic |
| 288 | + r2_url = None |
| 289 | + r2_key = job_input.get("r2_key") # Target path like "users/uid/batches/bid/jid.jpg" |
| 290 | + |
| 291 | + if r2_key: |
| 292 | + print(f"☁️ [Direct-to-R2] Uploading result to: {r2_key}") |
| 293 | + r2_url = upload_to_r2(encoded_img.tobytes(), r2_key) |
| 294 | + if r2_url: |
| 295 | + print(f"✅ [Direct-to-R2] Success: {r2_url}") |
| 296 | + |
243 | 297 | b64 = base64.b64encode(encoded_img).decode("utf-8") |
244 | 298 |
|
245 | 299 | h, w = img.shape[:2] |
246 | 300 | oh, ow = output.shape[:2] |
247 | 301 |
|
248 | 302 | return { |
249 | | - "image": b64, |
| 303 | + "image": b64 if not r2_url else None, # Skip Base64 if R2 success to save bandwidth |
| 304 | + "r2_url": r2_url, |
250 | 305 | "image_format": "jpg", |
251 | 306 | "model": model_name, |
252 | 307 | "face_enhance_applied": face_enhance, |
|
0 commit comments