Skip to content

Commit 2973619

Browse files
committed
docstring and fixing tests
1 parent ad97776 commit 2973619

File tree

13 files changed

+118
-71
lines changed

13 files changed

+118
-71
lines changed

tests/test_doros.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from __future__ import annotations
2+
13
from datetime import datetime
24
from pathlib import Path
5+
from typing import TYPE_CHECKING
36

47
import h5py
58
import numpy as np
@@ -11,6 +14,9 @@
1114
from turn_by_turn.doros import DEFAULT_BUNCH_ID, DataKeys, read_tbt, write_tbt
1215
from turn_by_turn.structures import TbtData, TransverseData
1316

17+
if TYPE_CHECKING:
18+
from turn_by_turn.constants import MetaDict
19+
1420
INPUTS_DIR = Path(__file__).parent / "inputs"
1521

1622

@@ -100,6 +106,9 @@ def _tbt_data() -> TbtData:
100106
"""TbT data for testing. Adding random noise, so that the data is different per BPM."""
101107
nturns = 2000
102108
bpms = ["TBPM1", "TBPM2", "TBPM3", "TBPM4"]
109+
meta: MetaDict = {
110+
"date": datetime.now(),
111+
}
103112

104113
return TbtData(
105114
matrices=[
@@ -126,7 +135,7 @@ def _tbt_data() -> TbtData:
126135
),
127136
)
128137
],
129-
date=datetime.now(),
130-
bunch_ids=[DEFAULT_BUNCH_ID],
131138
nturns=nturns,
139+
bunch_ids=[DEFAULT_BUNCH_ID],
140+
meta=meta,
132141
)

tests/test_madng.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_write_ng(_ng_file: Path, tmp_path: Path, example_fake_tbt: TbtData):
3232

3333
new_tbt = read_tbt(from_tbt, datatype="madng")
3434
compare_tbt(written_tbt, new_tbt, no_binary=True)
35-
assert written_tbt.date == new_tbt.date
35+
assert written_tbt.meta["date"] == new_tbt.meta["date"]
3636

3737

3838
def test_error_ng(_error_file: Path):

tests/test_ptc_trackone.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def test_read_ptc_raises_on_invalid_file(_invalid_ptc_file):
2424

2525
def test_read_ptc_defaults_date(_ptc_file_no_date):
2626
new = ptc.read_tbt(_ptc_file_no_date)
27-
assert new.date.day == datetime.today().day
28-
assert new.date.tzname() == "UTC"
27+
assert new.meta["date"].day == datetime.today().day
28+
assert new.meta["date"].tzname() == "UTC"
2929

3030

3131
def test_read_ptc_sci(_ptc_file_sci):
@@ -85,7 +85,7 @@ def _original_trackone(track: bool = False) -> TbtData:
8585
Y=pd.DataFrame(index=names, data=[[0.0011, 0.00077614, -0.00022749, -0.00103188]]),
8686
),
8787
]
88-
return TbtData(matrix, None, [0, 1] if track else [1, 2], 4)
88+
return TbtData(matrix, nturns=4, bunch_ids=[0, 1] if track else [1, 2])
8989

9090

9191
def _original_simulation_data() -> TbtData:
@@ -112,9 +112,9 @@ def _original_simulation_data() -> TbtData:
112112
E=pd.DataFrame(index=names, data=[[500.00088, 500.00088, 500.00088, 500.00088]]),
113113
),
114114
]
115-
return TbtData(
116-
matrices, date=None, bunch_ids=[0, 1], nturns=4
117-
) # [0, 1] for bunch_ids because it's from tracking
115+
116+
# [0, 1] for bunch_ids because it's from tracking
117+
return TbtData(matrices, bunch_ids=[0, 1], nturns=4)
118118

119119

120120
# ----- Fixtures ----- #

tests/test_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_compare_average_tbtdata():
122122
)
123123
for i in range(npart)
124124
],
125-
date=datetime.now(),
126125
bunch_ids=range(npart),
127126
nturns=10,
128127
)
@@ -142,7 +141,6 @@ def test_compare_average_tbtdata():
142141
),
143142
)
144143
],
145-
date=datetime.now(),
146144
bunch_ids=[1],
147145
nturns=10,
148146
)

turn_by_turn/ascii.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ def _write_header(tbt_data: TbtData, bunch_id: int, output_file: TextIO) -> None
9090
"""
9191
Write the appropriate headers for a ``TbtData`` object's given bunch_id in the ASCII format.
9292
"""
93+
# fmt: off
9394
output_file.write(f"#{ASCII_ID} v1\n")
94-
output_file.write(
95-
f"#Created: {datetime.now().strftime('%Y-%m-%d at %H:%M:%S')} By: Python turn_by_turn Package\n"
96-
)
95+
output_file.write(f"#Created: {datetime.now().strftime(ACQ_DATE_FORMAT)} By: Python turn_by_turn Package\n")
9796
output_file.write(f"#Number of turns: {tbt_data.nturns}\n")
98-
output_file.write(
99-
f"#Number of horizontal monitors: {tbt_data.matrices[bunch_id].X.index.size}\n"
100-
)
97+
output_file.write(f"#Number of horizontal monitors: {tbt_data.matrices[bunch_id].X.index.size}\n")
10198
output_file.write(f"#Number of vertical monitors: {tbt_data.matrices[bunch_id].Y.index.size}\n")
102-
output_file.write(f"#Acquisition date: {tbt_data.date.strftime('%Y-%m-%d at %H:%M:%S')}\n")
99+
# fmt: on
100+
101+
if date := tbt_data.meta.get("date"):
102+
output_file.write(f"#Acquisition date: {date.strftime(ACQ_DATE_FORMAT)}\n")
103103

104104

105105
def _write_tbt_data(tbt_data: TbtData, bunch_id: int, output_file: TextIO) -> None:

turn_by_turn/iota.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
from enum import Enum
1313
from pathlib import Path
14-
from typing import TYPE_CHECKING, Literal
14+
from typing import Literal
1515

1616
import h5py
1717
import numpy as np

turn_by_turn/lhc.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from turn_by_turn.ascii import is_ascii_file
2020
from turn_by_turn.ascii import read_tbt as read_ascii
21-
from turn_by_turn.constants import PLANE_TO_NUM, PLANES
21+
from turn_by_turn.constants import PLANE_TO_NUM, PLANES, MetaDict
2222
from turn_by_turn.structures import TbtData, TransverseData
2323
from turn_by_turn.utils import matrices_to_array
2424

@@ -68,14 +68,21 @@ def read_tbt(file_path: str | Path) -> TbtData:
6868
bpm_names = sdds_file.values[BPM_NAMES]
6969
nbpms = len(bpm_names)
7070
data = {k: sdds_file.values[POSITIONS[k]].reshape((nbpms, nbunches, nturns)) for k in PLANES}
71+
7172
matrices = [
7273
TransverseData(
7374
X=pd.DataFrame(index=bpm_names, data=data["X"][:, idx, :], dtype=float),
7475
Y=pd.DataFrame(index=bpm_names, data=data["Y"][:, idx, :], dtype=float),
7576
)
7677
for idx in range(nbunches)
7778
]
78-
return TbtData(matrices, date, bunch_ids, nturns)
79+
80+
meta: MetaDict = {
81+
"file": file_path,
82+
"source_datatype": "lhc",
83+
"date": date,
84+
}
85+
return TbtData(matrices, nturns=nturns, bunch_ids=bunch_ids, meta=meta)
7986

8087

8188
def write_tbt(output_path: str | Path, tbt_data: TbtData) -> None:
@@ -101,7 +108,7 @@ def write_tbt(output_path: str | Path, tbt_data: TbtData) -> None:
101108
sdds.classes.Array(POSITIONS["Y"], "float"),
102109
]
103110
values = [
104-
tbt_data.date.timestamp() * 1e9,
111+
tbt_data.meta.get("date", datetime.now(tz=tz.tzutc())).timestamp() * 1e9,
105112
tbt_data.nbunches,
106113
tbt_data.nturns,
107114
tbt_data.bunch_ids,

turn_by_turn/madng.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323

2424
import pandas as pd
2525

26+
from turn_by_turn.structures import TbtData, TransverseData
27+
2628
if TYPE_CHECKING:
27-
from pathlib import Path # Only used for type hinting
29+
from pathlib import Path
2830

2931
import tfs
3032

31-
from turn_by_turn.structures import TbtData, TransverseData
33+
from turn_by_turn.constants import MetaDict
34+
3235

3336
LOGGER = logging.getLogger(__name__)
3437

@@ -72,7 +75,10 @@ def read_tbt(file_path: str | Path) -> TbtData:
7275

7376
LOGGER.debug("Starting to read TBT data from dataframe")
7477
df = tfs.read(file_path)
75-
return convert_to_tbt(df)
78+
tbt_data = convert_to_tbt(df)
79+
tbt_data.meta["file"] = file_path
80+
tbt_data.meta["source_datatype"] = "madng"
81+
return tbt_data
7682

7783

7884
def convert_to_tbt(df: pd.DataFrame | tfs.TfsDataFrame) -> TbtData:
@@ -104,6 +110,8 @@ def convert_to_tbt(df: pd.DataFrame | tfs.TfsDataFrame) -> TbtData:
104110
LOGGER.debug("The 'tfs' package is not installed. Assuming a pandas DataFrame.")
105111
is_tfs_df = False
106112

113+
meta: MetaDict = {}
114+
107115
if is_tfs_df:
108116
date_str = df.headers.get(DATE)
109117
time_str = df.headers.get(TIME)
@@ -112,11 +120,10 @@ def convert_to_tbt(df: pd.DataFrame | tfs.TfsDataFrame) -> TbtData:
112120
time_str = df.attrs.get(TIME)
113121

114122
# Combine the date and time into a datetime object
115-
date = None
116123
if date_str and time_str:
117-
date = datetime.strptime(f"{date_str} {time_str}", "%d/%m/%y %H:%M:%S")
124+
meta["date"] = datetime.strptime(f"{date_str} {time_str}", "%d/%m/%y %H:%M:%S")
118125
elif date_str:
119-
date = datetime.strptime(date_str, "%d/%m/%y")
126+
meta["date"] = datetime.strptime(date_str, "%d/%m/%y")
120127

121128
nturns = int(df.iloc[-1].loc[TURN])
122129
npart = int(df.iloc[-1].loc[PARTICLE_ID])
@@ -161,7 +168,7 @@ def convert_to_tbt(df: pd.DataFrame | tfs.TfsDataFrame) -> TbtData:
161168
matrices.append(TransverseData(**tracking_data_dict))
162169

163170
LOGGER.debug("Finished reading TBT data")
164-
return TbtData(matrices=matrices, bunch_ids=list(particle_ids), nturns=nturns, date=date)
171+
return TbtData(matrices=matrices, bunch_ids=list(particle_ids), nturns=nturns, meta=meta)
165172

166173

167174
def write_tbt(output_path: str | Path, tbt_data: TbtData) -> None:
@@ -236,8 +243,10 @@ def write_tbt(output_path: str | Path, tbt_data: TbtData) -> None:
236243
headers = {
237244
HNAME: "TbtData",
238245
ORIGIN: "Python",
239-
DATE: tbt_data.date.strftime("%d/%m/%y"),
240-
TIME: tbt_data.date.strftime("%H:%M:%S"),
241246
REFCOL: NAME,
242247
}
248+
if date := tbt_data.meta.get("date"):
249+
headers[DATE] = date.strftime("%d/%m/%y")
250+
headers[TIME] = date.strftime("%H:%M:%S")
251+
243252
tfs.write(output_path, merged_df, headers_dict=headers, save_index=NAME)

turn_by_turn/ptc.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import copy
1414
import logging
1515
from collections import namedtuple
16+
from dataclasses import dataclass, field
1617
from datetime import datetime
1718
from pathlib import Path
1819
from typing import TYPE_CHECKING, Any
@@ -21,7 +22,7 @@
2122
import pandas as pd
2223
from dateutil import tz
2324

24-
from turn_by_turn.constants import PLANES
25+
from turn_by_turn.constants import PLANES, MetaDict
2526
from turn_by_turn.errors import PTCFormatError
2627
from turn_by_turn.structures import TbtData, TransverseData
2728

@@ -66,21 +67,27 @@ def read_tbt(file_path: str | Path) -> TbtData:
6667
lines = lines[header_length:]
6768

6869
# parameters
69-
bpms, particles, column_indices, n_turns, n_particles = _read_from_first_turn(lines)
70+
params = _read_from_first_turn(lines)
7071

7172
# read into dict first for speed then convert to DFs
7273
matrices = [
73-
{p: {bpm: np.zeros(n_turns) for bpm in bpms} for p in PLANES} for _ in range(n_particles)
74+
{p: {bpm: np.zeros(params.n_turns) for bpm in params.bpms} for p in PLANES} for _ in range(params.n_particles)
7475
]
75-
matrices = _read_data(lines, matrices, column_indices)
76-
for bunch in range(n_particles):
76+
matrices = _read_data(lines, matrices, params.column_indices)
77+
for bunch in range(params.n_particles):
7778
matrices[bunch] = TransverseData(
7879
X=pd.DataFrame(matrices[bunch]["X"]).transpose(),
7980
Y=pd.DataFrame(matrices[bunch]["Y"]).transpose(),
8081
)
8182

8283
LOGGER.debug(f"Read Tbt matrices from: '{file_path.absolute()}'")
83-
return TbtData(matrices=matrices, date=date, bunch_ids=particles, nturns=n_turns)
84+
85+
meta: MetaDict = {
86+
"date": date,
87+
"file": file_path,
88+
"source_datatype": "ptc",
89+
}
90+
return TbtData(matrices=matrices, nturns=params.n_turns, bunch_ids=params.particles, meta=meta)
8491

8592

8693
def _read_header(lines: Sequence[str]) -> tuple[datetime, int]:
@@ -105,19 +112,23 @@ def _read_header(lines: Sequence[str]) -> tuple[datetime, int]:
105112
return datetime.strptime(f"{date_str[DATE]} {date_str[TIME]}", TIME_FORMAT), idx_line
106113

107114

108-
def _read_from_first_turn(
109-
lines: Sequence[str],
110-
) -> tuple[list[str], list[int], dict[Any, Any], int, int]:
115+
@dataclass(slots=True)
116+
class TbTParams:
117+
""" Parameters read from the first turn of the file. """
118+
bpms: list[str] = field(default_factory=list)
119+
particles: list[int] = field(default_factory=list)
120+
column_indices: dict[Any, Any] | None = None
121+
n_turns: int = 0
122+
n_particles: int = 0
123+
124+
125+
def _read_from_first_turn(lines: Sequence[str]) -> TbTParams:
111126
"""
112127
Reads the BPMs, particles, column indices and number of turns and particles from the matrices of
113128
the first turn.
114129
"""
115130
LOGGER.debug("Reading first turn to define boundary parameters.")
116-
bpms = []
117-
particles = []
118-
column_indices = None
119-
n_turns = 0
120-
n_particles = 0
131+
data = TbTParams()
121132
first_segment = True
122133

123134
for line in lines:
@@ -126,37 +137,38 @@ def _read_from_first_turn(
126137
continue
127138

128139
if parts[0] == NAMES: # read column names
129-
if column_indices is not None:
140+
if data.column_indices is not None:
130141
raise KeyError(f"{NAMES} are defined twice in tbt file!")
131-
column_indices = _parse_column_names_to_indices(parts[1:])
142+
data.column_indices = _parse_column_names_to_indices(parts[1:])
132143
continue
133144

134145
if parts[0] == SEGMENTS: # read segments, append to bunch_id
135146
segment = Segment(*parts[1:])
136147
if segment.name == SEGMENT_MARKER[0]: # start of first segment
137-
n_turns = int(segment.turns) - 1
138-
n_particles = int(segment.particles)
148+
data.n_turns = int(segment.turns) - 1
149+
data.n_particles = int(segment.particles)
139150

140151
elif segment.name == SEGMENT_MARKER[1]: # end of first segment
141152
break
142153

143154
else:
144155
first_segment = False
145-
bpms.append(segment.name)
156+
data.bpms.append(segment.name)
146157

147158
elif first_segment:
148-
if column_indices is None:
159+
if data.column_indices is None:
149160
LOGGER.error("Columns not defined in Tbt file")
150161
raise PTCFormatError
151162

152-
new_data = _parse_data(column_indices, parts)
163+
new_data = _parse_data(data.column_indices, parts)
153164
particle = int(float(new_data[COLPARTICLE]))
154-
particles.append(particle)
165+
data.particles.append(particle)
155166

156-
if len(particles) == 0:
157-
LOGGER.error("No matrices found in TbT file")
158-
raise PTCFormatError
159-
return bpms, particles, column_indices, n_turns, n_particles
167+
if len(data.particles) == 0:
168+
msg = "No particles found in TbT file"
169+
LOGGER.error(msg)
170+
raise PTCFormatError(msg)
171+
return data
160172

161173

162174
def _read_data(

0 commit comments

Comments
 (0)