diff --git a/README.md b/README.md index 6c050b5..c3d435f 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ A GenAI-powered catalog enrichment system that transforms basic product images i **AI Models:** - NVIDIA Nemotron VLM (vision-language model) - NVIDIA Nemotron LLM (prompt planning) +- NVIDIA Embeddings (Policy Compliance) - FLUX models (image generation) - Microsoft TRELLIS (3D generation) @@ -56,6 +57,7 @@ A GenAI-powered catalog enrichment system that transforms basic product images i - Docker & Docker Compose - NVIDIA NIM containers - HuggingFace model hosting +- Milvus vector database for policy PDF retrieval ## Minimum System Requirements @@ -67,10 +69,11 @@ For self-hosting the NIM microservices locally, the following GPU requirements a |-------|---------|-------------|-----------------| | Nemotron-Nano-12B-V2-VL | Vision-Language Analysis | 1× A100 | 1× H100 | | Nemotron-Nano-V3 | Prompt Planning (LLM) | 1× A100 | 1× H100 | +| nv-embedqa | Embeddings (Policy Compliance) | 1× A100 | 1× H100 | | FLUX Kontext Dev | Image Generation | 1× H100 | 1× H100 | -| Microsoft TRELLIS | 3D Asset Generation | 1× L40S | 1× L40S | +| Microsoft TRELLIS | 3D Asset Generation | 1× L40S | 1× H100 | -**Total recommended setup**: 3× H100 + 1× L40S (or 4× H100 for uniform configuration) +**Total recommended setup**: 3× H100 + 1× L40S (or 4× H100 for uniform configuration). Embeddings model can be deploy on the same GPU as Flux or Trellis models. ### Deployment Options @@ -146,6 +149,10 @@ Make sure you have accepted [https://huggingface.co/black-forest-labs/FLUX.1-Kon trellis: url: "http://localhost:8004/v1/infer" # Your TRELLIS NIM endpoint + + embeddings: + url: "http://localhost:8005/v1" #Your Embeddings NIM endpoint + model: "nvidia/nv-embedqa-e5-v5" ``` See the **[Docker Deployment Guide](docs/DOCKER.md)** for instructions on deploying these NIMs. @@ -166,7 +173,7 @@ The frontend at `http://localhost:3000`. ### Docker Deployment (Self-Hosted NIMs) -The Docker deployment includes all required self-hosted NVIDIA NIM containers (Nemotron VLM, Nemotron LLM, FLUX, and TRELLIS). The `shared/config/config.yaml` is pre-configured with the correct service URLs for Docker networking. +The Docker deployment includes all required self-hosted NVIDIA NIM containers (Nemotron VLM, Nemotron LLM, FLUX, and TRELLIS). If you want to use uploaded policy PDFs in the UI, start the companion Milvus stack from `docker-compose.rag.yml` as well. The `shared/config/config.yaml` is pre-configured with the correct service URLs for Docker networking. For complete Docker deployment instructions, see the **[Docker Deployment Guide](docs/DOCKER.md)**. @@ -185,15 +192,27 @@ For complete Docker deployment instructions, see the **[Docker Deployment Guide] chmod a+w "$LOCAL_NIM_CACHE" ``` -3. **Start all services**: +3. **Create the shared Docker network**: + ```bash + docker network create catalog-network || true + ``` + +4. **Start the policy RAG stack**: + ```bash + docker compose -f docker-compose.rag.yml up -d + ``` + +5. **Start the application stack**: ```bash - docker-compose up -d + docker compose up -d ``` -4. **Access the application**: +6. **Access the application**: - Frontend: `http://localhost:3000` - Backend API: `http://localhost:8000` - Health Check: `http://localhost:8000/health` + - Milvus: `localhost:19530` + - MinIO Console: `http://localhost:9001` ## API Endpoints @@ -211,7 +230,7 @@ For detailed API documentation with request/response examples, see **[API Docume ## License -GOVERNING TERMS: The Blueprint scripts are governed by Apache License, Version 2.0, and enables use of separate open source and proprietary software governed by their respective licenses: [NVIDIA-Nemotron-Nano-12B-v2-VL](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-nano-12b-v2-vl?version=1), [Nemotron-Nano-V3](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-3-nano?version=1.7.0), [FLUX.1-Kontext-Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md), and [Microsoft TRELLIS](https://catalog.ngc.nvidia.com/orgs/nim/teams/microsoft/containers/trellis?version=1). +GOVERNING TERMS: The Blueprint scripts are governed by Apache License, Version 2.0, and enables use of separate open source and proprietary software governed by their respective licenses: [NVIDIA-Nemotron-Nano-12B-v2-VL](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-nano-12b-v2-vl?version=1), [Nemotron-Nano-V3](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-3-nano?version=1.7.0), [nv-embedqa-e5-v5](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nv-embedqa-e5-v5?version=latest), [FLUX.1-Kontext-Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md), and [Microsoft TRELLIS](https://catalog.ngc.nvidia.com/orgs/nim/teams/microsoft/containers/trellis?version=1). ADDITIONAL INFORMATION: FLUX.1-Kontext-Dev license: [https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md). @@ -219,4 +238,4 @@ FLUX.1-Kontext-Dev license: [https://huggingface.co/black-forest-labs/FLUX.1-Kon Third-Party Community Consideration: The FLUX Kontext model is not owned or developed by NVIDIA. This model has been developed and built to a third-party’s requirements for this application and use case; see link to: black-forest-labs/FLUX.1-Kontext-dev Model Card - [https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev). -This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use. \ No newline at end of file +This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use. diff --git a/deploy/1_Deploy_Catalog_Enrichment.ipynb b/deploy/1_Deploy_Catalog_Enrichment.ipynb index 59c35fe..d71f5c3 100644 --- a/deploy/1_Deploy_Catalog_Enrichment.ipynb +++ b/deploy/1_Deploy_Catalog_Enrichment.ipynb @@ -357,7 +357,7 @@ "source": [ "\n", "## Spin Up Blueprint\n", - "Docker compose scripts are provided which spin up the microservices on a single node. This docker-compose yaml file will start the agents as well as dependant microservices. This may take up to **15 minutes** to complete.\n" + "Docker compose scripts are provided which spin up the microservices on a single node. Start by creating the shared Docker network, then launch the Milvus policy RAG stack from `docker-compose.rag.yml`, and finally bring up the main application stack. This may take up to **15 minutes** to complete.\n" ] }, { @@ -369,6 +369,8 @@ }, "outputs": [], "source": [ + "!docker network create catalog-network || true\n", + "!docker compose -f docker-compose.rag.yml up -d > /dev/null 2>&1\n", "!docker compose up -d > /dev/null 2>&1" ] }, @@ -413,7 +415,7 @@ "id": "7d90c358-f0e9-4607-8b88-32a44ffce74e", "metadata": {}, "source": [ - "This command should produce similiar output in the following format:" + "These commands should produce similar output in the following format:" ] }, { @@ -430,6 +432,10 @@ "nim-llm 2025-12-16 18:30:24 +0000 UTC Up 1 minutes\n", "nim-trellis 2025-12-16 18:30:24 +0000 UTC Up 1 minutes\n", "nim-flux 2025-12-16 18:30:24 +0000 UTC Up 1 minutes\n", + "embedqa 2025-12-16 18:30:24 +0000 UTC Up 1 minutes\n", + "milvus-etcd 2025-12-16 18:30:24 +0000 UTC Up 1 minutes (healthy)\n", + "milvus-minio 2025-12-16 18:30:24 +0000 UTC Up 1 minutes (healthy)\n", + "milvus-standalone 2025-12-16 18:30:24 +0000 UTC Up 1 minutes (healthy)\n", "```" ] }, @@ -529,7 +535,7 @@ "\n", "## Stopping Services and Cleaning Up\n", "\n", - "To shut down the microservices, run the following command" + "To shut down the microservices, run the following commands" ] }, { @@ -539,7 +545,8 @@ "metadata": {}, "outputs": [], "source": [ - "!docker compose down > /dev/null 2>&1" + "!docker compose down > /dev/null 2>&1\n", + "!docker compose -f docker-compose.rag.yml down > /dev/null 2>&1" ] }, { @@ -577,7 +584,8 @@ "\n", "**Explanation:** When running the blueprint for the first time, all models need to be downloaded from their respective sources. Depending on your internet connection speed, this process can take 20-30 minutes or longer. The models include:\n", "- NVIDIA Nemotron VLM\n", - "- NVIDIA Nemotron LLM \n", + "- NVIDIA Nemotron LLM \n", + "- NVIDIA Embeddings \n", "- FLUX image generation model\n", "- TRELLIS 3D asset generation model\n", "\n", @@ -596,7 +604,7 @@ "source": [ "## LICENSE\n", "\n", - "GOVERNING TERMS: The Blueprint scripts are governed by Apache License, Version 2.0, and enables use of separate open source and proprietary software governed by their respective licenses: [NVIDIA-Nemotron-Nano-12B-v2-VL](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-nano-12b-v2-vl?version=1), [Nemotron-Nano-V3](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-3-nano?version=1.7.0), [FLUX.1-Kontext-Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md), and [Microsoft TRELLIS](https://catalog.ngc.nvidia.com/orgs/nim/teams/microsoft/containers/trellis?version=1).\n", + "GOVERNING TERMS: The Blueprint scripts are governed by Apache License, Version 2.0, and enables use of separate open source and proprietary software governed by their respective licenses: [NVIDIA-Nemotron-Nano-12B-v2-VL](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-nano-12b-v2-vl?version=1), [Nemotron-Nano-V3](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nemotron-3-nano?version=1.7.0), [nv-embedqa-e5-v5](https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/nv-embedqa-e5-v5?version=latest) [FLUX.1-Kontext-Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md), and [Microsoft TRELLIS](https://catalog.ngc.nvidia.com/orgs/nim/teams/microsoft/containers/trellis?version=1).\n", "\n", "ADDITIONAL INFORMATION: \n", "FLUX.1-Kontext-Dev license: [https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/LICENSE.md).\n", diff --git a/docker-compose.rag.yml b/docker-compose.rag.yml new file mode 100644 index 0000000..21888e0 --- /dev/null +++ b/docker-compose.rag.yml @@ -0,0 +1,73 @@ +services: + milvus-etcd: + image: quay.io/coreos/etcd:v3.5.5 + container_name: milvus-etcd + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - etcd_data:/etcd + command: etcd -listen-client-urls=http://0.0.0.0:2379 -advertise-client-urls=http://milvus-etcd:2379 --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + networks: + - catalog-network + + milvus-minio: + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + container_name: milvus-minio + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + volumes: + - minio_data:/minio_data + command: minio server /minio_data --console-address ":9001" + ports: + - "9001:9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + networks: + - catalog-network + + milvus-standalone: + image: milvusdb/milvus:v2.4.0 + container_name: milvus-standalone + command: ["milvus", "run", "standalone"] + ports: + - "19530:19530" + - "9091:9091" + environment: + ETCD_ENDPOINTS: milvus-etcd:2379 + MINIO_ADDRESS: milvus-minio:9000 + volumes: + - milvus_data:/var/lib/milvus + depends_on: + milvus-etcd: + condition: service_healthy + milvus-minio: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + timeout: 20s + retries: 3 + networks: + - catalog-network + +networks: + catalog-network: + external: true + name: catalog-network + +volumes: + etcd_data: + minio_data: + milvus_data: diff --git a/docker-compose.yml b/docker-compose.yml index b605183..56e04c6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,9 +22,14 @@ services: container_name: catalog-enrichment-backend environment: - NGC_API_KEY=${NGC_API_KEY} + - NVIDIA_API_KEY=${NVIDIA_API_KEY:-} + - NVIDIA_API_BASE_URL=${NVIDIA_API_BASE_URL:-https://integrate.api.nvidia.com/v1} - HF_TOKEN=${HF_TOKEN} + - MILVUS_HOST=${MILVUS_HOST:-milvus-standalone} + - MILVUS_PORT=${MILVUS_PORT:-19530} volumes: - ./data/outputs:/app/data/outputs + - ./data/policies:/app/data/policies - ./shared/config:/app/shared/config:ro depends_on: - vlm-nim @@ -105,6 +110,33 @@ services: networks: - catalog-network + # NVIDIA NIM - Embedding Model + embedqa: + image: nvcr.io/nim/nvidia/nv-embedqa-e5-v5:1.6 + container_name: embedqa + ports: + - "8005:8000" + environment: + - NGC_API_KEY=${NGC_API_KEY} + volumes: + - ${LOCAL_NIM_CACHE:-~/.cache/nim}:/opt/nim/.cache + user: "${UID:-1000}" + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['2'] + capabilities: [gpu] + restart: "no" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/v1/health/ready"] + interval: 30s + timeout: 10s + retries: 5 + networks: + - catalog-network + # Trellis Model - 3D Asset Generation trellis-nim: image: nvcr.io/nim/microsoft/trellis:1.0.1 @@ -157,9 +189,9 @@ services: networks: catalog-network: - driver: bridge + external: true + name: catalog-network volumes: nim-cache: driver: local - diff --git a/docs/API.md b/docs/API.md index c74acbe..8f3cb2a 100644 --- a/docs/API.md +++ b/docs/API.md @@ -14,7 +14,7 @@ Returns a plaintext greeting message. **Response**: ``` -Welcome to Catalog Enrichment API +Catalog Enrichment Backend ``` ### GET `/health` @@ -48,9 +48,110 @@ The API provides a modular approach for optimal performance and flexibility: --- -## 1️⃣ Fast VLM Analysis: `/vlm/analyze` +## 1️⃣ Policy Library: `/policies` -Extract product fields using NVIDIA Nemotron VLM (no image generation). +Manage the persistent PDF policy library used during analysis. + +Policy documents are handled as a persistent single-user RAG library: +- uploaded PDFs are parsed and normalized into structured policy summaries +- normalized policy records are embedded and stored in Milvus +- `/vlm/analyze` automatically performs semantic retrieval against the loaded policy library +- the compliance classifier receives the analyzed product plus the retrieved policy records + +### GET `/policies` + +Returns metadata for the currently loaded policy library. + +### Response Schema + +```json +{ + "documents": [ + { + "document_hash": "string", + "filename": "string", + "file_size": 12345, + "chunk_count": 10, + "created_at": 1735689600, + "updated_at": 1735689600 + } + ] +} +``` + +`chunk_count` is the number of indexed policy records generated from the normalized PDF, not the raw page count. + +### POST `/policies` + +**Content-Type**: `multipart/form-data` + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `files` | file[] | Yes | One or more PDF files to add to the persistent policy library | +| `locale` | string | No | Locale used when normalizing newly uploaded policies (default: `en-US`) | + +### POST Example + +```bash +curl -X POST \ + -F "locale=en-US" \ + -F "files=@policy-a.pdf;type=application/pdf" \ + -F "files=@policy-b.pdf;type=application/pdf" \ + http://localhost:8000/policies +``` + +### POST Response Schema + +```json +{ + "documents": [ + { + "document_hash": "string", + "filename": "string", + "file_size": 12345, + "chunk_count": 10, + "created_at": 1735689600, + "updated_at": 1735689600 + } + ], + "results": [ + { + "document_hash": "string", + "filename": "string", + "chunk_count": 10, + "already_loaded": false, + "processed": true + } + ] +} +``` + +Notes: +- repeated uploads of the same PDF are deduplicated by content hash +- `already_loaded=true` means the document was already present in the library +- `processed=true` means the upload was newly parsed, normalized, embedded, and indexed + +### DELETE `/policies` + +Clears the persistent policy library, including stored PDF artifacts and vector embeddings. + +```bash +curl -X DELETE http://localhost:8000/policies +``` + +### DELETE Response + +```json +{ + "status": "ok" +} +``` + +--- + +## 2️⃣ Fast VLM Analysis: `/vlm/analyze` + +Extract product fields using NVIDIA Nemotron VLM and, when policies are loaded, run policy retrieval plus compliance classification. **Endpoint**: `POST /vlm/analyze` **Content-Type**: `multipart/form-data` @@ -64,6 +165,10 @@ Extract product fields using NVIDIA Nemotron VLM (no image generation). | `product_data` | JSON string | No | Existing product data to augment | | `brand_instructions` | string | No | Custom brand voice, tone, style, and taxonomy guidelines | +When one or more policy PDFs have been loaded through `/policies`, this endpoint also: +- retrieves semantically relevant normalized policy records from Milvus using the VLM title/description/categories/tags/colors +- runs a compliance classifier against the analyzed product and the retrieved policy records + ### Product Data Schema (Optional) ```json @@ -85,10 +190,28 @@ Extract product fields using NVIDIA Nemotron VLM (no image generation). "categories": ["string"], "tags": ["string"], "colors": ["string"], - "locale": "string" + "locale": "string", + "policy_decision": { + "status": "pass | fail", + "label": "string", + "summary": "string", + "matched_policies": [ + { + "document_name": "string", + "policy_title": "string", + "rule_title": "string", + "reason": "string", + "evidence": ["string"] + } + ], + "warnings": ["string"], + "evidence_note": "string" + } } ``` +`policy_decision` is included only when the policy library contains at least one loaded document. + ### Usage Examples #### Image Only (Generation Mode) @@ -136,13 +259,21 @@ curl -X POST \ "categories": ["accessories"], "tags": ["black leather", "gold accents", "evening bag", "rectangular shape"], "colors": ["black", "gold"], - "locale": "en-US" + "locale": "en-US", + "policy_decision": { + "status": "pass", + "label": "Policy Check Passed", + "summary": "No loaded policy appears applicable to this product.", + "matched_policies": [], + "warnings": [], + "evidence_note": "Policy retrieval did not return any candidate matches for this product." + } } ``` --- -## 2️⃣ Image Generation: `/generate/variation` +## 3️⃣ Image Generation: `/generate/variation` Generate culturally-appropriate product variations using FLUX models based on VLM analysis results. @@ -203,7 +334,7 @@ curl -X POST \ --- -## 3️⃣ 3D Asset Generation: `/generate/3d` +## 4️⃣ 3D Asset Generation: `/generate/3d` Generate interactive 3D GLB models from 2D product images using Microsoft's TRELLIS model. diff --git a/docs/DOCKER.md b/docs/DOCKER.md index 63a04c5..56470dd 100644 --- a/docs/DOCKER.md +++ b/docs/DOCKER.md @@ -12,6 +12,8 @@ The application consists of the following services: - **LLM NIM** (Port 8002): Large Language Model for text generation - **Flux NIM** (Port 8003): Image generation model for product variations - **Trellis NIM** (Port 8004): 3D asset generation model +- **Embeddings NIM** (Post 8005): Embeddings for policy compliance +- **Milvus Stack** (Ports 19530, 9091, 9001): Persistent vector search for loaded policy PDFs ## Prerequisites @@ -43,12 +45,19 @@ mkdir -p "$LOCAL_NIM_CACHE" chmod a+w "$LOCAL_NIM_CACHE" ``` +### 3. Create Shared Docker Network + +```bash +docker network create catalog-network || true +``` + ## Running the Application ### Start All Services ```bash docker-compose up -d +docker compose -f docker-compose.rag.yml up -d ``` ### Start Specific Services @@ -62,6 +71,9 @@ docker-compose up -d vlm-nim # Start all NIM models docker-compose up -d vlm-nim llm-nim flux-nim trellis-nim + +# Start the persistent policy RAG stack +docker compose -f docker-compose.rag.yml up -d ``` ### View Logs @@ -73,6 +85,7 @@ docker-compose logs -f # Specific service docker-compose logs -f backend docker-compose logs -f frontend +docker compose -f docker-compose.rag.yml logs -f milvus-standalone ``` ### Stop Services @@ -83,6 +96,7 @@ docker-compose down # Stop and remove volumes docker-compose down -v +docker compose -f docker-compose.rag.yml down -v ``` ## Building Images @@ -113,6 +127,9 @@ Once all services are running: - **Frontend UI**: http://localhost:3000 - **Backend API**: http://localhost:8000 - **Health Check**: http://localhost:8000/health +- **Milvus gRPC**: localhost:19530 +- **Milvus health**: localhost:9091 +- **MinIO Console**: http://localhost:9001 ## GPU Configuration @@ -155,6 +172,7 @@ docker-compose ps ```bash docker-compose logs backend docker-compose logs vlm-nim +docker compose -f docker-compose.rag.yml logs milvus-standalone ``` ### Restart a Service @@ -177,6 +195,7 @@ docker-compose up -d ```bash docker-compose down --rmi all +docker compose -f docker-compose.rag.yml down -v ``` ### Clean Up Cache @@ -184,4 +203,3 @@ docker-compose down --rmi all ```bash rm -rf ~/.cache/nim/* ``` - diff --git a/pyproject.toml b/pyproject.toml index 0dd2742..1f4cdcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,9 @@ dependencies = [ "requests==2.32.5", "httpx==0.28.1", "pillow==12.1.1", + "pymilvus==2.5.10", "pyyaml==6.0.3", + "pypdf==6.1.1", ] [tool.hatch.build.targets.wheel] @@ -34,5 +36,3 @@ dev-dependencies = [ "pytest-asyncio==0.25.2", "httpx==0.28.1", ] - - diff --git a/shared/config/config.yaml b/shared/config/config.yaml index 9b07f86..d6c977b 100644 --- a/shared/config/config.yaml +++ b/shared/config/config.yaml @@ -1,5 +1,6 @@ # Configuration file for catalog-enrichment endpoints + vlm: url: "http://nim-vlm:8000/v1" model: "nvidia/nemotron-nano-12b-v2-vl" @@ -13,3 +14,19 @@ flux: trellis: url: "http://nim-trellis:8000/v1/infer" + +embeddings: + url: "http://embedqa:8000/v1" + model: "nvidia/nv-embedqa-e5-v5" + +milvus: + host: "milvus-standalone" + port: 19530 + collection: "policy_chunks" + alias: "policy_library" + +policy_library: + storage_dir: "data/policies" + db_path: "data/policies/library.db" + top_k: 8 + min_relevance_score: 0.3 diff --git a/src/backend/config.py b/src/backend/config.py index f15772c..8fada26 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -14,6 +14,7 @@ # limitations under the License. import yaml +import os from pathlib import Path from typing import Dict, Any, Optional import logging @@ -49,6 +50,9 @@ def _get_section_config(self, section: str, required_fields: list) -> Dict[str, raise ValueError(f"{section.upper()} {field} not configured") result[field] = value return result + + def _get_optional_section_config(self, section: str) -> Dict[str, Any]: + return self._config_data.get(section, {}) or {} def get_vlm_config(self) -> Dict[str, str]: return self._get_section_config('vlm', ['url', 'model']) @@ -62,6 +66,33 @@ def get_flux_config(self) -> Dict[str, str]: def get_trellis_config(self) -> Dict[str, str]: return self._get_section_config('trellis', ['url']) + def get_embeddings_config(self) -> Dict[str, str]: + config = self._get_optional_section_config('embeddings') + return { + "url": os.getenv("NVIDIA_API_BASE_URL") or config.get("url") or "https://integrate.api.nvidia.com/v1", + "model": config.get("model") or "nvidia/nv-embedqa-e5-v5", + } + + def get_milvus_config(self) -> Dict[str, Any]: + config = self._get_optional_section_config('milvus') + return { + "host": os.getenv("MILVUS_HOST") or config.get("host") or "localhost", + "port": os.getenv("MILVUS_PORT") or str(config.get("port") or "19530"), + "collection": os.getenv("MILVUS_COLLECTION") or config.get("collection") or "policy_chunks", + "alias": config.get("alias") or "policy_library", + } + + def get_policy_library_config(self) -> Dict[str, Any]: + config = self._get_optional_section_config('policy_library') + return { + "storage_dir": os.getenv("POLICY_LIBRARY_STORAGE_DIR") or config.get("storage_dir") or "data/policies", + "db_path": os.getenv("POLICY_LIBRARY_DB_PATH") or config.get("db_path") or "data/policies/library.db", + "top_k": int(os.getenv("POLICY_LIBRARY_TOP_K") or config.get("top_k") or 8), + "min_relevance_score": float( + os.getenv("POLICY_LIBRARY_MIN_RELEVANCE_SCORE") or config.get("min_relevance_score") or 0.3 + ), + } + _config_instance: Optional[Config] = None @@ -70,4 +101,4 @@ def get_config() -> Config: global _config_instance if _config_instance is None: _config_instance = Config() - return _config_instance \ No newline at end of file + return _config_instance diff --git a/src/backend/main.py b/src/backend/main.py index 52a31d3..05be8c1 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -26,7 +26,9 @@ import httpx from openai import APIConnectionError -from backend.vlm import run_vlm_analysis +from backend.policy import evaluate_policy_compliance +from backend.policy_library import PolicyLibrary +from backend.vlm import extract_vlm_observation, build_enriched_vlm_result from backend.image import generate_image_variation from backend.trellis import generate_3d_asset from backend.config import get_config @@ -35,6 +37,7 @@ logger = logging.getLogger("catalog_enrichment.api") VALID_LOCALES = {"en-US", "en-GB", "en-AU", "en-CA", "es-ES", "es-MX", "es-AR", "es-CO", "fr-FR", "fr-CA"} +policy_library = PolicyLibrary() @asynccontextmanager @@ -42,6 +45,7 @@ async def lifespan(app: FastAPI): if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s - %(message)s") logging.getLogger("httpx").setLevel(logging.WARNING) + policy_library.initialize() logger.info("App startup complete") yield @@ -172,7 +176,66 @@ async def vlm_analyze( image_bytes, content_type = validation_result logger.info(f"Running VLM analysis: locale={locale} mode={'augmentation' if product_json else 'generation'}") - result = run_vlm_analysis(image_bytes, content_type, locale, product_json, brand_instructions) + vlm_observation = await asyncio.to_thread(extract_vlm_observation, image_bytes, content_type) + + enrichment_task = asyncio.to_thread( + build_enriched_vlm_result, + vlm_observation, + locale, + product_json, + brand_instructions, + ) + retrieval_task = asyncio.to_thread( + policy_library.retrieve_context, + { + "title": vlm_observation.get("title", ""), + "description": vlm_observation.get("description", ""), + "categories": vlm_observation.get("categories", []), + "tags": vlm_observation.get("tags", []), + "colors": vlm_observation.get("colors", []), + }, + ) + result, policy_contexts = await asyncio.gather(enrichment_task, retrieval_task) + if policy_contexts: + logger.info("Policy retrieval returned %d candidate policy record(s); running compliance evaluation.", len(policy_contexts)) + product_snapshot = { + "locale": locale, + "title": vlm_observation.get("title", ""), + "description": vlm_observation.get("description", ""), + "categories": vlm_observation.get("categories", []), + "tags": vlm_observation.get("tags", []), + "colors": vlm_observation.get("colors", []), + "generated_catalog_fields": { + "title": result.get("title", ""), + "description": result.get("description", ""), + "categories": result.get("categories", []), + "tags": result.get("tags", []), + "colors": result.get("colors", []), + }, + "product_data": product_json or {}, + } + result["policy_decision"] = await asyncio.to_thread( + evaluate_policy_compliance, + product_snapshot, + policy_contexts, + locale, + ) + logger.info( + "Policy evaluation complete: status=%s matches=%d warnings=%d", + result["policy_decision"].get("status"), + len(result["policy_decision"].get("matched_policies", [])), + len(result["policy_decision"].get("warnings", [])), + ) + elif policy_library.list_documents(): + logger.info("Policy retrieval returned no candidates; treating loaded policies as not relevant to this product.") + result["policy_decision"] = { + "status": "pass", + "label": "Policy Check Passed", + "summary": "No loaded policy appears applicable to this product.", + "matched_policies": [], + "warnings": [], + "evidence_note": "Policy retrieval did not return any candidate matches for this product.", + } payload = { "title": result.get("title", ""), @@ -185,6 +248,8 @@ async def vlm_analyze( if result.get("enhanced_product"): payload["enhanced_product"] = result["enhanced_product"] + if result.get("policy_decision"): + payload["policy_decision"] = result["policy_decision"] logger.info(f"/vlm/analyze success: title_len={len(payload['title'])} desc_len={len(payload['description'])} locale={locale}") return JSONResponse(payload) @@ -199,6 +264,45 @@ async def vlm_analyze( return JSONResponse({"detail": str(exc)}, status_code=500) +@app.get("/policies") +async def list_policies() -> JSONResponse: + try: + return JSONResponse({"documents": policy_library.list_documents()}) + except Exception as exc: + logger.exception("/policies list exception: %s", exc) + return JSONResponse({"detail": str(exc)}, status_code=500) + + +@app.post("/policies") +async def upload_policies( + files: list[UploadFile] = File(...), + locale: str = Form("en-US"), +) -> JSONResponse: + try: + if locale not in VALID_LOCALES: + return JSONResponse({"detail": f"Invalid locale. Supported locales: {sorted(VALID_LOCALES)}"}, status_code=400) + + uploads, error_response = await _validate_policy_uploads(files, "/policies") + if error_response: + return error_response + + results = policy_library.ingest_documents(uploads, locale=locale) + return JSONResponse({"documents": policy_library.list_documents(), "results": results}) + except Exception as exc: + logger.exception("/policies upload exception: %s", exc) + return JSONResponse({"detail": str(exc)}, status_code=500) + + +@app.delete("/policies") +async def clear_policies() -> JSONResponse: + try: + policy_library.clear() + return JSONResponse({"status": "ok"}) + except Exception as exc: + logger.exception("/policies clear exception: %s", exc) + return JSONResponse({"detail": str(exc)}, status_code=500) + + @app.post("/generate/variation") async def generate_variation( image: UploadFile = File(...), @@ -288,6 +392,42 @@ async def _validate_image(image: UploadFile, endpoint: str): return (image_bytes, content_type), None +async def _validate_policy_uploads(policy_files: list[UploadFile], endpoint: str): + if not policy_files: + return None, JSONResponse({"detail": "At least one PDF file is required"}, status_code=400) + + uploads = [] + invalid_files = [] + + for policy_file in policy_files: + logger.info( + "POST %s policy filename=%s content_type=%s", + endpoint, + getattr(policy_file, "filename", None), + getattr(policy_file, "content_type", None), + ) + + filename = getattr(policy_file, "filename", None) or "policy.pdf" + content_type = getattr(policy_file, "content_type", None) or "application/pdf" + if content_type != "application/pdf" and not filename.lower().endswith(".pdf"): + invalid_files.append(filename) + continue + + pdf_bytes = await policy_file.read() + if not pdf_bytes: + invalid_files.append(filename) + continue + uploads.append({"filename": filename, "bytes": pdf_bytes}) + + if invalid_files: + return None, JSONResponse( + {"detail": f"Policy files must be non-empty PDFs. Invalid files: {', '.join(sorted(invalid_files))}"}, + status_code=400, + ) + + return uploads, None + + @app.post("/generate/3d") async def generate_3d( image: UploadFile = File(...), @@ -397,4 +537,4 @@ async def generate_3d( }, status_code=exc.response.status_code) except Exception as exc: logger.exception(f"/generate/3d exception: {exc}") - return JSONResponse({"detail": str(exc)}, status_code=500) \ No newline at end of file + return JSONResponse({"detail": str(exc)}, status_code=500) diff --git a/src/backend/policy.py b/src/backend/policy.py new file mode 100644 index 0000000..9a16b11 --- /dev/null +++ b/src/backend/policy.py @@ -0,0 +1,501 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from io import BytesIO +from typing import Any, Dict, List + +from openai import OpenAI +from pypdf import PdfReader + +from backend.config import get_config +from backend.utils import parse_llm_json + +logger = logging.getLogger("catalog_enrichment.policy") + +MAX_POLICY_TEXT_CHARS = 12000 +MAX_POLICY_SUMMARY_CHARS = 6000 +NGC_API_KEY_NOT_SET_ERROR = "NGC_API_KEY is not set" +LOCALE_CONFIG = { + "en-US": {"language": "English", "region": "United States", "country": "United States", "context": "American English with US terminology"}, + "en-GB": {"language": "English", "region": "United Kingdom", "country": "United Kingdom", "context": "British English with UK terminology"}, + "en-AU": {"language": "English", "region": "Australia", "country": "Australia", "context": "Australian English"}, + "en-CA": {"language": "English", "region": "Canada", "country": "Canada", "context": "Canadian English"}, + "es-ES": {"language": "Spanish", "region": "Spain", "country": "Spain", "context": "Peninsular Spanish"}, + "es-MX": {"language": "Spanish", "region": "Mexico", "country": "Mexico", "context": "Mexican Spanish"}, + "es-AR": {"language": "Spanish", "region": "Argentina", "country": "Argentina", "context": "Argentinian Spanish"}, + "es-CO": {"language": "Spanish", "region": "Colombia", "country": "Colombia", "context": "Colombian Spanish"}, + "fr-FR": {"language": "French", "region": "France", "country": "France", "context": "Metropolitan French"}, + "fr-CA": {"language": "French", "region": "Canada", "country": "Canada", "context": "Quebec French"}, +} + +def extract_text_from_pdf_bytes(pdf_bytes: bytes) -> str: + """Extract text from a PDF byte stream.""" + reader = PdfReader(BytesIO(pdf_bytes)) + parts: List[str] = [] + + for page in reader.pages: + page_text = page.extract_text() or "" + page_text = page_text.strip() + if page_text: + parts.append(page_text) + + return "\n\n".join(parts).strip() + + +def summarize_policy_document(document_name: str, document_text: str, locale: str = "en-US") -> Dict[str, Any]: + """Convert a policy PDF into compact structured rules for indexing and retrieval.""" + if not (api_key := os.getenv("NGC_API_KEY")): + raise RuntimeError(NGC_API_KEY_NOT_SET_ERROR) + + llm_config = get_config().get_llm_config() + client = OpenAI(base_url=llm_config["url"], api_key=api_key) + info = LOCALE_CONFIG.get(locale, LOCALE_CONFIG["en-US"]) + truncated_text = document_text[:MAX_POLICY_TEXT_CHARS] + + prompt = f"""/no_think You are a policy normalization assistant for an e-commerce catalog team. + +Convert the policy document below into concise structured JSON for downstream compliance checks. + +DOCUMENT NAME: +{document_name} + +TARGET MARKET CONTEXT: +{info["region"]} ({info["context"]}) + +POLICY DOCUMENT TEXT: +{truncated_text} + +Return ONLY valid JSON with this schema: +{{ + "document_name": "{document_name}", + "policy_title": "", + "summary": "<2-3 sentence summary>", + "blocking_rules": [ + {{ + "title": "", + "conditions": ["", ""], + "signals": ["", ""] + }} + ], + "permitted_rules": [ + {{ + "title": "", + "conditions": ["", ""] + }} + ], + "required_evidence": ["", "<...>"], + "notes": ["", "<...>"] +}} + +Rules: +- Keep the output compact and focused on classifying products against pass/fail policy checks. +- Prefer observable signals, packaging text, listing text, and ingredient/regulatory markers. +- If the document contains examples, convert them into explicit rules/signals. +- Do not quote long passages verbatim. +""" + + completion = client.chat.completions.create( + model=llm_config["model"], + messages=[{"role": "system", "content": "/no_think"}, {"role": "user", "content": prompt}], + temperature=0.1, + top_p=0.9, + max_tokens=1600, + stream=True, + extra_body={"reasoning_budget": 8192, "chat_template_kwargs": {"enable_thinking": False}}, + ) + + text = "".join( + chunk.choices[0].delta.content + for chunk in completion + if chunk.choices[0].delta and chunk.choices[0].delta.content + ) + + parsed = parse_llm_json(text, extract_braces=True, strip_comments=True) + if parsed is not None: + parsed.setdefault("document_name", document_name) + parsed.setdefault("policy_title", document_name) + parsed.setdefault("summary", "") + parsed.setdefault("blocking_rules", []) + parsed.setdefault("permitted_rules", []) + parsed.setdefault("required_evidence", []) + parsed.setdefault("notes", []) + return parsed + + logger.warning("Policy summary parse failed for %s; falling back to minimal summary", document_name) + return { + "document_name": document_name, + "policy_title": document_name, + "summary": truncated_text[:400], + "blocking_rules": [], + "permitted_rules": [], + "required_evidence": [], + "notes": ["Automatic policy summary fallback was used for this document."], + } + + +def _prepare_policy_context(policy_context: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Reduce duplicate document-level context while preserving retrieved policy records.""" + prepared: List[Dict[str, Any]] = [] + document_hashes_with_summary: set[str] = set() + + for item in policy_context: + document_hash = str(item.get("document_hash", "")) + prepared_item = { + "document_hash": document_hash, + "document_name": item.get("document_name"), + "policy_title": item.get("policy_title"), + "chunk_index": item.get("chunk_index"), + "score": item.get("score"), + "chunk_text": item.get("chunk_text"), + } + if document_hash and document_hash not in document_hashes_with_summary and item.get("document_summary"): + prepared_item["document_summary"] = item.get("document_summary") + document_hashes_with_summary.add(document_hash) + elif item.get("document_summary"): + prepared_item["document_summary"] = item.get("document_summary") + elif any( + key in item + for key in ("summary", "blocking_rules", "permitted_rules", "required_evidence", "notes") + ): + prepared_item["document_summary"] = { + key: item.get(key) + for key in ("document_name", "policy_title", "summary", "blocking_rules", "permitted_rules", "required_evidence", "notes") + if key in item + } + prepared.append(prepared_item) + + return prepared + + +def _format_product_snapshot_for_policy(product_snapshot: Dict[str, Any]) -> str: + primary_lines = [ + f"Observed title: {product_snapshot.get('title', '')}", + f"Observed description: {product_snapshot.get('description', '')}", + f"Observed categories: {', '.join(product_snapshot.get('categories', []))}", + f"Observed tags: {', '.join(product_snapshot.get('tags', []))}", + f"Observed colors: {', '.join(product_snapshot.get('colors', []))}", + ] + + generated = product_snapshot.get("generated_catalog_fields") or {} + secondary_lines = [] + if generated: + secondary_lines = [ + f"Generated title: {generated.get('title', '')}", + f"Generated description: {generated.get('description', '')}", + f"Generated categories: {', '.join(generated.get('categories', []))}", + f"Generated tags: {', '.join(generated.get('tags', []))}", + ] + + sections = [ + "PRIMARY PRODUCT EVIDENCE:", + "\n".join(line for line in primary_lines if line.strip()), + ] + if secondary_lines: + sections.extend( + [ + "SECONDARY GENERATED CATALOG CONTEXT:", + "\n".join(line for line in secondary_lines if line.strip()), + ] + ) + return "\n\n".join(section for section in sections if section.strip()) + + +def _format_policy_context_for_policy(prepared_policy_context: List[Dict[str, Any]]) -> str: + sections: List[str] = [] + for item in prepared_policy_context: + document_summary = item.get("document_summary") or {} + blocking_rules = document_summary.get("blocking_rules") or [] + permitted_rules = document_summary.get("permitted_rules") or [] + required_evidence = document_summary.get("required_evidence") or [] + blocking_titles = ", ".join( + str(rule.get("title", "")).strip() + for rule in blocking_rules + if str(rule.get("title", "")).strip() + ) + permitted_titles = ", ".join( + str(rule.get("title", "")).strip() + for rule in permitted_rules + if str(rule.get("title", "")).strip() + ) + section_lines = [ + f"Document: {item.get('document_name', '')}", + f"Policy title: {item.get('policy_title', '')}", + f"Chunk index: {item.get('chunk_index', '')}", + f"Similarity score: {item.get('score', '')}", + f"Policy summary: {document_summary.get('summary') or item.get('summary', '')}", + f"Blocking rules: {blocking_titles}", + f"Permitted rules: {permitted_titles}", + f"Required evidence: {', '.join(str(entry) for entry in required_evidence if str(entry).strip())}", + f"Retrieved chunk: {item.get('chunk_text', '')}", + ] + sections.append("\n".join(line for line in section_lines if line.strip())) + return "\n\n---\n\n".join(sections) + + +def _is_policy_decision_consistent(decision: Dict[str, Any]) -> bool: + status = str(decision.get("status", "pass")) + matched_policies = decision.get("matched_policies") + if not isinstance(matched_policies, list): + return False + if status == "pass" and matched_policies: + return False + if status == "fail" and not matched_policies: + return False + return True + + +def _repair_policy_decision( + client: OpenAI, + model: str, + locale_info: Dict[str, str], + product_json: str, + policy_json: str, + product_evidence_text: str, + policy_evidence_text: str, + candidate_decision: Dict[str, Any], +) -> Dict[str, Any] | None: + candidate_json = json.dumps(candidate_decision, ensure_ascii=False) + prompt = f"""/no_think You are repairing a malformed catalog compliance decision. + +The candidate JSON below is internally inconsistent. Rewrite it so the final JSON is both accurate and structurally valid. + +TARGET MARKET CONTEXT: +{locale_info["region"]} ({locale_info["context"]}) + +PRODUCT SNAPSHOT: +{product_json} + +RETRIEVED POLICY CONTEXT: +{policy_json} + +FOCUSED PRODUCT EVIDENCE: +{product_evidence_text} + +FOCUSED POLICY EVIDENCE: +{policy_evidence_text} + +INCONSISTENT CANDIDATE DECISION: +{candidate_json} + +Return ONLY valid JSON with this schema: +{{ + "status": "pass" | "fail", + "label": "", + "summary": "", + "matched_policies": [ + {{ + "document_name": "", + "policy_title": "", + "rule_title": "", + "reason": "", + "evidence": ["", ""] + }} + ], + "warnings": ["", "<...>"], + "evidence_note": "" +}} + +Rules: +- Keep the decision faithful to the supplied product and policy context. +- If status is "pass", matched_policies must be empty. +- If status is "fail", matched_policies must contain at least one supporting rule match. +- Keep the response concise and internally consistent. +""" + + completion = client.chat.completions.create( + model=model, + messages=[{"role": "system", "content": "/no_think"}, {"role": "user", "content": prompt}], + temperature=0.1, + top_p=0.9, + max_tokens=900, + stream=True, + extra_body={"reasoning_budget": 4096, "chat_template_kwargs": {"enable_thinking": False}}, + ) + + text = "".join( + chunk.choices[0].delta.content + for chunk in completion + if chunk.choices[0].delta and chunk.choices[0].delta.content + ) + return parse_llm_json(text, extract_braces=True, strip_comments=True) + + +def evaluate_policy_compliance( + product_snapshot: Dict[str, Any], + policy_context: List[Dict[str, Any]], + locale: str = "en-US", +) -> Dict[str, Any]: + """Classify the analyzed product against retrieved policy context.""" + if not (api_key := os.getenv("NGC_API_KEY")): + raise RuntimeError(NGC_API_KEY_NOT_SET_ERROR) + + llm_config = get_config().get_llm_config() + client = OpenAI(base_url=llm_config["url"], api_key=api_key) + info = LOCALE_CONFIG.get(locale, LOCALE_CONFIG["en-US"]) + + prepared_policy_context = _prepare_policy_context(policy_context) + policy_json = json.dumps(prepared_policy_context, ensure_ascii=False)[:MAX_POLICY_SUMMARY_CHARS * max(len(prepared_policy_context), 1)] + product_json = json.dumps(product_snapshot, ensure_ascii=False) + product_evidence_text = _format_product_snapshot_for_policy(product_snapshot) + policy_evidence_text = _format_policy_context_for_policy(prepared_policy_context) + + prompt = f"""/no_think You are a catalog compliance reviewer. + +Review the product below against the uploaded policy summaries. The UI supports two statuses: +- pass +- fail + +Choose the best-fit classification based on the observed product title, description, and retrieved policy records. + +TARGET MARKET CONTEXT: +{info["region"]} ({info["context"]}) + +PRODUCT SNAPSHOT: +{product_json} + +RETRIEVED POLICY CONTEXT: +{policy_json} + +FOCUSED PRODUCT EVIDENCE: +{product_evidence_text} + +FOCUSED POLICY EVIDENCE: +{policy_evidence_text} + +Return ONLY valid JSON with this schema: +{{ + "status": "pass" | "fail", + "label": "", + "summary": "", + "matched_policies": [ + {{ + "document_name": "", + "policy_title": "", + "rule_title": "", + "reason": "", + "evidence": ["", ""] + }} + ], + "warnings": ["", "<...>"], + "evidence_note": "" +}} + +Rules: +- Use "fail" if any policy clearly disallows the product. +- matched_policies must be empty when status is "pass". +- Be specific and short. +- Base the decision only on the supplied product snapshot and policies. +- Treat the top-level product fields as the primary evidence source. Those fields represent the raw product observation. +- Treat generated_catalog_fields as secondary context only. +- Prefer direct product evidence from the title, visible text, form, components, and retrieved policy records over polished marketing language. +- Do not require exact literal keyword equality when close lexical variants, inflections, or obvious wording variants point to the same product type and the product's form or function also aligns with the policy. +- Prefer "fail" when the product's observed title, visible text, or described function clearly names or strongly implies a blocked product family in the policy and there is no stronger allowed-category match. +- Treat blocking-rule conditions, listed keywords, and listed signals as alternative supporting indicators unless the policy explicitly says all of them are required together. +- Do not require every example component or every listed signal to be present when the product already strongly matches a blocked product family through title, visible text, or described purpose. +- Do not assume a product passes just because the listing does not explicitly state an end use if the retrieved policies define blocking by function, form, components, or keywords. +- Use the retrieved policy records as the policy source of truth. +- Before returning JSON, verify that status, summary, matched_policies, warnings, and evidence_note are internally consistent. +- If status is "pass", summary must clearly say that no retrieved policy blocks the product. +- If status is "fail", summary must clearly say that the product does not comply and matched_policies must contain the supporting rule matches. +""" + + completion = client.chat.completions.create( + model=llm_config["model"], + messages=[{"role": "system", "content": "/no_think"}, {"role": "user", "content": prompt}], + temperature=0.1, + top_p=0.9, + max_tokens=1200, + stream=True, + extra_body={"reasoning_budget": 8192, "chat_template_kwargs": {"enable_thinking": False}}, + ) + + text = "".join( + chunk.choices[0].delta.content + for chunk in completion + if chunk.choices[0].delta and chunk.choices[0].delta.content + ) + + parsed = parse_llm_json(text, extract_braces=True, strip_comments=True) + if parsed is not None: + parsed_status = str(parsed.get("status", "pass")) + if parsed_status not in {"pass", "fail"}: + parsed_status = "pass" + parsed["status"] = parsed_status + parsed.setdefault( + "label", + "Policy Check Failed" if parsed["status"] == "fail" else "Policy Check Passed", + ) + parsed.setdefault("summary", "") + parsed.setdefault("matched_policies", []) + parsed.setdefault("warnings", []) + parsed.setdefault("evidence_note", "") + if parsed["status"] == "pass": + parsed["matched_policies"] = [] + if not _is_policy_decision_consistent(parsed): + logger.warning( + "Policy decision was internally inconsistent; attempting repair. status=%s matched=%d", + parsed.get("status"), + len(parsed.get("matched_policies", [])) if isinstance(parsed.get("matched_policies"), list) else -1, + ) + repaired = _repair_policy_decision( + client, + llm_config["model"], + info, + product_json, + policy_json, + product_evidence_text, + policy_evidence_text, + parsed, + ) + if repaired is not None: + repaired_status = str(repaired.get("status", "pass")) + if repaired_status not in {"pass", "fail"}: + repaired_status = "pass" + repaired["status"] = repaired_status + repaired.setdefault( + "label", + "Policy Check Failed" if repaired["status"] == "fail" else "Policy Check Passed", + ) + repaired.setdefault("summary", "") + repaired.setdefault("matched_policies", []) + repaired.setdefault("warnings", []) + repaired.setdefault("evidence_note", "") + if repaired["status"] == "pass": + repaired["matched_policies"] = [] + if _is_policy_decision_consistent(repaired): + return repaired + logger.warning("Policy decision repair failed; using fallback pass result") + return { + "status": "pass", + "label": "Policy Check Passed", + "summary": "No retrieved policy blocks this product.", + "matched_policies": [], + "warnings": ["Policy evaluation used a fallback pass result because the model response was internally inconsistent."], + "evidence_note": "Fallback decision based on inconsistent model output.", + } + return parsed + + logger.warning("Policy compliance parse failed; falling back to pass result") + return { + "status": "pass", + "label": "Policy Check Passed", + "summary": "No retrieved policy blocks this product.", + "matched_policies": [], + "warnings": ["Policy evaluation used a fallback pass result because the model response was malformed."], + "evidence_note": "Fallback decision based on parser failure.", + } diff --git a/src/backend/policy_library.py b/src/backend/policy_library.py new file mode 100644 index 0000000..7f2a395 --- /dev/null +++ b/src/backend/policy_library.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import logging +import os +import shutil +import sqlite3 +import time +from pathlib import Path +from typing import Any, Dict, List, Sequence + +from openai import OpenAI +from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility + +from backend.config import get_config +from backend.policy import extract_text_from_pdf_bytes, summarize_policy_document + +logger = logging.getLogger("catalog_enrichment.policy_library") + +EMBEDDING_API_KEY_ERROR = "NVIDIA_API_KEY or NGC_API_KEY is not set" +MAX_QUERY_WORDS = 160 +MAX_EMBED_TOTAL_WORDS = 190 + + +def _limit_words(text: str, max_words: int) -> str: + words = text.split() + if len(words) <= max_words: + return text.strip() + return " ".join(words[:max_words]).strip() + + +def build_policy_query(product_snapshot: Dict[str, Any]) -> str: + """Build a compact retrieval query from analyzed product evidence.""" + parts = [ + f"Title: {product_snapshot.get('title', '')}", + f"Description: {product_snapshot.get('description', '')}", + f"Categories: {', '.join(product_snapshot.get('categories', []))}", + f"Tags: {', '.join(product_snapshot.get('tags', []))}", + f"Colors: {', '.join(product_snapshot.get('colors', []))}", + ] + return _limit_words("\n".join(part for part in parts if part.strip()), MAX_QUERY_WORDS) + + +class PolicyLibrary: + """Persistent single-user policy document library backed by SQLite and Milvus.""" + + def __init__(self) -> None: + config = get_config() + self._policy_config = config.get_policy_library_config() + self._milvus_config = config.get_milvus_config() + self._embedding_config = config.get_embeddings_config() + self._storage_dir = Path(self._policy_config["storage_dir"]) + self._db_path = Path(self._policy_config["db_path"]) + self._top_k = int(self._policy_config["top_k"]) + self._min_relevance_score = float(self._policy_config["min_relevance_score"]) + self._collection_name = str(self._milvus_config["collection"]) + self._milvus_alias = str(self._milvus_config["alias"]) + self._connected = False + + def initialize(self) -> None: + self._storage_dir.mkdir(parents=True, exist_ok=True) + self._db_path.parent.mkdir(parents=True, exist_ok=True) + with self._connect_db() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS policy_documents ( + document_hash TEXT PRIMARY KEY, + filename TEXT NOT NULL, + file_size INTEGER NOT NULL, + chunk_count INTEGER NOT NULL, + summary_json TEXT NOT NULL, + text_path TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) + """ + ) + conn.commit() + + def list_documents(self) -> List[Dict[str, Any]]: + with self._connect_db() as conn: + rows = conn.execute( + """ + SELECT document_hash, filename, file_size, chunk_count, created_at, updated_at + FROM policy_documents + ORDER BY updated_at DESC + """ + ).fetchall() + + return [ + { + "document_hash": row["document_hash"], + "filename": row["filename"], + "file_size": row["file_size"], + "chunk_count": row["chunk_count"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + } + for row in rows + ] + + def ingest_documents(self, uploads: Sequence[Dict[str, Any]], locale: str = "en-US") -> List[Dict[str, Any]]: + results = [] + for upload in uploads: + filename = upload["filename"] + pdf_bytes = upload["bytes"] + document_hash = hashlib.sha256(pdf_bytes).hexdigest() + existing = self._get_document(document_hash) + if existing is not None: + self._touch_document(document_hash) + results.append( + { + "document_hash": document_hash, + "filename": existing["filename"], + "chunk_count": existing["chunk_count"], + "already_loaded": True, + "processed": False, + } + ) + continue + + extracted_text = extract_text_from_pdf_bytes(pdf_bytes) + if not extracted_text: + raise ValueError(f"Unable to extract text from PDF: {filename}") + + normalized_text = extracted_text.strip() + summary = summarize_policy_document(filename, normalized_text, locale) + records = self._build_policy_entries(filename, summary) + embedding_inputs = [ + self._format_policy_entry_for_embedding(entry_text) + for entry_text in records + ] + vectors = self._embed_texts(embedding_inputs, input_type="passage") + if not vectors: + raise RuntimeError(f"No embeddings were returned for {filename}") + + self._ensure_collection(len(vectors[0])) + self._replace_document_vectors(document_hash, filename, summary, records, vectors) + self._persist_document(document_hash, filename, len(pdf_bytes), len(records), summary, normalized_text) + results.append( + { + "document_hash": document_hash, + "filename": filename, + "chunk_count": len(records), + "already_loaded": False, + "processed": True, + } + ) + + return results + + def retrieve_context(self, product_snapshot: Dict[str, Any]) -> List[Dict[str, Any]]: + if not self.list_documents(): + return [] + + if not self._collection_exists(): + return [] + + query_text = build_policy_query(product_snapshot) + if not query_text.strip(): + return [] + + query_vector = self._embed_texts([query_text], input_type="query")[0] + collection = self._get_collection(load=True) + results = collection.search( + data=[query_vector], + anns_field="embedding", + param={"metric_type": "COSINE", "params": {}}, + limit=self._top_k, + output_fields=["document_hash", "document_name", "policy_title", "summary", "chunk_text", "chunk_index"], + ) + + raw_hits = [] + document_hashes = set() + for hit in results[0]: + entity = hit.entity + document_hash = entity.get("document_hash") + if document_hash: + document_hashes.add(document_hash) + raw_hits.append((hit, entity)) + + if raw_hits: + logger.info( + "Policy retrieval candidate scores: %s", + ", ".join(f"{float(hit.score):.4f}" for hit, _ in raw_hits[: min(5, len(raw_hits))]), + ) + top_score = float(raw_hits[0][0].score) + if top_score < self._min_relevance_score: + logger.info( + "Policy retrieval skipped classification: top score %.4f below min_relevance_score %.4f", + top_score, + self._min_relevance_score, + ) + return [] + + document_summaries = self._get_document_summaries(document_hashes) + retrieved = [] + for hit, entity in raw_hits: + document_hash = entity.get("document_hash") + retrieved.append( + { + "document_hash": document_hash, + "document_name": entity.get("document_name"), + "policy_title": entity.get("policy_title"), + "summary": entity.get("summary"), + "chunk_text": entity.get("chunk_text"), + "chunk_index": entity.get("chunk_index"), + "score": float(hit.score), + "document_summary": document_summaries.get(document_hash, {}), + } + ) + return retrieved + + def clear(self) -> None: + if self._collection_exists(): + utility.drop_collection(self._collection_name, using=self._milvus_alias) + if self._db_path.exists(): + with self._connect_db() as conn: + conn.execute("DELETE FROM policy_documents") + conn.commit() + if self._storage_dir.exists(): + shutil.rmtree(self._storage_dir) + self._storage_dir.mkdir(parents=True, exist_ok=True) + self._db_path.parent.mkdir(parents=True, exist_ok=True) + self.initialize() + + def _connect_db(self) -> sqlite3.Connection: + conn = sqlite3.connect(self._db_path) + conn.row_factory = sqlite3.Row + return conn + + def _get_document(self, document_hash: str) -> sqlite3.Row | None: + with self._connect_db() as conn: + row = conn.execute( + """ + SELECT document_hash, filename, chunk_count + FROM policy_documents + WHERE document_hash = ? + """, + (document_hash,), + ).fetchone() + return row + + def _touch_document(self, document_hash: str) -> None: + now = int(time.time()) + with self._connect_db() as conn: + conn.execute( + """ + UPDATE policy_documents + SET updated_at = ? + WHERE document_hash = ? + """, + (now, document_hash), + ) + conn.commit() + + def _persist_document( + self, + document_hash: str, + filename: str, + file_size: int, + chunk_count: int, + summary: Dict[str, Any], + extracted_text: str, + ) -> None: + now = int(time.time()) + document_dir = self._storage_dir / document_hash + document_dir.mkdir(parents=True, exist_ok=True) + text_path = document_dir / "text.txt" + summary_path = document_dir / "summary.json" + text_path.write_text(extracted_text, encoding="utf-8") + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + with self._connect_db() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO policy_documents ( + document_hash, filename, file_size, chunk_count, summary_json, text_path, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + document_hash, + filename, + file_size, + chunk_count, + json.dumps(summary, ensure_ascii=False), + str(text_path), + now, + now, + ), + ) + conn.commit() + + def _get_document_summaries(self, document_hashes: Sequence[str]) -> Dict[str, Dict[str, Any]]: + unique_hashes = [document_hash for document_hash in dict.fromkeys(document_hashes) if document_hash] + if not unique_hashes: + return {} + + placeholders = ",".join("?" for _ in unique_hashes) + with self._connect_db() as conn: + rows = conn.execute( + f""" + SELECT document_hash, summary_json + FROM policy_documents + WHERE document_hash IN ({placeholders}) + """, + unique_hashes, + ).fetchall() + + summaries: Dict[str, Dict[str, Any]] = {} + for row in rows: + try: + summaries[row["document_hash"]] = json.loads(row["summary_json"]) + except json.JSONDecodeError: + logger.warning("Failed to decode stored summary_json for %s", row["document_hash"]) + summaries[row["document_hash"]] = {} + return summaries + + def _build_policy_entries(self, filename: str, summary: Dict[str, Any]) -> List[str]: + entries: List[str] = [] + + overview_lines = [ + f"Document: {filename}", + f"Policy Title: {summary.get('policy_title', filename)}", + f"Summary: {summary.get('summary', '')}", + "Rule Type: Overview", + ] + required_evidence = [str(item) for item in summary.get("required_evidence", []) if str(item).strip()] + notes = [str(item) for item in summary.get("notes", []) if str(item).strip()] + if required_evidence: + overview_lines.append(f"Required Evidence: {'; '.join(required_evidence)}") + if notes: + overview_lines.append(f"Notes: {'; '.join(notes)}") + entries.append("\n".join(overview_lines)) + + for rule_type, rules in ( + ("Blocking", summary.get("blocking_rules", [])), + ("Permitted", summary.get("permitted_rules", [])), + ): + for rule in rules: + lines = [ + f"Document: {filename}", + f"Policy Title: {summary.get('policy_title', filename)}", + f"Summary: {summary.get('summary', '')}", + f"Rule Type: {rule_type}", + f"Rule Title: {rule.get('title', '')}", + f"Conditions: {'; '.join(str(item) for item in rule.get('conditions', []) if str(item).strip())}", + ] + signals = [str(item) for item in rule.get("signals", []) if str(item).strip()] + if signals: + lines.append(f"Signals: {'; '.join(signals)}") + entries.append("\n".join(line for line in lines if line.strip())) + + return [_limit_words(entry, MAX_EMBED_TOTAL_WORDS) for entry in entries if entry.strip()] + + def _format_policy_entry_for_embedding(self, entry_text: str) -> str: + return _limit_words(entry_text, MAX_EMBED_TOTAL_WORDS) + + def _embed_texts(self, texts: Sequence[str], input_type: str) -> List[List[float]]: + if not texts: + return [] + api_key = os.getenv("NVIDIA_API_KEY") or os.getenv("NGC_API_KEY") + if not api_key: + raise RuntimeError(EMBEDDING_API_KEY_ERROR) + client = OpenAI(api_key=api_key, base_url=self._embedding_config["url"]) + response = client.embeddings.create( + input=list(texts), + model=self._embedding_config["model"], + encoding_format="float", + extra_body={"input_type": input_type, "truncate": "NONE"}, + ) + return [item.embedding for item in response.data] + + def _connect_milvus(self) -> None: + if self._connected: + return + connections.connect( + alias=self._milvus_alias, + host=self._milvus_config["host"], + port=self._milvus_config["port"], + ) + self._connected = True + + def _collection_exists(self) -> bool: + self._connect_milvus() + return utility.has_collection(self._collection_name, using=self._milvus_alias) + + def _ensure_collection(self, dimension: int) -> None: + self._connect_milvus() + if utility.has_collection(self._collection_name, using=self._milvus_alias): + return + + fields = [ + FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=128), + FieldSchema(name="document_hash", dtype=DataType.VARCHAR, max_length=64), + FieldSchema(name="document_name", dtype=DataType.VARCHAR, max_length=512), + FieldSchema(name="policy_title", dtype=DataType.VARCHAR, max_length=512), + FieldSchema(name="summary", dtype=DataType.VARCHAR, max_length=4096), + FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=16384), + FieldSchema(name="chunk_index", dtype=DataType.INT64), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension), + ] + schema = CollectionSchema(fields=fields, description="Persistent normalized policy records") + collection = Collection(name=self._collection_name, schema=schema, using=self._milvus_alias) + collection.create_index("embedding", {"index_type": "AUTOINDEX", "metric_type": "COSINE", "params": {}}) + collection.load() + + def _get_collection(self, load: bool = False) -> Collection: + self._connect_milvus() + collection = Collection(name=self._collection_name, using=self._milvus_alias) + if load: + collection.load() + return collection + + def _replace_document_vectors( + self, + document_hash: str, + filename: str, + summary: Dict[str, Any], + records: Sequence[str], + vectors: Sequence[Sequence[float]], + ) -> None: + collection = self._get_collection(load=True) + collection.delete(expr=f'document_hash == "{document_hash}"') + entities = [ + [f"{document_hash}:{index}" for index in range(len(records))], + [document_hash] * len(records), + [filename] * len(records), + [str(summary.get("policy_title", filename))] * len(records), + [str(summary.get("summary", ""))] * len(records), + list(records), + list(range(len(records))), + [list(vector) for vector in vectors], + ] + collection.insert(entities) + collection.flush() diff --git a/src/backend/vlm.py b/src/backend/vlm.py index 4ce8f61..2242fc8 100644 --- a/src/backend/vlm.py +++ b/src/backend/vlm.py @@ -17,7 +17,7 @@ import json import base64 import logging -from typing import Optional, List, Dict, Any +from typing import Optional, Dict, Any from dotenv import load_dotenv from openai import OpenAI @@ -290,7 +290,7 @@ def _call_vlm(image_bytes: bytes, content_type: str) -> Dict[str, Any]: TASK: 1. Describe the product itself - its materials, design, and features -2. Include any visible brand names or product text shown on the item +2. Include any visible brand names, packaging text, ingredient text, regulatory labels, ratings, warnings, or other product text shown on the item 3. Write in ENGLISH - be accurate about what you see CATEGORIES - Choose ONLY from this allowed set: {categories_str} @@ -333,12 +333,59 @@ def _call_vlm(image_bytes: bytes, content_type: str) -> Dict[str, Any]: return parsed return {"title": "", "description": text.strip(), "categories": ["uncategorized"], "tags": [], "colors": []} + +def extract_vlm_observation(image_bytes: bytes, content_type: str) -> Dict[str, Any]: + """Run only the raw VLM observation step.""" + if not image_bytes: + raise ValueError("image_bytes is required") + if not isinstance(content_type, str) or not content_type.startswith("image/"): + raise ValueError("content_type must be an image/* MIME type") + + vlm_result = _call_vlm(image_bytes, content_type) + logger.info( + "VLM analysis complete (English): title_len=%d desc_len=%d categories=%s", + len(vlm_result.get("title", "")), + len(vlm_result.get("description", "")), + vlm_result.get("categories", []), + ) + return vlm_result + + +def build_enriched_vlm_result( + vlm_result: Dict[str, Any], + locale: str = "en-US", + product_data: Optional[Dict[str, Any]] = None, + brand_instructions: Optional[str] = None, +) -> Dict[str, Any]: + """Build enriched catalog fields from a raw VLM observation.""" + enhanced = _call_nemotron_enhance(vlm_result, product_data, locale, brand_instructions) + logger.info("Nemotron enhance complete: keys=%s", list(enhanced.keys())) + + categories = ( + enhanced.get("categories") + if enhanced.get("categories") and isinstance(enhanced.get("categories"), list) + else vlm_result.get("categories", ["uncategorized"]) + ) + + result = { + "title": enhanced.get("title", vlm_result.get("title", "")), + "description": enhanced.get("description", vlm_result.get("description", "")), + "categories": categories, + "tags": enhanced.get("tags", vlm_result.get("tags", [])), + "colors": enhanced.get("colors", vlm_result.get("colors", [])), + } + + if product_data: + result["enhanced_product"] = {**product_data, **enhanced} + + return result + def run_vlm_analysis( image_bytes: bytes, content_type: str, locale: str = "en-US", product_data: Optional[Dict[str, Any]] = None, - brand_instructions: Optional[str] = None + brand_instructions: Optional[str] = None, ) -> Dict[str, Any]: """ Run VLM analysis on an image to extract product fields. @@ -352,40 +399,10 @@ def run_vlm_analysis( locale: Target locale for analysis product_data: Optional existing product data to augment brand_instructions: Optional brand-specific tone/style instructions - + Returns: - Dict with title, description, categories, tags, colors, enhanced_product (if augmentation) + Dict with title, description, categories, tags, colors, and enhanced_product (if augmentation) """ logger.info("Running VLM analysis: locale=%s mode=%s brand_instructions=%s", locale, "augmentation" if product_data else "generation", bool(brand_instructions)) - - if not image_bytes: - raise ValueError("image_bytes is required") - if not isinstance(content_type, str) or not content_type.startswith("image/"): - raise ValueError("content_type must be an image/* MIME type") - - # Run VLM analysis (always in English) - vlm_result = _call_vlm(image_bytes, content_type) - logger.info("VLM analysis complete (English): title_len=%d desc_len=%d categories=%s", - len(vlm_result.get("title", "")), len(vlm_result.get("description", "")), vlm_result.get("categories", [])) - - # Always enhance VLM output with Nemotron (handles all scenarios) - enhanced = _call_nemotron_enhance(vlm_result, product_data, locale, brand_instructions) - logger.info("Nemotron enhance complete: keys=%s", list(enhanced.keys())) - - categories = (enhanced.get("categories") if enhanced.get("categories") and isinstance(enhanced.get("categories"), list) - else vlm_result.get("categories", ["uncategorized"])) - - result = { - "title": enhanced.get("title", vlm_result.get("title", "")), - "description": enhanced.get("description", vlm_result.get("description", "")), - "categories": categories, - "tags": enhanced.get("tags", vlm_result.get("tags", [])), - "colors": enhanced.get("colors", vlm_result.get("colors", [])) - } - - # If product data was provided, merge enhanced fields back into original product_data - # to preserve untouched fields (price, SKU, specs, etc.) - if product_data: - result["enhanced_product"] = {**product_data, **enhanced} - - return result + vlm_result = extract_vlm_observation(image_bytes, content_type) + return build_enriched_vlm_result(vlm_result, locale, product_data, brand_instructions) diff --git a/src/ui/app/page.tsx b/src/ui/app/page.tsx index 5946b41..39b3d7a 100644 --- a/src/ui/app/page.tsx +++ b/src/ui/app/page.tsx @@ -1,15 +1,16 @@ 'use client'; import { Stack } from '@/kui-foundations-react-external'; -import { useState, useRef } from 'react'; +import { useEffect, useState, useRef } from 'react'; import { Nebula } from '@/kui-foundations-react-external/nebula'; import { Header } from '@/components/Header'; import { ImageUploadCard } from '@/components/ImageUploadCard'; import { FieldsCard } from '@/components/FieldsCard'; +import { AdvancedOptionsCard } from '@/components/AdvancedOptionsCard'; import { GeneratedVariationsSection } from '@/components/GeneratedVariationsSection'; -import { ProductFields, AugmentedData, ImageMetadata, SUPPORTED_LOCALES } from '@/types'; -import { analyzeImage, generateImageVariation, generate3DModel, prepareProductData } from '@/lib/api'; -import { formatFileSize } from '@/lib/utils'; +import { ProductFields, AugmentedData, PolicyDocument, PolicyUploadResult, SUPPORTED_LOCALES } from '@/types'; +import { analyzeImage, clearPolicies, generateImageVariation, generate3DModel, listPolicies, prepareProductData, uploadPolicies } from '@/lib/api'; + function Home() { const [uploadedImage, setUploadedImage] = useState(null); @@ -17,8 +18,11 @@ function Home() { const [isUploading, setIsUploading] = useState(false); const [isAnalyzingFields, setIsAnalyzingFields] = useState(false); const [isGeneratingImage, setIsGeneratingImage] = useState(false); - const [imageMetadata, setImageMetadata] = useState(null); const [locale, setLocale] = useState('en-US'); + const [loadedPolicies, setLoadedPolicies] = useState([]); + const [policyUploadResults, setPolicyUploadResults] = useState([]); + const [policyUploadError, setPolicyUploadError] = useState(null); + const [isUploadingPolicies, setIsUploadingPolicies] = useState(false); const [fields, setFields] = useState({ title: '', description: '', @@ -38,6 +42,11 @@ function Home() { const [enableVariation2, setEnableVariation2] = useState(true); const [enable3D, setEnable3D] = useState(true); const fileInputRef = useRef(null); + const policyFileInputRef = useRef(null); + + useEffect(() => { + void refreshPolicies({ silent: true }); + }, []); const handleFileUpload = async (file: File) => { if (!['image/png', 'image/jpeg', 'image/jpg'].includes(file.type)) { @@ -56,11 +65,6 @@ function Home() { reader.onload = (e) => { const img = new window.Image(); img.onload = () => { - setImageMetadata({ - name: file.name, - size: formatFileSize(file.size), - dimensions: `${img.width} × ${img.height}` - }); setIsUploading(false); }; img.src = e.target?.result as string; @@ -84,7 +88,6 @@ function Home() { const handleReset = () => { setUploadedImage(null); setUploadedFile(null); - setImageMetadata(null); setAugmentedData(null); setGeneratedImages([null, null]); setQualityScores([null, null]); @@ -92,11 +95,79 @@ function Home() { setGenerated3DModel(null); setModel3DError(null); setLocale('en-US'); + setPolicyUploadResults([]); + setPolicyUploadError(null); setFields({ title: '', description: '', color: '', categories: '', tags: '', price: '', brandInstructions: '' }); setEnableVariation1(true); setEnableVariation2(true); setEnable3D(true); if (fileInputRef.current) fileInputRef.current.value = ''; + if (policyFileInputRef.current) policyFileInputRef.current.value = ''; + }; + + const refreshPolicies = async ({ silent = false }: { silent?: boolean } = {}) => { + try { + const documents = await listPolicies(); + setLoadedPolicies(documents); + if (!silent) { + setPolicyUploadError(null); + } + } catch (error) { + console.error('Error loading policy library:', error); + setLoadedPolicies([]); + if (!silent) { + setPolicyUploadError(error instanceof Error ? error.message : 'Failed to load policy library'); + } + } + }; + + const handlePolicyFilesUpload = async (files: FileList | null) => { + const selectedFiles = Array.from(files || []); + if (selectedFiles.length === 0) { + return; + } + + const invalidFile = selectedFiles.find((file) => file.type !== 'application/pdf' && !file.name.toLowerCase().endsWith('.pdf')); + if (invalidFile) { + setPolicyUploadError(`"${invalidFile.name}" is not a PDF.`); + return; + } + + try { + setIsUploadingPolicies(true); + setPolicyUploadError(null); + const response = await uploadPolicies(selectedFiles, locale); + setLoadedPolicies(response.documents || []); + setPolicyUploadResults(response.results || []); + } catch (error) { + console.error('Error uploading policy PDFs:', error); + setPolicyUploadError(error instanceof Error ? error.message : 'Failed to upload policy PDFs'); + } finally { + setIsUploadingPolicies(false); + if (policyFileInputRef.current) { + policyFileInputRef.current.value = ''; + } + } + }; + + const handleClearPolicyLibrary = async () => { + if (!window.confirm('This will clear all loaded policy PDFs and embeddings. Continue?')) { + return; + } + + try { + setIsUploadingPolicies(true); + setPolicyUploadError(null); + await clearPolicies(); + setLoadedPolicies([]); + setPolicyUploadResults([]); + setAugmentedData(prev => prev ? { ...prev, policyDecision: undefined } : prev); + } catch (error) { + console.error('Error clearing policy library:', error); + setPolicyUploadError(error instanceof Error ? error.message : 'Failed to clear policy library'); + } finally { + setIsUploadingPolicies(false); + } }; const handleGenerate = async () => { @@ -117,7 +188,7 @@ function Home() { file: uploadedFile, locale, productData, - brandInstructions: fields.brandInstructions + brandInstructions: fields.brandInstructions }); setAugmentedData({ @@ -125,7 +196,8 @@ function Home() { description: analyzeData.description || '', colors: analyzeData.colors || [], tags: analyzeData.tags || [], - categories: analyzeData.categories || [] + categories: analyzeData.categories || [], + policyDecision: analyzeData.policyDecision }); setIsAnalyzingFields(false); @@ -152,9 +224,6 @@ function Home() { setGeneratedImages(prev => [result.imageUrl, prev[1]]); setQualityScores(prev => [result.qualityScore, prev[1]]); setQualityIssues(prev => [result.qualityIssues, prev[1]]); - if (result.qualityIssues && result.qualityIssues.length > 0) { - console.log('[Variation 1] Quality issues:', result.qualityIssues); - } } catch (error) { console.error('Error generating variation 1:', error); } @@ -170,9 +239,6 @@ function Home() { setGeneratedImages(prev => [prev[0], result.imageUrl]); setQualityScores(prev => [prev[0], result.qualityScore]); setQualityIssues(prev => [prev[0], result.qualityIssues]); - if (result.qualityIssues && result.qualityIssues.length > 0) { - console.log('[Variation 2] Quality issues:', result.qualityIssues); - } } catch (error) { console.error('Error generating variation 2:', error); } @@ -280,28 +346,29 @@ function Home() { }} style={{ display: 'none' }} /> + { + void handlePolicyFilesUpload(e.target.files); + }} + style={{ display: 'none' }} + />
fileInputRef.current?.click()} onDragOver={handleDragOver} onDrop={handleDrop} onLocaleChange={setLocale} - onBrandInstructionsChange={(value) => setFields(prev => ({ ...prev, brandInstructions: value }))} - onEnableVariation1Change={setEnableVariation1} - onEnableVariation2Change={setEnableVariation2} - onEnable3DChange={setEnable3D} onGenerate={handleGenerate} onReset={handleReset} /> @@ -315,6 +382,29 @@ function Home() { />
+
+ setFields(prev => ({ ...prev, brandInstructions: value }))} + onPolicyFileSelect={() => policyFileInputRef.current?.click()} + onClearPolicyLibrary={() => { + void handleClearPolicyLibrary(); + }} + onEnableVariation1Change={setEnableVariation1} + onEnableVariation2Change={setEnableVariation2} + onEnable3DChange={setEnable3D} + /> +
+ void; + onPolicyFileSelect: () => void; + onClearPolicyLibrary: () => void; + onEnableVariation1Change: (value: boolean) => void; + onEnableVariation2Change: (value: boolean) => void; + onEnable3DChange: (value: boolean) => void; +} + +function PolicyUploadProgress() { + const [stageIndex, setStageIndex] = useState(0); + + useEffect(() => { + const interval = setInterval(() => { + setStageIndex((prev) => (prev + 1) % UPLOAD_STAGES.length); + }, 2400); + return () => clearInterval(interval); + }, []); + + return ( +
+
+
+
+ + {UPLOAD_STAGES[stageIndex]} + + + ... + +
+
+ {UPLOAD_STAGES.map((label, i) => ( +
+ ))} +
+
+ +
+ ); +} + +export function AdvancedOptionsCard({ + brandInstructions, + loadedPolicies, + policyUploadResults, + policyUploadError, + isUploadingPolicies, + enableVariation1, + enableVariation2, + enable3D, + isAnalyzingFields, + isGeneratingImage, + onBrandInstructionsChange, + onPolicyFileSelect, + onClearPolicyLibrary, + onEnableVariation1Change, + onEnableVariation2Change, + onEnable3DChange, +}: Props) { + const [isCollapsed, setIsCollapsed] = useState(false); + const disabled = isAnalyzingFields || isGeneratingImage; + + return ( + + + + + Configuration + + Advanced options + + Fine-tune generation with brand rules, compliance policies, and output toggles. + + + + + + {!isCollapsed && ( +
+ + {/* Brand Instructions */} +
+ + + Brand instructions + + Add voice, tone, taxonomy, or content rules that should guide generation. + + + Optional + +