Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,25 +593,39 @@ def get_model_names() -> Iterable:
#


def _to_image(input_):
# we require the input to be uint8
if input_.dtype != np.dtype("uint8"):
# first normalize the input to [0, 1]
input_ = input_.astype("float32") - input_.min()
input_ = input_ / input_.max()
# then bring to [0, 255] and cast to uint8
input_ = (input_ * 255).astype("uint8")

if input_.ndim == 2:
image = np.concatenate([input_[..., None]] * 3, axis=-1)
elif input_.ndim == 3 and input_.shape[-1] == 3:
image = input_
else:
raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
def _normalize_channel(input_, min_val=None, max_val=None):
# First normalize the input to [0, 1].
input_ = input_.astype("float32")
min_val = np.percentile(input_, 1) if min_val is None else min_val
input_ = input_ - min_val
max_val = np.percentile(input_, 99) if max_val is None else max_val
input_ = input_ / (max_val + 1e-7)
# Then bring it to [0, 255] and cast to uint8.
input_ = (np.clip(input_, 0, 1) * 255).astype("uint8")
return input_


def _to_image(input_, min_=None, max_=None):
# Explicitly return a numpy array for compatibility with torchvision,
# because the input_ array could be something like dask array.
image = np.array(input_)

if image.ndim == 2:
image = image[..., None]
elif image.ndim != 3 or (image.shape[-1] not in (1, 3)):
raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either grayscale or RGB image.")

image_normalized = np.zeros(image.shape, dtype="uint8")
for c in range(image.shape[2]):
min_val = None if min_ is None else min_[c]
max_val = None if max_ is None else max_[c]
image_normalized[..., c] = _normalize_channel(image[..., c], min_val=min_val, max_val=max_val)

# explicitly return a numpy array for compatibility with torchvision
# because the input_ array could be something like dask array
return np.array(image)
if image_normalized.shape[-1] == 1:
image_normalized = np.concatenate([image_normalized] * 3, axis=-1)
assert image_normalized.shape[2] == 3, f"{image_normalized.shape}"

return image_normalized


@torch.no_grad
Expand Down Expand Up @@ -690,6 +704,7 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init
pbar_init(n_tiles, "Compute Image Embeddings 2D tiled")

n_batches = int(np.ceil(n_tiles / batch_size))
input_ = _to_image(input_)
for batch_id in range(n_batches):
tile_start = batch_id * batch_size
tile_stop = min(tile_start + batch_size, n_tiles)
Expand All @@ -698,7 +713,7 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init
for tile_id in range(tile_start, tile_stop):
tile = tiling.getBlockWithHalo(tile_id, list(halo))
outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
tile_input = _to_image(input_[outer_tile])
tile_input = input_[outer_tile]
batched_images.append(tile_input)

batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
Expand All @@ -715,6 +730,18 @@ def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init
return features


def _precompute_stats(input_):
stats = {}
input_ = input_[..., None] if input_.ndim == 3 else input_
assert input_.ndim == 4
for z in range(input_.shape[0]):
min_ = {c: np.percentile(input_[z, ..., c], 1) for c in range(input_.shape[3])}
# TODO double check
max_ = {c: np.percentile(input_[z, ..., c], 99) - min_[c] for c in range(input_.shape[3])}
stats[z] = {"min": min_, "max": max_}
return stats


def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
assert input_.ndim == 3

Expand All @@ -733,6 +760,9 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init
# We batch across the z axis.
n_batches = int(np.ceil(n_slices / batch_size))

# Precompute min and max for each slice.
stats = _precompute_stats(input_)

for tile_id in range(n_tiles):
tile = tiling.getBlockWithHalo(tile_id, list(halo))
outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
Expand All @@ -744,7 +774,7 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init

batched_images = []
for z in range(z_start, z_stop):
tile_input = _to_image(input_[z][outer_tile])
tile_input = _to_image(input_[z][outer_tile], min_=stats[z]["min"], max_=stats[z]["max"])
batched_images.append(tile_input)

batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
Expand Down Expand Up @@ -858,8 +888,8 @@ def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_u
# Skip feature computation in case of partial features in non-zero slice.
if partial_features and np.count_nonzero(features[z]) != 0:
continue
tile_input = _to_image(input_[z])
batched_images.append(tile_input)
batch_input = _to_image(input_[z])
batched_images.append(batch_input)
batched_z.append(z)

batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
Expand Down
Loading