Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cleanvision/dataset/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Sized
from typing import List, Union
from typing import List, Union, Optional

from PIL import Image

Expand All @@ -20,7 +20,7 @@ def __len__(self) -> int:
"""Returns the number of examples in the dataset"""
raise NotImplementedError

def __getitem__(self, item: Union[int, str]) -> Image.Image:
def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]:
"""Returns the image at a given index"""
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion src/cleanvision/dataset/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __len__(self) -> int:

def __getitem__(self, item: Union[int, str]) -> Optional[Image.Image]:
try:
image = self._data[item][self._image_key]
image: Image.Image = self._data[item][self._image_key]
return image
except Exception as e:
print(f"Could not load image at index: {item}\n", e)
Expand Down
3 changes: 2 additions & 1 deletion src/cleanvision/dataset/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __len__(self) -> int:
return len(self._data)

def __getitem__(self, item: Union[int, str]) -> Image.Image:
return self._data[item][self._image_idx]
img: Image.Image = self._data[item][self._image_idx]
return img

def get_name(self, index: Union[int, str]) -> str:
return f"idx: {index}"
Expand Down
10 changes: 6 additions & 4 deletions src/cleanvision/imagelab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -502,6 +502,7 @@ def _visualize(
scores = sorted_df.head(num_images)[get_score_colname(issue_type)]
indices = scores.index.tolist()
images = [self._dataset[i] for i in indices]
images = cast(list[Image.Image], images)

# construct title info
title_info = {"scores": [f"score : {x:.4f}" for x in scores]}
Expand All @@ -526,6 +527,7 @@ def _visualize(
image_sets = []
for indices in image_sets_indices:
image_sets.append([self._dataset[index] for index in indices])
image_sets = cast(list[list[Image.Image]], image_sets)

title_info_sets = []
for s in image_sets_indices:
Expand Down Expand Up @@ -620,7 +622,7 @@ def visualize(
elif image_files is not None:
if len(image_files) == 0:
raise ValueError("image_files list is empty.")
images = [Image.open(path) for path in image_files]
images: List[Image.Image] = [Image.open(path) for path in image_files]
title_info = {"path": [path.split("/")[-1] for path in image_files]}
VizManager.individual_images(
images,
Expand All @@ -629,7 +631,7 @@ def visualize(
cell_size=cell_size,
)
elif indices:
images = [self._dataset[i] for i in indices]
images = [cast(Image.Image, self._dataset[i]) for i in indices]
title_info = {"name": [self._dataset.get_name(i) for i in indices]}
VizManager.individual_images(
images,
Expand All @@ -644,7 +646,7 @@ def visualize(
image_indices = random.sample(
self._dataset.index, min(num_images, len(self._dataset))
)
images = [self._dataset[i] for i in image_indices]
images = [cast(Image.Image, self._dataset[i]) for i in image_indices]
title_info = {
"name": [self._dataset.get_name(i) for i in image_indices]
}
Expand Down
14 changes: 12 additions & 2 deletions src/cleanvision/issue_managers/duplicate_issue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@

def get_hash(image: Image.Image, params: Dict[str, Any]) -> str:
hash_type, hash_size = params["hash_type"], params.get("hash_size", None)
supported_types = ["md5", "whash", "phash", "ahash", "dhash", "chash"]
if hash_type not in supported_types:
raise ValueError(
f"Hash type `{hash_type}` is not supported. Must be one of: {supported_types}"
)

if hash_type == "md5":
pixels = np.asarray(image)
return hashlib.md5(pixels.tobytes()).hexdigest()
elif hash_type == "whash":

if not isinstance(hash_size, int):
raise ValueError("hash_size must be declared as a int in params")

if hash_type == "whash":
return str(imagehash.whash(image, hash_size=hash_size))
elif hash_type == "phash":
return str(imagehash.phash(image, hash_size=hash_size))
Expand All @@ -31,7 +41,7 @@ def get_hash(image: Image.Image, params: Dict[str, Any]) -> str:
elif hash_type == "chash":
return str(imagehash.colorhash(image, binbits=hash_size))
else:
raise ValueError("Hash type not supported")
raise ValueError("hash_type not supported")


def compute_hash(
Expand Down
6 changes: 3 additions & 3 deletions src/cleanvision/utils/viz_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Sequence

import math
import matplotlib.axes
Expand All @@ -9,7 +9,7 @@
class VizManager:
@staticmethod
def individual_images(
images: List[Image.Image],
images: Sequence[Image.Image],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
Expand Down Expand Up @@ -86,7 +86,7 @@ def construct_titles(title_info: Dict[str, List[str]], cell_width: int) -> List[


def plot_image_grid(
images: List[Image.Image],
images: Sequence[Image.Image],
title_info: Dict[str, List[str]],
ncols: int,
cell_size: Tuple[int, int],
Expand Down