Skip to content

A more efficient inference function + smoothing #87

@Geethen

Description

@Geethen

I was recently working on creating a memory efficient and disk space friendly inference function. I settled on rasterio windowed reading and writing. It reads in a patch of data performs inference and then iteratively writes to disk. This way you do not need to accumulate the predictions in memory.

Another related problem I ran into was tiling artifacts. I still haven't come across a perfect solution (it seems like there isn't one). There seems to be multiple options including test time augmentations (rotating and mirroring) and then also smoothing of predictions (especially important for segmentation tasks).

See code below for implementation. There is also code commented out for the test time augmentation. The tradeoff will be longer inference times. The smoothing ideas are from this repo: https://github.com/Vooban/Smoothly-Blend-Image-Patches.

Here is an older version of the code that I was putting together in a blog but did not complete it. https://share.note.sx/pkm19n3b#4IGBUHzyXiCfr7FaOBTv9vI4Or0nlJWji9//9gTFMhI

Source code

import numpy as np
import torch
from scipy.signal.windows import triang  # Import the correct function
import matplotlib.pyplot as plt  # for optional visualization

# Spline window function
def _spline_window(window_size, power=2):
    """
    Squared spline window function for smooth transition.
    """
    intersection = int(window_size / 4)
    wind_outer = (abs(2 * triang(window_size)) ** power) / 2  # Use correct function
    wind_outer[intersection:-intersection] = 0

    wind_inner = 1 - (abs(2 * (triang(window_size) - 1)) ** power) / 2
    wind_inner[:intersection] = 0
    wind_inner[-intersection:] = 0

    wind = wind_inner + wind_outer
    wind = wind / np.average(wind)
    return wind

cached_2d_windows = dict()

def _window_2D(window_size, power=2):
    """
    Generate and return a 2D spline window.
    """
    global cached_2d_windows
    key = "{}_{}".format(window_size, power)
    if key in cached_2d_windows:
        wind = cached_2d_windows[key]
    else:
        wind = _spline_window(window_size, power)
        wind = np.outer(wind, wind)  # Create a 2D window by outer product
        cached_2d_windows[key] = wind
    return torch.tensor(wind, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

def split_tensor_to_patches(tensor, patch_size=256, overlap=128):
    """
    Efficiently split a tensor into overlapping patches using unfold.
    """
    stride = patch_size - overlap
    patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
    batch, channels, h_steps, w_steps, h_patch, w_patch = patches.shape
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    patches = patches.view(-1, channels, patch_size, patch_size)
    return patches, (h_steps, w_steps)

def reconstruct_from_patches(patches, grid_size, output_size, patch_size=256, overlap=128, power=2):
    """
    Reconstruct tensor from patches with smooth blending using a spline window.
    """
    batch_size, out_channels, height, width = output_size
    h_steps, w_steps = grid_size
    stride = patch_size - overlap
    
    # Create reconstruction tensor
    reconstructed = torch.zeros(output_size, device=patches.device)
    counter = torch.zeros_like(reconstructed)
    
    # Create spline weight matrix
    window_2d = _window_2D(patch_size, power=power).to(patches.device)
    
    # Calculate number of patches per batch
    patches_per_batch = h_steps * w_steps
    
    # Process each batch
    for b in range(batch_size):
        batch_patches = patches[b * patches_per_batch:(b + 1) * patches_per_batch]
        
        # Process each patch
        for idx in range(patches_per_batch):
            # Calculate grid position
            i = idx // w_steps
            j = idx % w_steps
            
            y_start = i * stride
            x_start = j * stride
            
            # Apply weight and add to reconstruction
            current_patch = batch_patches[idx].unsqueeze(0)  # Add batch dimension [1, C, H, W]
            
            # Create a 2D window map the same size as the patch
            patch_window = window_2d.repeat(1, out_channels, 1, 1)
            
            # Apply the 2D window to the patch
            weighted_patch = current_patch * patch_window
            
            # Add the weighted patch to the reconstruction tensor, with proper positioning
            reconstructed[b:b+1, :, y_start:y_start + patch_size, 
                          x_start:x_start + patch_size] += weighted_patch
            
            # Keep track of the total window weights at each pixel
            counter[b:b+1, :, y_start:y_start + patch_size, 
                    x_start:x_start + patch_size] += patch_window
    
    # Normalize by counter to blend overlapping regions
    reconstructed = reconstructed / (counter + 1e-8)
    
    return reconstructed

# Example usage and testing
if __name__ == "__main__":
    # Create sample tensor
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    sample_tensor = torch.randn(1, 14, 512, 512, device=device)
    
    # Parameters
    patch_size = 256
    overlap = 128
    
    # Split into patches
    patches, grid_size = split_tensor_to_patches(sample_tensor, patch_size, overlap)
    print(f"Patches shape: {patches.shape}")
    print(f"Grid size: {grid_size}")
    
    # Simulate channel dimension change (14 -> 3)
    modified_patches = patches[:, :2]
    print(f"Modified patches shape: {modified_patches.shape}")
    
    # Reconstruct
    reconstructed = reconstruct_from_patches(
        modified_patches,
        grid_size,
        (1, 2, 512, 512),
        patch_size,
        overlap
    )
    print(f"Reconstructed shape: {reconstructed.shape}")

import logging
import os
import math
from pathlib import Path
from typing import Callable, Optional

import rasterio as rio
import torch
# import torch.nn.functional as F
import torchvision.transforms.functional as F

import torchvision.transforms as transforms
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm
from rasterio.windows import Window

import os
import ee
import threading

def inference(
    infile: str,
    imgTransforms: Callable[[dict], dict],
    model: torch.nn.Module,
    outfile: str,
    patchSize: int = 256,
    overlap: int = 128,
    num_workers: int = 1,
    device: Optional[str] = None
) -> None:
    """
    Run inference using model on infile block-by-block and write to a new file (outfile).
    If the infile image width/height is not exactly divisible by 32, padding
    is added for inference and removed prior to saving the outfile.
    
    Args:
        infile (str): Path to input image/covariates.
        imgTransforms (Callable): Function to transform input images.
        model (torch.nn.Module): Loaded trained model/checkpoint.
        outfile (str): Path to save the predicted image.
        patchSize (int): Must be a multiple of 32. Size independent of model input size.
        overlap (int): Number of overlapping pixels between patches.
        num_workers (int): Number of workers to parallelize across.
        device (str, optional): Device to run the model on.
        
    Returns:
        None: A TIFF file is saved to the outfile destination.

    # Example usage:
        # infile = 'path/to/input.tif'
        # imgTransforms = some_torchgeo_transforms_function
        # model = some_loaded_pytorch_model
        # outfile = 'path/to/output.tif'
        # inference(infile, imgTransforms, model, outfile, patchSize=256, overlap=16, num_workers=4, device='cuda')
    """
    
    # Open the input file using rasterio
    with rio.open(infile) as src:
        # Set up logging
        logger = logging.getLogger(__name__)
        
        # Create a destination dataset based on source parameters
        profile = src.profile
        profile.update(blockxsize=patchSize, blockysize=patchSize, tiled=True, count=1, compress = 'lzw')
        
        # Open the output file with the updated profile
        with rio.open(Path(outfile), "w", **profile) as dst:
            # Get all windows (patches) in the destination dataset
            windows = [window for ij, window in dst.block_windows()]
            
            # Create locks for reading and writing to ensure thread safety
            read_lock = threading.Lock()
            write_lock = threading.Lock()
            
            def process(window: Window) -> None:
                """
                Process a single window (patch) by reading it, transforming it, running the model on it,
                and writing the result to the output file.
                """
                # Acquire the read lock to safely read from the input file
                with read_lock:

                    col_off = window.col_off - overlap
                    row_off = window.row_off - overlap
                    width = patchSize + overlap * 2
                    height = patchSize + overlap * 2
                    
                    # Create a window with overlap
                    overlap_window = Window(
                        col_off=col_off,
                        row_off=row_off,
                        width=width,
                        height=height
                    )
                    
                    # Read the data from the input file within the overlap window
                    src_array = src.read(boundless=True, window=overlap_window, fill_value = 0.0)
                    src_array = torch.from_numpy(src_array)

                    # Apply the image transformations
                    image = imgTransforms({"image": src_array})['image']#.squeeze()

                    def d4_transformations(image):
                        # Generate the eight \(\text{D}_4\) transformations
                        transformations = [
                            image,                                # Original
                            F.rotate(image, 90),                  # Rotate 90 degrees
                            F.rotate(image, 180),                 # Rotate 180 degrees
                            F.rotate(image, 270),                 # Rotate 270 degrees
                            F.hflip(image),                       # Horizontal flip
                            F.vflip(image),                       # Vertical flip
                            F.vflip(F.rotate(image, 90)),         # Vertical flip + Rotate 90 degrees
                            F.hflip(F.rotate(image, 90))          # Horizontal flip + Rotate 90 degrees
                        ]
                        return torch.stack(transformations)
                    
                    def inverse_d4_transformations(outputs):
                        # Define the inverse transformations for D4
                        inverses = [
                            lambda x: x,                          # Original
                            lambda x: F.rotate(x, -90),           # Rotate -90 degrees
                            lambda x: F.rotate(x, -180),          # Rotate -180 degrees
                            lambda x: F.rotate(x, -270),          # Rotate -270 degrees
                            lambda x: F.hflip(x),                 # Horizontal flip
                            lambda x: F.vflip(x),                 # Vertical flip
                            lambda x: F.rotate(F.vflip(x), -90),  # Rotate -90 degrees after vertical flip
                            lambda x: F.rotate(F.hflip(x), -90)   # Rotate -90 degrees after horizontal flip
                        ]
                        # Apply each inverse transformation to the corresponding output
                        return torch.stack([inverses[i](output) for i, output in enumerate(outputs)])


                    # Apply D4 transformations to create a batch
                    # d4_batch = d4_transformations(image)

                    # Move batch to the device
                    # d4_batch = d4_batch.to(device, dtype=torch.float)
                    image = image.to(device, dtype=torch.float)

                    patches, grid_size = split_tensor_to_patches(image, patchSize, overlap)

                    # Set the model to evaluation mode
                    model.eval()
                    with torch.no_grad():
                        # Run the model on the padded image
                        # softmax = torch.nn.Softmax2d()
                        output = model(torch.nan_to_num(patches, nan=0.0, neginf=0.0, posinf=0.0))

                        # Apply the inverse transformations
                        # outputs = inverse_d4_transformations(output)
                        
                        # Take the mean of the batch along the batch dimension
                        # output = outputs.mean(dim=0)
                        # Split into patches
    
                    # Reconstruct
                    output_tensor = reconstruct_from_patches(
                        output,
                        grid_size,
                        (1, 2, 512, 512),
                        patchSize,
                        overlap
                    )

                    # Remove overlap and/or padding from the output (assuming padding in the original code)
                    # Adjust final output extraction for your specific dimensions and padding
                    result = torch.argmax(output_tensor[:, :, overlap:patchSize + overlap, overlap:patchSize + overlap], dim=1).squeeze().detach().cpu()
                    
                # Acquire the write lock to safely write to the output file
                with write_lock:
                    dst.write(result.numpy(), 1, window=window)
            
            # Use a ThreadPoolExecutor to process the windows in parallel
            with tqdm(total=len(windows), desc=os.path.basename(outfile)) as pbar:
                with ThreadPoolExecutor(max_workers=num_workers) as executor:
                    futures = {executor.submit(process, window): window for window in windows}
                    
                    try:
                        for future in as_completed(futures):
                            future.result()  # Wait for the future to complete
                            pbar.update(1)  # Update the progress bar
                    except Exception as ex:
                        logger.error('Error during inference: %s', ex)
                        executor.shutdown(wait=False, cancel_futures=True)
                        raise ex

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions