Skip to content

Commit 5ef153f

Browse files
authored
Update main function of bundle application (#1585)
Signed-off-by: Andres <[email protected]>
1 parent 5e9732e commit 5ef153f

File tree

1 file changed

+55
-28
lines changed

1 file changed

+55
-28
lines changed

sample-apps/monaibundle/main.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,11 @@ def init_scoring_methods(self) -> Dict[str, ScoringMethod]:
133133

134134
def main():
135135
import argparse
136+
import shutil
136137
from pathlib import Path
137138

138-
from monailabel.config import settings
139+
from monailabel.utils.others.generic import device_list, file_ext
139140

140-
settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD = False
141-
settings.MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png", "*.jpg", "*.jpeg", ".nii", ".nii.gz"]
142141
os.putenv("MASTER_ADDR", "127.0.0.1")
143142
os.putenv("MASTER_PORT", "1234")
144143

@@ -154,43 +153,71 @@ def main():
154153

155154
parser = argparse.ArgumentParser()
156155
parser.add_argument("-s", "--studies", default=studies)
156+
parser.add_argument("-m", "--model", default="wholeBody_ct_segmentation")
157+
parser.add_argument("-t", "--test", default="infer", choices=("train", "infer", "batch_infer"))
157158
args = parser.parse_args()
158159

159160
app_dir = os.path.dirname(__file__)
160161
studies = args.studies
162+
conf = {
163+
"models": args.model,
164+
"preload": "false",
165+
}
166+
167+
app = MyApp(app_dir, studies, conf)
168+
169+
# Infer
170+
if args.test == "infer":
171+
sample = app.next_sample(request={"strategy": "first"})
172+
image_id = sample["id"]
173+
image_path = sample["path"]
174+
175+
# Run on all devices
176+
for device in device_list():
177+
res = app.infer(request={"model": args.model, "image": image_id, "device": device})
178+
label = res["file"]
179+
label_json = res["params"]
180+
test_dir = os.path.join(args.studies, "test_labels")
181+
os.makedirs(test_dir, exist_ok=True)
182+
183+
label_file = os.path.join(test_dir, image_id + file_ext(image_path))
184+
shutil.move(label, label_file)
185+
186+
print(label_json)
187+
print(f"++++ Image File: {image_path}")
188+
print(f"++++ Label File: {label_file}")
189+
break
190+
return
191+
192+
# Batch Infer
193+
if args.test == "batch_infer":
194+
app.batch_infer(
195+
request={
196+
"model": args.model,
197+
"multi_gpu": False,
198+
"save_label": True,
199+
"label_tag": "original",
200+
"max_workers": 1,
201+
"max_batch_size": 0,
202+
}
203+
)
204+
return
161205

162-
app = MyApp(app_dir, studies, {"preload": "false", "models": "spleen_deepedit_annotation"})
163-
# train(app)
164-
infer(app)
165-
166-
167-
def infer(app):
168-
import json
169-
import shutil
170-
171-
res = app.infer(
172-
request={
173-
"model": "spleen_deepedit_annotation",
174-
"image": "image",
175-
}
176-
)
177-
178-
print(json.dumps(res, indent=2))
179-
shutil.move(res["label"], os.path.join(app.studies, "test"))
180-
logger.info("All Done!")
181-
182-
183-
def train(app):
206+
# Train
184207
app.train(
185208
request={
186-
"model": "spleen_deepedit_annotation",
187-
"max_epochs": 2,
209+
"model": args.model,
210+
"max_epochs": 10,
211+
"dataset": "Dataset", # PersistentDataset, CacheDataset
212+
"train_batch_size": 1,
213+
"val_batch_size": 1,
188214
"multi_gpu": False,
189215
"val_split": 0.1,
190-
"val_interval": 1,
191216
},
192217
)
193218

194219

195220
if __name__ == "__main__":
221+
# export PYTHONPATH=~/Projects/MONAILabel:`pwd`
222+
# python main.py
196223
main()

0 commit comments

Comments
 (0)