Skip to content

Commit 8414089

Browse files
authored
Merge pull request #16 from hz-b/dev/feature/lattice-representation
Refacture of data model: separate lattice elements from integration attributes
2 parents 5b7fc0b + 37ed112 commit 8414089

31 files changed

+368
-239
lines changed

lat2db/bl/create_machine.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import logging
2+
from typing import Dict, Callable, Union
3+
4+
from ..model.geometric_info import GeometricInfo
5+
from ..model.lattice_elements.beam_position_monitor import BeamPositionMonitor
6+
from ..model.lattice_elements.bending import Bending
7+
from ..model.lattice_elements.cavity import Cavity
8+
from ..model.lattice_elements.drift import Drift
9+
from ..model.lattice_elements.element import Element, ElementTypeNames
10+
from ..model.lattice_elements.marker import Marker
11+
from ..model.lattice_elements.quadrupole import Quadrupole
12+
from ..model.lattice_elements.sextupole import Sextupole
13+
from ..model.lattice_elements.steerer import Steerer
14+
from ..model.machine import (
15+
extact_version_from_lattice,
16+
extract_physics_info_from_lattice,
17+
Machine,
18+
)
19+
20+
21+
logger = logging.getLogger("lat2db")
22+
23+
_elm_t_nams = ElementTypeNames
24+
25+
factory_dict_default = {
26+
_elm_t_nams.marker.value: Marker,
27+
_elm_t_nams.bpm.value: BeamPositionMonitor,
28+
_elm_t_nams.drift.value: Drift,
29+
_elm_t_nams.bending.value: Bending,
30+
_elm_t_nams.quadrupole.value: Quadrupole,
31+
_elm_t_nams.sextupole.value: Sextupole,
32+
_elm_t_nams.steerer.value: Steerer,
33+
"Horizontalsteerer": None,
34+
"Verticalsteerer": None,
35+
_elm_t_nams.cavity.value: Cavity,
36+
}
37+
38+
39+
def create_machine(lat, factory_dict: Dict[str, Callable] = None) -> Machine:
40+
factory_dict = factory_dict or factory_dict_default
41+
elms = [
42+
process_element(standardise_element_info(row_), factory_dict=factory_dict)
43+
for row_ in lat.elements
44+
]
45+
# Todo: how to hande elements we are not processing further ...
46+
# shall one add a place holder
47+
elms = [elm for elm in elms if elm is not None]
48+
machine = Machine(
49+
sequences=[elms],
50+
version=extact_version_from_lattice(
51+
lat.lattice_standard_metadata.lattice_version
52+
),
53+
phyics_info=extract_physics_info_from_lattice(lat.properties.physics.energy),
54+
geometric_info=GeometricInfo(is_ring=lat.properties.geometric.is_ring),
55+
name=lat.lattice_standard_metadata.machine_name,
56+
closed=lat.lattice_standard_metadata.closed,
57+
)
58+
return machine
59+
60+
61+
def process_element(elem_info: Dict, *, factory_dict) -> Element:
62+
type_class = factory_dict[elem_info["type"]]
63+
if not type_class:
64+
return None
65+
try:
66+
r = type_class(**elem_info)
67+
except Exception as exc:
68+
pass
69+
raise exc
70+
return r
71+
72+
73+
def standardise_element_info(
74+
elem_info: Dict[str, Union[str, int, float]], *, copy: bool = True
75+
) -> Dict[str, Union[str, int, float, Dict[str, int]]]:
76+
77+
if copy:
78+
res = elem_info.copy()
79+
logger.debug("elem_info= %s", res)
80+
81+
# revamp parameters as required for the dataclasses
82+
res["length"] = res.pop("L", 0e0)
83+
84+
type_name = res["type"]
85+
# address main strength
86+
if type_name in [_elm_t_nams.quadrupole.value, _elm_t_nams.sextupole.value]:
87+
if "K" in res.keys():
88+
res["main_multipole_strength"] = res.pop("K")
89+
elif type_name == _elm_t_nams.bending.value:
90+
if "K" in res.keys():
91+
res["quadrupole_strength"] = res.pop("K")
92+
elif type_name == _elm_t_nams.steerer.value:
93+
if "K" in res.keys():
94+
raise AssertionError("Did not expect K in steerer data")
95+
96+
if type_name in [
97+
_elm_t_nams.bending.value,
98+
_elm_t_nams.quadrupole.value,
99+
_elm_t_nams.sextupole.value,
100+
_elm_t_nams.steerer.value,
101+
]:
102+
# Integration info now in a separate dataclass
103+
if "N" in res.keys():
104+
res["integration_parameters"] = dict(
105+
n_slices=int(res.pop("N")),
106+
symplectic_order=int(res.pop("Method")),
107+
)
108+
109+
# renaming for specific parts
110+
if type_name == _elm_t_nams.bending.value:
111+
res["bending_angle"] = res.pop("T", 0e0)
112+
res["entry_angle"] = res.pop("T1", 0e0)
113+
res["exit_angle"] = res.pop("T2", 0e0)
114+
115+
if type_name == _elm_t_nams.cavity.value:
116+
res["frequency"] = res.pop("Frequency", 0e0)
117+
res["voltage"] = res.pop("Voltage", 0e0) # rename "Voltage" to "voltage"
118+
res["harmonic_number"] = res.pop("Harmonicnumber", 0e0)
119+
else:
120+
pass # do nothing if type_name is not recognized
121+
122+
return res
123+
124+
125+
__all__ = ["create_machine"]

lat2db/bl/set_machine.py

Lines changed: 10 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,27 @@
1-
import copy
2-
import logging
31
import os
42

53
import jsons
4+
from starlette.testclient import TestClient
5+
66
from fastapi import FastAPI
77
from pymongo import MongoClient
88

9-
from lat2db import mongodb_url
10-
from lat2db.controller import machine_controller
11-
from lat2db.model.beam_position_monitor import BeamPositionMonitor
12-
from lat2db.model.bending import Bending
13-
from lat2db.model.cavity import Cavity
14-
from lat2db.model.drift import Drift
15-
from lat2db.model.machine import Machine
16-
from lat2db.model.marker import Marker
17-
from lat2db.model.quadrupole import Quadrupole
18-
from lat2db.model.sequencer import Sequencer
19-
from lat2db.model.sextupole import Sextupole
20-
from lat2db.model.steerer import Steerer
9+
from .. import mongodb_url
10+
from ..model.machine import Machine
11+
from ..controller import machine_controller
12+
2113

2214
app = FastAPI()
2315
app.include_router(machine_controller.router, tags=["machines"], prefix="/machine")
2416
app.mongodb_client = MongoClient(mongodb_url)
2517
DB_NAME = os.environ.get("MONGODB_DB", "bessyii")
2618
app.database = app.mongodb_client[DB_NAME]
2719

28-
logger = logging.getLogger("tools")
29-
30-
31-
def create_machine(lat):
32-
machine = Machine()
33-
machine.set_base_parameters( lat )
34-
35-
# iterate through each row in lat.elements
36-
for row_ in lat.elements:
37-
38-
# make a copy of the row so that changes don't affect original data
39-
row = copy.copy(row_)
40-
41-
# print the row
42-
print(f'{row} ----')
20+
def set_machine(machine: Machine):
21+
# return machine
4322

44-
# revamp parameters as required for the dataclasses
45-
row.setdefault("length", row.pop("L", 0e0)) # rename "L" to "length"
46-
type_name = row['type']
47-
# Ensure 'passmethod' and 'tags' keys exist in the row dictionary
48-
row.setdefault("passmethod", None)
49-
row.setdefault("tags", None)
50-
if type_name in ["Bending", "Quadrupole", "Sextupole", "Steerer"]:
51-
row.setdefault("main_multipole_strength", row.pop("K", 0e0)) # rename "K" to "main_multipole_strength"
52-
row.setdefault("number_of_integration_steps",
53-
row.pop("N", 1)) # rename "N" to "number_of_integration_steps"
54-
row.setdefault("method", row.pop("Method", 4)) # rename "Method" to "method"
55-
elif type_name == "Cavity":
56-
row.setdefault("frequency", row.pop("Frequency", 0e0)) # rename "Frequency" to "frequency"
57-
row.setdefault("voltage", row.pop("Voltage", 0e0)) # rename "Voltage" to "voltage"
58-
row.setdefault("harmonic_number",
59-
row.pop("Harmonicnumber", 0e0)) # rename "HarmonicNumber" to "harmonic_number"
60-
else:
61-
pass # do nothing if type_name is not recognized
62-
63-
if type_name == "Bending":
64-
row.setdefault("bending_angle", row.pop("T", 0e0)) # rename "T" to "bending_angle"
65-
row.setdefault("entry_angle", row.pop("T1", 0e0)) # rename "T1" to "entry_angle"
66-
row.setdefault("exit_angle", row.pop("T2", 0e0)) # rename "T2" to "exit_angle"
67-
68-
# create a Sequencer instance using the modified row
69-
sequence_item = Sequencer(**row)
70-
71-
# add the sequence_item to the appropriate list in the machine instance based on its type_name
72-
type_dict = {
73-
"Drift": (Drift, machine.add_drift),
74-
"Marker": (Marker, machine.add_marker),
75-
"Sextupole": (Sextupole, machine.add_sextupole),
76-
"Steerer": (Steerer, machine.add_steerer()),
77-
"Bending": (Bending, machine.add_bending),
78-
"Quadrupole": (Quadrupole, machine.add_quadrupole),
79-
"Bpm": (BeamPositionMonitor, machine.add_beam_position_monitor),
80-
"Cavity": (Cavity, machine.add_cavity),
81-
}
82-
83-
type_class, type_method = type_dict.get(type_name, (None, None))
84-
if type_class is None or type_method is None:
85-
if type_name in ['Horizontalsteerer','Verticalsteerer']:
86-
continue #ignore the two steerers
87-
else:
88-
raise KeyError(f"Don't know type {type_name}")
89-
sequence_item.set_properties(row)
90-
type_instance = type_class(**row)
91-
machine.add_to_sequence(sequence_item)
92-
type_method(type_instance)
93-
94-
# return machine
95-
from starlette.testclient import TestClient
9623
with TestClient(app) as client:
97-
response = client.post("/machine/machine", json=jsons.dump(machine))
24+
response = client.post("/machine/machine", json=jsons.dump(machine.to_dict()))
9825
if response.status_code != 201:
9926
raise AssertionError(f"Got response {response}")
27+

lat2db/controller/machine_controller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def update_quadrupole_details(id: str, quad_name: str, request_body: Quadrupole_
161161
if "sequences" in machine:
162162
sequences_list = machine.get("sequences", [])
163163
for item_index, item in enumerate(sequences_list):
164-
if item.get("name") == request_body.updated_data.name and item.get("type") == "Quadrupole":
164+
if item.get("name") == request_body.updated_data.name and item.get("type") == ElmTNams.quadrupole:
165165
removed_quadrupole = sequences_list.pop(item_index)
166166
# affected_drift=item.get("index")
167167
print("******* affected drif index is ", affected_drift)
@@ -297,7 +297,7 @@ def update_quadrupole_details_copy(id: str, quad_name: str, request_body: Quadru
297297
if "sequences" in machine_copy:
298298
sequences_list = machine_copy.get("sequences", [])
299299
for item_index, item in enumerate(sequences_list):
300-
if item.get("name") == request_body.updated_data.name and item.get("type") == "Quadrupole":
300+
if item.get("name") == request_body.updated_data.name and item.get("type") == ElmTNams.quadrupole:
301301
removed_quadrupole = sequences_list.pop(item_index)
302302
# affected_drift=item.get("index")
303303
print("******* affected drif index is ", affected_drift)
@@ -809,7 +809,7 @@ def update_quadrupole_from_sequence(id: str, target_drift: str, quad_name: str,
809809
print("inside the drift")
810810
try:
811811
get_qud_prev = database.find_one(
812-
{"id": str(id), "sequences": {"$elemMatch": {"name": quad_name, "type": "Quadrupole"}}},
812+
{"id": str(id), "sequences": {"$elemMatch": {"name": quad_name, "type": ElmTNams.quadrupole}}},
813813
projection={"sequences.$": 1}
814814
)
815815

lat2db/model/element.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

lat2db/model/energy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Energy:
1010
#: energy or particle_energy
1111
name: str
1212
#: specify the unit
13-
#: eV to be consitent with the beam energy
13+
#: eV to be consistent with the beam energy
1414
#: deviates from SI, but that would be an acceptable compromise to
15-
#: the comumuity
15+
#: the community
1616
value: float

lat2db/model/integrator_configuration/__init__.py

Whitespace-only changes.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pydantic import BaseModel
2+
3+
4+
class ATSpecificInfo(BaseModel):
5+
passmethod: str
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel
2+
3+
4+
class IntegrationParameters(BaseModel):
5+
n_slices: int
6+
symplectic_order : int

lat2db/model/lattice_elements/__init__.py

Whitespace-only changes.

lat2db/model/beam_position_monitor.py renamed to lat2db/model/lattice_elements/beam_position_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lat2db.model.element import Element
1+
from .element import Element
22

33

44
class BeamPositionMonitor(Element):

0 commit comments

Comments
 (0)