@@ -133,12 +133,11 @@ def init_scoring_methods(self) -> Dict[str, ScoringMethod]:
133
133
134
134
def main ():
135
135
import argparse
136
+ import shutil
136
137
from pathlib import Path
137
138
138
- from monailabel .config import settings
139
+ from monailabel .utils . others . generic import device_list , file_ext
139
140
140
- settings .MONAI_LABEL_DATASTORE_AUTO_RELOAD = False
141
- settings .MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png" , "*.jpg" , "*.jpeg" , ".nii" , ".nii.gz" ]
142
141
os .putenv ("MASTER_ADDR" , "127.0.0.1" )
143
142
os .putenv ("MASTER_PORT" , "1234" )
144
143
@@ -154,43 +153,71 @@ def main():
154
153
155
154
parser = argparse .ArgumentParser ()
156
155
parser .add_argument ("-s" , "--studies" , default = studies )
156
+ parser .add_argument ("-m" , "--model" , default = "wholeBody_ct_segmentation" )
157
+ parser .add_argument ("-t" , "--test" , default = "infer" , choices = ("train" , "infer" , "batch_infer" ))
157
158
args = parser .parse_args ()
158
159
159
160
app_dir = os .path .dirname (__file__ )
160
161
studies = args .studies
162
+ conf = {
163
+ "models" : args .model ,
164
+ "preload" : "false" ,
165
+ }
166
+
167
+ app = MyApp (app_dir , studies , conf )
168
+
169
+ # Infer
170
+ if args .test == "infer" :
171
+ sample = app .next_sample (request = {"strategy" : "first" })
172
+ image_id = sample ["id" ]
173
+ image_path = sample ["path" ]
174
+
175
+ # Run on all devices
176
+ for device in device_list ():
177
+ res = app .infer (request = {"model" : args .model , "image" : image_id , "device" : device })
178
+ label = res ["file" ]
179
+ label_json = res ["params" ]
180
+ test_dir = os .path .join (args .studies , "test_labels" )
181
+ os .makedirs (test_dir , exist_ok = True )
182
+
183
+ label_file = os .path .join (test_dir , image_id + file_ext (image_path ))
184
+ shutil .move (label , label_file )
185
+
186
+ print (label_json )
187
+ print (f"++++ Image File: { image_path } " )
188
+ print (f"++++ Label File: { label_file } " )
189
+ break
190
+ return
191
+
192
+ # Batch Infer
193
+ if args .test == "batch_infer" :
194
+ app .batch_infer (
195
+ request = {
196
+ "model" : args .model ,
197
+ "multi_gpu" : False ,
198
+ "save_label" : True ,
199
+ "label_tag" : "original" ,
200
+ "max_workers" : 1 ,
201
+ "max_batch_size" : 0 ,
202
+ }
203
+ )
204
+ return
161
205
162
- app = MyApp (app_dir , studies , {"preload" : "false" , "models" : "spleen_deepedit_annotation" })
163
- # train(app)
164
- infer (app )
165
-
166
-
167
- def infer (app ):
168
- import json
169
- import shutil
170
-
171
- res = app .infer (
172
- request = {
173
- "model" : "spleen_deepedit_annotation" ,
174
- "image" : "image" ,
175
- }
176
- )
177
-
178
- print (json .dumps (res , indent = 2 ))
179
- shutil .move (res ["label" ], os .path .join (app .studies , "test" ))
180
- logger .info ("All Done!" )
181
-
182
-
183
- def train (app ):
206
+ # Train
184
207
app .train (
185
208
request = {
186
- "model" : "spleen_deepedit_annotation" ,
187
- "max_epochs" : 2 ,
209
+ "model" : args .model ,
210
+ "max_epochs" : 10 ,
211
+ "dataset" : "Dataset" , # PersistentDataset, CacheDataset
212
+ "train_batch_size" : 1 ,
213
+ "val_batch_size" : 1 ,
188
214
"multi_gpu" : False ,
189
215
"val_split" : 0.1 ,
190
- "val_interval" : 1 ,
191
216
},
192
217
)
193
218
194
219
195
220
if __name__ == "__main__" :
221
+ # export PYTHONPATH=~/Projects/MONAILabel:`pwd`
222
+ # python main.py
196
223
main ()
0 commit comments