Skip to content

Commit 972f060

Browse files
authored
Odd size revamp (#247)
* fixed formatting * updated flake8 check * Fixed test * Fixed test * Removed unncessary dependencies in dev req file * Fixed odd size tests * Fixed mypy error' * Fixed typing syntax * Updated odd size title key * Update info['statistics'] with describe stats * Updated tutorial notebook * Fixed tests
1 parent 4b4932d commit 972f060

File tree

9 files changed

+110
-92
lines changed

9 files changed

+110
-92
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ jobs:
2929
- name: Install dependencies
3030
run: |
3131
python -m pip install --upgrade pip
32-
pip install pytest pytest-cov psutil -e ".[all]"
32+
pip install -e ".[all]"
33+
pip install -r requirements-dev.txt
3334
shell: bash
3435
- name: Test with coverage
3536
run: pytest --verbose --cov=src/cleanvision/ --cov-config .coveragerc --cov-report=xml tests/

docs/source/tutorials/tutorial.ipynb

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
"cell_type": "markdown",
108108
"metadata": {},
109109
"source": [
110-
"### 1. Using CleanVision to detect default issue types"
110+
"### 1. Using CleanVision to detect issues in your dataset"
111111
]
112112
},
113113
{
@@ -124,9 +124,6 @@
124124
"# Initialize imagelab with your dataset\n",
125125
"imagelab = Imagelab(data_path=dataset_path)\n",
126126
"\n",
127-
"# Visualize a few sample images from the dataset\n",
128-
"imagelab.visualize(num_images=8)\n",
129-
"\n",
130127
"# Find issues\n",
131128
"imagelab.find_issues()"
132129
]
@@ -153,17 +150,17 @@
153150
"cell_type": "markdown",
154151
"metadata": {},
155152
"source": [
156-
"The main way to interface with your data is via the `Imagelab` class. This class can be used to understand the issues in your dataset at a high level (global overview) and low level (issues and quality scores for each image) as well as additional information about the dataset. It has three main attributes:\n",
153+
"The main way to interface with your data is via the [Imagelab](https://cleanvision.readthedocs.io/en/latest/cleanvision/imagelab.html#cleanvision.imagelab.Imagelab) class. This class can be used to understand the issues in your dataset at a high level (global overview) and low level (issues and quality scores for each image) as well as additional information about the dataset. It has three main attributes:\n",
154+
"\n",
157155
"- `Imagelab.issue_summary`\n",
158156
"- `Imagelab.issues`\n",
159157
"- `Imagelab.info`\n",
160158
"\n",
161159
"#### imagelab.issue_summary\n",
162-
"Dataframe with global summary of all issue types detected in your dataset and the overall prevalence of each type.\n",
160+
"This is a Dataframe containing a comprehensive summary of all detected issue types within your dataset, along with their respective prevalence levels. Each row in this summary includes the following information:\n",
163161
"\n",
164-
"In each row:\\\n",
165-
"`issue_type` - name of the issue\\\n",
166-
"`num_images` - number of images of that issue type found in the dataset"
162+
"`issue_type`: The name of the detected issue.\\\n",
163+
"`num_images`: The number of images exhibiting the identified issue within the dataset."
167164
]
168165
},
169166
{
@@ -301,7 +298,7 @@
301298
"tags": []
302299
},
303300
"source": [
304-
"You can see **entropy** values for each image in the dataset as shown below."
301+
"You can see **size** statistics for the dataset below. Here we observe, both the 25th and 75th percentile are 256 for the dataset, hence images that are further away from this range are detected as oddly sized."
305302
]
306303
},
307304
{
@@ -310,7 +307,7 @@
310307
"metadata": {},
311308
"outputs": [],
312309
"source": [
313-
"imagelab.info[\"statistics\"][\"entropy\"]"
310+
"imagelab.info[\"statistics\"][\"size\"]"
314311
]
315312
},
316313
{

requirements-dev.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ mypy
44
pre-commit
55
pytest
66
pytest-cov
7-
pytest-lazy-fixture
8-
datasets>=2.7.0
9-
torchvision>=0.12.0
107
black
118
build
129
flake8

src/cleanvision/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import sys
2+
from typing import Any, Union
3+
24
from cleanvision.imagelab import Imagelab as _Imagelab
35

46
PYTHON_VERSION_INFO = sys.version_info
57

68

7-
def get_version() -> str:
9+
def get_version() -> Union[str, Any]:
810
if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
911
import importlib.metadata
1012

1113
return importlib.metadata.version("cleanvision")
1214
else:
1315
import importlib_metadata
1416

15-
return importlib_metadata.version("cleanvision") # type:ignore
17+
return importlib_metadata.version("cleanvision")
1618

1719

1820
try:

src/cleanvision/imagelab.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,7 @@ def _visualize(
502502
if show_id:
503503
title_info["ids"] = [f"id : {i}" for i in indices]
504504
if issue_type == IssueType.ODD_SIZE.value:
505-
title_info["size"] = [
506-
f"original size: {image.size}" for image in images
507-
]
505+
title_info["size"] = [f"size: {image.size}" for image in images]
508506

509507
if images:
510508
VizManager.individual_images(

src/cleanvision/issue_managers/image_property.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import math
22
from abc import ABC, abstractmethod
3-
from typing import List, Dict, Any, Union, overload
3+
from typing import Any, Dict, List, Optional, Union, overload
44

55
import numpy as np
66
import pandas as pd
7-
from PIL import ImageStat, ImageFilter
7+
from PIL import ImageFilter, ImageStat
88
from PIL.Image import Image
99

1010
from cleanvision.issue_managers import IssueType
@@ -48,12 +48,16 @@ def get_scores(
4848
return
4949

5050
def mark_issue(
51-
self, scores: pd.DataFrame, threshold: float, issue_type: str
51+
self,
52+
scores: pd.DataFrame,
53+
issue_type: str,
54+
threshold: Optional[float] = None,
5255
) -> pd.DataFrame:
5356
is_issue = pd.DataFrame(index=scores.index)
54-
is_issue[get_is_issue_colname(issue_type)] = (
55-
scores[get_score_colname(issue_type)] < threshold
56-
)
57+
is_issue_colname, score_colname = get_is_issue_colname(
58+
issue_type
59+
), get_score_colname(issue_type)
60+
is_issue[is_issue_colname] = scores[score_colname] < threshold
5761
return is_issue
5862

5963

@@ -294,8 +298,8 @@ def calc_color_space(image: Image) -> str:
294298

295299

296300
def calc_image_area_sqrt(image: Image) -> float:
297-
size = image.size
298-
return math.sqrt(size[0] * size[1])
301+
w, h = image.size
302+
return math.sqrt(w) * math.sqrt(h)
299303

300304

301305
class ColorSpaceProperty(ImageProperty):
@@ -326,12 +330,14 @@ def get_scores(
326330
return scores
327331

328332
def mark_issue(
329-
self, scores: pd.DataFrame, threshold: float, issue_type: str
333+
self, scores: pd.DataFrame, issue_type: str, threshold: Optional[float] = None
330334
) -> pd.DataFrame:
331335
is_issue = pd.DataFrame(index=scores.index)
332-
is_issue[get_is_issue_colname(issue_type)] = (
333-
1 - scores[get_score_colname(issue_type)]
334-
).astype("bool")
336+
is_issue_colname, score_colname = get_is_issue_colname(
337+
issue_type
338+
), get_score_colname(issue_type)
339+
340+
is_issue[is_issue_colname] = (1 - scores[score_colname]).astype("bool")
335341
return is_issue
336342

337343

@@ -344,6 +350,7 @@ def score_columns(self) -> List[str]:
344350

345351
def __init__(self) -> None:
346352
self._score_columns = [self.name]
353+
self.threshold = 0.5 # todo: this ensures that the scores are evenly distributed across the range
347354

348355
def calculate(self, image: Image) -> Dict[str, Union[float, str]]:
349356
return {self.name: calc_image_area_sqrt(image)}
@@ -352,35 +359,49 @@ def get_scores(
352359
self,
353360
raw_scores: pd.DataFrame,
354361
issue_type: str,
362+
iqr_factor: float = 3.0,
355363
**kwargs: Any,
356364
) -> pd.DataFrame:
357365
super().get_scores(raw_scores, issue_type, **kwargs)
358366
assert raw_scores is not None
359367

360-
image_size_scores = raw_scores[self.score_columns[0]]
361-
median_image_size = image_size_scores.median()
362-
size_ratios = image_size_scores / median_image_size
363-
364-
# Computing the values of the two divisions
365-
size_division_1 = size_ratios
366-
size_division_2 = 1.0 / size_ratios
368+
size = raw_scores[self.name]
369+
q1, q3 = np.percentile(size, [25, 75])
370+
size_iqr = q3 - q1
371+
min_threshold, max_threshold = (
372+
q1 - iqr_factor * size_iqr,
373+
q3 + iqr_factor * size_iqr,
374+
)
375+
mid_threshold = (min_threshold + max_threshold) / 2
376+
threshold_gap = max_threshold - min_threshold
377+
distance = np.absolute(size - mid_threshold)
378+
379+
if threshold_gap > 0:
380+
norm_value = threshold_gap
381+
self.threshold = 0.5
382+
elif threshold_gap == 0:
383+
norm_value = mid_threshold
384+
self.threshold = 1.0
385+
else:
386+
raise ValueError("threshold_gap should be non negative")
367387

368-
# Using np.minimum to determine the element-wise minimum value between the two divisions
369-
size_scores = np.minimum(size_division_1, size_division_2)
388+
norm_dist = distance / norm_value
389+
score_values = 1 - np.clip(norm_dist, 0, 1)
370390

371391
scores = pd.DataFrame(index=raw_scores.index)
372-
scores[get_score_colname(issue_type)] = size_scores
392+
scores[get_score_colname(issue_type)] = score_values
373393
return scores
374394

375395
def mark_issue(
376-
self, scores: pd.DataFrame, threshold: float, issue_type: str
396+
self, scores: pd.DataFrame, issue_type: str, threshold: Optional[float] = None
377397
) -> pd.DataFrame:
398+
threshold = self.threshold if threshold is None else threshold
399+
is_issue_colname, score_colname = get_is_issue_colname(
400+
issue_type
401+
), get_score_colname(issue_type)
402+
378403
is_issue = pd.DataFrame(index=scores.index)
379-
is_issue[get_is_issue_colname(issue_type)] = np.where(
380-
scores[get_score_colname(issue_type)] < 1.0 / threshold,
381-
True,
382-
False,
383-
)
404+
is_issue[is_issue_colname] = scores[score_colname] < threshold
384405
return is_issue
385406

386407

src/cleanvision/issue_managers/image_property_issue_manager.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
import multiprocessing
2-
from typing import Dict, Any, List, Set, Optional, Union
2+
from typing import Any, Dict, List, Optional, Set, Union
33

44
import pandas as pd
55
from tqdm.auto import tqdm
66

77
from cleanvision.dataset.base_dataset import Dataset
8-
from cleanvision.issue_managers import register_issue_manager, IssueType
8+
from cleanvision.issue_managers import IssueType, register_issue_manager
99
from cleanvision.issue_managers.image_property import (
10-
BrightnessProperty,
1110
AspectRatioProperty,
12-
EntropyProperty,
1311
BlurrinessProperty,
12+
BrightnessProperty,
1413
ColorSpaceProperty,
14+
EntropyProperty,
1515
ImageProperty,
1616
SizeProperty,
1717
)
1818
from cleanvision.utils.base_issue_manager import IssueManager
1919
from cleanvision.utils.constants import (
2020
IMAGE_PROPERTY,
21-
MAX_PROCS,
2221
IMAGE_PROPERTY_ISSUE_TYPES_LIST,
22+
MAX_PROCS,
2323
)
24-
from cleanvision.utils.utils import (
25-
get_is_issue_colname,
26-
update_df,
27-
)
24+
from cleanvision.utils.utils import get_is_issue_colname, update_df
2825

2926

3027
def compute_scores(
@@ -72,7 +69,7 @@ def get_default_params(self) -> Dict[str, Any]:
7269
"color_threshold": 0.18,
7370
},
7471
IssueType.GRAYSCALE.value: {},
75-
IssueType.ODD_SIZE.value: {"threshold": 10.0},
72+
IssueType.ODD_SIZE.value: {"iqr_factor": 3.0},
7673
}
7774

7875
def update_params(self, params: Dict[str, Any]) -> None:
@@ -203,11 +200,15 @@ def update_issues(
203200
score_columns = agg_computations[score_column_names]
204201

205202
issue_scores = self.image_properties[issue_type].get_scores(
206-
score_columns, issue_type, **self.params[issue_type]
203+
raw_scores=score_columns,
204+
issue_type=issue_type,
205+
**self.params[issue_type],
207206
)
208207

209208
is_issue = self.image_properties[issue_type].mark_issue(
210-
issue_scores, self.params[issue_type].get("threshold"), issue_type
209+
scores=issue_scores,
210+
issue_type=issue_type,
211+
threshold=self.params[issue_type].get("threshold"),
211212
)
212213
self.issues = self.issues.join(issue_scores)
213214
self.issues = self.issues.join(is_issue)
@@ -240,23 +241,20 @@ def update_info(self, agg_computations: pd.DataFrame) -> None:
240241
issue_type: self.image_properties[issue_type].name
241242
for issue_type in self.issue_types
242243
}
243-
issue_columns = {
244-
issue_type: [
245-
col
246-
for col in agg_computations.columns
247-
if col.startswith(property_names[issue_type] + "_")
248-
]
249-
for issue_type in self.issue_types
250-
}
251244

252245
for issue_type in self.issue_types:
253-
self.info["statistics"][property_names[issue_type]] = agg_computations[
254-
property_names[issue_type]
246+
property_name = property_names[issue_type]
247+
248+
self.info["statistics"][property_name] = agg_computations[
249+
property_name
250+
].describe()
251+
252+
issue_columns = [
253+
col for col in agg_computations.columns if col.startswith(property_name)
255254
]
255+
256256
self.info[issue_type] = (
257-
agg_computations[issue_columns[issue_type]]
258-
if len(issue_columns[issue_type]) > 0
259-
else {}
257+
agg_computations[issue_columns] if len(issue_columns) > 0 else {}
260258
)
261259

262260
def update_summary(self) -> None:

tests/test_image_property_helpers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import pandas as pd
33
import pytest
44
from PIL import Image
5+
from pytest import approx
56

67
import cleanvision
7-
import math
88
from cleanvision.issue_managers import IssueType
99
from cleanvision.issue_managers.image_property import (
1010
BrightnessProperty,
11-
calculate_brightness,
12-
get_image_mode,
1311
calc_aspect_ratio,
12+
calc_blurriness,
1413
calc_entropy,
1514
calc_image_area_sqrt,
16-
calc_blurriness,
15+
calculate_brightness,
16+
get_image_mode,
1717
)
1818
from cleanvision.utils.utils import get_is_issue_colname, get_score_colname
1919

@@ -54,8 +54,8 @@ def test_calc_bluriness():
5454

5555
def test_calc_area():
5656
img = Image.new("RGB", (200, 200), (255, 0, 0))
57-
area = calc_image_area_sqrt(img) # img.size[0] * img.size[1]
58-
assert area == math.sqrt(200 * 200)
57+
area = calc_image_area_sqrt(img)
58+
assert area == approx(200)
5959

6060

6161
@pytest.mark.parametrize(
@@ -137,5 +137,5 @@ def test_get_scores(self, image_property, issue_type, expected_output):
137137
],
138138
)
139139
def test_mark_issue(self, image_property, scores, threshold, expected_mark):
140-
mark = image_property.mark_issue(scores, threshold, "fake_issue")
140+
mark = image_property.mark_issue(scores, "fake_issue", threshold)
141141
assert all(mark == expected_mark)

0 commit comments

Comments
 (0)