-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_band_actor.py
77 lines (60 loc) · 2.2 KB
/
test_band_actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
from pfb.utils.naming import xds_from_url
from daskms.fsspec_store import DaskMSStore
from omegaconf import OmegaConf
from pfb.utils.dist import band_actor
from pfb.opt.primal_dual import get_ratio
import time
if __name__=='__main__':
xds_name = '/home/landman/testing/pfb/out/data_I.xds'
xds, _ = xds_from_url(xds_name)
xds_store = DaskMSStore(xds_name)
xds_list = xds_store.fs.glob(f'{xds_store.url}/*')
ds_list = []
uv_max = 0.0
max_freq = 0.0
for ds_name, ds in zip(xds_list, xds):
idx = ds_name.find('band') + 4
bid = ds_name[idx:idx+4]
uv_max = np.maximum(uv_max, ds.uv_max)
max_freq = np.maximum(max_freq, ds.max_freq)
if bid == '0001':
ds_list.append(ds_name)
from pfb.parser.schemas import schema
init_args = {}
for key in schema.spotless["inputs"].keys():
init_args[key.replace("-", "_")] = schema.spotless["inputs"][key]["default"]
opts = OmegaConf.create(init_args)
opts['nthreads'] = 7
opts['field_of_view'] = 2.0
opts['super_resolution_factor'] = 4.0
actor = band_actor(ds_list,
opts,
1,
'/home/landman/testing/pfb/out/',
uv_max,
max_freq)
nx, ny, nymax, nxmax, cell_rad, ra, dec, x0, y0, freq_out, time_out = actor.get_image_info()
print(f"Image size set to ({nx}, {ny})")
bases = tuple(opts.bases.split(','))
nbasis = len(bases)
model = np.zeros((1, nx, ny))
residual, wsum = actor.set_image_data_products(model[0],0,from_cache=False)
residual /= wsum
rms = np.std(residual)
print('rms = ', rms)
hess_norm = 100.0
l1weight = np.ones((nbasis, nymax, nxmax))
ratio = np.zeros(l1weight.shape, dtype=l1weight.dtype)
actor.set_wsum(wsum)
gamma=1
nu = nbasis
lam = rms
sigma = hess_norm / (2.0 * gamma) / nu
update = actor.cg_update()
vtilde, _ = actor.init_pd_params(hess_norm, nbasis, gamma=gamma)
vtilde = vtilde[None]
get_ratio(vtilde, l1weight, sigma, rms, ratio)
ti = time.time()
vtilde, eps_num, eps_den, bandid = actor.pd_update(ratio)
print(time.time() - ti)