Skip to content

Commit 08d2548

Browse files
authored
Pass TbTData to Hole-in-one (#557)
1 parent 9ae47e5 commit 08d2548

File tree

5 files changed

+131
-43
lines changed

5 files changed

+131
-43
lines changed

omc3/harpy/handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@
5656

5757

5858
def run_per_bunch(
59-
tbt_data: TbtData, harpy_input: DotDict, file: Path
59+
tbt_data: TbtData, harpy_input: DotDict, output_filename: str
6060
) -> dict[str, tfs.TfsDataFrame]:
6161
"""
6262
Cleans data, analyses frequencies and searches for resonances.
6363
6464
Args:
6565
tbt_data: single bunch `TbtData`.
6666
harpy_input: Analysis settings taken from the commandline.
67+
output_filename: Name of the output files (placed in `harpy_input.outputdir`).
6768
6869
Returns:
6970
Dictionary with a `TfsDataFrame` per plane.
@@ -73,7 +74,7 @@ def run_per_bunch(
7374
model = tfs.read(harpy_input.model, index=COL_NAME).loc[:, COL_S]
7475

7576
bpm_datas, usvs, lins, bad_bpms = {}, {}, {}, {}
76-
output_file_path = harpy_input.outputdir / file.name
77+
output_file_path = harpy_input.outputdir / output_filename
7778

7879
for plane in PLANES:
7980
bpm_data = _get_cut_tbt_matrix(tbt_data, harpy_input.turns, plane)

omc3/hole_in_one.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@
5555
from omc3.utils.contexts import timeit
5656

5757
if TYPE_CHECKING:
58-
from collections.abc import Generator
58+
from collections.abc import Generator, Iterable
5959

6060
from generic_parser import DotDict
6161

6262
LOGGER = logging_tools.get_logger(__name__)
6363

6464
DEFAULT_CONFIG_FILENAME = "analysis_{time:s}.ini"
65+
DATATYPE_TBT: str = "tbt_data"
6566

6667

6768
def hole_in_one_params() -> EntryPointParameters:
@@ -89,7 +90,8 @@ def hole_in_one_entrypoint(opt: DotDict, rest: list[str]) -> None:
8990
Action: ``store_true``
9091
9192
Harpy Kwargs:
92-
- **files**: TbT files to analyse
93+
- **files**: TbT files to analyse.
94+
Can also be the TbtData objects directly, if 'tbt_data' is chosen as datatype.
9395
9496
Flags: **--files**
9597
Required: ``True``
@@ -405,58 +407,68 @@ def _run_harpy(harpy_options: DotDict) -> list[Path]:
405407
iotools.create_dirs(harpy_options.outputdir)
406408
with timeit(lambda spanned: LOGGER.info(f"Total time for Harpy: {spanned}")):
407409
lins = []
408-
tbt_datas = [
409-
(tbt.read_tbt(file, datatype=harpy_options.tbt_datatype), file)
410-
for file in harpy_options.files
411-
]
410+
tbt_datas = _parse_tbt_data(harpy_options.files, harpy_options.tbt_datatype)
412411
for tbt_data, file in tbt_datas:
413412
lins.extend(
414413
[
415-
handler.run_per_bunch(bunch_data, harpy_options, bunch_file)
416-
for bunch_data, bunch_file in _add_suffix_and_iter_bunches(
414+
handler.run_per_bunch(bunch_data, harpy_options, name_for_bunch)
415+
for bunch_data, name_for_bunch in _add_suffix_and_iter_bunches(
417416
tbt_data, harpy_options, file
418417
)
419418
]
420419
)
421420
return lins
422421

423422

423+
def _parse_tbt_data(files: Iterable[Path | str | tbt.TbtData], tbt_datatype: str
424+
) -> list[tuple[tbt.TbtData, str]]:
425+
"""Parse the turn-by-turn data reading given files or TbtData objects."""
426+
if tbt_datatype == DATATYPE_TBT:
427+
try:
428+
return [(file, Path(file.meta["file"]).name) for file in files]
429+
except KeyError as e:
430+
raise KeyError(
431+
"To determine output naming for hole-in-one, "
432+
"the given TbT objects must contain a 'file' entry in their meta-data."
433+
) from e
434+
435+
return [(tbt.read_tbt(file, datatype=tbt_datatype), Path(file).name) for file in files]
436+
437+
424438
def _add_suffix_and_iter_bunches(
425-
tbt_data: tbt.TbtData, options: DotDict, file: Path
426-
) -> Generator[tuple[tbt.TbtData, Path], None, None]:
427-
"""Add suffix to output files and iterate over bunches."""
428-
dir_name: Path = file.parent
429-
file_name: str = file.name
439+
tbt_data: tbt.TbtData, options: DotDict, file_name: str
440+
) -> Generator[tuple[tbt.TbtData, str], None, None]:
441+
"""Add the additional suffix (if given by user) to output files and
442+
split the TbT data into bunches to analyse them individually."""
430443
suffix: str = options.suffix or ""
431444

432445
# Single bunch
433446
if tbt_data.nbunches == 1:
434-
if suffix:
435-
file = dir_name / f"{file_name}{suffix}"
436-
yield tbt_data, file
447+
file_name_out = f"{file_name}{suffix}"
448+
yield tbt_data, file_name_out
437449
return
438450

439451
# Multibunch
440452
if options.bunch_ids is not None:
441453
unknown_bunches = set(options.bunch_ids) - set(tbt_data.bunch_ids)
442454
if unknown_bunches:
443-
LOGGER.warning(f"Bunch IDs {unknown_bunches} not present in multi-bunch file {file}.")
455+
LOGGER.warning(f"Bunch IDs {unknown_bunches} not present in multi-bunch file {file_name}.")
444456

445457
for index in range(tbt_data.nbunches):
446458
bunch_id = tbt_data.bunch_ids[index]
447459
if options.bunch_ids is not None and bunch_id not in options.bunch_ids:
448460
continue
449461

450462
bunch_id_str = f"_bunchID{bunch_id}"
451-
file = dir_name / f"{file_name}{bunch_id_str}{suffix}"
463+
file_name_out = f"{file_name}{bunch_id_str}{suffix}"
452464
yield (
453465
tbt.TbtData(
454466
matrices=[tbt_data.matrices[index]],
455467
nturns=tbt_data.nturns,
456468
bunch_ids=[bunch_id],
457469
meta=tbt_data.meta,
458470
),
459-
file,
471+
file_name_out,
460472
)
461473

462474

@@ -499,18 +511,34 @@ def _harpy_entrypoint(params: list[str]) -> tuple[DotDict, list[str]]:
499511
raise AttributeError(
500512
"The magnet order for resonance lines calculation should be between 2 and 8 (inclusive)."
501513
)
502-
options.files = [Path(file) for file in options.files]
503514
options.outputdir = Path(options.outputdir)
504515
return options, rest
505516

506517

507518
def harpy_params() -> EntryPointParameters:
508519
"""Create the entry point parameters for harpy."""
520+
# fmt: off
509521
params = EntryPointParameters()
510-
params.add_parameter(name="files", required=True, nargs="+", help="TbT files to analyse")
511-
params.add_parameter(name="outputdir", required=True, help="Output directory.")
512-
params.add_parameter(name="suffix", type=str, help="User-defined suffix for output filenames.")
513-
params.add_parameter(name="model", help="Model for BPM locations")
522+
params.add_parameter(
523+
name="files",
524+
required=True,
525+
nargs="+",
526+
help="TbT files to analyse. Can also be the TbtData objects directly, if 'tbt_data' is chosen as datatype."
527+
)
528+
params.add_parameter(
529+
name="outputdir",
530+
required=True,
531+
help="Output directory."
532+
)
533+
params.add_parameter(
534+
name="suffix",
535+
type=str,
536+
help="User-defined suffix for output filenames."
537+
)
538+
params.add_parameter(
539+
name="model",
540+
help="Model for BPM locations"
541+
)
514542
params.add_parameter(
515543
name="unit",
516544
type=str,
@@ -541,7 +569,7 @@ def harpy_params() -> EntryPointParameters:
541569
params.add_parameter(
542570
name="tbt_datatype",
543571
default=HARPY_DEFAULTS["tbt_datatype"],
544-
choices=list(tbt.io.TBT_MODULES.keys()),
572+
choices=list(tbt.io.TBT_MODULES.keys()) + [DATATYPE_TBT],
545573
help="Choose the datatype from which to import. ",
546574
)
547575

@@ -584,7 +612,11 @@ def harpy_params() -> EntryPointParameters:
584612
"and renormalisation in iterative SVD cleaning of dominant BPMs."
585613
" This is also equal to maximal number of BPMs removed per SVD mode.",
586614
)
587-
params.add_parameter(name="bad_bpms", nargs="*", help="Bad BPMs to clean.")
615+
params.add_parameter(
616+
name="bad_bpms",
617+
nargs="*",
618+
help="Bad BPMs to clean."
619+
)
588620
params.add_parameter(
589621
name="wrong_polarity_bpms",
590622
nargs="*",
@@ -691,6 +723,7 @@ def harpy_params() -> EntryPointParameters:
691723
default=HARPY_DEFAULTS["resonances"],
692724
help="Maximum magnet order of resonance lines to calculate.",
693725
)
726+
# fmt: on
694727
return params
695728

696729

@@ -708,11 +741,23 @@ def _optics_entrypoint(params: list[str]) -> tuple[DotDict, list[str]]:
708741

709742
def optics_params() -> EntryPointParameters:
710743
"""Create the entry point parameters for optics."""
744+
# fmt: off
711745
params = EntryPointParameters()
712-
params.add_parameter(name="files", required=True, nargs="+", help="Files for analysis")
713-
params.add_parameter(name="outputdir", required=True, help="Output directory")
714746
params.add_parameter(
715-
name="calibrationdir", type=str, help="Path to calibration files directory."
747+
name="files",
748+
required=True,
749+
nargs="+",
750+
help="Files for analysis"
751+
)
752+
params.add_parameter(
753+
name="outputdir",
754+
required=True,
755+
help="Output directory"
756+
)
757+
params.add_parameter(
758+
name="calibrationdir",
759+
type=str,
760+
help="Path to calibration files directory."
716761
)
717762
params.add_parameter(
718763
name="coupling_method",
@@ -761,7 +806,9 @@ def optics_params() -> EntryPointParameters:
761806
help="Use 3 BPM method in beta from phase",
762807
)
763808
params.add_parameter(
764-
name="only_coupling", action="store_true", help="Calculate only coupling. "
809+
name="only_coupling",
810+
action="store_true",
811+
help="Calculate only coupling. "
765812
)
766813
params.add_parameter(
767814
name="compensation",
@@ -797,6 +844,7 @@ def optics_params() -> EntryPointParameters:
797844
help="Filter files to analyse by this value (in analysis for tune, phase, rdt and crdt). "
798845
"Use `None` for no filtering",
799846
)
847+
# fmt: on
800848
return params
801849

802850

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,6 @@ markers = [
136136

137137
addopts = [
138138
"--import-mode=importlib",
139-
"--cov-report=term-missing",
140-
"--cov-config=pyproject.toml",
141-
"--cov=omc3",
142139
]
143140

144141
# Helpful for pytest-debugging (leave commented out on commit):

tests/unit/test_harpy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def test_input_suffix_and_single_bunch(suffix):
1818
""" Tests the function :func:`omc3.hole_in_one._add_suffix_and_loop_over_bunches`
1919
by checking that the suffix is attached to single-bunch files."""
20-
file_name = Path("input_file.sdds")
20+
input_name = "input_file.sdds"
2121
options = DotDict(
2222
suffix=suffix,
2323
bunch_ids=None,
@@ -28,10 +28,10 @@ def test_input_suffix_and_single_bunch(suffix):
2828
bunch_ids=[0],
2929
)
3030
n_data = 0
31-
for data, file in _add_suffix_and_iter_bunches(tbt_data, options, file_name):
31+
for data, file_name in _add_suffix_and_iter_bunches(tbt_data, options, input_name):
3232
suffix_str = suffix or ""
33-
assert file.name.endswith(f"{file_name}{suffix_str}")
34-
assert "bunchID" not in str(file)
33+
assert file_name == f"{input_name}{suffix_str}"
34+
assert "bunchID" not in input_name
3535
assert data is tbt_data
3636
n_data += 1
3737

@@ -45,7 +45,7 @@ def test_input_suffix_and_multibunch(suffix, bunches):
4545
""" Tests the function :func:`omc3.hole_in_one._add_suffix_and_loop_over_bunches`
4646
by checking that the suffixes are attached to multi-bunch files and they are
4747
split up into single-bunch files correctly."""
48-
file_name = Path("input_file.sdds")
48+
input_name = "input_file.sdds"
4949
options = DotDict(
5050
suffix=suffix,
5151
bunch_ids=None if bunches is None else list(bunches),
@@ -58,10 +58,10 @@ def test_input_suffix_and_multibunch(suffix, bunches):
5858
n_data = 0
5959
bunch_ids = bunches or tbt_data.bunch_ids
6060
matrices = [tbt_data.matrices[tbt_data.bunch_ids.index(id_)] for id_ in bunch_ids]
61-
for (data, file), bunch_id, matrix in zip(_add_suffix_and_iter_bunches(tbt_data, options, file_name), bunch_ids, matrices):
61+
for (data, filename_with_suffix), bunch_id, matrix in zip(_add_suffix_and_iter_bunches(tbt_data, options, input_name), bunch_ids, matrices):
6262
bunch_str = f"_bunchID{bunch_id}"
6363
suffix_str = suffix or ""
64-
assert file.name.endswith(f"{file_name}{bunch_str}{suffix_str}")
64+
assert filename_with_suffix == f"{input_name}{bunch_str}{suffix_str}"
6565

6666
assert len(data.matrices) == 1
6767
assert data.matrices[0] == matrix

tests/unit/test_hole_in_one.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from typing import TYPE_CHECKING
2626

2727
import pytest
28+
from turn_by_turn import TbtData
29+
from tests.accuracy.test_harpy import _get_model_dataframe
30+
from tests.unit.test_harpy import create_tbt_data
2831

2932
from omc3.harpy.constants import FILE_AMPS_EXT, FILE_FREQS_EXT, FILE_LIN_EXT
3033
from omc3.hole_in_one import (
@@ -63,7 +66,6 @@
6366
"-50Hz": ["Beam1@BunchTurn@2024_03_08@18_24_02_100_250turns.sdds", "Beam1@BunchTurn@2024_03_08@18_25_23_729_250turns.sdds", "Beam1@BunchTurn@2024_03_08@18_26_41_811_250turns.sdds"],
6467
}
6568

66-
6769
@pytest.mark.extended
6870
@pytest.mark.parametrize("which_files", ("SINGLE", "0Hz", "all"))
6971
@pytest.mark.parametrize("clean", (True, False), ids=ids_str("clean={}"))
@@ -173,6 +175,46 @@ def test_hole_in_one(tmp_path, clean, which_files, caplog):
173175
)
174176

175177

178+
@pytest.mark.basic
179+
def test_harpy_tbtdata_ok(tmp_path):
180+
""" Tests the harpy entrypoint with `tbt_datatype == 'tbt_data'`."""
181+
# Mock some TbT data
182+
model = _get_model_dataframe()
183+
tbt_data = create_tbt_data(model=model, bunch_ids=[0])
184+
tbt_data.meta = {"file": "test.sdds"}
185+
186+
hole_in_one_entrypoint(
187+
harpy=True,
188+
files=[tbt_data],
189+
tbt_datatype="tbt_data",
190+
unit='m',
191+
autotunes="transverse",
192+
clean=False,
193+
outputdir=tmp_path,
194+
)
195+
196+
197+
@pytest.mark.basic
198+
def test_harpy_tbtdata_no_name(tmp_path):
199+
""" Tests the harpy entrypoint by checking that meta-field `file` is required
200+
when using `tbt_datatype == 'tbt_data'`."""
201+
# Mock some TbT data
202+
model = _get_model_dataframe()
203+
tbt_data = create_tbt_data(model=model, bunch_ids=[0])
204+
tbt_data.meta = {} # in case someone adds a meta in the create_tbt_data helper
205+
206+
with pytest.raises(KeyError) as e:
207+
hole_in_one_entrypoint(
208+
harpy=True,
209+
files=[tbt_data],
210+
tbt_datatype="tbt_data",
211+
unit='m',
212+
autotunes="transverse",
213+
clean=False,
214+
outputdir=tmp_path,
215+
)
216+
assert "must contain a 'file' entry" in str(e.value)
217+
176218
# Helper -----------------------------------------------------------------------
177219

178220
def _check_all_harpy_files(outputdir: Path, sdds_file: Path):

0 commit comments

Comments
 (0)