virtual_stain_flow is a framework for the reproducible development and training of image-to-image translation models that enable virtual staining (the prediction of "virtual" stains) from label-free microscopy images.
The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmentation, model customization to training.
-
U-Net – A classical encoder–decoder architecture with skip connections.
This lightweight design has been widely used for virtual staining tasks, as demonstrated by Ounkomol et al., 2018. -
wGAN-GP – A Wasserstein GAN with Gradient Penalty.
This generative adversarial setup combines a U-Net generator with a convolutional discriminator regularized via gradient penalty for stable training.
As shown by Cross-Zamirski et al., 2022, adversarial training enhances the realism of synthetic stains. -
ConvNeXt-UNet – A fully convolutional architecture inspired by recent computer vision advances.
Drawing from Liu et al., 2022 and Liu et al., 2025, this variant incorporates transformer-like architectural refinements to improve the fidelity of virtual staining details, at the cost of higher computational demand.
Prediction (DNA, Hoechst 33342) generated by virtual_stain_flow using the ConvNeXt-UNet model, from brightfield microscopy images of the U2-OS cell line.
- datasets/ - Data loading and preprocessing pipelines
- models/ - Virtual staining models and building blocks
- trainers/ - Training loops
- transforms/ - Image normalization and augmnentation
- vsf_logging/ - Experiment tracking and logging
Check out the examples/ directory for complete training scripts and tutorials demonstrating various use cases and configurations.
Defining a UNet model generating staining of 3 target channels from 1 input (phase here).
- Here the
compute_blockis set as the coventional Conv2D > Normalize > ReLU. Alternative compute blocks are avaiable in this package, including theConv2DConvNeXtBlock.
from virtual_stain_flow.models import UNet, Conv2DNormActBlock
model = UNet(
input_channels=1,
output_channels=3,
comp_block=Conv2DNormActBlock,
_num_units=2,
depth=4
)Building a dataset tailored to the UNet model specification.
- The input channel is configured as
['phase']which matches theinput_channels=1of the model specification. - Likewise, the target channel is configured as
["dapi", "tubulin", "actin"]to match theoutput_channels=3model setting. - Note these channel keys must exist in the supplied
file_index_dfas a column of image filepaths. - Specify the appropriate post-processing to match image bit depth and/or model output activation. By default all models activate output with sigmoid and a maxscale normalization normalizing by max pixel value is used.
from virtual_stain_flow.datasets import BaseImageDataset
from virtual_stain_flow.transforms import MaxScaleNormalize
dataset = BaseImageDataset(
file_index=file_index_df,
input_channel_keys="phase",
target_channel_keys=["dapi", "tubulin", "actin"],
transform=MaxScaleNormalize(normalization_factor='16bit')
)
from virtual_stain_flow.trainers import Trainer
from virtual_stain_flow.vsf_logging import MlflowLogger
logger = MlflowLogger(experiment_name="virtual_staining")
trainer = Trainer(model, dataset)
trainer.train(logger=logger)pip install virtual-stain-flowSee the LICENSE file for full details.