-
-
Notifications
You must be signed in to change notification settings - Fork 336
Description
I was recently working on creating a memory efficient and disk space friendly inference function. I settled on rasterio windowed reading and writing. It reads in a patch of data performs inference and then iteratively writes to disk. This way you do not need to accumulate the predictions in memory.
Another related problem I ran into was tiling artifacts. I still haven't come across a perfect solution (it seems like there isn't one). There seems to be multiple options including test time augmentations (rotating and mirroring) and then also smoothing of predictions (especially important for segmentation tasks).
See code below for implementation. There is also code commented out for the test time augmentation. The tradeoff will be longer inference times. The smoothing ideas are from this repo: https://github.com/Vooban/Smoothly-Blend-Image-Patches.
Here is an older version of the code that I was putting together in a blog but did not complete it. https://share.note.sx/pkm19n3b#4IGBUHzyXiCfr7FaOBTv9vI4Or0nlJWji9//9gTFMhI
Source code
import numpy as np
import torch
from scipy.signal.windows import triang # Import the correct function
import matplotlib.pyplot as plt # for optional visualization
# Spline window function
def _spline_window(window_size, power=2):
"""
Squared spline window function for smooth transition.
"""
intersection = int(window_size / 4)
wind_outer = (abs(2 * triang(window_size)) ** power) / 2 # Use correct function
wind_outer[intersection:-intersection] = 0
wind_inner = 1 - (abs(2 * (triang(window_size) - 1)) ** power) / 2
wind_inner[:intersection] = 0
wind_inner[-intersection:] = 0
wind = wind_inner + wind_outer
wind = wind / np.average(wind)
return wind
cached_2d_windows = dict()
def _window_2D(window_size, power=2):
"""
Generate and return a 2D spline window.
"""
global cached_2d_windows
key = "{}_{}".format(window_size, power)
if key in cached_2d_windows:
wind = cached_2d_windows[key]
else:
wind = _spline_window(window_size, power)
wind = np.outer(wind, wind) # Create a 2D window by outer product
cached_2d_windows[key] = wind
return torch.tensor(wind, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
def split_tensor_to_patches(tensor, patch_size=256, overlap=128):
"""
Efficiently split a tensor into overlapping patches using unfold.
"""
stride = patch_size - overlap
patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
batch, channels, h_steps, w_steps, h_patch, w_patch = patches.shape
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
patches = patches.view(-1, channels, patch_size, patch_size)
return patches, (h_steps, w_steps)
def reconstruct_from_patches(patches, grid_size, output_size, patch_size=256, overlap=128, power=2):
"""
Reconstruct tensor from patches with smooth blending using a spline window.
"""
batch_size, out_channels, height, width = output_size
h_steps, w_steps = grid_size
stride = patch_size - overlap
# Create reconstruction tensor
reconstructed = torch.zeros(output_size, device=patches.device)
counter = torch.zeros_like(reconstructed)
# Create spline weight matrix
window_2d = _window_2D(patch_size, power=power).to(patches.device)
# Calculate number of patches per batch
patches_per_batch = h_steps * w_steps
# Process each batch
for b in range(batch_size):
batch_patches = patches[b * patches_per_batch:(b + 1) * patches_per_batch]
# Process each patch
for idx in range(patches_per_batch):
# Calculate grid position
i = idx // w_steps
j = idx % w_steps
y_start = i * stride
x_start = j * stride
# Apply weight and add to reconstruction
current_patch = batch_patches[idx].unsqueeze(0) # Add batch dimension [1, C, H, W]
# Create a 2D window map the same size as the patch
patch_window = window_2d.repeat(1, out_channels, 1, 1)
# Apply the 2D window to the patch
weighted_patch = current_patch * patch_window
# Add the weighted patch to the reconstruction tensor, with proper positioning
reconstructed[b:b+1, :, y_start:y_start + patch_size,
x_start:x_start + patch_size] += weighted_patch
# Keep track of the total window weights at each pixel
counter[b:b+1, :, y_start:y_start + patch_size,
x_start:x_start + patch_size] += patch_window
# Normalize by counter to blend overlapping regions
reconstructed = reconstructed / (counter + 1e-8)
return reconstructed
# Example usage and testing
if __name__ == "__main__":
# Create sample tensor
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
sample_tensor = torch.randn(1, 14, 512, 512, device=device)
# Parameters
patch_size = 256
overlap = 128
# Split into patches
patches, grid_size = split_tensor_to_patches(sample_tensor, patch_size, overlap)
print(f"Patches shape: {patches.shape}")
print(f"Grid size: {grid_size}")
# Simulate channel dimension change (14 -> 3)
modified_patches = patches[:, :2]
print(f"Modified patches shape: {modified_patches.shape}")
# Reconstruct
reconstructed = reconstruct_from_patches(
modified_patches,
grid_size,
(1, 2, 512, 512),
patch_size,
overlap
)
print(f"Reconstructed shape: {reconstructed.shape}")
import logging
import os
import math
from pathlib import Path
from typing import Callable, Optional
import rasterio as rio
import torch
# import torch.nn.functional as F
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm
from rasterio.windows import Window
import os
import ee
import threading
def inference(
infile: str,
imgTransforms: Callable[[dict], dict],
model: torch.nn.Module,
outfile: str,
patchSize: int = 256,
overlap: int = 128,
num_workers: int = 1,
device: Optional[str] = None
) -> None:
"""
Run inference using model on infile block-by-block and write to a new file (outfile).
If the infile image width/height is not exactly divisible by 32, padding
is added for inference and removed prior to saving the outfile.
Args:
infile (str): Path to input image/covariates.
imgTransforms (Callable): Function to transform input images.
model (torch.nn.Module): Loaded trained model/checkpoint.
outfile (str): Path to save the predicted image.
patchSize (int): Must be a multiple of 32. Size independent of model input size.
overlap (int): Number of overlapping pixels between patches.
num_workers (int): Number of workers to parallelize across.
device (str, optional): Device to run the model on.
Returns:
None: A TIFF file is saved to the outfile destination.
# Example usage:
# infile = 'path/to/input.tif'
# imgTransforms = some_torchgeo_transforms_function
# model = some_loaded_pytorch_model
# outfile = 'path/to/output.tif'
# inference(infile, imgTransforms, model, outfile, patchSize=256, overlap=16, num_workers=4, device='cuda')
"""
# Open the input file using rasterio
with rio.open(infile) as src:
# Set up logging
logger = logging.getLogger(__name__)
# Create a destination dataset based on source parameters
profile = src.profile
profile.update(blockxsize=patchSize, blockysize=patchSize, tiled=True, count=1, compress = 'lzw')
# Open the output file with the updated profile
with rio.open(Path(outfile), "w", **profile) as dst:
# Get all windows (patches) in the destination dataset
windows = [window for ij, window in dst.block_windows()]
# Create locks for reading and writing to ensure thread safety
read_lock = threading.Lock()
write_lock = threading.Lock()
def process(window: Window) -> None:
"""
Process a single window (patch) by reading it, transforming it, running the model on it,
and writing the result to the output file.
"""
# Acquire the read lock to safely read from the input file
with read_lock:
col_off = window.col_off - overlap
row_off = window.row_off - overlap
width = patchSize + overlap * 2
height = patchSize + overlap * 2
# Create a window with overlap
overlap_window = Window(
col_off=col_off,
row_off=row_off,
width=width,
height=height
)
# Read the data from the input file within the overlap window
src_array = src.read(boundless=True, window=overlap_window, fill_value = 0.0)
src_array = torch.from_numpy(src_array)
# Apply the image transformations
image = imgTransforms({"image": src_array})['image']#.squeeze()
def d4_transformations(image):
# Generate the eight \(\text{D}_4\) transformations
transformations = [
image, # Original
F.rotate(image, 90), # Rotate 90 degrees
F.rotate(image, 180), # Rotate 180 degrees
F.rotate(image, 270), # Rotate 270 degrees
F.hflip(image), # Horizontal flip
F.vflip(image), # Vertical flip
F.vflip(F.rotate(image, 90)), # Vertical flip + Rotate 90 degrees
F.hflip(F.rotate(image, 90)) # Horizontal flip + Rotate 90 degrees
]
return torch.stack(transformations)
def inverse_d4_transformations(outputs):
# Define the inverse transformations for D4
inverses = [
lambda x: x, # Original
lambda x: F.rotate(x, -90), # Rotate -90 degrees
lambda x: F.rotate(x, -180), # Rotate -180 degrees
lambda x: F.rotate(x, -270), # Rotate -270 degrees
lambda x: F.hflip(x), # Horizontal flip
lambda x: F.vflip(x), # Vertical flip
lambda x: F.rotate(F.vflip(x), -90), # Rotate -90 degrees after vertical flip
lambda x: F.rotate(F.hflip(x), -90) # Rotate -90 degrees after horizontal flip
]
# Apply each inverse transformation to the corresponding output
return torch.stack([inverses[i](output) for i, output in enumerate(outputs)])
# Apply D4 transformations to create a batch
# d4_batch = d4_transformations(image)
# Move batch to the device
# d4_batch = d4_batch.to(device, dtype=torch.float)
image = image.to(device, dtype=torch.float)
patches, grid_size = split_tensor_to_patches(image, patchSize, overlap)
# Set the model to evaluation mode
model.eval()
with torch.no_grad():
# Run the model on the padded image
# softmax = torch.nn.Softmax2d()
output = model(torch.nan_to_num(patches, nan=0.0, neginf=0.0, posinf=0.0))
# Apply the inverse transformations
# outputs = inverse_d4_transformations(output)
# Take the mean of the batch along the batch dimension
# output = outputs.mean(dim=0)
# Split into patches
# Reconstruct
output_tensor = reconstruct_from_patches(
output,
grid_size,
(1, 2, 512, 512),
patchSize,
overlap
)
# Remove overlap and/or padding from the output (assuming padding in the original code)
# Adjust final output extraction for your specific dimensions and padding
result = torch.argmax(output_tensor[:, :, overlap:patchSize + overlap, overlap:patchSize + overlap], dim=1).squeeze().detach().cpu()
# Acquire the write lock to safely write to the output file
with write_lock:
dst.write(result.numpy(), 1, window=window)
# Use a ThreadPoolExecutor to process the windows in parallel
with tqdm(total=len(windows), desc=os.path.basename(outfile)) as pbar:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(process, window): window for window in windows}
try:
for future in as_completed(futures):
future.result() # Wait for the future to complete
pbar.update(1) # Update the progress bar
except Exception as ex:
logger.error('Error during inference: %s', ex)
executor.shutdown(wait=False, cancel_futures=True)
raise ex