Skip to content

Commit 8b4f94d

Browse files
committed
increased tests
1 parent fc75331 commit 8b4f94d

4 files changed

Lines changed: 192 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,6 @@ If you use this work, please cite:
134134

135135
Developed within the **ESA Φ-lab / OpenSR** initiative. Simon Donike is the main contributor and maintainer of the repository. Cesar Aybar and Julio Contreras contributed the datasets as well as implementation, documentation and publishing support. Prof. Luis Gómez-Chova contributed the remote sensing-specific perspective and signal processing advice.
136136
> The development history of this code began in 2020 with the implementation of an SR-GAN for a MSc thesis project. Since then, over several iterations, the codebase has been expanded and many training tweaks implemented, based on the experiences made training SR-GANs for the OpenSR project. The fundamental training outline, training tweaks, normalizations, and inference procedures are built upon that experience.
137-
The added complexity that came with (a) the implementation of many different models and blocks, (b) more data sources, (c) according normalizations, and (d) complex testing and documentation structures, was handled to varying degrees with the help of *Codex*. Specifically, the docs, the automated testing workflows, and the normalizer class are in part AI generated. This code and its functionalities have been verified and tested to the best of my ability.
137+
The added complexity that came with (a) the implementation of many different models and blocks, (b) more data sources, (c) according normalizations, and (d) testing and documentation structures, was handled to varying degrees with the help of *Codex*. Specifically, the docs, the automated testing workflows, and the normalizer class are in (large) parts AI generated. This code and its functionalities have been verified and tested to the best of my ability.
138138

139139
---

docs/getting-started.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
This guide walks through installing dependencies, configuring datasets, and launching your first ESA OpenSR experiment. The stack supports Python 3.10-3.12, PyTorch Lightning, and Weights & Biases for experiment tracking.
44

5+
## Try it in Colab first
6+
7+
For the fastest start, open the interactive notebook in Google Colab and run through the introduction without setting up a local environment.
8+
9+
<p align="center">
10+
<a href="https://colab.research.google.com/drive/16W0FWr6py1J8P4po7JbNDMaepHUM97yL?usp=sharing">
11+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab">
12+
</a>
13+
</p>
14+
515
> 💡 **Only need inference?** Install the published package instead: `python -m pip install opensr-srgan`. It exposes `load_from_config` and `load_inference_model` so you can instantiate models without cloning the repository. Continue with the rest of this guide when you want to train, fine-tune, or otherwise modify the codebase.
616
717
## 1. Install the environment

tests/test_model_blocks/test_ema.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,116 @@ def test_invalid_decay_raises():
8383
model = TinyNet()
8484
with pytest.raises(ValueError):
8585
ExponentialMovingAverage(model, decay=1.5)
86+
87+
88+
def test_update_with_num_updates_uses_warmup_decay_and_updates_buffer():
89+
model = TinyNet()
90+
ema = ExponentialMovingAverage(model, decay=0.9, use_num_updates=True)
91+
old_shadow = ema.shadow_params["lin.weight"].clone()
92+
93+
with torch.no_grad():
94+
model.lin.weight.add_(1.0)
95+
model.scale.fill_(3.0)
96+
97+
ema.update(model)
98+
99+
expected_decay = min(0.9, (1 + 1) / (10 + 1))
100+
expected = old_shadow.lerp(model.lin.weight.detach(), 1.0 - expected_decay)
101+
assert ema.num_updates == 1
102+
assert torch.allclose(ema.shadow_params["lin.weight"], expected)
103+
assert torch.allclose(ema.shadow_buffers["scale"], torch.tensor([3.0]))
104+
105+
106+
def test_register_and_update_skip_frozen_parameters():
107+
model = TinyNet()
108+
model.lin.bias.requires_grad_(False)
109+
110+
ema = ExponentialMovingAverage(model, decay=0.5, use_num_updates=False)
111+
assert set(ema.shadow_params) == {"lin.weight"}
112+
113+
with torch.no_grad():
114+
model.lin.bias.add_(10.0)
115+
ema.update(model)
116+
117+
assert "lin.bias" not in ema.shadow_params
118+
119+
120+
def test_update_registers_new_trainable_parameters_and_buffers():
121+
model = TinyNet()
122+
ema = ExponentialMovingAverage(model, decay=0.5, use_num_updates=False)
123+
124+
model.extra = nn.Parameter(torch.full((1,), 4.0))
125+
model.register_buffer("offset", torch.full((1,), 2.0))
126+
127+
ema.update(model)
128+
129+
assert torch.allclose(ema.shadow_params["extra"], torch.tensor([4.0]))
130+
assert torch.allclose(ema.shadow_buffers["offset"], torch.tensor([2.0]))
131+
132+
133+
def test_apply_to_swaps_buffers_and_rejects_reapply_before_restore():
134+
model = TinyNet()
135+
ema = ExponentialMovingAverage(model, decay=0.9)
136+
original_scale = model.scale.clone()
137+
138+
ema.shadow_buffers["scale"].fill_(5.0)
139+
ema.apply_to(model)
140+
141+
assert torch.allclose(model.scale, torch.tensor([5.0]))
142+
with pytest.raises(RuntimeError, match="already applied"):
143+
ema.apply_to(model)
144+
145+
ema.restore(model)
146+
assert torch.allclose(model.scale, original_scale)
147+
assert not ema.collected_params
148+
assert not ema.collected_buffers
149+
150+
151+
def test_restore_without_apply_is_noop():
152+
model = TinyNet()
153+
ema = ExponentialMovingAverage(model, decay=0.9)
154+
original_weight = model.lin.weight.clone()
155+
original_scale = model.scale.clone()
156+
157+
ema.restore(model)
158+
159+
assert torch.allclose(model.lin.weight, original_weight)
160+
assert torch.allclose(model.scale, original_scale)
161+
162+
163+
def test_average_parameters_restores_after_exception():
164+
model = TinyNet()
165+
ema = ExponentialMovingAverage(model, decay=0.9)
166+
ema.shadow_params["lin.weight"].add_(2.0)
167+
original_weight = model.lin.weight.clone()
168+
169+
with pytest.raises(RuntimeError, match="boom"):
170+
with ema.average_parameters(model):
171+
assert torch.allclose(model.lin.weight, ema.shadow_params["lin.weight"])
172+
raise RuntimeError("boom")
173+
174+
assert torch.allclose(model.lin.weight, original_weight)
175+
assert not ema.collected_params
176+
177+
178+
def test_load_state_dict_restores_device_and_clears_collected_caches():
179+
model = TinyNet()
180+
ema = ExponentialMovingAverage(model, decay=0.9)
181+
state = ema.state_dict()
182+
state["decay"] = 0.25
183+
state["num_updates"] = None
184+
state["device"] = "cpu"
185+
186+
ema2 = ExponentialMovingAverage(model, decay=0.1)
187+
ema2.apply_to(model)
188+
assert ema2.collected_params
189+
190+
ema2.load_state_dict(state)
191+
192+
assert ema2.decay == pytest.approx(0.25)
193+
assert ema2.num_updates is None
194+
assert ema2.device == torch.device("cpu")
195+
assert not ema2.collected_params
196+
assert not ema2.collected_buffers
197+
for tensor in ema2.shadow_params.values():
198+
assert tensor.device.type == "cpu"

tests/test_models/test_generators.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,74 @@ def test_esrgan_generator_warns_about_srresnet_specific_options(capsys):
167167
)
168168

169169

170+
@pytest.mark.parametrize("block_type", ["res", "rcab", "rrdb", "lka"])
171+
@pytest.mark.parametrize(
172+
"scale, expected_shape",
173+
[
174+
(2, (2, 3, 10, 12)),
175+
(4, (2, 3, 20, 24)),
176+
(8, (2, 3, 40, 48)),
177+
],
178+
)
179+
def test_flexible_generator_forward_scales_output_for_all_blocks(
180+
block_type, scale, expected_shape
181+
):
182+
generator = FlexibleGenerator(
183+
in_channels=3,
184+
n_channels=16,
185+
n_blocks=1,
186+
small_kernel=3,
187+
large_kernel=3,
188+
scale=scale,
189+
block_type=block_type,
190+
)
191+
lr = torch.randn(2, 3, 5, 6)
192+
193+
sr = generator(lr)
194+
195+
assert sr.shape == expected_shape
196+
197+
198+
def test_flexible_generator_rejects_invalid_configuration():
199+
with pytest.raises(ValueError, match="scale must be one of"):
200+
FlexibleGenerator(scale=1)
201+
202+
with pytest.raises(ValueError, match="block_type must be one of"):
203+
FlexibleGenerator(block_type="unknown")
204+
205+
206+
@pytest.mark.parametrize(
207+
"scale, expected_shape",
208+
[
209+
(1, (2, 2, 5, 6)),
210+
(2, (2, 2, 10, 12)),
211+
(4, (2, 2, 20, 24)),
212+
],
213+
)
214+
def test_esrgan_generator_forward_scales_output(scale, expected_shape):
215+
generator = ESRGANGenerator(
216+
in_channels=2,
217+
out_channels=2,
218+
n_features=8,
219+
n_blocks=1,
220+
growth_channels=4,
221+
scale=scale,
222+
)
223+
lr = torch.randn(2, 2, 5, 6)
224+
225+
sr = generator(lr)
226+
227+
assert sr.shape == expected_shape
228+
229+
230+
def test_esrgan_generator_rejects_invalid_configuration():
231+
with pytest.raises(ValueError, match="power-of-two scales"):
232+
ESRGANGenerator(scale=3)
233+
234+
with pytest.raises(ValueError, match="at least one RRDB block"):
235+
ESRGANGenerator(n_blocks=0)
236+
237+
170238
def test_stochastic_generator_forward_noise_paths():
171239
generator = StochasticGenerator(
172240
in_channels=2,

0 commit comments

Comments
 (0)