99from pytest_mock import MockerFixture
1010from sqlmodel import Session
1111
12- from lightly_studio .models .annotation .annotation_base import AnnotationBaseTable , AnnotationType
1312from lightly_studio .models .tag import TagCreate
1413from lightly_studio .resolvers import (
1514 image_resolver ,
@@ -732,31 +731,11 @@ def test_get_class_balancing_data_input(test_db: Session) -> None:
732731 label_id_dog = UUID ("00000000-0000-0000-0000-000000000002" )
733732 sample_id_1 = UUID ("11111111-1111-1111-1111-111111111111" )
734733 sample_id_2 = UUID ("22222222-2222-2222-2222-222222222222" )
735- dataset_id = uuid4 ()
736-
737- ann_cat_1 = AnnotationBaseTable (
738- annotation_label_id = label_id_cat ,
739- parent_sample_id = sample_id_1 ,
740- dataset_id = dataset_id ,
741- annotation_type = AnnotationType .CLASSIFICATION ,
742- )
743- ann_cat_2 = AnnotationBaseTable (
744- annotation_label_id = label_id_cat ,
745- parent_sample_id = sample_id_2 ,
746- dataset_id = dataset_id ,
747- annotation_type = AnnotationType .CLASSIFICATION ,
748- )
749- ann_dog_1 = AnnotationBaseTable (
750- annotation_label_id = label_id_dog ,
751- parent_sample_id = sample_id_2 ,
752- dataset_id = dataset_id ,
753- annotation_type = AnnotationType .CLASSIFICATION ,
754- )
755734
756735 # The order of target keys depends on the insertion order in this list.
757736 # 'cat' appears first, 'dog' appears second.
758737 # Target Keys: [cat, dog]
759- all_annotations = [ann_cat_1 , ann_cat_2 , ann_dog_1 ]
738+ all_annotation_labels = [label_id_cat , label_id_cat , label_id_dog ]
760739 input_sample_ids = [sample_id_1 , sample_id_2 ]
761740
762741 sample_id_to_annotation_label_ids = {
@@ -769,7 +748,7 @@ def test_get_class_balancing_data_input(test_db: Session) -> None:
769748 class_dist , target_vals = _get_class_balancing_data (
770749 session = test_db ,
771750 strat = strat ,
772- annotations = all_annotations ,
751+ annotation_label_ids = all_annotation_labels ,
773752 input_sample_ids = input_sample_ids ,
774753 sample_id_to_annotation_label_ids = sample_id_to_annotation_label_ids ,
775754 )
@@ -790,28 +769,8 @@ def test_get_class_balancing_data_uniform(test_db: Session) -> None:
790769 label_id_dog = UUID ("00000000-0000-0000-0000-000000000002" )
791770 sample_id_1 = UUID ("11111111-1111-1111-1111-111111111111" )
792771 sample_id_2 = UUID ("22222222-2222-2222-2222-222222222222" )
793- dataset_id = uuid4 ()
794772
795- ann_cat_1 = AnnotationBaseTable (
796- annotation_label_id = label_id_cat ,
797- parent_sample_id = sample_id_1 ,
798- dataset_id = dataset_id ,
799- annotation_type = AnnotationType .CLASSIFICATION ,
800- )
801- ann_cat_2 = AnnotationBaseTable (
802- annotation_label_id = label_id_cat ,
803- parent_sample_id = sample_id_2 ,
804- dataset_id = dataset_id ,
805- annotation_type = AnnotationType .CLASSIFICATION ,
806- )
807- ann_dog_1 = AnnotationBaseTable (
808- annotation_label_id = label_id_dog ,
809- parent_sample_id = sample_id_2 ,
810- dataset_id = dataset_id ,
811- annotation_type = AnnotationType .CLASSIFICATION ,
812- )
813-
814- all_annotations = [ann_cat_1 , ann_cat_2 , ann_dog_1 ]
773+ all_annotation_labels = [label_id_cat , label_id_cat , label_id_dog ]
815774 input_sample_ids = [sample_id_1 , sample_id_2 ]
816775
817776 sample_id_to_annotation_label_ids = {
@@ -824,7 +783,7 @@ def test_get_class_balancing_data_uniform(test_db: Session) -> None:
824783 class_dist , target_vals = _get_class_balancing_data (
825784 session = test_db ,
826785 strat = strat ,
827- annotations = all_annotations ,
786+ annotation_label_ids = all_annotation_labels ,
828787 input_sample_ids = input_sample_ids ,
829788 sample_id_to_annotation_label_ids = sample_id_to_annotation_label_ids ,
830789 )
@@ -849,28 +808,8 @@ def test_get_class_balancing_data_target(test_db: Session) -> None:
849808
850809 sample_id_1 = UUID ("11111111-1111-1111-1111-111111111111" )
851810 sample_id_2 = UUID ("22222222-2222-2222-2222-222222222222" )
852- dataset_id = uuid4 ()
853-
854- ann_cat_1 = AnnotationBaseTable (
855- annotation_label_id = label_id_cat ,
856- parent_sample_id = sample_id_1 ,
857- dataset_id = dataset_id ,
858- annotation_type = AnnotationType .CLASSIFICATION ,
859- )
860- ann_cat_2 = AnnotationBaseTable (
861- annotation_label_id = label_id_cat ,
862- parent_sample_id = sample_id_2 ,
863- dataset_id = dataset_id ,
864- annotation_type = AnnotationType .CLASSIFICATION ,
865- )
866- ann_dog_1 = AnnotationBaseTable (
867- annotation_label_id = label_id_dog ,
868- parent_sample_id = sample_id_2 ,
869- dataset_id = dataset_id ,
870- annotation_type = AnnotationType .CLASSIFICATION ,
871- )
872811
873- all_annotations = [ann_cat_1 , ann_cat_2 , ann_dog_1 ]
812+ all_annotation_labels = [label_id_cat , label_id_cat , label_id_dog ]
874813 input_sample_ids = [sample_id_1 , sample_id_2 ]
875814
876815 sample_id_to_annotation_label_ids = {
@@ -888,7 +827,7 @@ def test_get_class_balancing_data_target(test_db: Session) -> None:
888827 class_dist , target_vals = _get_class_balancing_data (
889828 session = test_db ,
890829 strat = strat ,
891- annotations = all_annotations ,
830+ annotation_label_ids = all_annotation_labels ,
892831 input_sample_ids = input_sample_ids ,
893832 sample_id_to_annotation_label_ids = sample_id_to_annotation_label_ids ,
894833 )
0 commit comments