Skip to content

Latest commit

 

History

History
67 lines (57 loc) · 1.66 KB

README.md

File metadata and controls

67 lines (57 loc) · 1.66 KB

Examples

Pipeline Architecture

Train

from studiosr import Evaluator, Trainer
from studiosr.data import DIV2K
from studiosr.models import SwinIR

dataset_dir="path/to/dataset_dir",
scale = 4
size = 64
dataset = DIV2K(
    dataset_dir=dataset_dir,
    scale=scale,
    size=size,
    transform=True, # data augmentations
    to_tensor=True,
    download=True, # if you don't have the dataset
)
evaluator = Evaluator(scale=scale)

model = SwinIR(scale=scale)
trainer = Trainer(model, dataset, evaluator)
trainer.run()

# Train with the model's training configuration.
model = SwinIR(scale=scale)
config = model.get_training_config()
trainer = Trainer(model, dataset, evaluator, **config)
trainer.run()

Evaluate

from studiosr import Evaluator
from studiosr.models import SwinIR
from studiosr.utils import get_device

scale = 2  # 2, 3, 4
dataset = "Set5"  # Set5, Set14, BSD100, Urban100, Manga109
device = get_device()
model = SwinIR.from_pretrained(scale=scale).eval().to(device)
evaluator = Evaluator(dataset, scale=scale)
psnr, ssim = evaluator(model.inference)

# Evaluation with self-ensemble
psnr, ssim = evaluator(model.inference_with_self_ensemble)

Benchmark

from studiosr import benchmark
from studiosr.models import RCAN, HAN, SwinIR, HAT
from studiosr.utils import get_device

device = get_device()
for model_class in [RCAN, HAN, SwinIR, HAT]:
    for scale in [2, 3, 4]:
        model = model_class.from_pretrained(scale=scale).eval().to(device)
        print(f"Benchmark -> {model_class.__name__}")
        benchmark(model.inference, scale=scale)