Skip to content

Commit 462b300

Browse files
Refactor _get_class_balancing_data() (#191)
1 parent 25e8c85 commit 462b300

File tree

2 files changed

+19
-80
lines changed

2 files changed

+19
-80
lines changed

lightly_studio/src/lightly_studio/selection/select_via_db.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from numpy.typing import NDArray
1313
from sqlmodel import Session
1414

15-
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
1615
from lightly_studio.models.tag import TagCreate
1716
from lightly_studio.resolvers import (
1817
annotation_label_resolver,
@@ -75,16 +74,16 @@ def _aggregate_class_distributions(
7574
def _process_explicit_target_distribution(
7675
session: Session,
7776
target_distribution: dict[str, float],
78-
annotations: Sequence[AnnotationBaseTable],
77+
annotation_label_ids: Sequence[UUID],
7978
) -> tuple[dict[UUID, float], set[UUID], float]:
8079
"""Processes the explicit target distribution.
8180
8281
Args:
8382
session: The SQLAlchemy session.
8483
target_distribution:
8584
A dictionary mapping annotation label names to their target proportions.
86-
annotations:
87-
A sequence of all annotations to consider for class balancing.
85+
annotation_label_ids:
86+
A sequence of all annotation label IDs to consider for class balancing.
8887
8988
Returns:
9089
Tuple of:
@@ -111,7 +110,7 @@ def _process_explicit_target_distribution(
111110
label_id_to_target[annotation_label.annotation_label_id] = target
112111
total_targets += target
113112

114-
all_label_ids = {a.annotation_label_id for a in annotations}
113+
all_label_ids = set(annotation_label_ids)
115114
unused_label_ids = all_label_ids - set(label_id_to_target.keys())
116115
# `total_targets` can be more or less than 1.0. Both can be ignored, selection will still
117116
# try correctly to reach the target.
@@ -122,18 +121,18 @@ def _process_explicit_target_distribution(
122121
def _get_class_balancing_data(
123122
session: Session,
124123
strat: AnnotationClassBalancingStrategy,
125-
annotations: Sequence[AnnotationBaseTable],
124+
annotation_label_ids: Sequence[UUID],
126125
input_sample_ids: Sequence[UUID],
127126
sample_id_to_annotation_label_ids: Mapping[UUID, list[UUID]],
128127
) -> tuple[NDArray[np.float32], list[float]]:
129128
"""Helper function to get class balancing data."""
130129
if strat.target_distribution == "uniform":
131-
target_keys_set = {a.annotation_label_id for a in annotations}
130+
target_keys_set = set(annotation_label_ids)
132131
target_keys = list(target_keys_set)
133132
target_values = [1.0 / len(target_keys)] * len(target_keys)
134133
elif strat.target_distribution == "input":
135134
# Count the number of times each label appears in the input
136-
input_label_count = Counter(a.annotation_label_id for a in annotations)
135+
input_label_count = Counter(annotation_label_ids)
137136
target_keys, target_values = (
138137
list(input_label_count.keys()),
139138
list(input_label_count.values()),
@@ -143,18 +142,18 @@ def _get_class_balancing_data(
143142
_process_explicit_target_distribution(
144143
session=session,
145144
target_distribution=strat.target_distribution,
146-
annotations=annotations,
145+
annotation_label_ids=annotation_label_ids,
147146
)
148147
)
149148
if len(unused_label_ids) >= 1:
150149
other_uuid = uuid4()
151150
# Handle the case when not all classes have a target.
152151
# We replace UUIDs that are present in `unused_label_ids` for `other_uuid` and the
153152
# target for `other_uuid` is `remaining_ratio`.
154-
for annotation_label_ids in sample_id_to_annotation_label_ids.values():
155-
for i, label_id in enumerate(annotation_label_ids):
153+
for sample_annotation_label_ids in sample_id_to_annotation_label_ids.values():
154+
for i, label_id in enumerate(sample_annotation_label_ids):
156155
if label_id in unused_label_ids:
157-
annotation_label_ids[i] = other_uuid
156+
sample_annotation_label_ids[i] = other_uuid
158157
label_id_to_target[other_uuid] = remaining_ratio
159158

160159
target_keys, target_values = (
@@ -230,6 +229,7 @@ def select_via_database(
230229
session=session,
231230
filters=AnnotationsFilter(sample_ids=input_sample_ids),
232231
).annotations
232+
annotation_label_ids = [a.annotation_label_id for a in annotations]
233233
sample_id_to_annotation_label_ids = defaultdict(list)
234234
for annotation in annotations:
235235
sample_id_to_annotation_label_ids[annotation.parent_sample_id].append(
@@ -239,7 +239,7 @@ def select_via_database(
239239
class_distributions, target_values = _get_class_balancing_data(
240240
session=session,
241241
strat=strat,
242-
annotations=annotations,
242+
annotation_label_ids=annotation_label_ids,
243243
input_sample_ids=input_sample_ids,
244244
sample_id_to_annotation_label_ids=sample_id_to_annotation_label_ids,
245245
)

lightly_studio/tests/selection/test_select_via_db.py

Lines changed: 6 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pytest_mock import MockerFixture
1010
from sqlmodel import Session
1111

12-
from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable, AnnotationType
1312
from lightly_studio.models.tag import TagCreate
1413
from lightly_studio.resolvers import (
1514
image_resolver,
@@ -732,31 +731,11 @@ def test_get_class_balancing_data_input(test_db: Session) -> None:
732731
label_id_dog = UUID("00000000-0000-0000-0000-000000000002")
733732
sample_id_1 = UUID("11111111-1111-1111-1111-111111111111")
734733
sample_id_2 = UUID("22222222-2222-2222-2222-222222222222")
735-
dataset_id = uuid4()
736-
737-
ann_cat_1 = AnnotationBaseTable(
738-
annotation_label_id=label_id_cat,
739-
parent_sample_id=sample_id_1,
740-
dataset_id=dataset_id,
741-
annotation_type=AnnotationType.CLASSIFICATION,
742-
)
743-
ann_cat_2 = AnnotationBaseTable(
744-
annotation_label_id=label_id_cat,
745-
parent_sample_id=sample_id_2,
746-
dataset_id=dataset_id,
747-
annotation_type=AnnotationType.CLASSIFICATION,
748-
)
749-
ann_dog_1 = AnnotationBaseTable(
750-
annotation_label_id=label_id_dog,
751-
parent_sample_id=sample_id_2,
752-
dataset_id=dataset_id,
753-
annotation_type=AnnotationType.CLASSIFICATION,
754-
)
755734

756735
# The order of target keys depends on the insertion order in this list.
757736
# 'cat' appears first, 'dog' appears second.
758737
# Target Keys: [cat, dog]
759-
all_annotations = [ann_cat_1, ann_cat_2, ann_dog_1]
738+
all_annotation_labels = [label_id_cat, label_id_cat, label_id_dog]
760739
input_sample_ids = [sample_id_1, sample_id_2]
761740

762741
sample_id_to_annotation_label_ids = {
@@ -769,7 +748,7 @@ def test_get_class_balancing_data_input(test_db: Session) -> None:
769748
class_dist, target_vals = _get_class_balancing_data(
770749
session=test_db,
771750
strat=strat,
772-
annotations=all_annotations,
751+
annotation_label_ids=all_annotation_labels,
773752
input_sample_ids=input_sample_ids,
774753
sample_id_to_annotation_label_ids=sample_id_to_annotation_label_ids,
775754
)
@@ -790,28 +769,8 @@ def test_get_class_balancing_data_uniform(test_db: Session) -> None:
790769
label_id_dog = UUID("00000000-0000-0000-0000-000000000002")
791770
sample_id_1 = UUID("11111111-1111-1111-1111-111111111111")
792771
sample_id_2 = UUID("22222222-2222-2222-2222-222222222222")
793-
dataset_id = uuid4()
794772

795-
ann_cat_1 = AnnotationBaseTable(
796-
annotation_label_id=label_id_cat,
797-
parent_sample_id=sample_id_1,
798-
dataset_id=dataset_id,
799-
annotation_type=AnnotationType.CLASSIFICATION,
800-
)
801-
ann_cat_2 = AnnotationBaseTable(
802-
annotation_label_id=label_id_cat,
803-
parent_sample_id=sample_id_2,
804-
dataset_id=dataset_id,
805-
annotation_type=AnnotationType.CLASSIFICATION,
806-
)
807-
ann_dog_1 = AnnotationBaseTable(
808-
annotation_label_id=label_id_dog,
809-
parent_sample_id=sample_id_2,
810-
dataset_id=dataset_id,
811-
annotation_type=AnnotationType.CLASSIFICATION,
812-
)
813-
814-
all_annotations = [ann_cat_1, ann_cat_2, ann_dog_1]
773+
all_annotation_labels = [label_id_cat, label_id_cat, label_id_dog]
815774
input_sample_ids = [sample_id_1, sample_id_2]
816775

817776
sample_id_to_annotation_label_ids = {
@@ -824,7 +783,7 @@ def test_get_class_balancing_data_uniform(test_db: Session) -> None:
824783
class_dist, target_vals = _get_class_balancing_data(
825784
session=test_db,
826785
strat=strat,
827-
annotations=all_annotations,
786+
annotation_label_ids=all_annotation_labels,
828787
input_sample_ids=input_sample_ids,
829788
sample_id_to_annotation_label_ids=sample_id_to_annotation_label_ids,
830789
)
@@ -849,28 +808,8 @@ def test_get_class_balancing_data_target(test_db: Session) -> None:
849808

850809
sample_id_1 = UUID("11111111-1111-1111-1111-111111111111")
851810
sample_id_2 = UUID("22222222-2222-2222-2222-222222222222")
852-
dataset_id = uuid4()
853-
854-
ann_cat_1 = AnnotationBaseTable(
855-
annotation_label_id=label_id_cat,
856-
parent_sample_id=sample_id_1,
857-
dataset_id=dataset_id,
858-
annotation_type=AnnotationType.CLASSIFICATION,
859-
)
860-
ann_cat_2 = AnnotationBaseTable(
861-
annotation_label_id=label_id_cat,
862-
parent_sample_id=sample_id_2,
863-
dataset_id=dataset_id,
864-
annotation_type=AnnotationType.CLASSIFICATION,
865-
)
866-
ann_dog_1 = AnnotationBaseTable(
867-
annotation_label_id=label_id_dog,
868-
parent_sample_id=sample_id_2,
869-
dataset_id=dataset_id,
870-
annotation_type=AnnotationType.CLASSIFICATION,
871-
)
872811

873-
all_annotations = [ann_cat_1, ann_cat_2, ann_dog_1]
812+
all_annotation_labels = [label_id_cat, label_id_cat, label_id_dog]
874813
input_sample_ids = [sample_id_1, sample_id_2]
875814

876815
sample_id_to_annotation_label_ids = {
@@ -888,7 +827,7 @@ def test_get_class_balancing_data_target(test_db: Session) -> None:
888827
class_dist, target_vals = _get_class_balancing_data(
889828
session=test_db,
890829
strat=strat,
891-
annotations=all_annotations,
830+
annotation_label_ids=all_annotation_labels,
892831
input_sample_ids=input_sample_ids,
893832
sample_id_to_annotation_label_ids=sample_id_to_annotation_label_ids,
894833
)

0 commit comments

Comments
 (0)