-
Notifications
You must be signed in to change notification settings - Fork 115
/
Copy pathtest_correlation1d.py
43 lines (39 loc) · 2.04 KB
/
test_correlation1d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
from mmflow.models.utils.correlation1d import Correlation1D
_feat1 = torch.arange(0, 24).view(1, 2, 3, 4)
_feat2 = _feat1 + 1
b, c, h, w = _feat1.size()
def test_correlation():
gt_corr_x = Tensor([[[[110.3087, 118.7939, 127.2792, 135.7645],
[120.2082, 130.1077, 140.0071, 149.9066],
[130.1077, 141.4214, 152.7351, 164.0488],
[140.0071, 152.7351, 165.4630, 178.1909]],
[[206.4752, 220.6173, 234.7595, 248.9016],
[222.0315, 237.5879, 253.1442, 268.7006],
[237.5879, 254.5584, 271.5290, 288.4996],
[253.1442, 271.5290, 289.9138, 308.2986]],
[[347.8965, 367.6955, 387.4945, 407.2935],
[369.1097, 390.3229, 411.5362, 432.7494],
[390.3229, 412.9504, 435.5778, 458.2052],
[411.5362, 435.5778, 459.6194, 483.6610]]]])
gt_corr_y = Tensor([[[[110.3087, 144.2498, 178.1909],
[149.9066, 206.4752, 263.0437],
[189.5046, 268.7006, 347.8965]],
[[130.1077, 169.7056, 209.3036],
[175.3625, 237.5879, 299.8133],
[220.6173, 305.4701, 390.3229]],
[[152.7351, 197.9899, 243.2447],
[203.6468, 271.5290, 339.4113],
[254.5584, 345.0681, 435.5778]],
[[178.1909, 229.1026, 280.0143],
[234.7595, 308.2986, 381.8377],
[291.3280, 387.4945, 483.6610]]]])
corr = Correlation1D()
corr_x = corr(_feat1, _feat2, False)
corr_y = corr(_feat1, _feat2, True)
assert corr_x.size() == (b, h, w, w)
assert corr_y.size() == (b, w, h, h)
assert torch.allclose(corr_x, gt_corr_x, atol=1e-4)
assert torch.allclose(corr_y, gt_corr_y, atol=1e-4)