|
2 | 2 | from datetime import datetime |
3 | 3 | from pathlib import Path |
4 | 4 |
|
| 5 | +import h5py |
5 | 6 | import numpy as np |
6 | 7 | import pandas as pd |
7 | 8 | import pytest |
8 | | -import h5py |
9 | 9 |
|
| 10 | +import turn_by_turn as tbt |
| 11 | +from tests.test_lhc_and_general import compare_tbt, create_data |
| 12 | +from turn_by_turn.doros import DEFAULT_BUNCH_ID, DataKeys, read_tbt, write_tbt |
10 | 13 | from turn_by_turn.structures import TbtData, TransverseData |
11 | | -from tests.test_lhc_and_general import create_data, compare_tbt |
12 | | - |
13 | | -from turn_by_turn.doros import N_ORBIT_SAMPLES, read_tbt, write_tbt, DEFAULT_BUNCH_ID, POSITIONS |
14 | 14 |
|
15 | 15 | INPUTS_DIR = Path(__file__).parent / "inputs" |
16 | 16 |
|
17 | | -@pytest.mark.parametrize("filename", ["test_doros.h5", "test_doros_2024-09-29.h5"]) |
18 | | -def test_read_write_real_data(tmp_path, filename): |
19 | | - tbt = read_tbt(INPUTS_DIR / filename, bunch_id=10) |
| 17 | +@pytest.mark.parametrize("datatype", DataKeys.types()) |
| 18 | +def test_read_write_real_data(tmp_path, datatype): |
| 19 | + tbt_data = read_tbt(INPUTS_DIR / "test_doros.h5", bunch_id=10, data_type=datatype) |
20 | 20 |
|
21 | | - assert tbt.nbunches == 1 |
22 | | - assert len(tbt.matrices) == 1 |
23 | | - assert tbt.nturns == 50000 |
24 | | - assert tbt.matrices[0].X.shape == (3, tbt.nturns) |
25 | | - assert tbt.matrices[0].Y.shape == (3, tbt.nturns) |
26 | | - assert len(set(tbt.matrices[0].X.index)) == 3 |
27 | | - assert np.all(tbt.matrices[0].X.index == tbt.matrices[0].Y.index) |
| 21 | + assert tbt_data.nbunches == 1 |
| 22 | + assert len(tbt_data.matrices) == 1 |
| 23 | + assert tbt_data.nturns == 50000 |
| 24 | + assert tbt_data.matrices[0].X.shape == (3, tbt_data.nturns) |
| 25 | + assert tbt_data.matrices[0].Y.shape == (3, tbt_data.nturns) |
| 26 | + assert len(set(tbt_data.matrices[0].X.index)) == 3 |
| 27 | + assert np.all(tbt_data.matrices[0].X.index == tbt_data.matrices[0].Y.index) |
28 | 28 |
|
29 | 29 | file_path = tmp_path / "test_file.h5" |
30 | | - write_tbt(tbt, file_path) |
31 | | - new = read_tbt(file_path, bunch_id=10) |
32 | | - compare_tbt(tbt, new, no_binary=False) |
| 30 | + write_tbt(file_path, tbt_data, data_type=datatype) |
| 31 | + new = read_tbt(file_path, bunch_id=10, data_type=datatype) |
| 32 | + compare_tbt(tbt_data, new, no_binary=False) |
33 | 33 |
|
34 | 34 |
|
35 | | -def test_write_read(tmp_path): |
36 | | - tbt = _tbt_data() |
| 35 | +@pytest.mark.parametrize("datatype", DataKeys.types()) |
| 36 | +def test_write_read(tmp_path, datatype): |
| 37 | + tbt_data = _tbt_data() |
37 | 38 | file_path = tmp_path / "test_file.h5" |
38 | | - write_tbt(tbt, file_path) |
39 | | - new = read_tbt(file_path) |
40 | | - compare_tbt(tbt, new, no_binary=False) |
| 39 | + write_tbt(file_path, tbt_data, data_type=datatype) |
| 40 | + new = read_tbt(file_path, data_type=datatype) |
| 41 | + compare_tbt(tbt_data, new, no_binary=False) |
| 42 | + |
| 43 | + |
| 44 | +@pytest.mark.parametrize("datatype", ["doros_oscillations", "doros_positions"]) |
| 45 | +def test_write_read_via_io_module(tmp_path, datatype): |
| 46 | + tbt_data = _tbt_data() |
| 47 | + file_path = tmp_path / "test_file.h5" |
| 48 | + tbt.write(file_path, tbt_data, datatype=datatype) |
| 49 | + new = tbt.read(file_path, datatype=datatype) |
| 50 | + compare_tbt(tbt_data, new, no_binary=False) |
41 | 51 |
|
42 | 52 |
|
43 | 53 | def test_read_raises_different_bpm_lengths(tmp_path): |
44 | | - tbt = _tbt_data() |
| 54 | + tbt_data = _tbt_data() |
45 | 55 | file_path = tmp_path / "test_file.h5" |
46 | | - write_tbt(tbt, file_path) |
| 56 | + data_type = DataKeys.OSCILLATIONS |
| 57 | + write_tbt(file_path, tbt_data, data_type=data_type) |
| 58 | + keys = DataKeys.get_data_keys(data_type) |
47 | 59 |
|
48 | | - bpm = tbt.matrices[0].X.index[0] |
| 60 | + bpm = tbt_data.matrices[0].X.index[0] |
49 | 61 |
|
50 | 62 | # modify the BPM lengths in the file |
51 | 63 | with h5py.File(file_path, "r+") as h5f: |
52 | 64 | delta = 10 |
53 | | - del h5f[bpm][N_ORBIT_SAMPLES] |
54 | | - h5f[bpm][N_ORBIT_SAMPLES] = [tbt.matrices[0].X.shape[1] - delta] |
55 | | - for key in POSITIONS.values(): |
| 65 | + del h5f[bpm][keys.n_samples] |
| 66 | + h5f[bpm][keys.n_samples] = [tbt_data.matrices[0].X.shape[1] - delta] |
| 67 | + for key in keys.data.values(): |
56 | 68 | data = h5f[bpm][key][:-delta] |
57 | 69 | del h5f[bpm][key] |
58 | 70 | h5f[bpm][key] = data |
59 | 71 |
|
60 | 72 | with pytest.raises(ValueError) as e: |
61 | | - read_tbt(file_path) |
| 73 | + read_tbt(file_path, data_type=DataKeys.OSCILLATIONS) |
62 | 74 | assert "Not all BPMs have the same number of turns!" in str(e) |
63 | 75 |
|
64 | 76 |
|
65 | 77 | def test_read_raises_on_different_bpm_lengths_in_data(tmp_path): |
66 | | - tbt = _tbt_data() |
| 78 | + tbt_data = _tbt_data() |
67 | 79 | file_path = tmp_path / "test_file.h5" |
68 | | - write_tbt(tbt, file_path) |
| 80 | + data_type = DataKeys.OSCILLATIONS |
| 81 | + keys = DataKeys.get_data_keys(data_type) |
| 82 | + |
| 83 | + write_tbt(file_path, tbt_data, data_type=data_type) |
69 | 84 |
|
70 | | - bpms = [tbt.matrices[0].X.index[i] for i in (0, 2)] |
| 85 | + bpms = [tbt_data.matrices[0].X.index[i] for i in (0, 2)] |
71 | 86 |
|
72 | 87 | # modify the BPM lengths in the file |
73 | 88 | with h5py.File(file_path, "r+") as h5f: |
74 | 89 | for bpm in bpms: |
75 | | - del h5f[bpm][N_ORBIT_SAMPLES] |
76 | | - h5f[bpm][N_ORBIT_SAMPLES] = [tbt.matrices[0].X.shape[1] + 10] |
| 90 | + del h5f[bpm][keys.n_samples] |
| 91 | + h5f[bpm][keys.n_samples] = [tbt_data.matrices[0].X.shape[1] + 10] |
77 | 92 |
|
78 | 93 | with pytest.raises(ValueError) as e: |
79 | | - read_tbt(file_path) |
| 94 | + read_tbt(file_path, data_type=data_type) |
80 | 95 | assert "Found BPMs with different data lengths" in str(e) |
81 | 96 | assert all(bpm in str(e) for bpm in bpms) |
82 | 97 |
|
|
0 commit comments