Skip to content

Commit 9e9ba66

Browse files
authored
Merge pull request #300 from JdeRobot/issue-299
Update computational cost estimation and add CLI support
2 parents 4b556f3 + 222ebd7 commit 9e9ba66

File tree

7 files changed

+346
-187
lines changed

7 files changed

+346
-187
lines changed

detectionmetrics/cli/__init__.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,129 @@
1+
from detectionmetrics import datasets
2+
from detectionmetrics import models
13
from detectionmetrics.cli.evaluate import evaluate
4+
from detectionmetrics.cli.computational_cost import computational_cost
25

36
REGISTRY = {
47
"evaluate": evaluate,
8+
"computational_cost": computational_cost,
59
}
10+
11+
12+
def get_model(task, input_type, model_format, model, ontology, model_cfg):
13+
# Init model from registry
14+
model_name = f"{model_format}_{input_type}_{task}"
15+
if model_name not in models.REGISTRY:
16+
raise ValueError(
17+
f"Model format not supported: {model_format}. "
18+
f"Must be one of {models.REGISTRY.keys()}",
19+
)
20+
return models.REGISTRY[model_name](model, model_cfg, ontology)
21+
22+
23+
def get_dataset(
24+
task,
25+
input_type,
26+
dataset_format,
27+
dataset_fname,
28+
dataset_dir,
29+
split_dir,
30+
train_dataset_dir,
31+
val_dataset_dir,
32+
test_dataset_dir,
33+
images_dir,
34+
labels_dir,
35+
data_suffix,
36+
label_suffix,
37+
ontology,
38+
split,
39+
):
40+
# Check if required data is available
41+
if dataset_format == "gaia":
42+
if dataset_fname is None:
43+
raise ValueError("--dataset is required for 'gaia' format")
44+
45+
elif dataset_format in ["rellis3d", "wildscenes"]:
46+
if dataset_dir is None:
47+
raise ValueError(
48+
"--dataset_dir is required for 'rellis3d' and 'wildscenes' formats"
49+
)
50+
if split_dir is None:
51+
raise ValueError(
52+
"--split_dir is required for 'rellis3d' and 'wildscenes' formats"
53+
)
54+
55+
if dataset_format == "rellis3d" and ontology is None:
56+
raise ValueError("--dataset_ontology is required for 'rellis3d' format")
57+
58+
elif dataset_format in ["goose", "generic"]:
59+
if "train" in split and train_dataset_dir is None:
60+
raise ValueError(
61+
"--train_dataset_dir is required for 'train' split in 'goose' and 'generic' formats"
62+
)
63+
elif "val" in split and val_dataset_dir is None:
64+
raise ValueError(
65+
"--val_dataset_dir is required for 'val' split in 'goose' and 'generic' formats"
66+
)
67+
elif "test" in split and test_dataset_dir is None:
68+
raise ValueError(
69+
"--test_dataset_dir is required for 'test' split in 'goose' and 'generic' formats"
70+
)
71+
72+
if dataset_format == "generic":
73+
if data_suffix is None:
74+
raise ValueError("--data_suffix is required for 'generic' format")
75+
if label_suffix is None:
76+
raise ValueError("--label_suffix is required for 'generic' format")
77+
if ontology is None:
78+
raise ValueError("--dataset_ontology is required for 'generic' format")
79+
80+
elif dataset_format == "rugd":
81+
if images_dir is None:
82+
raise ValueError("--images_dir is required for 'rugd' format")
83+
if labels_dir is None:
84+
raise ValueError("--labels_dir is required for 'rugd' format")
85+
86+
else:
87+
raise ValueError(f"Dataset format not supported: {dataset_format}")
88+
89+
# Get arguments to init dataset
90+
if dataset_format == "gaia":
91+
dataset_args = {"dataset_fname": dataset_fname}
92+
elif dataset_format == "rellis3d":
93+
dataset_args = {
94+
"dataset_dir": dataset_dir,
95+
"split_dir": split_dir,
96+
"ontology_fname": ontology,
97+
}
98+
elif dataset_format == "goose":
99+
dataset_args = {
100+
"train_dataset_dir": train_dataset_dir,
101+
"val_dataset_dir": val_dataset_dir,
102+
"test_dataset_dir": test_dataset_dir,
103+
}
104+
elif dataset_format == "generic":
105+
dataset_args = {
106+
"data_suffix": data_suffix,
107+
"label_suffix": label_suffix,
108+
"ontology_fname": ontology,
109+
"train_dataset_dir": train_dataset_dir,
110+
"val_dataset_dir": val_dataset_dir,
111+
"test_dataset_dir": test_dataset_dir,
112+
}
113+
elif dataset_format == "rugd":
114+
dataset_args = {
115+
"images_dir": images_dir,
116+
"labels_dir": labels_dir,
117+
"ontology_fname": ontology,
118+
}
119+
else:
120+
raise ValueError(f"Dataset format not supported: {dataset_format}")
121+
122+
# Init dataset from registry
123+
dataset_name = f"{dataset_format}_{input_type}_{task}"
124+
if dataset_name not in datasets.REGISTRY:
125+
raise ValueError(
126+
f"Dataset format not supported: {dataset_format}. "
127+
f"Must be one of {datasets.REGISTRY.keys()}",
128+
)
129+
return datasets.REGISTRY[dataset_name](**dataset_args)

detectionmetrics/cli/batch.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def batch(command, jobs_cfg):
2121
jobs_cfg["model"] = [jobs_cfg["model"]]
2222

2323
# Same for dataset
24-
if not isinstance(jobs_cfg["dataset"], list):
24+
has_dataset = "dataset" in jobs_cfg
25+
if has_dataset and not isinstance(jobs_cfg["dataset"], list):
2526
jobs_cfg["dataset"] = [jobs_cfg["dataset"]]
2627

2728
# Build list of model configurations
@@ -54,11 +55,16 @@ def batch(command, jobs_cfg):
5455

5556
# Build list of jobs (IDs must be unique)
5657
all_jobs = {}
57-
for model_cfg, dataset_cfg in product(model_cfgs, jobs_cfg["dataset"]):
58-
job_id = f"{model_cfg['id']}-{dataset_cfg['id']}"
58+
job_iter = product(model_cfgs, jobs_cfg["dataset"]) if has_dataset else model_cfgs
59+
for job_components in job_iter:
60+
if not isinstance(job_components, tuple):
61+
job_components = (job_components,)
62+
63+
job_id = "-".join([str(jc["id"]) for jc in job_components])
5964
if job_id in all_jobs:
6065
raise ValueError(f"Job ID {job_id} is not unique")
61-
all_jobs[job_id] = (model_cfg, dataset_cfg)
66+
67+
all_jobs[job_id] = job_components
6268

6369
print("\n" + "-" * 80)
6470
print(f"{len(all_jobs)} job(s) will be executed:")
@@ -69,7 +75,7 @@ def batch(command, jobs_cfg):
6975
# Start processing jobs
7076
pbar = tqdm(all_jobs.items(), total=len(all_jobs), leave=True)
7177
preds_outdir = None
72-
for job_id, (model_cfg, dataset_cfg) in pbar:
78+
for job_id, job_components in pbar:
7379
job_out_fname = os.path.join(jobs_cfg["outdir"], f"{job_id}.csv")
7480
if jobs_cfg.get("store_results_per_sample", False):
7581
preds_outdir = os.path.join(jobs_cfg["outdir"], f"preds-{job_id}")
@@ -84,31 +90,50 @@ def batch(command, jobs_cfg):
8490

8591
ctx = click.get_current_context()
8692
try:
87-
result = ctx.invoke(
88-
cli_registry[command],
89-
task=jobs_cfg["task"],
90-
input_type=jobs_cfg["input_type"],
91-
model_format=model_cfg["format"],
92-
model=model_cfg["path"],
93-
model_ontology=model_cfg["ontology"],
94-
model_cfg=model_cfg["cfg"],
95-
dataset_format=dataset_cfg["format"],
96-
dataset_fname=dataset_cfg.get("fname", None),
97-
dataset_dir=dataset_cfg.get("dir", None),
98-
split_dir=dataset_cfg.get("split_dir", None),
99-
train_dataset_dir=dataset_cfg.get("train_dir", None),
100-
val_dataset_dir=dataset_cfg.get("val_dir", None),
101-
test_dataset_dir=dataset_cfg.get("test_dir", None),
102-
images_dir=dataset_cfg.get("data_dir", None),
103-
labels_dir=dataset_cfg.get("labels_dir", None),
104-
data_suffix=dataset_cfg.get("data_suffix", None),
105-
label_suffix=dataset_cfg.get("label_suffix", None),
106-
dataset_ontology=dataset_cfg.get("ontology", None),
107-
split=dataset_cfg["split"],
108-
ontology_translation=jobs_cfg.get("ontology_translation", None),
109-
out_fname=job_out_fname,
110-
predictions_outdir=preds_outdir,
93+
params = {
94+
"task": jobs_cfg["task"],
95+
"input_type": jobs_cfg["input_type"],
96+
}
97+
98+
model_cfg = job_components[0]
99+
params.update(
100+
{
101+
"model_format": model_cfg["format"],
102+
"model": model_cfg["path"],
103+
"model_ontology": model_cfg["ontology"],
104+
"model_cfg": model_cfg["cfg"],
105+
# "image_size": model_cfg.get("image_size", None),
106+
}
111107
)
108+
if has_dataset:
109+
dataset_cfg = job_components[1]
110+
params.update(
111+
{
112+
"dataset_format": dataset_cfg.get("format", None),
113+
"dataset_fname": dataset_cfg.get("fname", None),
114+
"dataset_dir": dataset_cfg.get("dir", None),
115+
"split_dir": dataset_cfg.get("split_dir", None),
116+
"train_dataset_dir": dataset_cfg.get("train_dir", None),
117+
"val_dataset_dir": dataset_cfg.get("val_dir", None),
118+
"test_dataset_dir": dataset_cfg.get("test_dir", None),
119+
"images_dir": dataset_cfg.get("data_dir", None),
120+
"labels_dir": dataset_cfg.get("labels_dir", None),
121+
"data_suffix": dataset_cfg.get("data_suffix", None),
122+
"label_suffix": dataset_cfg.get("label_suffix", None),
123+
"dataset_ontology": dataset_cfg.get("ontology", None),
124+
"split": dataset_cfg["split"],
125+
"ontology_translation": jobs_cfg.get(
126+
"ontology_translation", None
127+
),
128+
}
129+
)
130+
131+
params.update({"out_fname": job_out_fname})
132+
if preds_outdir is not None:
133+
params.update({"predictions_outdir": preds_outdir})
134+
135+
result = ctx.invoke(cli_registry[command], **params)
136+
112137
except Exception as e:
113138
print(f"Error processing job {job_id}: {e}")
114139
continue
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import click
2+
3+
from detectionmetrics import cli
4+
from detectionmetrics.utils.io import read_json
5+
6+
7+
@click.command(name="computational_cost", help="Estimate model computational cost")
8+
@click.argument("task", type=click.Choice(["segmentation"], case_sensitive=False))
9+
@click.argument(
10+
"input_type", type=click.Choice(["image", "lidar"], case_sensitive=False)
11+
)
12+
# model
13+
@click.option(
14+
"--model_format",
15+
type=click.Choice(
16+
["torch", "tensorflow", "tensorflow_explicit"], case_sensitive=False
17+
),
18+
show_default=True,
19+
default="torch",
20+
help="Trained model format",
21+
)
22+
@click.option(
23+
"--model",
24+
type=click.Path(exists=True),
25+
required=True,
26+
help="Trained model filename (TorchScript) or directory (TensorFlow SavedModel)",
27+
)
28+
@click.option(
29+
"--model_ontology",
30+
type=click.Path(exists=True, dir_okay=False),
31+
required=True,
32+
help="JSON file containing model output ontology",
33+
)
34+
@click.option(
35+
"--model_cfg",
36+
type=click.Path(exists=True, dir_okay=False),
37+
required=True,
38+
help="JSON file with model configuration (norm. parameters, image size, etc.)",
39+
)
40+
@click.option(
41+
"--image_size",
42+
type=(int, int),
43+
required=False,
44+
help="Dummy image size used for computational cost estimation",
45+
)
46+
# output
47+
@click.option(
48+
"--out_fname",
49+
type=click.Path(writable=True),
50+
help="CSV file where the computational cost estimation results will be stored",
51+
)
52+
def computational_cost(
53+
task,
54+
input_type,
55+
model_format,
56+
model,
57+
model_ontology,
58+
model_cfg,
59+
image_size,
60+
out_fname,
61+
):
62+
"""Estimate model computational cost"""
63+
64+
if image_size is None:
65+
parsed_model_cfg = read_json(model_cfg)
66+
if "image_size" in parsed_model_cfg:
67+
image_size = parsed_model_cfg["image_size"]
68+
else:
69+
raise ValueError(
70+
"Image size must be provided either as an argument or in the model configuration file"
71+
)
72+
73+
model = cli.get_model(
74+
task, input_type, model_format, model, model_ontology, model_cfg
75+
)
76+
results = model.get_computational_cost(image_size)
77+
results.to_csv(out_fname)
78+
79+
return results

0 commit comments

Comments
 (0)