Skip to content

WayScience/virtual_stain_flow

Repository files navigation

virtual_stain_flow - For developing virtual staining models

Overview

virtual_stain_flow is a framework for the reproducible development and training of image-to-image translation models that enable virtual staining (the prediction of "virtual" stains) from label-free microscopy images.
The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmentation, model customization to training.


Supported Model Architectures

  • U-Net – A classical encoder–decoder architecture with skip connections.
    This lightweight design has been widely used for virtual staining tasks, as demonstrated by Ounkomol et al., 2018.

  • wGAN-GP – A Wasserstein GAN with Gradient Penalty.
    This generative adversarial setup combines a U-Net generator with a convolutional discriminator regularized via gradient penalty for stable training.
    As shown by Cross-Zamirski et al., 2022, adversarial training enhances the realism of synthetic stains.

  • ConvNeXt-UNet – A fully convolutional architecture inspired by recent computer vision advances.
    Drawing from Liu et al., 2022 and Liu et al., 2025, this variant incorporates transformer-like architectural refinements to improve the fidelity of virtual staining details, at the cost of higher computational demand.

Showcasing of model prediction

Input/Target/Prediction Prediction (DNA, Hoechst 33342) generated by virtual_stain_flow using the ConvNeXt-UNet model, from brightfield microscopy images of the U2-OS cell line.

Core Components


Quick Start

Check out the examples/ directory for complete training scripts and tutorials demonstrating various use cases and configurations.

Defining a UNet model generating staining of 3 target channels from 1 input (phase here).

  • Here the compute_block is set as the coventional Conv2D > Normalize > ReLU. Alternative compute blocks are avaiable in this package, including the Conv2DConvNeXtBlock.
from virtual_stain_flow.models import UNet, Conv2DNormActBlock

model = UNet(
    input_channels=1,
    output_channels=3,
    comp_block=Conv2DNormActBlock,
    _num_units=2,
    depth=4
)

Building a dataset tailored to the UNet model specification.

  • The input channel is configured as ['phase'] which matches the input_channels=1 of the model specification.
  • Likewise, the target channel is configured as ["dapi", "tubulin", "actin"] to match the output_channels=3 model setting.
  • Note these channel keys must exist in the supplied file_index_df as a column of image filepaths.
  • Specify the appropriate post-processing to match image bit depth and/or model output activation. By default all models activate output with sigmoid and a maxscale normalization normalizing by max pixel value is used.
from virtual_stain_flow.datasets import BaseImageDataset
from virtual_stain_flow.transforms import MaxScaleNormalize

dataset = BaseImageDataset(
    file_index=file_index_df,
    input_channel_keys="phase",
    target_channel_keys=["dapi", "tubulin", "actin"],
    transform=MaxScaleNormalize(normalization_factor='16bit')
)

from virtual_stain_flow.trainers import Trainer
from virtual_stain_flow.vsf_logging import MlflowLogger

logger = MlflowLogger(experiment_name="virtual_staining")
trainer = Trainer(model, dataset)
trainer.train(logger=logger)

Installation

pip install virtual-stain-flow

License

See the LICENSE file for full details.

About

For developing virtual staining models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages