diff --git a/README.md b/README.md index 17a3e68..bd68f1a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ PhotoSort is a powerful desktop application focused on speed designed to streaml * **Ratings & Labels**: Assign star ratings for quick categorization. * **Blur Detection**: Automatically identify and flag blurry photos. * **AI Orientation Detection**: Auto-detects the correct image orientation using a lightweight ONNX model and proposes rotations. + * **AI Best Shot Picker**: Let a vision-language model evaluate selected images and highlight the best shot without leaving the flow. * **Similarity Analysis**: Group visually similar images to easily spot duplicates or near-duplicates. * **Fast Processing**: Intensive operations (scanning, thumbnailing, analysis) run once in batch to ensure fast image scrolling. * **Optimized Image Handling**: Supports a wide range of formats, including various RAW types, with efficient caching. @@ -98,6 +99,17 @@ To use the **Auto Rotate Images** feature (`Ctrl+R`), you need to download the p The application will automatically detect and load the model when you use the rotation detection feature. +### AI Best Shot Picker + +PhotoSort can now analyze a set of similar photos and automatically pick the best shot. Configure the integration inside the app under `Settings → Preferences → AI Best Shot`: + +- Set the **API URL** to match your server (for LM Studio the default is `http://localhost:1234/v1`; update the IP/port if you run it elsewhere). +- Provide an **API Key** if your endpoint requires one. +- Choose the **Model** you want the picker to use. We recommend loading **Qwen3 VL** in LM Studio. +- Adjust the **Timeout** to fit your hardware and model size. + +Once configured, select two or more images and trigger the analysis via `Image → Pick Best Shot with AI…` or `Ctrl+B` (`Cmd+B` on macOS). The result dialog highlights the chosen image and the reasoning so you can jump straight to the winner. + ### Exporting Logs To capture detailed logs for debugging, you can enable file logging by setting an environment variable before running the application. diff --git a/requirements.txt b/requirements.txt index 8cf5304..2fd564b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pyexiv2 piexif onnxruntime torchvision +openai diff --git a/src/core/ai/__init__.py b/src/core/ai/__init__.py new file mode 100644 index 0000000..72eaaf3 --- /dev/null +++ b/src/core/ai/__init__.py @@ -0,0 +1,13 @@ +""" +AI Module for PhotoSort + +This module contains AI-powered features for PhotoSort, including: +- Best shot picker: Automatically select the best image from a group + +The AI features use vision language models through an OpenAI-compatible API, +with default support for LM Studio local server. +""" + +from .best_shot_picker import BestShotPicker, BestShotResult, BestShotPickerError + +__all__ = ["BestShotPicker", "BestShotResult", "BestShotPickerError"] diff --git a/src/core/ai/best_shot_picker.py b/src/core/ai/best_shot_picker.py new file mode 100644 index 0000000..d79936e --- /dev/null +++ b/src/core/ai/best_shot_picker.py @@ -0,0 +1,634 @@ +""" +AI Best Shot Picker Module +Communicates with LM Studio (or compatible OpenAI API) to select the best image +from a group of images using vision model analysis. +""" + +import base64 +import io +import logging +import mimetypes +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Dict + +try: + from PIL import Image, ImageDraw, ImageFont +except ImportError: # pragma: no cover - Pillow is an optional dependency in some envs + Image = None # type: ignore + ImageDraw = None # type: ignore + ImageFont = None # type: ignore + +from openai import OpenAI + +logger = logging.getLogger(__name__) + +# Suppress noisy debug logs from third-party libraries +logging.getLogger("httpcore").setLevel(logging.WARNING) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("openai").setLevel(logging.WARNING) + + +@dataclass +class BestShotResult: + """Result from best shot analysis.""" + + best_image_index: int # Index of the best image (0-based) + best_image_path: str # Path to the best image + reasoning: str # AI's reasoning for the selection + confidence: str # Confidence level (if provided) + raw_response: str # Full AI response + + +class BestShotPickerError(Exception): + """Exception raised when best shot picking fails.""" + + pass + + +class BestShotPicker: + """ + AI-powered best shot picker using vision models via OpenAI-compatible API. + + This class communicates with LM Studio or any OpenAI-compatible API endpoint + to analyze multiple images and select the best one based on various criteria + like focus, composition, exposure, and overall quality. + """ + + def __init__( + self, + base_url: str = "http://localhost:1234/v1", + api_key: str = "not-needed", + model: str = "local-model", + timeout: int = 120, + ): + """ + Initialize the BestShotPicker. + + Args: + base_url: The base URL for the API endpoint (default: LM Studio local) + api_key: API key (not needed for local LM Studio) + model: Model identifier (placeholder for LM Studio) + timeout: Request timeout in seconds (default: 120 for vision models) + """ + self.base_url = base_url + self.api_key = api_key + self.model = model + self.timeout = timeout + self.client = None + self.debug_overlay_enabled = True + logger.info("AI best shot debug overlays enabled (default)") + + def _initialize_client(self): + """Initialize the OpenAI client.""" + if self.client is None: + self.client = OpenAI( + base_url=self.base_url, api_key=self.api_key, timeout=self.timeout + ) + + def _get_base64_image(self, image_path: str) -> str: + """ + Convert an image file to Base64-encoded string. + + Args: + image_path: Path to the image file + + Returns: + Base64-encoded string of the image + + Raises: + FileNotFoundError: If the image file doesn't exist + BestShotPickerError: If encoding fails + """ + try: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + except FileNotFoundError: + raise FileNotFoundError(f"Image file not found: {image_path}") + except Exception as e: + raise BestShotPickerError(f"Failed to encode image {image_path}: {e}") + + def _get_overlay_encoded_image( + self, + image_path: str, + mime_type: str, + label: str, + pil_image: Optional[Any] = None, + ) -> tuple[str, str]: + """Return the encoded payload for an image with a debug label overlay.""" + + if Image is None or ImageDraw is None or ImageFont is None: + raise BestShotPickerError( + "Pillow is required for debug overlay rendering" + ) + + try: + if pil_image is not None: + img = pil_image.convert("RGBA") + else: + with Image.open(image_path) as opened_img: + img = opened_img.convert("RGBA") + + draw = ImageDraw.Draw(img) + width, height = img.size + if width <= 0 or height <= 0: + raise BestShotPickerError( + f"Image {image_path} has invalid dimensions" + ) + + text = (label or "?").strip() + base_dimension = max(48, int(min(width, height) * 0.18)) + padding = max(6, base_dimension // 6) + + try: + font = ImageFont.truetype("DejaVuSans-Bold.ttf", base_dimension) + except (OSError, IOError): # pragma: no cover - font fallback + font = ImageFont.load_default() + + text_bbox = draw.textbbox((0, 0), text, font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + rect_width = text_width + padding * 2 + rect_height = text_height + padding * 2 + rect_coords = ( + padding, + padding, + padding + rect_width, + padding + rect_height, + ) + + draw.rectangle(rect_coords, fill=(0, 0, 0, 192)) + + text_x = padding + (rect_width - text_width) / 2 + text_y = padding + (rect_height - text_height) / 2 + draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) + + return self._encode_pil_image(img, mime_type) + + except FileNotFoundError: + raise + except Exception as exc: + source = image_path if pil_image is None else f"{image_path} (preview)" + raise BestShotPickerError( + f"Failed to render debug overlay for {source}: {exc}" + ) from exc + + def _encode_pil_image( + self, pil_image: Any, mime_type: str + ) -> tuple[str, str]: + if Image is None: + raise BestShotPickerError("Pillow is required for image encoding") + + buffer = io.BytesIO() + if mime_type == "image/jpeg": + pil_image.convert("RGB").save(buffer, format="JPEG", quality=90) + effective_mime = "image/jpeg" + else: + pil_image.convert("RGBA").save(buffer, format="PNG") + effective_mime = "image/png" + + encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") + return encoded, effective_mime + + def _prepare_encoded_image( + self, + image_path: str, + mime_type: str, + debug_label: str | None, + pil_image: Optional[Any] = None, + ) -> tuple[str, str]: + """ + Prepare the base64 payload for an image, optionally adding a debug label. + + Returns: + Tuple containing (base64_data, mime_type_used). + """ + + if debug_label: + return self._get_overlay_encoded_image( + image_path, mime_type, debug_label, pil_image + ) + + if pil_image is not None: + return self._encode_pil_image(pil_image, mime_type) + + return self._get_base64_image(image_path), mime_type + + def prepare_preview_payload( + self, + image_path: str, + pil_image: Any, + overlay_label: Optional[str], + mime_type: str = "image/jpeg", + ) -> tuple[str, str]: + """Encode a pre-generated preview image for AI analysis.""" + + label = overlay_label if self.debug_overlay_enabled else None + return self._prepare_encoded_image( + image_path=image_path, + mime_type=mime_type, + debug_label=label, + pil_image=pil_image, + ) + + def _get_image_url(self, image_path: str) -> str: + """ + Convert an image file path to a file:// URL or return existing URL. + + Args: + image_path: Path to the image file or URL + + Returns: + URL string for the image + + Raises: + FileNotFoundError: If the image file doesn't exist + """ + # Check if it's already a URL (http/https) + if image_path.startswith(("http://", "https://")): + return image_path + + # Convert local file path to file:// URL + try: + path = Path(image_path) + if not path.exists(): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # Convert to absolute path and create file:// URL + absolute_path = path.resolve() + return absolute_path.as_uri() + except Exception as e: + raise BestShotPickerError(f"Failed to create URL for image {image_path}: {e}") + + def _build_prompt( + self, image_count: int, include_debug_instruction: bool = False + ) -> str: + """ + Build the prompt for the AI to analyze images. + + Args: + image_count: Number of images being analyzed + + Returns: + Formatted prompt string + """ + prompt = f"""You are an expert photography critic tasked with selecting the best image from a set of {image_count} images. + +Analyze each image based on the following criteria: +- **Sharpness and Focus**: Is the subject in focus? Are there motion blur or focus issues? +- **Color/Lightining**: Color & Lighting – accuracy, contrast, saturation, white balance, and visual appeal. +- **Composition**: Does the image follow framing, subject placement, use of space principles? +- **Subject Expression**: For portraits, does the subject have a good expression (eyes open, natural smile)? +- **Technical Quality**: Are there any artifacts, noise, or technical issues? +- **Overall Appeal**: Which image is most visually appealing? +- **Editing Potential**: – how well the image could respond to color grading, retouching, or enhancement. +- **Subject Sharpness** – focus quality, motion blur, and clarity of the main subject. + +If any person’s eyes are closed, the photo automatically receives a low rating (1–2) regardless of other factors, unless it's a clear visual choice. + +Please analyze each image and then provide your response in the following format: + +**Best Image**: [Image number, 1-{image_count}] +**Confidence**: [High/Medium/Low] +**Reasoning**: [Brief explanation of why this image is the best] + +Be decisive and pick ONE image as the best, even if the differences are subtle.""" + + if include_debug_instruction: + prompt += ( + "\n\nEach image has a bold verification number in the corner. " + "In your response add a new line exactly as follows:" + "\n**Overlay Number**: [the number printed on the selected image]" + "\nAlso reference the overlay numbers when comparing images in your reasoning so we can confirm alignment." + ) + + return prompt + + def _parse_response( + self, response: str, image_paths: list[str] + ) -> BestShotResult: + """ + Parse the AI response to extract the best image selection. + + Args: + response: Raw response from the AI + image_paths: List of image paths in order + + Returns: + BestShotResult with parsed information + + Raises: + BestShotPickerError: If parsing fails + """ + try: + # Look for "Best Image: X" pattern + import re + + # Try to find the image number in various formats + patterns = [ + r"(?i)\*\*Best Image\*\*:\s*(?:Image\s*)?(\d+)", # **Best Image**: Image 3 + r"(?i)Best Image:\s*(?:Image\s*)?(\d+)", # Best Image: 3 + r"(?i)Image\s*(\d+)\s+is\s+(?:the\s+)?best", # Image 3 is best + r"(?i)select(?:ed)?\s+(?:image\s*)?(\d+)", # Selected image 3 + r"(?i)choose\s+(?:image\s*)?(\d+)", # Choose image 3 + ] + + best_index = None + for pattern in patterns: + match = re.search(pattern, response) + if match: + image_num = int(match.group(1)) + logger.debug(f"Parsed image number from response: {image_num}") + if 1 <= image_num <= len(image_paths): + best_index = image_num - 1 # Convert to 0-based index + logger.debug(f"Converted to 0-based index: {best_index}") + logger.debug(f"Corresponding path: {image_paths[best_index]}") + break + else: + logger.warning(f"Image number {image_num} out of range (1-{len(image_paths)})") + + if best_index is None: + logger.warning(f"Could not parse image number from response: {response}") + # Default to first image if parsing fails + best_index = 0 + reasoning = "Failed to parse AI response. Defaulting to first image." + confidence = "Unknown" + else: + # Extract reasoning + reasoning_match = re.search( + r"(?i)\*\*Reasoning\*\*:\s*(.+?)(?:\n\n|\Z)", response, re.DOTALL + ) + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + else: + # Try to extract any explanation text + reasoning = response.split("**Reasoning**:")[-1].strip() + if not reasoning: + reasoning = "No detailed reasoning provided." + + # Extract confidence + confidence_match = re.search( + r"(?i)\*\*Confidence\*\*:\s*(\w+)", response + ) + if confidence_match: + confidence = confidence_match.group(1) + else: + confidence = "Not specified" + + return BestShotResult( + best_image_index=best_index, + best_image_path=image_paths[best_index], + reasoning=reasoning, + confidence=confidence, + raw_response=response, + ) + + except Exception as e: + logger.error(f"Failed to parse AI response: {e}") + raise BestShotPickerError(f"Failed to parse AI response: {e}") + + def select_best_image( + self, + image_paths: list[str], + max_tokens: int = 1000, + stream: bool = False, + preview_overrides: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> BestShotResult: + """ + Analyze multiple images and select the best one. + + Args: + image_paths: List of paths to images to analyze + max_tokens: Maximum tokens in the response + stream: Whether to stream the response + preview_overrides: Optional mapping of image paths to pre-encoded + preview payloads containing ``base64``, ``mime_type``, and an + optional ``overlay_label``. + + Returns: + BestShotResult containing the selection and reasoning + + Raises: + ValueError: If image_paths is empty or contains only one image + BestShotPickerError: If the API call or analysis fails + """ + if not image_paths: + raise ValueError("No images provided for analysis") + + if len(image_paths) == 1: + logger.info("Only one image provided, returning it as the best") + return BestShotResult( + best_image_index=0, + best_image_path=image_paths[0], + reasoning="Only one image provided", + confidence="High", + raw_response="Single image - no comparison needed", + ) + + logger.info(f"Analyzing {len(image_paths)} images to select the best one") + + prepared_images: list[dict[str, Any]] = [] + skipped_paths: list[str] = [] + overrides = preview_overrides or {} + mime_type_map = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpe": "image/jpeg", + ".jfif": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + ".gif": "image/gif", + ".bmp": "image/bmp", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".heif": "image/heif", + ".heic": "image/heic", + ".avif": "image/avif", + } + + for original_index, image_path in enumerate(image_paths): + debug_label = ( + str(len(prepared_images) + 1) if self.debug_overlay_enabled else None + ) + + try: + override_payload = overrides.get(image_path) + overlay_label = None + + if override_payload: + base64_image = override_payload.get("base64") + if not base64_image: + logger.warning( + "Preview override missing base64 data for %s; falling back to source", + image_path, + ) + override_payload = None + else: + effective_mime = override_payload.get( + "mime_type", "image/jpeg" + ) + overlay_label = override_payload.get("overlay_label") or debug_label + + if not override_payload: + ext = Path(image_path).suffix.lower() + mime_type = mime_type_map.get(ext) + if not mime_type: + guessed_type, _ = mimetypes.guess_type(image_path) + mime_type = guessed_type or "image/jpeg" + if not mime_type.startswith("image/"): + logger.debug( + "Unsupported mime type %s for %s, defaulting to image/jpeg", + mime_type, + image_path, + ) + mime_type = "image/jpeg" + + try: + base64_image, effective_mime = self._prepare_encoded_image( + image_path, mime_type, debug_label + ) + overlay_label = debug_label + except BestShotPickerError as overlay_error: + logger.debug( + "Failed to prepare debug overlay for %s: %s. " + "Falling back to raw encoding without overlay.", + image_path, + overlay_error, + ) + base64_image = self._get_base64_image(image_path) + effective_mime = mime_type + overlay_label = None + + except FileNotFoundError: + logger.warning(f"Image not found: {image_path}, skipping") + skipped_paths.append(image_path) + continue + except BestShotPickerError as e: + logger.warning(f"Failed to encode image {image_path}: {e}, skipping") + skipped_paths.append(image_path) + continue + + prepared_images.append( + { + "original_index": original_index, + "path": image_path, + "base64": base64_image, + "mime_type": effective_mime, + "overlay_label": overlay_label, + } + ) + + if not prepared_images: + logger.error("No valid images available for analysis after preprocessing") + raise BestShotPickerError("No valid images to analyze") + + if skipped_paths: + logger.info( + "Skipping %d invalid image(s) before analysis", len(skipped_paths) + ) + + logger.info("Image order being sent to AI:") + for position, item in enumerate(prepared_images, 1): + logger.info(" Position %d: %s", position, Path(item["path"]).name) + + self._initialize_client() + + content = [ + { + "type": "text", + "text": self._build_prompt( + len(prepared_images), self.debug_overlay_enabled + ), + } + ] + + for position, item in enumerate(prepared_images, 1): + image_name = Path(item["path"]).name + logger.debug( + "Adding Image %d: %s (original index %d)", + position, + image_name, + item["original_index"], + ) + encoded_data = item.pop("base64") + description_lines = [f"\n**Image {position}** ({image_name}):"] + if self.debug_overlay_enabled and item.get("overlay_label"): + description_lines.append( + f"Overlay number visible on this image: **{item['overlay_label']}**" + ) + + content.append( + { + "type": "text", + "text": "\n".join(description_lines), + } + ) + content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{item['mime_type']};base64,{encoded_data}" + }, + } + ) + + # Make API call + try: + logger.debug(f"Sending request to {self.base_url}") + completion = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": content}], + max_tokens=max_tokens, + stream=stream, + ) + + response_text = completion.choices[0].message.content + logger.debug(f"Received response: {response_text[:200]}...") + + # Parse and return result + analysis_paths = [item["path"] for item in prepared_images] + result = self._parse_response(response_text, analysis_paths) + + chosen_entry = prepared_images[result.best_image_index] + result.best_image_index = chosen_entry["original_index"] + result.best_image_path = image_paths[result.best_image_index] + + logger.info( + "Selected image %d/%d: %s", + result.best_image_index + 1, + len(image_paths), + Path(result.best_image_path).name, + ) + + return result + + except Exception as e: + logger.error(f"API call failed: {e}") + raise BestShotPickerError(f"Failed to analyze images: {e}") + + def test_connection(self) -> bool: + """ + Test if the API endpoint is accessible and responding. + + Returns: + True if connection is successful, False otherwise + """ + try: + self._initialize_client() + # Simple test with a basic text-only message + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": "Hello, respond with OK"}], + } + ], + max_tokens=10, + ) + logger.debug(f"Connection test successful: {response.choices[0].message.content if response.choices else 'No response'}") + return True + except Exception as e: + logger.error(f"Connection test failed: {e}") + return False diff --git a/src/core/app_settings.py b/src/core/app_settings.py index d839370..074f5a6 100644 --- a/src/core/app_settings.py +++ b/src/core/app_settings.py @@ -52,6 +52,10 @@ def from_string(cls, value: str) -> "PerformanceMode": CUSTOM_THREAD_COUNT_KEY = ( "Performance/CustomThreadCount" # User-defined thread count for custom mode ) +AI_BEST_SHOT_API_URL_KEY = "AI/BestShotApiUrl" # API URL for best shot picker +AI_BEST_SHOT_API_KEY_KEY = "AI/BestShotApiKey" # API key for best shot picker +AI_BEST_SHOT_MODEL_KEY = "AI/BestShotModel" # Model name for best shot picker +AI_BEST_SHOT_TIMEOUT_KEY = "AI/BestShotTimeout" # Timeout for best shot picker API calls # Default values DEFAULT_PREVIEW_CACHE_SIZE_GB = 2.0 # Default to 2 GB for preview cache @@ -62,6 +66,10 @@ def from_string(cls, value: str) -> "PerformanceMode": DEFAULT_UPDATE_CHECK_ENABLED = True # Default to enable automatic update checks DEFAULT_PERFORMANCE_MODE = PerformanceMode.BALANCED # Default to balanced mode DEFAULT_CUSTOM_THREAD_COUNT = 4 # Default custom thread count +DEFAULT_AI_BEST_SHOT_API_URL = "http://localhost:1234/v1" # Default LM Studio URL +DEFAULT_AI_BEST_SHOT_API_KEY = "not-needed" # Default API key for local LM Studio +DEFAULT_AI_BEST_SHOT_MODEL = "local-model" # Default model name +DEFAULT_AI_BEST_SHOT_TIMEOUT = 120 # Default timeout in seconds # --- UI Constants --- # Grid view settings @@ -379,3 +387,58 @@ def calculate_max_workers(min_workers: int = 1, max_workers: int = None) -> int: workers = min(max_workers, workers) return workers + + +# --- AI Best Shot Picker Settings --- +def get_ai_best_shot_api_url() -> str: + """Gets the API URL for the best shot picker.""" + settings = _get_settings() + return settings.value( + AI_BEST_SHOT_API_URL_KEY, DEFAULT_AI_BEST_SHOT_API_URL, type=str + ) + + +def set_ai_best_shot_api_url(url: str): + """Sets the API URL for the best shot picker.""" + settings = _get_settings() + settings.setValue(AI_BEST_SHOT_API_URL_KEY, url) + + +def get_ai_best_shot_api_key() -> str: + """Gets the API key for the best shot picker.""" + settings = _get_settings() + return settings.value( + AI_BEST_SHOT_API_KEY_KEY, DEFAULT_AI_BEST_SHOT_API_KEY, type=str + ) + + +def set_ai_best_shot_api_key(api_key: str): + """Sets the API key for the best shot picker.""" + settings = _get_settings() + settings.setValue(AI_BEST_SHOT_API_KEY_KEY, api_key) + + +def get_ai_best_shot_model() -> str: + """Gets the model name for the best shot picker.""" + settings = _get_settings() + return settings.value(AI_BEST_SHOT_MODEL_KEY, DEFAULT_AI_BEST_SHOT_MODEL, type=str) + + +def set_ai_best_shot_model(model: str): + """Sets the model name for the best shot picker.""" + settings = _get_settings() + settings.setValue(AI_BEST_SHOT_MODEL_KEY, model) + + +def get_ai_best_shot_timeout() -> int: + """Gets the timeout for best shot picker API calls.""" + settings = _get_settings() + return settings.value( + AI_BEST_SHOT_TIMEOUT_KEY, DEFAULT_AI_BEST_SHOT_TIMEOUT, type=int + ) + + +def set_ai_best_shot_timeout(timeout: int): + """Sets the timeout for best shot picker API calls.""" + settings = _get_settings() + settings.setValue(AI_BEST_SHOT_TIMEOUT_KEY, timeout) diff --git a/src/ui/app_controller.py b/src/ui/app_controller.py index 1cb8f3a..5132248 100644 --- a/src/ui/app_controller.py +++ b/src/ui/app_controller.py @@ -270,6 +270,9 @@ def load_folder(self, folder_path: str): self.main_window.cluster_sort_combo.setCurrentIndex(0) self.main_window.menu_manager.group_by_similarity_action.setEnabled(False) self.main_window.menu_manager.group_by_similarity_action.setChecked(False) + self.main_window.menu_manager.pick_best_shots_for_clusters_action.setEnabled( + False + ) self.main_window.file_system_model.clear() self.main_window.file_system_model.setColumnCount(1) @@ -610,6 +613,9 @@ def handle_clustering_complete(self, cluster_results_dict: Dict[str, int]): self.main_window.statusBar().showMessage( "Clustering did not produce results.", 3000 ) + self.main_window.menu_manager.pick_best_shots_for_clusters_action.setEnabled( + False + ) return self.main_window.update_loading_text("Clustering complete. Updating view...") @@ -626,6 +632,9 @@ def handle_clustering_complete(self, cluster_results_dict: Dict[str, int]): ): self.main_window.menu_manager.cluster_sort_action.setVisible(True) self.main_window.cluster_sort_combo.setEnabled(True) + self.main_window.menu_manager.pick_best_shots_for_clusters_action.setEnabled( + bool(cluster_ids) + ) if self.main_window.group_by_similarity_mode: self.main_window._rebuild_model_view() self.main_window.hide_loading_overlay() @@ -636,6 +645,9 @@ def handle_similarity_error(self, message): self.main_window.menu_manager.analyze_similarity_action.setEnabled( bool(self.app_state.image_files_data) ) + self.main_window.menu_manager.pick_best_shots_for_clusters_action.setEnabled( + False + ) self.main_window.hide_loading_overlay() def handle_blur_detection_progress( diff --git a/src/ui/app_state.py b/src/ui/app_state.py index df36f2a..a068717 100644 --- a/src/ui/app_state.py +++ b/src/ui/app_state.py @@ -31,6 +31,7 @@ def __init__(self): ) # Instance of the new disk cache for ratings self.exif_disk_cache = ExifCache() # Instance of the new disk cache for EXIF data, now reads size from app_settings self.marked_for_deletion: set = set() # Set of file paths marked for deletion + self.best_shot_paths: set[str] = set() # Paths flagged as AI-selected best shots # Could also hold current folder path, filter states, etc. if desired. self.current_folder_path: Optional[str] = None @@ -46,6 +47,7 @@ def clear_all_file_specific_data(self): self.cluster_results.clear() self.embeddings_cache.clear() self.marked_for_deletion.clear() # Clear marked for deletion set + self.best_shot_paths.clear() if self.rating_disk_cache: self.rating_disk_cache.clear() # Decide if folder clear should wipe the whole disk cache if self.exif_disk_cache: @@ -71,6 +73,7 @@ def remove_data_for_path(self, file_path: str): date_removed = self.date_cache.pop(file_path, None) cluster_removed = self.cluster_results.pop(file_path, None) embedding_removed = self.embeddings_cache.pop(file_path, None) + self.best_shot_paths.discard(file_path) logger.debug( f"Removed data for {os.path.basename(file_path)}: " @@ -114,6 +117,10 @@ def update_path(self, old_path: str, new_path: str): if self.focused_image_path == old_path: self.focused_image_path = new_path + if old_path in self.best_shot_paths: + self.best_shot_paths.discard(old_path) + self.best_shot_paths.add(new_path) + # Add more methods as needed, e.g., to get specific data, # update blur status, etc. def update_blur_status(self, file_path: str, is_blurred: Optional[bool]): diff --git a/src/ui/controllers/best_shot_picker_controller.py b/src/ui/controllers/best_shot_picker_controller.py new file mode 100644 index 0000000..238ac8a --- /dev/null +++ b/src/ui/controllers/best_shot_picker_controller.py @@ -0,0 +1,285 @@ +""" +Best Shot Picker Controller +Manages the AI-powered best shot selection feature. +""" + +import logging +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +from PyQt6.QtCore import Qt, pyqtSignal, QItemSelectionModel +from PyQt6.QtWidgets import ( + QMessageBox, + QDialog, + QLabel, + QPushButton, + QVBoxLayout, + QHBoxLayout, +) + +from core.ai.best_shot_picker import BestShotResult + +if TYPE_CHECKING: + from ui.main_window import MainWindow + +logger = logging.getLogger(__name__) + + +class _BestShotProgressDialog(QDialog): + """Frameless, styled dialog for long-running AI analysis.""" + + cancelled = pyqtSignal() + + def __init__(self, parent=None): + super().__init__(parent) + self.setModal(True) + self.setWindowTitle("AI Best Shot Picker") + self.setObjectName("aiBestShotProgressDialog") + self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose) + self.setWindowFlags(self.windowFlags() | Qt.WindowType.FramelessWindowHint) + self.setMinimumWidth(360) + + layout = QVBoxLayout(self) + layout.setSpacing(16) + layout.setContentsMargins(24, 24, 24, 24) + + title_label = QLabel("Analyzing Selection") + title_label.setObjectName("aiBestShotProgressTitle") + title_label.setStyleSheet("font-weight: bold; font-size: 14px;") + layout.addWidget(title_label) + + self.status_label = QLabel("Preparing analysis...") + self.status_label.setObjectName("aiBestShotProgressStatus") + self.status_label.setWordWrap(True) + self.status_label.setStyleSheet("color: #767676; font-size: 11px;") + layout.addWidget(self.status_label) + + layout.addStretch() + + button_row = QHBoxLayout() + button_row.addStretch() + self.cancel_button = QPushButton("Cancel Analysis") + self.cancel_button.setObjectName("aiBestShotProgressCancelButton") + self.cancel_button.clicked.connect(self.cancelled.emit) + button_row.addWidget(self.cancel_button) + layout.addLayout(button_row) + + def update_status(self, message: str): + if message: + self.status_label.setText(message) + + def set_cancel_enabled(self, enabled: bool): + self.cancel_button.setEnabled(enabled) + + +class BestShotPickerController: + """ + Controller for AI-powered best shot selection. + + Manages the workflow of: + 1. Getting selected images from the UI + 2. Running the AI analysis in a background thread + 3. Presenting the results to the user + 4. Optionally selecting the best image in the UI + """ + + def __init__(self, main_window: "MainWindow"): + """Initialize the controller.""" + self.main_window = main_window + self.worker_manager = self.main_window.worker_manager + self.current_image_paths: list[str] = [] + self.progress_dialog: Optional[_BestShotProgressDialog] = None + + self.worker_manager.best_shot_progress.connect(self._on_progress) + self.worker_manager.best_shot_result_ready.connect(self._on_result_ready) + self.worker_manager.best_shot_error.connect(self._on_error) + self.worker_manager.best_shot_finished.connect(self._on_finished) + + def can_pick_best_shot(self) -> bool: + """Return True when 2+ images are selected and no analysis is running.""" + if self.worker_manager.is_best_shot_running(): + return False + + selected_paths = self.main_window.selection_controller.get_selected_file_paths() + return len(selected_paths) >= 2 + + def start_analysis(self): + """Start the best shot analysis for the current selection.""" + if self.worker_manager.is_best_shot_running(): + QMessageBox.information( + self.main_window, + "Analysis In Progress", + "AI best shot analysis is already running. Please wait for it to finish.", + ) + return + + selected_paths = self.main_window.selection_controller.get_selected_file_paths() + + if len(selected_paths) < 2: + QMessageBox.warning( + self.main_window, + "Not Enough Images", + "Please select at least 2 images to pick the best shot.", + ) + return + + logger.info(f"Starting best shot analysis for {len(selected_paths)} images") + logger.info("Images retrieved from selection (in order):") + for idx, path in enumerate(selected_paths, 1): + logger.info(" Selection position %d: %s", idx, Path(path).name) + + self.current_image_paths = list(selected_paths) + + self._show_progress_dialog() + self._set_action_enabled(False) + + try: + self.worker_manager.start_best_shot_analysis(self.current_image_paths) + except ValueError as exc: + logger.error("Failed to start best shot analysis: %s", exc) + self._close_progress_dialog() + self._set_action_enabled(True) + QMessageBox.critical(self.main_window, "Analysis Failed", str(exc)) + + def _show_progress_dialog(self): + self._close_progress_dialog() + self.progress_dialog = _BestShotProgressDialog(self.main_window) + self.progress_dialog.cancelled.connect(self._on_cancel) + self.progress_dialog.update_status("Testing API connection...") + self.progress_dialog.show() + + def _on_progress(self, message: str): + """Display progress updates from the worker in the dialog.""" + if self.progress_dialog and message: + self.progress_dialog.update_status(message) + + def _on_result_ready(self, result: BestShotResult): + """Handle analysis result and show it to the user.""" + logger.info("Best shot selected: %s", result.best_image_path) + self.main_window.update_best_shot_labels( + [result.best_image_path], replace=True + ) + self._close_progress_dialog() + self._show_result_dialog(result) + + def _on_error(self, error_message: str): + """Handle an error emitted by the worker.""" + logger.error("Best shot analysis error: %s", error_message) + self._close_progress_dialog() + QMessageBox.critical( + self.main_window, + "Analysis Failed", + f"Failed to analyze images:\n\n{error_message}", + ) + self._set_action_enabled(True) + + def _on_finished(self, success: bool): + """Handle worker completion regardless of result.""" + logger.info("Best shot analysis finished (success: %s)", success) + self._set_action_enabled(True) + if not success: + self.main_window.statusBar().showMessage( + "AI best shot analysis cancelled.", 5000 + ) + self._close_progress_dialog() + self.current_image_paths = [] + + def _on_cancel(self): + """Handle user cancelling the dialog.""" + self.worker_manager.stop_best_shot_analysis() + self.main_window.statusBar().showMessage( + "Cancelling AI best shot analysis...", 3000 + ) + if self.progress_dialog: + self.progress_dialog.set_cancel_enabled(False) + self.progress_dialog.update_status("Cancelling analysis...") + + def cleanup(self): + """Cleanup resources when shutting down.""" + self.worker_manager.stop_best_shot_analysis() + self._close_progress_dialog() + self.current_image_paths = [] + self._set_action_enabled(True) + + def _close_progress_dialog(self): + if self.progress_dialog: + try: + self.progress_dialog.close() + finally: + self.progress_dialog.deleteLater() + self.progress_dialog = None + + def _set_action_enabled(self, enabled: bool): + try: + action = self.main_window.menu_manager.pick_best_shot_action + action.setEnabled(enabled) + except Exception: # pragma: no cover - defensive + logger.debug("Failed to toggle best shot action state", exc_info=True) + + def _show_result_dialog(self, result: BestShotResult): + """Show a dialog with the analysis results.""" + image_name = Path(result.best_image_path).name + image_num = result.best_image_index + 1 + total_images = len(self.current_image_paths) + + message = f"""

Best Image Selected

+

Image: {image_name}
+Position: {image_num} of {total_images}
+Confidence: {result.confidence}

+ +

Reasoning:
+{result.reasoning}

+ +

Would you like to select this image in the viewer?

""" + + reply = QMessageBox.question( + self.main_window, + "Best Shot Selected", + message, + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + + if reply == QMessageBox.StandardButton.Yes: + self._select_image_in_ui(result.best_image_path) + + def _select_image_in_ui(self, image_path: str): + """Select the specified image in the UI.""" + try: + logger.info("Selecting image in UI: %s", image_path) + + active_view = self.main_window._get_active_file_view() + if not active_view: + logger.warning("No active view available") + return + + proxy_index = self.main_window._find_proxy_index_for_path(image_path) + if not proxy_index or not proxy_index.isValid(): + logger.warning("Could not find index for path: %s", image_path) + QMessageBox.information( + self.main_window, + "Navigation Failed", + "Could not find the selected image in the current view.\n\n" + "The image may be filtered out or in a different folder.", + ) + return + + selection_model = active_view.selectionModel() + if selection_model: + selection_model.setCurrentIndex( + proxy_index, + QItemSelectionModel.SelectionFlag.ClearAndSelect, + ) + active_view.scrollTo(proxy_index) + logger.info("Successfully selected best image in UI") + else: + logger.warning("No selection model available") + + except Exception as exc: # pragma: no cover - UI safety + logger.error("Failed to select image in UI: %s", exc) + QMessageBox.warning( + self.main_window, + "Selection Failed", + f"Could not select the image in the UI:\n{exc}", + ) + diff --git a/src/ui/controllers/cluster_best_shot_controller.py b/src/ui/controllers/cluster_best_shot_controller.py new file mode 100644 index 0000000..cd76308 --- /dev/null +++ b/src/ui/controllers/cluster_best_shot_controller.py @@ -0,0 +1,355 @@ +"""Controller for running the AI best shot picker across similarity clusters.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import List, Optional, TYPE_CHECKING + +from PyQt6.QtCore import Qt, pyqtSignal, QItemSelectionModel +from PyQt6.QtWidgets import ( + QDialog, + QHBoxLayout, + QLabel, + QListWidget, + QListWidgetItem, + QMessageBox, + QPushButton, + QVBoxLayout, + QApplication, +) +from PyQt6.QtCore import QEventLoop + +from core.ai.best_shot_picker import BestShotResult + +if TYPE_CHECKING: + from ui.main_window import MainWindow + +logger = logging.getLogger(__name__) + + +class _ClusterBestShotProgressDialog(QDialog): + """Progress dialog specialised for cluster-wide analysis.""" + + cancelled = pyqtSignal() + + def __init__(self, parent=None): + super().__init__(parent) + self.setModal(True) + self.setWindowTitle("AI Cluster Best Shots") + self.setObjectName("aiClusterBestShotProgressDialog") + self.setWindowFlags(self.windowFlags() | Qt.WindowType.FramelessWindowHint) + self.setMinimumWidth(380) + + layout = QVBoxLayout(self) + layout.setSpacing(16) + layout.setContentsMargins(24, 24, 24, 24) + + title_label = QLabel("Analyzing Similarity Groups") + title_label.setStyleSheet("font-weight: bold; font-size: 14px;") + layout.addWidget(title_label) + + self.cluster_label = QLabel("Cluster 0 of 0") + self.cluster_label.setStyleSheet("font-size: 12px; color: #4a4a4a;") + layout.addWidget(self.cluster_label) + + self.status_label = QLabel("Preparing analysis...") + self.status_label.setWordWrap(True) + self.status_label.setStyleSheet("color: #767676; font-size: 11px;") + layout.addWidget(self.status_label) + + layout.addStretch() + + button_row = QHBoxLayout() + button_row.addStretch() + self.cancel_button = QPushButton("Cancel Analysis") + self.cancel_button.clicked.connect(self.cancelled.emit) + button_row.addWidget(self.cancel_button) + layout.addLayout(button_row) + + def update_cluster_position(self, current: int, total: int): + self.cluster_label.setText(f"Cluster {current} of {total}") + + def update_status(self, message: str): + if message: + self.status_label.setText(message) + + def set_cancel_enabled(self, enabled: bool): + self.cancel_button.setEnabled(enabled) + + +@dataclass +class ClusterBestShotSummaryItem: + cluster_id: int + result: BestShotResult + image_paths: List[str] + index: int + total: int + + +class ClusterBestShotSummaryDialog(QDialog): + """Displays the results for each cluster and offers quick navigation.""" + + def __init__( + self, + items: List[ClusterBestShotSummaryItem], + select_callback, + parent=None, + ): + super().__init__(parent) + self._items = items + self._select_callback = select_callback + self.setWindowTitle("AI Cluster Best Shots") + self.setMinimumWidth(500) + + layout = QVBoxLayout(self) + layout.setSpacing(14) + layout.setContentsMargins(20, 20, 20, 20) + + heading = QLabel( + f"Analyzed {len(items)} cluster{'s' if len(items) != 1 else ''}." + ) + heading.setStyleSheet("font-weight: bold; font-size: 13px;") + layout.addWidget(heading) + + self.list_widget = QListWidget() + self.list_widget.setSelectionMode(QListWidget.SelectionMode.ExtendedSelection) + self.list_widget.itemActivated.connect(self._handle_item_activated) + layout.addWidget(self.list_widget) + + for item in items: + filename = item.result.best_image_path.split("/")[-1] + text = ( + f"Group {item.cluster_id}: {filename} — Confidence {item.result.confidence}" + ) + list_item = QListWidgetItem(text) + list_item.setToolTip(item.result.reasoning) + list_item.setData(Qt.ItemDataRole.UserRole, item.result.best_image_path) + self.list_widget.addItem(list_item) + + button_row = QHBoxLayout() + self.select_button = QPushButton("Select Winners in Viewer") + self.select_button.clicked.connect(self._select_all) + button_row.addWidget(self.select_button) + + close_button = QPushButton("Close") + close_button.clicked.connect(self.accept) + button_row.addWidget(close_button) + + layout.addLayout(button_row) + + def _select_all(self): + paths = [ + item.result.best_image_path + for item in self._items + if item.result.best_image_path + ] + self._select_callback(paths) + + def _handle_item_activated(self, list_item: QListWidgetItem): + path = list_item.data(Qt.ItemDataRole.UserRole) + if path: + self._select_callback([path]) + + +class ClusterBestShotController: + """Coordinates running the AI best shot picker across similarity clusters.""" + + def __init__(self, main_window: "MainWindow"): + self.main_window = main_window + self.worker_manager = main_window.worker_manager + self.progress_dialog: Optional[_ClusterBestShotProgressDialog] = None + self._results: List[ClusterBestShotSummaryItem] = [] + self._total_clusters = 0 + + self.worker_manager.best_shot_clusters_progress.connect(self._on_progress) + self.worker_manager.best_shot_clusters_result.connect(self._on_cluster_result) + self.worker_manager.best_shot_clusters_error.connect(self._on_error) + self.worker_manager.best_shot_clusters_finished.connect(self._on_finished) + + def start_analysis(self): + if self.worker_manager.is_best_shot_clusters_running(): + QMessageBox.information( + self.main_window, + "Analysis In Progress", + "AI cluster best shot analysis is already running.", + ) + return + + cluster_payloads = self._build_cluster_input() + if not cluster_payloads: + QMessageBox.information( + self.main_window, + "No Similarity Groups", + "Similarity analysis has not been run or produced no groups.", + ) + return + + self.main_window.update_best_shot_labels([], replace=True) + self._results.clear() + self._total_clusters = len(cluster_payloads) + self._show_progress_dialog() + self._set_action_enabled(False) + + try: + self.worker_manager.start_best_shot_clusters(cluster_payloads) + except ValueError as exc: + logger.error("Failed to start cluster best shot analysis: %s", exc) + self._close_progress_dialog() + self._set_action_enabled(True) + QMessageBox.critical(self.main_window, "Analysis Failed", str(exc)) + + def _build_cluster_input(self) -> List[tuple[int, List[str]]]: + cluster_results = getattr(self.main_window.app_state, "cluster_results", {}) + if not cluster_results: + return [] + + sort_mode = self.main_window.cluster_sort_combo.currentText() + cluster_info = self.main_window.similarity_controller.prepare_clusters( + sort_mode + ) + images_by_cluster = cluster_info.get("images_by_cluster", {}) + sorted_cluster_ids = cluster_info.get("sorted_cluster_ids", []) + + cluster_payloads: List[tuple[int, List[str]]] = [] + for cluster_id in sorted_cluster_ids: + file_data_list = images_by_cluster.get(cluster_id, []) + image_paths = [ + file_data.get("path") + for file_data in file_data_list + if isinstance(file_data, dict) and file_data.get("path") + ] + if image_paths: + cluster_payloads.append((cluster_id, image_paths)) + return cluster_payloads + + def _show_progress_dialog(self): + self._close_progress_dialog() + self.progress_dialog = _ClusterBestShotProgressDialog(self.main_window) + self.progress_dialog.cancelled.connect(self._on_cancel) + self.progress_dialog.update_cluster_position(0, self._total_clusters) + self.progress_dialog.update_status("Connecting to AI service...") + self.progress_dialog.show() + self._process_ui_events() + + def _on_progress(self, current: int, total: int, message: str): + if self.progress_dialog: + self.progress_dialog.update_cluster_position(current, total) + self.progress_dialog.update_status(message) + self._process_ui_events() + + def _on_cluster_result(self, payload: object): + if not isinstance(payload, dict): + return + result_obj = payload.get("result") + if not isinstance(result_obj, BestShotResult): + return + summary_item = ClusterBestShotSummaryItem( + cluster_id=payload.get("cluster_id"), + result=result_obj, + image_paths=payload.get("image_paths", []), + index=payload.get("index", 0), + total=payload.get("total", self._total_clusters), + ) + self._results.append(summary_item) + if self.progress_dialog: + filename = result_obj.best_image_path.split("/")[-1] + self.progress_dialog.update_status( + f"Selected {filename} for group {summary_item.cluster_id}" + ) + self._process_ui_events() + if result_obj.best_image_path: + self.main_window.update_best_shot_labels( + [result_obj.best_image_path], replace=False + ) + + def _on_error(self, message: str): + logger.error("Cluster best shot analysis error: %s", message) + self._close_progress_dialog() + self._set_action_enabled(True) + QMessageBox.critical( + self.main_window, + "Analysis Failed", + f"Failed to analyze clusters:\n\n{message}", + ) + + def _on_finished(self, success: bool, summary: object): + logger.info("Cluster best shot analysis finished (success: %s)", success) + self._set_action_enabled(True) + self._close_progress_dialog() + if not success and not self._results: + return + + dialog = ClusterBestShotSummaryDialog(self._results, self._select_paths_in_ui) + dialog.exec() + + def _on_cancel(self): + logger.info("User requested cluster best shot cancellation") + self.worker_manager.stop_best_shot_clusters() + if self.progress_dialog: + self.progress_dialog.set_cancel_enabled(False) + self.progress_dialog.update_status("Cancelling analysis...") + self.main_window.statusBar().showMessage( + "Cancelling AI cluster best shot analysis...", + 3000, + ) + self._process_ui_events() + + def cleanup(self): + self.worker_manager.stop_best_shot_clusters() + self._close_progress_dialog() + self._set_action_enabled(True) + self._results.clear() + + def _close_progress_dialog(self): + if self.progress_dialog: + try: + self.progress_dialog.close() + finally: + self.progress_dialog.deleteLater() + self.progress_dialog = None + + def _set_action_enabled(self, enabled: bool): + try: + action = self.main_window.menu_manager.pick_best_shots_for_clusters_action + action.setEnabled(enabled) + except Exception: # pragma: no cover - defensive + logger.debug("Failed to toggle cluster best shot action", exc_info=True) + + def _process_ui_events(self): + try: + QApplication.processEvents( + QEventLoop.ProcessEventsFlag.ExcludeUserInputEvents + | QEventLoop.ProcessEventsFlag.ExcludeSocketNotifiers + ) + except Exception: + logger.debug("Failed to pump UI events", exc_info=True) + + def _select_paths_in_ui(self, paths: List[str]): + if not paths: + return + view = self.main_window._get_active_file_view() + if not view: + return + selection_model = view.selectionModel() + if not selection_model: + return + + selection_model.clearSelection() + first_proxy = None + for path in paths: + proxy_index = self.main_window._find_proxy_index_for_path(path) + if proxy_index and proxy_index.isValid(): + selection_model.select( + proxy_index, + QItemSelectionModel.SelectionFlag.Select + | QItemSelectionModel.SelectionFlag.Rows, + ) + if first_proxy is None: + first_proxy = proxy_index + if first_proxy: + view.scrollTo(first_proxy) + self.main_window.statusBar().showMessage( + f"Selected {len(paths)} best image{'s' if len(paths) != 1 else ''}.", + 4000, + ) diff --git a/src/ui/controllers/deletion_mark_controller.py b/src/ui/controllers/deletion_mark_controller.py index 0af7b69..98c9cd5 100644 --- a/src/ui/controllers/deletion_mark_controller.py +++ b/src/ui/controllers/deletion_mark_controller.py @@ -1,6 +1,7 @@ # Renamed from deletion_controller.py to align with class name DeletionMarkController from __future__ import annotations from typing import Optional, List, Callable, Iterable, Tuple +import os from PyQt6.QtGui import QStandardItem, QColor from PyQt6.QtWidgets import QApplication from PyQt6.QtCore import Qt @@ -34,13 +35,18 @@ def apply_presentation( else False ) + basename = os.path.basename(file_path) + is_best = file_path in getattr(self.app_state, "best_shot_paths", set()) pres = build_presentation( - basename=item.text().split(" ")[0] - if item.text() - else file_path.split("\\")[-1], + basename=basename, is_marked=is_marked, is_blurred=is_blurred, + is_best=is_best, ) + data = item.data(Qt.ItemDataRole.UserRole) + if isinstance(data, dict): + data["is_best_shot"] = is_best + item.setData(data, Qt.ItemDataRole.UserRole) if pres.is_marked: item.setForeground(self.ORANGE) elif pres.is_blurred: diff --git a/src/ui/dialog_manager.py b/src/ui/dialog_manager.py index 132ed99..351f592 100644 --- a/src/ui/dialog_manager.py +++ b/src/ui/dialog_manager.py @@ -22,6 +22,7 @@ QStyle, QRadioButton, QSlider, + QWidget, ) from PyQt6.QtCore import Qt, QSize, QUrl, QEventLoop, QThread from PyQt6.QtGui import QIcon, QDesktopServices @@ -37,6 +38,7 @@ ModelNotFoundError, ) from workers.thumbnail_preload_worker import ThumbnailPreloadWorker +from ui.helpers.ai_best_shot_settings import AIBestShotSettingsWidget logger = logging.getLogger(__name__) @@ -430,23 +432,28 @@ def show_preferences_dialog(self): main_layout.setSpacing(20) main_layout.setContentsMargins(25, 25, 25, 25) + content_widget = QWidget() + content_layout = QVBoxLayout(content_widget) + content_layout.setSpacing(20) + content_layout.setContentsMargins(0, 0, 0, 0) + # Title title_label = QLabel("Preferences") title_label.setObjectName("aboutTitle") - main_layout.addWidget(title_label) + content_layout.addWidget(title_label) # Performance Mode Section perf_section_label = QLabel("Performance Mode") perf_section_label.setObjectName("preferencesSectionLabel") perf_section_label.setStyleSheet("font-weight: bold; font-size: 13px;") - main_layout.addWidget(perf_section_label) + content_layout.addWidget(perf_section_label) # Description desc_label = QLabel( "Control how many CPU threads PhotoSort uses for processing:" ) desc_label.setWordWrap(True) - main_layout.addWidget(desc_label) + content_layout.addWidget(desc_label) # Radio buttons for performance mode balanced_radio = QRadioButton("Balanced (Recommended)") @@ -520,12 +527,12 @@ def on_custom_toggled(checked): custom_radio.toggled.connect(on_custom_toggled) # Add all radio options to layout - main_layout.addWidget(balanced_radio) - main_layout.addWidget(balanced_desc) - main_layout.addWidget(performance_radio) - main_layout.addWidget(perf_desc) - main_layout.addWidget(custom_radio) - main_layout.addLayout(custom_control_layout) + content_layout.addWidget(balanced_radio) + content_layout.addWidget(balanced_desc) + content_layout.addWidget(performance_radio) + content_layout.addWidget(perf_desc) + content_layout.addWidget(custom_radio) + content_layout.addLayout(custom_control_layout) # Note note_label = QLabel("Note: Changes take effect immediately for new operations.") @@ -533,9 +540,21 @@ def on_custom_toggled(checked): note_label.setStyleSheet( "color: #888; font-style: italic; font-size: 11px; margin-top: 10px;" ) - main_layout.addWidget(note_label) + content_layout.addWidget(note_label) + + # AI Best Shot settings section + ai_settings_widget = AIBestShotSettingsWidget(dialog) + ai_settings_widget.setObjectName("aiBestShotSettings") + content_layout.addWidget(ai_settings_widget) + + content_layout.addStretch() + + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setFrameShape(QFrame.Shape.NoFrame) + scroll_area.setWidget(content_widget) - main_layout.addStretch() + main_layout.addWidget(scroll_area) # Buttons button_layout = QHBoxLayout() @@ -558,6 +577,7 @@ def save_preferences(): else: # custom_radio.isChecked() set_performance_mode(PerformanceMode.CUSTOM) set_custom_thread_count(thread_count_slider.value()) + ai_settings_widget.apply_settings() logger.info( f"Preferences saved: mode={get_performance_mode().value}, custom_threads={get_custom_thread_count()}" ) diff --git a/src/ui/helpers/ai_best_shot_settings.py b/src/ui/helpers/ai_best_shot_settings.py new file mode 100644 index 0000000..a582461 --- /dev/null +++ b/src/ui/helpers/ai_best_shot_settings.py @@ -0,0 +1,319 @@ +""" +AI Best Shot Picker Settings UI Component + +This module provides a reusable widget for configuring AI Best Shot Picker settings. +Can be integrated into the preferences dialog or used standalone. + +Usage: + settings_widget = AIBestShotSettingsWidget() + # Add to a dialog or layout + dialog_layout.addWidget(settings_widget) + + # Get current values + config = settings_widget.get_configuration() + + # Apply settings + settings_widget.apply_settings() +""" + +import logging +from PyQt6.QtCore import Qt, pyqtSignal +from PyQt6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QSpinBox, + QGroupBox, + QFormLayout, + QMessageBox, + QDialog, + QDialogButtonBox, +) + +from core import app_settings +from core.ai.best_shot_picker import BestShotPicker + +logger = logging.getLogger(__name__) + + +class AIBestShotSettingsWidget(QWidget): + """ + Widget for configuring AI Best Shot Picker settings. + + Provides inputs for: + - API URL + - API Key + - Model name + - Timeout + + Includes a test connection button to verify settings. + """ + + settings_changed = pyqtSignal() # Emitted when settings are modified + + def __init__(self, parent=None): + super().__init__(parent) + self._setup_ui() + self._load_settings() + + def _setup_ui(self): + """Set up the user interface.""" + layout = QVBoxLayout(self) + layout.setSpacing(15) + + # Section title + title = QLabel("AI Best Shot Picker") + title.setStyleSheet("font-weight: bold; font-size: 13px;") + layout.addWidget(title) + + # Description + desc = QLabel( + "Configure the AI service for automatic best image selection. " + "Uses vision language models to analyze and compare images." + ) + desc.setWordWrap(True) + desc.setStyleSheet("color: #666; font-size: 11px;") + layout.addWidget(desc) + + # Settings group + settings_group = QGroupBox("API Configuration") + settings_layout = QFormLayout(settings_group) + settings_layout.setSpacing(10) + + # API URL + self.api_url_input = QLineEdit() + self.api_url_input.setPlaceholderText("http://localhost:1234/v1") + self.api_url_input.textChanged.connect(self.settings_changed.emit) + settings_layout.addRow("API URL:", self.api_url_input) + + # API Key + self.api_key_input = QLineEdit() + self.api_key_input.setPlaceholderText("not-needed (for local LM Studio)") + self.api_key_input.setEchoMode(QLineEdit.EchoMode.Password) + self.api_key_input.textChanged.connect(self.settings_changed.emit) + + # Add show/hide button for API key + api_key_layout = QHBoxLayout() + api_key_layout.addWidget(self.api_key_input) + + show_key_button = QPushButton("Show") + show_key_button.setMaximumWidth(60) + show_key_button.setCheckable(True) + show_key_button.toggled.connect(self._toggle_api_key_visibility) + api_key_layout.addWidget(show_key_button) + + settings_layout.addRow("API Key:", api_key_layout) + + # Model name + self.model_input = QLineEdit() + self.model_input.setPlaceholderText("local-model") + self.model_input.textChanged.connect(self.settings_changed.emit) + settings_layout.addRow("Model:", self.model_input) + + # Timeout + self.timeout_spinbox = QSpinBox() + self.timeout_spinbox.setMinimum(10) + self.timeout_spinbox.setMaximum(600) + self.timeout_spinbox.setValue(120) + self.timeout_spinbox.setSuffix(" seconds") + self.timeout_spinbox.valueChanged.connect(self.settings_changed.emit) + settings_layout.addRow("Timeout:", self.timeout_spinbox) + + layout.addWidget(settings_group) + + # Test connection button + test_button_layout = QHBoxLayout() + test_button_layout.addStretch() + + self.test_button = QPushButton("Test Connection") + self.test_button.clicked.connect(self._test_connection) + test_button_layout.addWidget(self.test_button) + + layout.addLayout(test_button_layout) + + # Help text + help_text = QLabel( + "Quick Setup with LM Studio:
" + "1. Download LM Studio from lmstudio.ai
" + "2. Install a vision model (e.g., qwen2-vl-7b)
" + "3. Load the model with mmproj file
" + "4. Start the local server
" + "5. Use default settings above" + ) + help_text.setWordWrap(True) + help_text.setOpenExternalLinks(True) + help_text.setStyleSheet( + "background-color: #f0f0f0; padding: 10px; " + "border-radius: 5px; font-size: 11px;" + ) + layout.addWidget(help_text) + + layout.addStretch() + + def _toggle_api_key_visibility(self, show: bool): + """Toggle API key visibility.""" + if show: + self.api_key_input.setEchoMode(QLineEdit.EchoMode.Normal) + else: + self.api_key_input.setEchoMode(QLineEdit.EchoMode.Password) + + def _load_settings(self): + """Load current settings from app_settings.""" + self.api_url_input.setText(app_settings.get_ai_best_shot_api_url()) + self.api_key_input.setText(app_settings.get_ai_best_shot_api_key()) + self.model_input.setText(app_settings.get_ai_best_shot_model()) + self.timeout_spinbox.setValue(app_settings.get_ai_best_shot_timeout()) + + def _test_connection(self): + """Test the connection to the AI service.""" + logger.info("Testing AI Best Shot Picker connection...") + + # Disable button during test + self.test_button.setEnabled(False) + self.test_button.setText("Testing...") + + try: + # Create picker with current settings + picker = BestShotPicker( + base_url=self.api_url_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_URL, + api_key=self.api_key_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_KEY, + model=self.model_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_MODEL, + timeout=self.timeout_spinbox.value(), + ) + + # Test connection + if picker.test_connection(): + QMessageBox.information( + self, + "Connection Successful", + "Successfully connected to the AI service!\n\n" + "The best shot picker is ready to use.", + ) + logger.info("AI connection test successful") + else: + QMessageBox.warning( + self, + "Connection Failed", + "Could not connect to the AI service.\n\n" + "Please check:\n" + "• LM Studio is running\n" + "• A vision model is loaded\n" + "• The local server is started\n" + "• The API URL is correct", + ) + logger.warning("AI connection test failed") + + except Exception as e: + QMessageBox.critical( + self, + "Connection Error", + f"An error occurred while testing the connection:\n\n{str(e)}", + ) + logger.error(f"AI connection test error: {e}") + + finally: + # Re-enable button + self.test_button.setEnabled(True) + self.test_button.setText("Test Connection") + + def apply_settings(self): + """Apply the current settings to app_settings.""" + api_url = self.api_url_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_URL + api_key = self.api_key_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_KEY + model = self.model_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_MODEL + timeout = self.timeout_spinbox.value() + + app_settings.set_ai_best_shot_api_url(api_url) + app_settings.set_ai_best_shot_api_key(api_key) + app_settings.set_ai_best_shot_model(model) + app_settings.set_ai_best_shot_timeout(timeout) + + logger.info( + f"AI Best Shot Picker settings saved: " + f"url={api_url}, model={model}, timeout={timeout}" + ) + + def get_configuration(self) -> dict: + """ + Get the current configuration as a dictionary. + + Returns: + dict: Configuration with keys: api_url, api_key, model, timeout + """ + return { + "api_url": self.api_url_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_URL, + "api_key": self.api_key_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_API_KEY, + "model": self.model_input.text() or app_settings.DEFAULT_AI_BEST_SHOT_MODEL, + "timeout": self.timeout_spinbox.value(), + } + + def reset_to_defaults(self): + """Reset all settings to their default values.""" + self.api_url_input.setText(app_settings.DEFAULT_AI_BEST_SHOT_API_URL) + self.api_key_input.setText(app_settings.DEFAULT_AI_BEST_SHOT_API_KEY) + self.model_input.setText(app_settings.DEFAULT_AI_BEST_SHOT_MODEL) + self.timeout_spinbox.setValue(app_settings.DEFAULT_AI_BEST_SHOT_TIMEOUT) + self.settings_changed.emit() + + +# Standalone dialog for testing +class AIBestShotSettingsDialog(QDialog): + """Standalone dialog for AI Best Shot Picker settings.""" + + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("AI Best Shot Picker Settings") + self.setModal(True) + self.setMinimumWidth(500) + + layout = QVBoxLayout(self) + + # Add settings widget + self.settings_widget = AIBestShotSettingsWidget(self) + layout.addWidget(self.settings_widget) + + # Add buttons + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | + QDialogButtonBox.StandardButton.Cancel | + QDialogButtonBox.StandardButton.RestoreDefaults + ) + button_box.accepted.connect(self.accept) + button_box.rejected.connect(self.reject) + + # Connect restore defaults button + restore_button = button_box.button(QDialogButtonBox.StandardButton.RestoreDefaults) + if restore_button: + restore_button.clicked.connect(self.settings_widget.reset_to_defaults) + + layout.addWidget(button_box) + + def accept(self): + """Apply settings and close dialog.""" + self.settings_widget.apply_settings() + super().accept() + + +if __name__ == "__main__": + """ + Test the settings widget standalone. + Usage: python -m ui.helpers.ai_best_shot_settings + """ + import sys + from PyQt6.QtWidgets import QApplication + + app = QApplication(sys.argv) + + # Test the dialog + dialog = AIBestShotSettingsDialog() + if dialog.exec(): + print("Settings saved!") + config = dialog.settings_widget.get_configuration() + print(f"Configuration: {config}") + else: + print("Cancelled") + + sys.exit(0) diff --git a/src/ui/helpers/deletion_utils.py b/src/ui/helpers/deletion_utils.py index d368d3b..260cacc 100644 --- a/src/ui/helpers/deletion_utils.py +++ b/src/ui/helpers/deletion_utils.py @@ -11,9 +11,15 @@ class DeletionPresentation: text: str is_marked: bool is_blurred: Optional[bool] + is_best: bool = False -def build_item_text(basename: str, is_marked: bool, is_blurred: Optional[bool]) -> str: +def build_item_text( + basename: str, + is_marked: bool, + is_blurred: Optional[bool], + is_best: bool = False, +) -> str: """Return the display text given mark + blur states. Rules (mirrors legacy inline logic): @@ -21,7 +27,10 @@ def build_item_text(basename: str, is_marked: bool, is_blurred: Optional[bool]) - Append (Blurred) when blurred. - Order: filename (DELETED) (Blurred) """ - parts = [basename] + parts = [] + if is_best: + parts.append("[BEST]") + parts.append(basename) if is_marked: parts.append("(DELETED)") if is_blurred: @@ -30,10 +39,14 @@ def build_item_text(basename: str, is_marked: bool, is_blurred: Optional[bool]) def build_presentation( - basename: str, is_marked: bool, is_blurred: Optional[bool] + basename: str, + is_marked: bool, + is_blurred: Optional[bool], + is_best: bool = False, ) -> DeletionPresentation: return DeletionPresentation( - text=build_item_text(basename, is_marked, is_blurred), + text=build_item_text(basename, is_marked, is_blurred, is_best), is_marked=is_marked, is_blurred=is_blurred, + is_best=is_best, ) diff --git a/src/ui/main_window.py b/src/ui/main_window.py index 53f0d22..2af9947 100644 --- a/src/ui/main_window.py +++ b/src/ui/main_window.py @@ -23,6 +23,7 @@ Optional, Any, Tuple, + Iterable, ) # Import List and Dict for type hinting, Optional, Any, Tuple from PyQt6.QtCore import ( Qt, @@ -83,6 +84,7 @@ from ui.controllers.similarity_controller import SimilarityController from ui.controllers.preview_controller import PreviewController from ui.controllers.metadata_controller import MetadataController +from ui.controllers.cluster_best_shot_controller import ClusterBestShotController logger = logging.getLogger(__name__) @@ -219,6 +221,8 @@ def __init__(self, initial_folder=None): self.similarity_controller = SimilarityController(self) self.preview_controller = PreviewController(self) self.metadata_controller = MetadataController(self) + self.best_shot_controller = None # Lazy initialization on first use + self.cluster_best_shot_controller = None # Hotkey controller wraps navigation key handling self.hotkey_controller = HotkeyController(self) @@ -902,7 +906,57 @@ def populate_cluster_filter(self, cluster_ids: List[int]) -> None: self.cluster_filter_combo.addItems( ["All Clusters"] + [f"Cluster {cid}" for cid in cluster_ids] ) - self.cluster_filter_combo.setEnabled(bool(cluster_ids)) + has_clusters = bool(cluster_ids) + self.cluster_filter_combo.setEnabled(has_clusters) + try: + action = self.menu_manager.pick_best_shots_for_clusters_action + allow_click = has_clusters and not self.worker_manager.is_best_shot_clusters_running() + action.setEnabled(allow_click) + except Exception: # pragma: no cover - defensive + logger.debug("Failed to toggle cluster best shot menu state", exc_info=True) + + def update_best_shot_labels(self, paths: Iterable[str], *, replace: bool) -> None: + """Update best-shot labels for the provided image paths.""" + normalized = {os.path.normpath(p) for p in paths if p} + current = set(getattr(self.app_state, "best_shot_paths", set())) + target = normalized if replace else current.union(normalized) + + if target == current: + return + + removed = current - target + added = target - current + self.app_state.best_shot_paths = target + + for path in removed: + self._apply_best_shot_presentation(path, is_best=False) + for path in added: + self._apply_best_shot_presentation(path, is_best=True) + + def _apply_best_shot_presentation(self, file_path: str, is_best: bool) -> None: + file_data = self.app_state.get_file_data_by_path(file_path) + if file_data is not None: + file_data["is_best_shot"] = is_best + + proxy_index = self._find_proxy_index_for_path(file_path) + if proxy_index and proxy_index.isValid(): + source_index = self.proxy_model.mapToSource(proxy_index) + item = self.file_system_model.itemFromIndex(source_index) + else: + item = None + + is_blurred = None + if item: + data = item.data(Qt.ItemDataRole.UserRole) + if isinstance(data, dict): + data["is_best_shot"] = is_best + is_blurred = data.get("is_blurred") + item.setData(data, Qt.ItemDataRole.UserRole) + + self.deletion_controller.apply_presentation( + item, file_path, is_blurred + ) + def get_selected_file_paths(self) -> List[str]: # For MetadataController # Prefer SelectionController; fall back only if something unexpected occurs @@ -1459,6 +1513,19 @@ def closeEvent(self, event): return logger.info("Stopping all workers on application close.") + if self.best_shot_controller: + try: + self.best_shot_controller.cleanup() + except Exception: + logger.debug("Failed to cleanup best shot controller on close", exc_info=True) + if self.cluster_best_shot_controller: + try: + self.cluster_best_shot_controller.cleanup() + except Exception: + logger.debug( + "Failed to cleanup cluster best shot controller on close", + exc_info=True, + ) self.worker_manager.stop_all_workers() # Use WorkerManager to stop all event.accept() @@ -3604,6 +3671,26 @@ def _mark_selection_for_deletion(self): override_selected_paths=original_selection_paths ) + def _pick_best_shot_with_ai(self): + """Launch the AI best shot picker for the current selection.""" + # Lazy initialization of controller + if self.best_shot_controller is None: + from ui.controllers.best_shot_picker_controller import ( + BestShotPickerController, + ) + + self.best_shot_controller = BestShotPickerController(self) + + # Start the analysis + self.best_shot_controller.start_analysis() + + def _pick_best_shots_for_similarity_clusters(self): + """Launch the AI best shot picker across all similarity clusters.""" + if self.cluster_best_shot_controller is None: + self.cluster_best_shot_controller = ClusterBestShotController(self) + + self.cluster_best_shot_controller.start_analysis() + def _mark_image_for_deletion(self, file_path: str): """Marks a single image for deletion, updating the model in-place.""" if not file_path: diff --git a/src/ui/menu_manager.py b/src/ui/menu_manager.py index c2980c1..99abe84 100644 --- a/src/ui/menu_manager.py +++ b/src/ui/menu_manager.py @@ -1,6 +1,7 @@ import logging import os import subprocess +import sys from typing import TYPE_CHECKING from PyQt6.QtCore import QPoint, Qt @@ -63,6 +64,8 @@ def __init__(self, main_window: "MainWindow"): self.unmark_for_delete_action: QAction self.commit_deletions_action: QAction self.clear_marked_deletions_action: QAction + self.pick_best_shot_action: QAction + self.pick_best_shots_for_clusters_action: QAction # Viewer Actions self.zoom_in_action: QAction @@ -213,6 +216,26 @@ def _create_actions(self): self.clear_marked_deletions_action.setShortcut(QKeySequence("Alt+D")) main_win.addAction(self.clear_marked_deletions_action) + # AI Best Shot Picker action + self.pick_best_shot_action = QAction("Pick Best Shot with AI...", main_win) + self.pick_best_shot_action.setShortcut(QKeySequence("Ctrl+B")) + self.pick_best_shot_action.setToolTip( + "Use AI to select the best image from the current selection (requires 2+ images)" + ) + main_win.addAction(self.pick_best_shot_action) + + self.pick_best_shots_for_clusters_action = QAction( + "Pick Cluster Best Shots with AI...", main_win + ) + self.pick_best_shots_for_clusters_action.setShortcut( + QKeySequence("Ctrl+Shift+B") + ) + self.pick_best_shots_for_clusters_action.setEnabled(False) + self.pick_best_shots_for_clusters_action.setToolTip( + "Run AI best shot analysis across every similarity group" + ) + main_win.addAction(self.pick_best_shots_for_clusters_action) + # About action self.about_action = QAction("&About", main_win) self.about_action.setShortcut(QKeySequence("F12")) @@ -318,6 +341,10 @@ def _create_image_menu(self, menu_bar): image_menu.addAction(self.rotate_180_action) image_menu.addSeparator() + image_menu.addAction(self.pick_best_shot_action) + image_menu.addAction(self.pick_best_shots_for_clusters_action) + image_menu.addSeparator() + image_menu.addAction(self.mark_for_delete_action) image_menu.addAction(self.commit_deletions_action) image_menu.addAction(self.clear_marked_deletions_action) @@ -362,7 +389,10 @@ def _create_settings_menu(self, menu_bar): settings_menu = menu_bar.addMenu("&Settings") self.preferences_action = QAction("Preferences...", main_win) - self.preferences_action.setShortcut(QKeySequence("F10")) + if sys.platform == "darwin": + self.preferences_action.setShortcut(QKeySequence.StandardKey.Preferences) + else: + self.preferences_action.setShortcut(QKeySequence("F10")) settings_menu.addAction(self.preferences_action) settings_menu.addSeparator() @@ -462,6 +492,14 @@ def _guarded_show_rotation_view(): main_win._clear_all_deletion_marks ) + # AI Best Shot Picker + self.pick_best_shot_action.triggered.connect( + main_win._pick_best_shot_with_ai + ) + self.pick_best_shots_for_clusters_action.triggered.connect( + main_win._pick_best_shots_for_similarity_clusters + ) + # Zoom actions self.zoom_in_action.triggered.connect( main_win.advanced_image_viewer._zoom_in_all diff --git a/src/ui/worker_manager.py b/src/ui/worker_manager.py index b24a1e1..c6fa584 100644 --- a/src/ui/worker_manager.py +++ b/src/ui/worker_manager.py @@ -1,6 +1,7 @@ import logging +import os from PyQt6.QtCore import QObject, pyqtSignal, QThread -from typing import List, Dict, Any, Optional, TYPE_CHECKING +from typing import List, Dict, Any, Optional, Tuple, TYPE_CHECKING # Import worker classes from core.file_scanner import FileScanner @@ -16,7 +17,10 @@ from workers.rating_writer_worker import RatingWriterWorker from workers.rotation_application_worker import RotationApplicationWorker from workers.thumbnail_preload_worker import ThumbnailPreloadWorker -from core.image_pipeline import ImagePipeline +from workers.best_shot_picker_worker import BestShotPickerWorker +from workers.cluster_best_shot_worker import ClusterBestShotWorker +from core.image_pipeline import ImagePipeline, PRELOAD_MAX_RESOLUTION +from core.image_processing.raw_image_processor import is_raw_extension from workers.rating_loader_worker import ( RatingLoaderWorker, ) @@ -108,6 +112,18 @@ class WorkerManager(QObject): thumbnail_preload_finished = pyqtSignal() thumbnail_preload_error = pyqtSignal(str) + # Best Shot Picker Signals + best_shot_progress = pyqtSignal(str) + best_shot_result_ready = pyqtSignal(object) + best_shot_finished = pyqtSignal(bool) + best_shot_error = pyqtSignal(str) + + # Cluster Best Shot Picker Signals + best_shot_clusters_progress = pyqtSignal(int, int, str) + best_shot_clusters_result = pyqtSignal(object) + best_shot_clusters_finished = pyqtSignal(bool, object) + best_shot_clusters_error = pyqtSignal(str) + def __init__( self, image_pipeline_instance: ImagePipeline, parent: Optional[QObject] = None ): @@ -147,6 +163,11 @@ def __init__( self.update_check_thread: Optional[QThread] = None self.update_check_worker: Optional[UpdateCheckWorker] = None + self.best_shot_thread: Optional[QThread] = None + self.best_shot_worker: Optional[BestShotPickerWorker] = None + self.best_shot_clusters_thread: Optional[QThread] = None + self.best_shot_clusters_worker: Optional[ClusterBestShotWorker] = None + def _terminate_thread( self, thread: Optional[QThread], worker_stop_method: Optional[callable] = None ): @@ -546,6 +567,8 @@ def stop_all_workers(self): self.stop_rotation_application() self.stop_thumbnail_preload() self.stop_cuda_detection() + self.stop_best_shot_clusters() + self.stop_best_shot_analysis() logger.info("All workers stop requested.") def is_file_scanner_running(self) -> bool: @@ -639,6 +662,8 @@ def is_any_worker_running(self) -> bool: or self.is_rating_writer_running() or self.is_rotation_application_running() or self.is_thumbnail_preload_running() + or self.is_best_shot_running() + or self.is_best_shot_clusters_running() ) # --- Rating Writer Management --- @@ -844,3 +869,181 @@ def stop_thumbnail_preload(self): self.thumbnail_preload_worker = None else: self.thumbnail_preload_thread = temp_thread + + def _collect_best_shot_preview_payloads( + self, image_paths: List[str] + ) -> Dict[str, Dict[str, Any]]: + """Return cached preview images suitable for AI overrides.""" + + overrides: Dict[str, Dict[str, Any]] = {} + + for image_path in image_paths: + try: + normalized_path = os.path.normpath(image_path) + ext = os.path.splitext(normalized_path)[1].lower() + apply_auto_edits = is_raw_extension(ext) + + cache_key = ( + normalized_path, + PRELOAD_MAX_RESOLUTION, + apply_auto_edits, + ) + + pil_image = self.image_pipeline.preview_cache.get(cache_key) + if pil_image is None: + continue + + overrides[image_path] = { + "pil_image": pil_image, + "mime_type": "image/jpeg", + } + + except Exception: + logger.exception( + "Failed to retrieve cached preview for %s", image_path + ) + + return overrides + + def _cleanup_best_shot_worker(self): + """Clean up the best shot picker worker and thread.""" + if self.best_shot_thread is not None: + self.best_shot_thread.quit() + self.best_shot_thread.wait() + self.best_shot_thread = None + if self.best_shot_worker: + self.best_shot_worker.deleteLater() + self.best_shot_worker = None + + def is_best_shot_running(self) -> bool: + return ( + self.best_shot_thread is not None + and self.best_shot_thread.isRunning() + ) + + def start_best_shot_analysis(self, image_paths: List[str]): + """Start the AI best shot analysis worker.""" + if self.is_best_shot_running(): + logger.warning("Best shot analysis is already running") + return + + if not image_paths: + raise ValueError("No images provided for best shot analysis") + + preview_payloads = self._collect_best_shot_preview_payloads(image_paths) + if preview_payloads: + logger.info( + "Using %d cached preview(s) for best shot analysis", + len(preview_payloads), + ) + else: + logger.info("No cached previews available; AI will load from disk") + + self.best_shot_thread = QThread() + self.best_shot_worker = BestShotPickerWorker() + self.best_shot_worker.moveToThread(self.best_shot_thread) + + self.best_shot_worker.progress.connect(self.best_shot_progress.emit) + self.best_shot_worker.result_ready.connect(self.best_shot_result_ready.emit) + self.best_shot_worker.error.connect(self.best_shot_error.emit) + self.best_shot_worker.finished.connect(self.best_shot_finished.emit) + self.best_shot_worker.finished.connect(self._cleanup_best_shot_worker) + + self.best_shot_thread.started.connect( + lambda: self.best_shot_worker.analyze_images( + list(image_paths), preview_payloads + ) + ) + + self.best_shot_thread.start() + logger.info("Best shot picker thread started.") + + def stop_best_shot_analysis(self): + """Stop the best shot analysis worker.""" + worker_stop = self.best_shot_worker.stop if self.best_shot_worker else None + temp_thread, _ = self._terminate_thread(self.best_shot_thread, worker_stop) + if temp_thread is None: + self.best_shot_thread = None + if self.best_shot_worker: + self.best_shot_worker.deleteLater() + self.best_shot_worker = None + else: + self.best_shot_thread = temp_thread + + def _cleanup_best_shot_clusters_worker(self): + if self.best_shot_clusters_thread is not None: + self.best_shot_clusters_thread.quit() + self.best_shot_clusters_thread.wait() + self.best_shot_clusters_thread = None + if self.best_shot_clusters_worker: + self.best_shot_clusters_worker.deleteLater() + self.best_shot_clusters_worker = None + + def is_best_shot_clusters_running(self) -> bool: + return ( + self.best_shot_clusters_thread is not None + and self.best_shot_clusters_thread.isRunning() + ) + + def start_best_shot_clusters( + self, cluster_payloads: List[Tuple[int, List[str]]] + ): + if self.is_best_shot_clusters_running(): + logger.warning("Cluster best shot analysis is already running") + return + + if not cluster_payloads: + raise ValueError("No clusters provided for best shot analysis") + + prepared_clusters: List[Dict[str, Any]] = [] + for cluster_id, image_paths in cluster_payloads: + if not image_paths: + continue + normalized_paths = [path for path in image_paths if path] + if not normalized_paths: + continue + previews = self._collect_best_shot_preview_payloads(normalized_paths) + prepared_clusters.append( + { + "cluster_id": cluster_id, + "image_paths": normalized_paths, + "preview_payloads": previews if previews else None, + } + ) + + if not prepared_clusters: + raise ValueError("No valid images available across clusters for analysis") + + self.best_shot_clusters_thread = QThread() + self.best_shot_clusters_worker = ClusterBestShotWorker() + self.best_shot_clusters_worker.moveToThread(self.best_shot_clusters_thread) + + worker = self.best_shot_clusters_worker + thread = self.best_shot_clusters_thread + + worker.progress.connect(self.best_shot_clusters_progress.emit) + worker.cluster_result_ready.connect(self.best_shot_clusters_result.emit) + worker.error.connect(self.best_shot_clusters_error.emit) + worker.finished.connect(self.best_shot_clusters_finished.emit) + worker.finished.connect(lambda *_: self._cleanup_best_shot_clusters_worker()) + + thread.started.connect(lambda: worker.analyze_clusters(prepared_clusters)) + thread.start() + logger.info("Cluster best shot picker thread started.") + + def stop_best_shot_clusters(self): + worker_stop = ( + self.best_shot_clusters_worker.stop + if self.best_shot_clusters_worker + else None + ) + temp_thread, _ = self._terminate_thread( + self.best_shot_clusters_thread, worker_stop + ) + if temp_thread is None: + self.best_shot_clusters_thread = None + if self.best_shot_clusters_worker: + self.best_shot_clusters_worker.deleteLater() + self.best_shot_clusters_worker = None + else: + self.best_shot_clusters_thread = temp_thread diff --git a/src/workers/best_shot_picker_worker.py b/src/workers/best_shot_picker_worker.py new file mode 100644 index 0000000..e76032b --- /dev/null +++ b/src/workers/best_shot_picker_worker.py @@ -0,0 +1,174 @@ +""" +Best Shot Picker Worker +Background worker for AI-powered best shot selection without blocking the UI. +""" + +import logging +from typing import Any, Dict, List, Optional + +from PyQt6.QtCore import QObject, pyqtSignal + +from core.ai.best_shot_picker import BestShotPicker, BestShotPickerError, BestShotResult +from core import app_settings + +logger = logging.getLogger(__name__) + + +class BestShotPickerWorker(QObject): + """Worker for analyzing images and selecting the best shot in a background thread.""" + + # Signals + progress = pyqtSignal(str) # Progress message + result_ready = pyqtSignal(object) # BestShotResult object + finished = pyqtSignal(bool) # Success status + error = pyqtSignal(str) # Error message + + def __init__(self, parent=None): + super().__init__(parent) + self._is_running = True + self.picker = None + + def stop(self): + """Signal the worker to stop processing.""" + self._is_running = False + + def analyze_images( + self, + image_paths: List[str], + preview_pil_map: Optional[Dict[str, Dict[str, Any]]] = None, + ): + """ + Analyze images and select the best one. + + Args: + image_paths: List of image file paths to analyze + preview_pil_map: Optional mapping of image paths to cached preview + payloads containing PIL images (and optional metadata) that + should be sent to the AI instead of reloading from disk. + """ + self._is_running = True + + try: + # Get settings + api_url = app_settings.get_ai_best_shot_api_url() + api_key = app_settings.get_ai_best_shot_api_key() + model = app_settings.get_ai_best_shot_model() + timeout = app_settings.get_ai_best_shot_timeout() + + logger.info( + f"Starting best shot analysis for {len(image_paths)} images " + f"using API: {api_url}" + ) + + # Log the order of images being analyzed + for idx, path in enumerate(image_paths, 1): + from pathlib import Path + logger.info(f" Image {idx}: {Path(path).name}") + + # Create picker instance + self.picker = BestShotPicker( + base_url=api_url, api_key=api_key, model=model, timeout=timeout + ) + + # Emit progress + self.progress.emit("Testing API connection...") + + # Test connection + if not self.picker.test_connection(): + error_msg = ( + "Failed to connect to AI API. Please check your settings and " + "ensure LM Studio (or compatible server) is running." + ) + logger.error(error_msg) + self.error.emit(error_msg) + self.finished.emit(False) + return + + self.progress.emit("Connection established. Preparing analysis...") + + if not self._is_running: + logger.info("Analysis cancelled by user") + self.finished.emit(False) + return + + # Analyze images + self.progress.emit( + f"Analyzing {len(image_paths)} image(s) with AI..." + ) + + preview_overrides: Dict[str, Dict[str, Any]] = {} + if preview_pil_map: + for image_path, payload in preview_pil_map.items(): + pil_image = payload.get("pil_image") + if pil_image is None: + continue + + mime_type = payload.get("mime_type", "image/jpeg") + overlay_label = payload.get("overlay_label") + + try: + encoded_data, effective_mime = self.picker.prepare_preview_payload( + image_path=image_path, + pil_image=pil_image, + overlay_label=overlay_label, + mime_type=mime_type, + ) + except BestShotPickerError as encode_error: + logger.debug( + "Failed to prepare preview override for %s: %s", + image_path, + encode_error, + ) + continue + + override_entry: Dict[str, Any] = { + "base64": encoded_data, + "mime_type": effective_mime, + } + if overlay_label: + override_entry["overlay_label"] = overlay_label + + preview_overrides[image_path] = override_entry + + if preview_overrides: + logger.info( + "Prepared %d cached preview override(s) for AI analysis", + len(preview_overrides), + ) + + result = self.picker.select_best_image( + image_paths, preview_overrides=preview_overrides or None + ) + + if not self._is_running: + logger.info("Analysis cancelled by user") + self.finished.emit(False) + return + + self.progress.emit("Analysis complete.") + + # Emit result + logger.info( + f"Best shot selected: {result.best_image_path} " + f"(confidence: {result.confidence})" + ) + self.result_ready.emit(result) + self.finished.emit(True) + + except ValueError as e: + error_msg = f"Invalid input: {e}" + logger.error(error_msg) + self.error.emit(error_msg) + self.finished.emit(False) + + except BestShotPickerError as e: + error_msg = f"Analysis failed: {e}" + logger.error(error_msg) + self.error.emit(error_msg) + self.finished.emit(False) + + except Exception as e: + error_msg = f"Unexpected error during analysis: {e}" + logger.exception(error_msg) + self.error.emit(error_msg) + self.finished.emit(False) diff --git a/src/workers/cluster_best_shot_worker.py b/src/workers/cluster_best_shot_worker.py new file mode 100644 index 0000000..2398700 --- /dev/null +++ b/src/workers/cluster_best_shot_worker.py @@ -0,0 +1,186 @@ +""" +Cluster Best Shot Worker +Runs the AI best shot picker across every similarity cluster without blocking the UI. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from PyQt6.QtCore import QObject, pyqtSignal + +from core import app_settings +from core.ai.best_shot_picker import BestShotPicker, BestShotPickerError, BestShotResult + +logger = logging.getLogger(__name__) + + +class ClusterBestShotWorker(QObject): + """Background worker that iterates through similarity clusters.""" + + progress = pyqtSignal(int, int, str) # current cluster, total clusters, message + cluster_result_ready = pyqtSignal(object) # payload describing a cluster result + finished = pyqtSignal(bool, object) # success flag, summary list + error = pyqtSignal(str) # fatal error message + + def __init__(self, parent: Optional[QObject] = None): + super().__init__(parent) + self._is_running = True + self._picker: Optional[BestShotPicker] = None + + def stop(self): + """Signal the worker to stop after the current iteration.""" + self._is_running = False + + # pylint: disable=too-many-locals + def analyze_clusters(self, clusters: List[Dict[str, Any]]): + """Analyze each cluster and emit results progressively.""" + + summary: List[Dict[str, Any]] = [] + total_clusters = len(clusters) + + if total_clusters == 0: + self.finished.emit(True, summary) + return + + try: + self._is_running = True + + self._picker = BestShotPicker( + base_url=app_settings.get_ai_best_shot_api_url(), + api_key=app_settings.get_ai_best_shot_api_key(), + model=app_settings.get_ai_best_shot_model(), + timeout=app_settings.get_ai_best_shot_timeout(), + ) + + requires_api = any( + len(cluster.get("image_paths", [])) > 1 for cluster in clusters + ) + + if requires_api: + self.progress.emit(0, total_clusters, "Testing AI service connectivity...") + if not self._picker.test_connection(): + error_msg = ( + "Failed to connect to AI service. Verify settings under " + "Preferences → AI Best Shot and ensure the server is running." + ) + logger.error(error_msg) + self.error.emit(error_msg) + self.finished.emit(False, summary) + return + + for index, cluster in enumerate(clusters, start=1): + if not self._is_running: + logger.info("Cluster best shot analysis cancelled by user") + self.finished.emit(False, summary) + return + + cluster_id = cluster.get("cluster_id") + image_paths: List[str] = [ + path + for path in cluster.get("image_paths", []) + if isinstance(path, str) and path + ] + + image_count = len(image_paths) + status_message = ( + f"Cluster {cluster_id} ({image_count} image" + f"{'s' if image_count != 1 else ''})" + ) + self.progress.emit(index, total_clusters, status_message) + + if not image_paths: + logger.info("Skipping empty cluster %s", cluster_id) + continue + + try: + if image_count == 1: + # No API call needed; only one candidate. + only_path = image_paths[0] + result = BestShotResult( + best_image_index=0, + best_image_path=only_path, + reasoning=( + "Cluster contains a single image. Selected by default." + ), + confidence="High", + raw_response="Single image cluster", + ) + else: + preview_payloads = self._build_preview_overrides( + cluster.get("preview_payloads"), + image_paths, + ) + result = self._picker.select_best_image( # type: ignore[union-attr] + image_paths, + preview_overrides=preview_payloads or None, + ) + except BestShotPickerError as exc: + error_msg = f"Cluster {cluster_id}: {exc}" + logger.error(error_msg) + self.error.emit(error_msg) + self.finished.emit(False, summary) + return + except Exception as exc: # pragma: no cover - defensive + logger.exception("Unexpected failure analyzing cluster %s", cluster_id) + self.error.emit(str(exc)) + self.finished.emit(False, summary) + return + + payload = { + "cluster_id": cluster_id, + "result": result, + "image_paths": image_paths, + "index": index, + "total": total_clusters, + } + summary.append(payload) + self.cluster_result_ready.emit(payload) + + self.finished.emit(True, summary) + + except Exception as exc: # pragma: no cover - defensive + logger.exception("Cluster best shot worker crashed: %s", exc) + self.error.emit(str(exc)) + self.finished.emit(False, summary) + + def _build_preview_overrides( + self, + preview_payloads: Optional[Dict[str, Dict[str, Any]]], + image_paths: List[str], + ) -> Dict[str, Dict[str, Any]]: + if not preview_payloads or not self._picker: + return {} + + overrides: Dict[str, Dict[str, Any]] = {} + for image_path in image_paths: + payload = preview_payloads.get(image_path) + if not payload: + continue + pil_image = payload.get("pil_image") + if pil_image is None: + continue + mime_type = payload.get("mime_type", "image/jpeg") + overlay_label = payload.get("overlay_label") + try: + base64_data, effective_mime = self._picker.prepare_preview_payload( # type: ignore[union-attr] + image_path=image_path, + pil_image=pil_image, + overlay_label=overlay_label, + mime_type=mime_type, + ) + except BestShotPickerError as exc: + logger.debug( + "Skipping cached preview override for %s: %s", image_path, exc + ) + continue + + override_entry: Dict[str, Any] = { + "base64": base64_data, + "mime_type": effective_mime, + } + if overlay_label: + override_entry["overlay_label"] = overlay_label + overrides[image_path] = override_entry + return overrides diff --git a/tests/test_best_shot_picker.py b/tests/test_best_shot_picker.py new file mode 100644 index 0000000..01b1a0a --- /dev/null +++ b/tests/test_best_shot_picker.py @@ -0,0 +1,250 @@ +""" +Tests for AI Best Shot Picker functionality +""" + +import os +import pytest +from unittest.mock import Mock, patch, MagicMock +from pathlib import Path + +from core.ai.best_shot_picker import ( + BestShotPicker, + BestShotResult, + BestShotPickerError, +) + + +@pytest.fixture +def sample_images(tmp_path): + """Create some sample image files for testing.""" + images = [] + for i in range(3): + img_path = tmp_path / f"test_image_{i}.jpg" + img_path.write_bytes(b"fake image data") + images.append(str(img_path)) + return images + + +@pytest.fixture +def mock_openai_client(): + """Mock OpenAI client for testing.""" + with patch("core.ai.best_shot_picker.OpenAI") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock successful completion + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = """ +**Best Image**: Image 2 +**Confidence**: High +**Reasoning**: This image has the best focus and composition. +""" + mock_client.chat.completions.create.return_value = mock_completion + + yield mock_client + + +class TestBestShotPicker: + """Test suite for BestShotPicker class.""" + + def test_initialization_default_params(self): + """Test initialization with default parameters.""" + picker = BestShotPicker() + assert picker.base_url == "http://localhost:1234/v1" + assert picker.api_key == "not-needed" + assert picker.model == "local-model" + assert picker.timeout == 120 + + def test_initialization_custom_params(self): + """Test initialization with custom parameters.""" + picker = BestShotPicker( + base_url="http://custom:8080/v1", + api_key="custom-key", + model="custom-model", + timeout=60, + ) + assert picker.base_url == "http://custom:8080/v1" + assert picker.api_key == "custom-key" + assert picker.model == "custom-model" + assert picker.timeout == 60 + + def test_get_base64_image_success(self, sample_images): + """Test successful Base64 encoding of an image.""" + picker = BestShotPicker() + base64_str = picker._get_base64_image(sample_images[0]) + assert isinstance(base64_str, str) + assert len(base64_str) > 0 + + def test_get_base64_image_file_not_found(self): + """Test Base64 encoding with non-existent file.""" + picker = BestShotPicker() + with pytest.raises(FileNotFoundError): + picker._get_base64_image("/nonexistent/file.jpg") + + def test_build_prompt(self): + """Test prompt building.""" + picker = BestShotPicker() + prompt = picker._build_prompt(3) + assert "3 images" in prompt + assert "Sharpness and Focus" in prompt + assert "Best Image" in prompt + + def test_parse_response_standard_format(self, sample_images): + """Test parsing a standard formatted response.""" + picker = BestShotPicker() + response = """ +**Best Image**: Image 2 +**Confidence**: High +**Reasoning**: This image has excellent focus and proper exposure. +""" + result = picker._parse_response(response, sample_images) + assert result.best_image_index == 1 # 0-based index + assert result.best_image_path == sample_images[1] + assert result.confidence == "High" + assert "excellent focus" in result.reasoning + + def test_parse_response_alternative_formats(self, sample_images): + """Test parsing various response formats.""" + picker = BestShotPicker() + + # Test format without markdown + response1 = "Best Image: 1\nThis is the best one." + result1 = picker._parse_response(response1, sample_images) + assert result1.best_image_index == 0 + + # Test format with "Image X is best" + response2 = "After analysis, Image 3 is best because of composition." + result2 = picker._parse_response(response2, sample_images) + assert result2.best_image_index == 2 + + def test_parse_response_invalid_defaults_to_first(self, sample_images): + """Test that unparseable response defaults to first image.""" + picker = BestShotPicker() + response = "This response has no clear image selection." + result = picker._parse_response(response, sample_images) + assert result.best_image_index == 0 + assert "Failed to parse" in result.reasoning + + def test_select_best_image_single_image(self, sample_images): + """Test selecting best from a single image.""" + picker = BestShotPicker() + result = picker.select_best_image([sample_images[0]]) + assert result.best_image_index == 0 + assert result.best_image_path == sample_images[0] + assert "Only one image" in result.reasoning + + def test_select_best_image_empty_list(self): + """Test error when no images provided.""" + picker = BestShotPicker() + with pytest.raises(ValueError, match="No images provided"): + picker.select_best_image([]) + + def test_select_best_image_success(self, sample_images, mock_openai_client): + """Test successful image selection.""" + picker = BestShotPicker() + result = picker.select_best_image(sample_images) + + assert result.best_image_index == 1 # Response says "Image 2" + assert result.best_image_path == sample_images[1] + assert result.confidence == "High" + assert "best focus" in result.reasoning + + # Verify API was called + mock_openai_client.chat.completions.create.assert_called_once() + + def test_select_best_image_api_error(self, sample_images): + """Test handling of API errors.""" + with patch("core.ai.best_shot_picker.OpenAI") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.completions.create.side_effect = Exception("API Error") + + picker = BestShotPicker() + with pytest.raises(BestShotPickerError, match="Failed to analyze images"): + picker.select_best_image(sample_images) + + def test_select_best_image_skips_missing_files( + self, sample_images, mock_openai_client + ): + """Test that missing files are skipped during analysis.""" + # Add a non-existent file to the list + images_with_missing = sample_images + ["/nonexistent/image.jpg"] + + picker = BestShotPicker() + result = picker.select_best_image(images_with_missing) + + # Should still work and return a valid result + assert result.best_image_index in [0, 1, 2] + + def test_test_connection_success(self): + """Test successful connection test.""" + with patch("core.ai.best_shot_picker.OpenAI") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_client.chat.completions.create.return_value = mock_completion + + picker = BestShotPicker() + assert picker.test_connection() is True + + def test_test_connection_failure(self): + """Test failed connection test.""" + with patch("core.ai.best_shot_picker.OpenAI") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.completions.create.side_effect = Exception( + "Connection failed" + ) + + picker = BestShotPicker() + assert picker.test_connection() is False + + +class TestBestShotResult: + """Test suite for BestShotResult dataclass.""" + + def test_result_creation(self): + """Test creating a BestShotResult.""" + result = BestShotResult( + best_image_index=2, + best_image_path="/path/to/image.jpg", + reasoning="Great composition", + confidence="High", + raw_response="Full response text", + ) + + assert result.best_image_index == 2 + assert result.best_image_path == "/path/to/image.jpg" + assert result.reasoning == "Great composition" + assert result.confidence == "High" + assert result.raw_response == "Full response text" + + +def test_best_shot_picker_integration(sample_images, mock_openai_client): + """Integration test for the full best shot picking workflow.""" + picker = BestShotPicker( + base_url="http://localhost:1234/v1", + api_key="test-key", + model="test-model", + timeout=60, + ) + + # Test connection first + assert picker.test_connection() is True + + # Select best image + result = picker.select_best_image(sample_images) + + # Verify result + assert isinstance(result, BestShotResult) + assert 0 <= result.best_image_index < len(sample_images) + assert result.best_image_path in sample_images + assert len(result.reasoning) > 0 + assert result.confidence in ["High", "Medium", "Low", "Not specified"] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])