@@ -90,8 +90,8 @@ def dset(db: Session) -> models.Dataset:
9090
9191
9292@pytest .fixture
93- def img (db : Session , dset : models .Dataset ) -> models .Image :
94- img = models .Image (uid = "uid" , dataset_id = dset .id , height = 1000 , width = 2000 )
93+ def img (db : Session , dset : models .Dataset ) -> models .Datum :
94+ img = models .Datum (uid = "uid" , dataset_id = dset .id , height = 1000 , width = 2000 )
9595 db .add (img )
9696 db .commit ()
9797
@@ -118,7 +118,12 @@ def groundtruths(
118118 detections. These detections are taken from a torchmetrics unit test (see test_metrics.py)
119119 """
120120 dataset_name = "test dataset"
121- crud .create_dataset (db , dataset = schemas .DatasetCreate (name = dataset_name ))
121+ crud .create_dataset (
122+ db ,
123+ dataset = schemas .DatasetCreate (
124+ name = dataset_name , type = schemas .DatumTypes .IMAGE
125+ ),
126+ )
122127 gts_per_img = [
123128 {"boxes" : [[214.1500 , 41.2900 , 562.4100 , 285.0700 ]], "labels" : ["4" ]},
124129 {
@@ -205,7 +210,9 @@ def predictions(
205210 """
206211 model_name = "test model"
207212 dset_name = "test dataset"
208- crud .create_model (db , schemas .Model (name = model_name ))
213+ crud .create_model (
214+ db , schemas .Model (name = model_name , type = schemas .DatumTypes .IMAGE )
215+ )
209216
210217 # predictions for four images taken from
211218 # https://github.com/Lightning-AI/metrics/blob/107dbfd5fb158b7ae6d76281df44bd94c836bfce/tests/unittests/detection/test_map.py#L59
0 commit comments