Skip to content

Retina Detection Method #28

@Oli4

Description

@Oli4

Is is often useful to have a method that detects the Retina robustly. In case no layer segmentations are available. For example if you want to focus you visualization on the retina and therefore crop to the this region. I used the code below for this purpose. It probably makes sense to evaluate speed and robustness. It should be as device and contrast agnostic as possible.

def compute_retina_mask(image, threshold=2, max_iterations=10, upper=None, lower=None):
    """
    Create a retina mask for the input image.
    Args:
        image (array): Input image array.
        threshold (int): Threshold for removing outliers.
        max_iterations (int): Maximum number of iterations for outlier removal.
    Returns:
        array: A binary mask for the input image.
    """
    # Remove top and bottom 5% of the image to remove noisy border intensities
    logger.debug("Computing retina mask")
    image = np.copy(image)
    image[: int(image.shape[0] * 0.05), :] = 0
    image[-int(image.shape[0] * 0.05) :, :] = 0
    logger.debug(f"Gaussian filtering: {image.dtype}")

    image[np.isnan(image)] = np.nanmean(image)
    image = ndimage.gaussian_filter(image, 3)
    # gradient_image = np.gradient(image)[0]

    logger.debug("Peak finder")
    result = np.apply_along_axis(peak_finder, 0, image)
    logger.debug("Peak finder done")

    # replace nan values with closest non-nan value

    result[0, result[0] == -1] = np.mean(result[0][result[0] != -1])
    result[1, result[1] == -1] = np.mean(result[1][result[1] != -1])

    logger.debug("Fitting polynomials")
    if upper is None:
        upper = (
            np.poly1d(
                iteratively_remove_outliers(
                    result[0], threshold=threshold, max_iterations=max_iterations
                )[0]
            )(np.arange(image.shape[1]))
            - 10
        )

    if lower is None:
        lower = (
            np.poly1d(
                iteratively_remove_outliers(
                    result[1], threshold=threshold, max_iterations=max_iterations
                )[0]
            )(np.arange(image.shape[1]))
            + 10
        )

    upper = np.rint(upper).astype(int)
    lower = np.rint(lower).astype(int)
    mask = np.zeros_like(image)

    for col in range(mask.shape[1]):
        mask[upper[col] : lower[col], col] = 1

    return mask


def peak_finder(data, window_size=10):
    """
    Find the two highest prominences in the given data.
    Args:
        data (array): Input data array.
        window_size (int): Window size for convolution operation.
    Returns:
        array: An array containing the indices of the two highest prominences.
    """
    # Convolve the data with a window for smoothing
    data = np.convolve(data, np.ones(window_size) / window_size, mode="same")
    data = np.convolve(data, np.ones(window_size) / window_size, mode="same")

    peaks, properties = find_peaks(data, prominence=1 / 250)
    prominences = properties["prominences"]

    if len(peaks) < 2:
        return np.array([-1, -1])

    highest_prominence_indices = np.argsort(prominences)[-2:]
    return peaks[np.sort(highest_prominence_indices)]


def remove_outliers(data, residuals, threshold=2):
    """
    Remove outlier data points based on the given threshold.
    Args:
        data (array): Input data array.
        residuals (array): Residuals of the data.
        threshold (int): Threshold for removing outliers.
    Returns:
        tuple: A tuple containing the new data and mask after removing outliers.
    """
    mean = np.mean(residuals)
    std_dev = np.std(residuals)
    mask = np.abs(residuals - mean) < threshold * std_dev
    return data[mask], mask


def iteratively_remove_outliers(data, x=None, threshold=3, max_iterations=5):
    """
    Iteratively remove outliers from the data and fit a polynomial.
    Args:
        data (array): Input data array.
        x (array, optional): X-axis values. Defaults to None.
        threshold (int): Threshold for removing outliers.
        max_iterations (int): Maximum number of iterations for outlier removal.
    Returns:
        tuple: A tuple containing the polynomial coefficients, new data, and x values.
    """
    if x is None:
        x = np.arange(len(data))

    for _ in range(max_iterations):
        logger.debug(f"Fit iteration {_}, data length {len(data)}")
        coeffs = np.polyfit(x, data, 3)
        poly = np.poly1d(coeffs)
        residuals = data - poly(x)
        new_data, mask = remove_outliers(data, residuals, threshold)
        new_x = x[mask]

        if len(new_data) == len(data):
            break

        data = new_data
        x = new_x

    return coeffs, data, x

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions