-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_cross.py
More file actions
61 lines (48 loc) · 2.06 KB
/
test_cross.py
File metadata and controls
61 lines (48 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Test the cross approximation method.
"""
import torchtt as tntt
import torch as tn
import numpy as np
import pytest
err_rel = lambda t, ref: tn.linalg.norm(t - ref).numpy() / tn.linalg.norm(ref).numpy() if ref.shape == t.shape else np.inf
def test_dmrg_cross_interpolation():
"""
Test the DMRG cross interpolation method.
"""
func1 = lambda I: 1 / (2 + tn.sum(I + 1, 1).to(dtype=tn.float64))
N = [20] * 4
x = tntt.interpolate.dmrg_cross(func1, N, eps=1e-7)
Is = tntt.meshgrid([tn.arange(0, n, dtype=tn.float64) for n in N])
x_ref = 1 / (2 + Is[0].full() + Is[1].full() + Is[2].full() + Is[3].full() + 4)
assert err_rel(x.full(), x_ref) < 1e-6
def test_dmrg_cross_interpolation_nonvect():
"""
Test the DMRG cross interpolation method for non vectorized function.
"""
func1 = lambda I,J,K,L: 1 / (6 + I + J + K + L)
N = [20] * 4
x = tntt.interpolate.dmrg_cross(func1, N, eps=1e-7, eval_vect=False)
Is = tntt.meshgrid([tn.arange(0, n, dtype=tn.float64) for n in N])
x_ref = 1 / (2 + Is[0].full() + Is[1].full() + Is[2].full() + Is[3].full() + 4)
assert err_rel(x.full(), x_ref) < 1e-6
def test_function_interpolate_multivariable():
"""
Test the DMRG cross interpolation method for function approximation.
"""
func1 = lambda I: 1 / (2 + tn.sum(I + 1, 1).to(dtype=tn.float64))
N = [20] * 4
Is = tntt.meshgrid([tn.arange(0, n, dtype=tn.float64) for n in N])
x_ref = 1 / (2 + Is[0].full() + Is[1].full() + Is[2].full() + Is[3].full() + 4)
y = tntt.interpolate.function_interpolate(func1, Is, 1e-8)
assert err_rel(y.full(), x_ref) < 1e-7
def test_function_interpolate_univariate():
"""
Test the DMRG cross interpolation method for function approximation.
"""
N = [20] * 4
Is = tntt.meshgrid([tn.arange(0, n, dtype=tn.float64) for n in N])
x_ref = 1 / (2 + Is[0].full() + Is[1].full() + Is[2].full() + Is[3].full() + 4)
x = tntt.TT(x_ref)
y = tntt.interpolate.function_interpolate(lambda x: tn.log(x), x, eps=1e-7)
assert err_rel(y.full(), tn.log(x_ref)) < 1e-6