|
1 | | -import torch |
2 | | -import opensr_utils |
3 | | -from omegaconf import OmegaConf |
| 1 | +# inference.py |
| 2 | + |
4 | 3 | import os |
| 4 | +import torch |
| 5 | +from model.SRGAN import SRGAN_model |
5 | 6 |
|
6 | | -# set visible GPUs and device |
7 | | -os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
8 | | -device = "cuda" if torch.cuda.is_available() else "cpu" |
9 | 7 |
|
10 | | -# Load Model and weights |
11 | | -from model.SRGAN import SRGAN_model |
12 | | -model = SRGAN_model(config_file_path="configs/config_20m.yaml") |
13 | | -model = model.eval() |
14 | | -model.load_from_checkpoint("checkpoints/srgan-20m-6band/last.ckpt", strict=False) |
15 | | - |
16 | | -# Set up Sen2 Inference Pipeline |
17 | | -sen2_path = "data/S2A_MSIL2A_20230901T104031_N0509_R137_T31TFJ_20230901T130204.SAFE" # Set Path to file or folder |
18 | | -sr_object = opensr_utils.large_file_processing( |
19 | | - root=sen2_path, # File or Folder path |
20 | | - model=model, # SR model |
21 | | - window_size=(128, 128), # LR window size for model input |
22 | | - factor=4, # SR factor (10m → 2.5m) |
23 | | - overlap=12, # overlapping pixels for mosaic stitching |
24 | | - eliminate_border_px=2, # No of discarded border pixels per prediction |
25 | | - device=device, # "cuda" for GPU-accelerated inference |
26 | | - gpus=[0], # pass GPU ID (int) or list of GPUs |
27 | | - save_preview=False, # save a low-res preview of the output, and a tif georef |
28 | | - debug=False, |
29 | | - ) |
30 | | -sr_object.start_super_resolution() |
| 8 | +def load_model(config_path=None, ckpt_path=None, device=None): |
| 9 | + """Build SRGAN model and (optionally) load weights. Safe to call from tests.""" |
| 10 | + if device is None: |
| 11 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 12 | + |
| 13 | + model = SRGAN_model(config_file_path=config_path).eval().to(device) |
| 14 | + |
| 15 | + if ckpt_path: |
| 16 | + # Try Lightning API first (without 'strict'); fall back to raw state_dict |
| 17 | + try: |
| 18 | + model = SRGAN_model.load_from_checkpoint( |
| 19 | + ckpt_path, map_location=device |
| 20 | + ).eval().to(device) |
| 21 | + except TypeError: |
| 22 | + state = torch.load(ckpt_path, map_location=device) |
| 23 | + state = state.get("state_dict", state) |
| 24 | + model.load_state_dict(state, strict=False) |
| 25 | + |
| 26 | + return model, device |
| 27 | + |
| 28 | + |
| 29 | +def run_sen2_inference( |
| 30 | + sen2_path=None, |
| 31 | + config_path=None, |
| 32 | + ckpt_path=None, |
| 33 | + gpus=None, |
| 34 | + window_size=(128, 128), |
| 35 | + factor=4, |
| 36 | + overlap=12, |
| 37 | + eliminate_border_px=2, |
| 38 | + save_preview=False, |
| 39 | + debug=False, |
| 40 | +): |
| 41 | + """Run Sentinel-2 SR inference. Kept out of import-time for CI.""" |
| 42 | + if gpus is not None and len(gpus) > 0: |
| 43 | + os.environ.setdefault("CUDA_VISIBLE_DEVICES", ",".join(map(str, gpus))) |
| 44 | + |
| 45 | + model, device = load_model(config_path=config_path, ckpt_path=ckpt_path) |
| 46 | + |
| 47 | + import opensr_utils |
| 48 | + |
| 49 | + sr_object = opensr_utils.large_file_processing( |
| 50 | + root=sen2_path, |
| 51 | + model=model, |
| 52 | + window_size=window_size, |
| 53 | + factor=factor, |
| 54 | + overlap=overlap, |
| 55 | + eliminate_border_px=eliminate_border_px, |
| 56 | + device=device, |
| 57 | + gpus=gpus if gpus is not None else ([0] if device == "cuda" else []), |
| 58 | + save_preview=save_preview, |
| 59 | + debug=debug, |
| 60 | + ) |
| 61 | + sr_object.start_super_resolution() |
| 62 | + return sr_object |
| 63 | + |
| 64 | + |
| 65 | +def main(): |
| 66 | + os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") |
| 67 | + |
| 68 | + # ---- Define placeholders ---- |
| 69 | + sen2_path = "data/S2A_MSIL2A_EXAMPLE.SAFE" |
| 70 | + config_path = "configs/config_20m.yaml" |
| 71 | + ckpt_path = "checkpoints/srgan-20m-6band/last.ckpt" |
| 72 | + gpus = [0] |
| 73 | + |
| 74 | + run_sen2_inference( |
| 75 | + sen2_path=sen2_path, |
| 76 | + config_path=config_path, |
| 77 | + ckpt_path=ckpt_path, |
| 78 | + gpus=gpus, |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + main() |
0 commit comments