1313from utils .logger import get_logger , set_tracker
1414from geo_inference .geo_inference import GeoInference
1515from utils .utils import get_device_ids , get_key_def , set_device
16- from utils .utils import get_device_ids , get_key_def , set_device
1716
1817# Set the logging file
1918logging = get_logger (__name__ )
@@ -25,10 +24,7 @@ def stac_input_to_temp_csv(input_stac_item: Union[str, Path]) -> Path:
2524 csv .writer (fh ).writerow ([str (input_stac_item ), None , "inference" , Path (input_stac_item ).stem ])
2625 return Path (stac_temp_csv )
2726
28-
29- def calc_inference_chunk_size (
30- gpu_devices_dict : dict , max_pix_per_mb_gpu : int = 200 , default : int = 512
31- ) -> int :
27+ def calc_inference_chunk_size (gpu_devices_dict : dict , max_pix_per_mb_gpu : int = 200 , default : int = 512 ) -> int :
3228 """
3329 Calculate maximum chunk_size that could fit on GPU during inference based on thumb rule with hardcoded
3430 "pixels per MB of GPU RAM" as threshold. Threshold based on inference with a large model (Deeplabv3_resnet101)
@@ -39,9 +35,7 @@ def calc_inference_chunk_size(
3935 if not gpu_devices_dict :
4036 return default
4137 # get max ram for smallest gpu
42- smallest_gpu_ram = min (
43- gpu_info ["max_ram" ] for _ , gpu_info in gpu_devices_dict .items ()
44- )
38+ smallest_gpu_ram = min (gpu_info ['max_ram' ] for _ , gpu_info in gpu_devices_dict .items ())
4539 # rule of thumb to determine max chunk size based on approximate max pixels a gpu can handle during inference
4640 max_chunk_size = sqrt (max_pix_per_mb_gpu * smallest_gpu_ram )
4741 max_chunk_size_rd = int (max_chunk_size - (max_chunk_size % 256 )) # round to the closest multiple of 256
@@ -63,13 +57,9 @@ def main(params:Union[DictConfig, Dict]):
6357 # Set the device
6458 num_devices = get_key_def ('gpu' , params ['inference' ], default = 0 , expected_type = (int , bool ))
6559 if num_devices > 1 :
66- logging .warning (
67- "Inference is not yet implemented for multi-gpu use. Will request only 1 GPU."
68- )
60+ logging .warning (f"Inference is not yet implemented for multi-gpu use. Will request only 1 GPU." )
6961 num_devices = 1
70- max_used_ram = get_key_def (
71- "max_used_ram" , params ["inference" ], default = 25 , expected_type = int
72- )
62+ max_used_ram = get_key_def ('max_used_ram' , params ['inference' ], default = 25 , expected_type = int )
7363 if not (0 <= max_used_ram <= 100 ):
7464 raise ValueError (f'\n Max used ram parameter should be a percentage. Got { max_used_ram } .' )
7565 max_used_perc = get_key_def ('max_used_perc' , params ['inference' ], default = 25 , expected_type = int )
@@ -95,23 +85,13 @@ def main(params:Union[DictConfig, Dict]):
9585 validate_path_exists = True )
9686
9787 if raw_data_csv and input_stac_item :
98- raise ValueError (
99- 'Input imagery should be either a csv or a stac item. Got inputs from both "raw_data_csv" '
100- 'and "input stac item".'
101- )
102-
103- if global_params ["input_stac_item" ]:
104- raw_data_csv = stac_input_to_temp_csv (global_params ["input_stac_item" ])
105- if not all (
106- [SingleBandItemEO .is_valid_cname (band ) for band in global_params ["bands" ]]
107- ):
108- logging .warning (
109- f"Requested bands are not valid stac item common names. Got: { global_params ['bands' ]} "
110- )
111- # returns red, blue, green
112- bands = [
113- SingleBandItemEO .band_to_cname (band ) for band in global_params ["bands" ]
114- ]
88+ raise ValueError (f"Input imagery should be either a csv of stac item. Got inputs from both \" raw_data_csv\" "
89+ f"and \" input stac item\" " )
90+ if input_stac_item :
91+ raw_data_csv = stac_input_to_temp_csv (input_stac_item )
92+ if not all ([SingleBandItemEO .is_valid_cname (band ) for band in bands_requested ]):
93+ logging .warning (f"Requested bands are not valid stac item common names. Got: { bands_requested } " )
94+ bands_requested = [SingleBandItemEO .band_to_cname (band ) for band in bands_requested ]
11595 logging .warning (f"Will request: { bands_requested } " )
11696
11797 # LOGGING PARAMETERS
@@ -121,24 +101,13 @@ def main(params:Union[DictConfig, Dict]):
121101 set_tracker (mode = 'inference' , type = 'mlflow' , task = 'segmentation' , experiment_name = exper_name , run_name = run_name ,
122102 tracker_uri = tracker_uri , params = params , keys2log = ['general' , 'dataset' , 'model' , 'inference' ])
123103
124- set_tracker (
125- mode = "inference" ,
126- type = "mlflow" ,
127- task = "segmentation" ,
128- experiment_name = exper_name ,
129- run_name = run_name ,
130- tracker_uri = tracker_uri ,
131- params = params ,
132- keys2log = ["general" , "dataset" , "model" , "inference" ],
133- )
134-
135104 # GET LIST OF INPUT IMAGES FOR INFERENCE
136105 list_aois = aois_from_csv (
137106 csv_path = raw_data_csv ,
138- bands_requested = bands ,
139- download_data = global_params [ " download_data" ] ,
140- data_dir = global_params [ "raw_data_dir" ] ,
141- equalize_clahe_clip_limit = global_params [ " clahe_clip_limit" ] ,
107+ bands_requested = bands_requested ,
108+ download_data = download_data ,
109+ data_dir = data_dir ,
110+ equalize_clahe_clip_limit = clahe_clip_limit ,
142111 )
143112
144113 # Create the inference object
0 commit comments