Skip to content

Commit b952c1a

Browse files
author
marjan.asgari
committed
update_ruff
1 parent e2d10fb commit b952c1a

File tree

9 files changed

+110
-199
lines changed

9 files changed

+110
-199
lines changed

dataset/aoi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def __init__(
237237
f"Extended check: {False}"
238238
)
239239

240-
logging.info(f"\n\tSuccessfully parsed Rasters \n: {raster_parsed}\n")
241240
# If stac item input, keep Stac item object as attribute
242241
if is_stac_item(self.raster_raw_input):
243242
item = SingleBandItemEO(
@@ -869,4 +868,4 @@ def aois_from_csv(
869868
f"Failed to create AOI:\n{aoi_dict}\n"
870869
f"Index: {i}"
871870
)
872-
return aois
871+
return aois

dataset/stacitem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def band_to_cname(input_band: str):
9797
Naive conversion of a band to a valid common name
9898
See: https://github.com/stac-extensions/eo/issues/13
9999
"""
100-
bands_ref = (("red", "R"), ("green", "G"), ("blue", "B"), ("nir", "N"))
100+
bands_ref = (("red", "R"), ("green", "G"), ("blue", "B"), ('nir', "N"))
101101
if isinstance(input_band, int) and 1 <= input_band <= 4:
102102
return bands_ref[input_band-1][0]
103103
elif isinstance(input_band, str) and len(input_band) == 1:

inference_segmentation.py

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from utils.logger import get_logger, set_tracker
1414
from geo_inference.geo_inference import GeoInference
1515
from utils.utils import get_device_ids, get_key_def, set_device
16-
from utils.utils import get_device_ids, get_key_def, set_device
1716

1817
# Set the logging file
1918
logging = get_logger(__name__)
@@ -25,10 +24,7 @@ def stac_input_to_temp_csv(input_stac_item: Union[str, Path]) -> Path:
2524
csv.writer(fh).writerow([str(input_stac_item), None, "inference", Path(input_stac_item).stem])
2625
return Path(stac_temp_csv)
2726

28-
29-
def calc_inference_chunk_size(
30-
gpu_devices_dict: dict, max_pix_per_mb_gpu: int = 200, default: int = 512
31-
) -> int:
27+
def calc_inference_chunk_size(gpu_devices_dict: dict, max_pix_per_mb_gpu: int = 200, default: int = 512) -> int:
3228
"""
3329
Calculate maximum chunk_size that could fit on GPU during inference based on thumb rule with hardcoded
3430
"pixels per MB of GPU RAM" as threshold. Threshold based on inference with a large model (Deeplabv3_resnet101)
@@ -39,9 +35,7 @@ def calc_inference_chunk_size(
3935
if not gpu_devices_dict:
4036
return default
4137
# get max ram for smallest gpu
42-
smallest_gpu_ram = min(
43-
gpu_info["max_ram"] for _, gpu_info in gpu_devices_dict.items()
44-
)
38+
smallest_gpu_ram = min(gpu_info['max_ram'] for _, gpu_info in gpu_devices_dict.items())
4539
# rule of thumb to determine max chunk size based on approximate max pixels a gpu can handle during inference
4640
max_chunk_size = sqrt(max_pix_per_mb_gpu * smallest_gpu_ram)
4741
max_chunk_size_rd = int(max_chunk_size - (max_chunk_size % 256)) # round to the closest multiple of 256
@@ -63,13 +57,9 @@ def main(params:Union[DictConfig, Dict]):
6357
# Set the device
6458
num_devices = get_key_def('gpu', params['inference'], default=0, expected_type=(int, bool))
6559
if num_devices > 1:
66-
logging.warning(
67-
"Inference is not yet implemented for multi-gpu use. Will request only 1 GPU."
68-
)
60+
logging.warning(f"Inference is not yet implemented for multi-gpu use. Will request only 1 GPU.")
6961
num_devices = 1
70-
max_used_ram = get_key_def(
71-
"max_used_ram", params["inference"], default=25, expected_type=int
72-
)
62+
max_used_ram = get_key_def('max_used_ram', params['inference'], default=25, expected_type=int)
7363
if not (0 <= max_used_ram <= 100):
7464
raise ValueError(f'\nMax used ram parameter should be a percentage. Got {max_used_ram}.')
7565
max_used_perc = get_key_def('max_used_perc', params['inference'], default=25, expected_type=int)
@@ -95,23 +85,13 @@ def main(params:Union[DictConfig, Dict]):
9585
validate_path_exists=True)
9686

9787
if raw_data_csv and input_stac_item:
98-
raise ValueError(
99-
'Input imagery should be either a csv or a stac item. Got inputs from both "raw_data_csv" '
100-
'and "input stac item".'
101-
)
102-
103-
if global_params["input_stac_item"]:
104-
raw_data_csv = stac_input_to_temp_csv(global_params["input_stac_item"])
105-
if not all(
106-
[SingleBandItemEO.is_valid_cname(band) for band in global_params["bands"]]
107-
):
108-
logging.warning(
109-
f"Requested bands are not valid stac item common names. Got: {global_params['bands']}"
110-
)
111-
# returns red, blue, green
112-
bands = [
113-
SingleBandItemEO.band_to_cname(band) for band in global_params["bands"]
114-
]
88+
raise ValueError(f"Input imagery should be either a csv of stac item. Got inputs from both \"raw_data_csv\" "
89+
f"and \"input stac item\"")
90+
if input_stac_item:
91+
raw_data_csv = stac_input_to_temp_csv(input_stac_item)
92+
if not all([SingleBandItemEO.is_valid_cname(band) for band in bands_requested]):
93+
logging.warning(f"Requested bands are not valid stac item common names. Got: {bands_requested}")
94+
bands_requested = [SingleBandItemEO.band_to_cname(band) for band in bands_requested]
11595
logging.warning(f"Will request: {bands_requested}")
11696

11797
# LOGGING PARAMETERS
@@ -121,24 +101,13 @@ def main(params:Union[DictConfig, Dict]):
121101
set_tracker(mode='inference', type='mlflow', task='segmentation', experiment_name=exper_name, run_name=run_name,
122102
tracker_uri=tracker_uri, params=params, keys2log=['general', 'dataset', 'model', 'inference'])
123103

124-
set_tracker(
125-
mode="inference",
126-
type="mlflow",
127-
task="segmentation",
128-
experiment_name=exper_name,
129-
run_name=run_name,
130-
tracker_uri=tracker_uri,
131-
params=params,
132-
keys2log=["general", "dataset", "model", "inference"],
133-
)
134-
135104
# GET LIST OF INPUT IMAGES FOR INFERENCE
136105
list_aois = aois_from_csv(
137106
csv_path=raw_data_csv,
138-
bands_requested=bands,
139-
download_data=global_params["download_data"],
140-
data_dir=global_params["raw_data_dir"],
141-
equalize_clahe_clip_limit=global_params["clahe_clip_limit"],
107+
bands_requested=bands_requested,
108+
download_data=download_data,
109+
data_dir=data_dir,
110+
equalize_clahe_clip_limit=clahe_clip_limit,
142111
)
143112

144113
# Create the inference object

tests/data/inference/test.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
tests/data/inference/test.tif
1+
https://datacube-stage.services.geo.ca/api/collections/worldview-2-ortho-pansharp/items/BC6P002-052652307020_01_P002-WV02

tests/test_inference_dask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import tracemalloc
1111
import psutil
1212

13-
if str(Path(__file__).parents[0]) not in sys.path:
14-
sys.path.insert(0, str(Path(__file__).parents[0]))
13+
if str(Path(__file__).parents[1]) not in sys.path:
14+
sys.path.insert(0, str(Path(__file__).parents[1]))
1515
from utils.logger import get_logger
1616
from utils.aoiutils import aois_from_csv
1717

utils/aoiutils.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ def aois_from_csv(
2020
) -> list:
2121
"""
2222
Creates list of AOIs by parsing a csv file referencing input data.
23-
23+
2424
.. note::
25-
See AOI docstring for information on other parameters and
25+
See AOI docstring for information on other parameters and
2626
see the dataset docs for details on expected structure of csv.
2727
2828
Args:
29-
csv_path (Union[str, Path]): path to csv file containing list of input data.
29+
csv_path (Union[str, Path]): path to csv file containing list of input data.
3030
bands_requested (List, optional): _description_. Defaults to [].
3131
attr_field_filter (str, optional): _description_. Defaults to None.
3232
attr_values_filter (str, optional): _description_. Defaults to None.
@@ -37,17 +37,13 @@ def aois_from_csv(
3737
3838
Returns:
3939
list: list of AOIs objects.
40-
"""
40+
"""
4141
aois = []
4242
data_list = read_csv(csv_path)
43-
logging.info(
44-
f"\n\tSuccessfully read csv file: {Path(csv_path).name}\n"
45-
f"\tNumber of rows: {len(data_list)}\n"
46-
f"\tCopying first row:\n{data_list[0]}\n"
47-
)
48-
with tqdm(
49-
enumerate(data_list), desc="Creating AOI's", total=len(data_list)
50-
) as _tqdm:
43+
logging.info(f'\n\tSuccessfully read csv file: {Path(csv_path).name}\n'
44+
f'\tNumber of rows: {len(data_list)}\n'
45+
f'\tCopying first row:\n{data_list[0]}\n')
46+
with tqdm(enumerate(data_list), desc="Creating AOI's", total=len(data_list)) as _tqdm:
5147
for i, aoi_dict in _tqdm:
5248
_tqdm.set_postfix_str(f"Image: {Path(aoi_dict['tif']).stem}")
5349
try:
@@ -64,34 +60,32 @@ def aois_from_csv(
6460
logging.debug(new_aoi)
6561
aois.append(new_aoi)
6662
except FileNotFoundError as e:
67-
logging.error(
68-
f"{e}\nGround truth file may not exist or is empty.\n"
69-
f"Failed to create AOI:\n{aoi_dict}\n"
70-
f"Index: {i}"
71-
)
63+
logging.error(f"{e}\nGround truth file may not exist or is empty.\n"
64+
f"Failed to create AOI:\n{aoi_dict}\n"
65+
f"Index: {i}")
7266
return aois
7367

7468

7569
def aois_from_csv_change_detection(
76-
csv_path: Union[str, Path],
77-
bands_requested: List = [],
78-
attr_field_filter: str = None,
79-
attr_values_filter: str = None,
80-
download_data: bool = False,
81-
data_dir: str = "data",
82-
for_multiprocessing=False,
83-
write_dest_raster=False,
84-
equalize_clahe_clip_limit: int = 0,
70+
csv_path: Union[str, Path],
71+
bands_requested: List = [],
72+
attr_field_filter: str = None,
73+
attr_values_filter: str = None,
74+
download_data: bool = False,
75+
data_dir: str = "data",
76+
for_multiprocessing = False,
77+
write_dest_raster = False,
78+
equalize_clahe_clip_limit: int = 0,
8579
) -> dict:
8680
"""
8781
Creates list of AOIs by parsing a csv file referencing input data.
88-
82+
8983
.. note::
90-
See AOI docstring for information on other parameters and
84+
See AOI docstring for information on other parameters and
9185
see the dataset docs for details on expected structure of csv.
9286
9387
Args:
94-
csv_path (Union[str, Path]): path to csv file containing list of input data.
88+
csv_path (Union[str, Path]): path to csv file containing list of input data.
9589
bands_requested (List, optional): _description_. Defaults to [].
9690
attr_field_filter (str, optional): _description_. Defaults to None.
9791
attr_values_filter (str, optional): _description_. Defaults to None.
@@ -103,7 +97,7 @@ def aois_from_csv_change_detection(
10397
10498
Returns:
10599
dict: dictionary of list of AOIs objects.
106-
"""
100+
"""
107101
aois = {}
108102
data_dict = read_csv_change_detection(csv_path)
109103
logging.info(

0 commit comments

Comments
 (0)