-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Labels
Description
At the moment we only need to know about the driver and writer and any plugins that contribute NDAttributes. We rely on the constructor of the detector to pass dimensionality information down. We could scan the chain and work out the dimensions, but this would require an exhaustive set of plugins to be created with is too much for the moment.
In the future we could do something like:
@dataclass
class NDArrayInfo:
dtype_numpy: str
"""The numpy dtype for this field e.g. <i2 or <f8"""
shape: tuple[int, ...]
"""The shape of the NDArray, e.g. (768, 1024)"""
attribute_dtypes: dict[str, str]
"""The names and dtype_numpy of each NDAttribute stamped on the NDArray"""
async def get_chain(
driver: ADBaseIO, plugins: set[NDPluginBaseIO], writer: NDPluginBaseIO
) -> Sequence[NDArrayBaseIO]:
# Find the source (port_name) and sink (nd_array_port) of each element
port_lookup, sink_port = await asyncio.gather(
gather_dict({x.port_name.get_value(): x for x in (driver, writer, *plugins)}),
gather_dict(
{
cast(NDArrayBaseIO, x): x.nd_array_port.get_value()
for x in (writer, *plugins)
}
),
)
# Follow the chain back from the writer until we get to the driver
chain: Sequence[NDArrayBaseIO] = (writer,)
while chain[0] is not driver:
upstream = port_lookup[sink_port[chain[0]]]
chain = (upstream, *chain)
return chain
async def get_attribute_dtypes(chain: Sequence[NDArrayBaseIO]) -> dict[str, str]:
nd_attribute_xmls = await asyncio.gather(
*[x.nd_attributes_file.get_value() for x in chain]
)
attribute_dtypes: dict[str, str] = {}
for maybe_xml in nd_attribute_xmls:
# This is the check that ADCore does to see if it is an XML string
# rather than a filename to parse
if "<Attributes>" in maybe_xml:
root = ET.fromstring(maybe_xml)
for child in root:
if child.attrib.get("type", "EPICS_PV") == "EPICS_PV":
dbrtype = child.attrib.get("dbrtype", "DBR_NATIVE")
dtype_numpy = NDAttributePvDbrType(dbrtype).value
else:
datatype = child.attrib.get("datatype", "INT")
dtype_numpy = NDAttributeDataType(datatype).value
attribute_dtypes[child.attrib["name"]] = dtype_numpy
return attribute_dtypes
async def get_ndarray_info(
driver: ADBaseIO, plugins: set[NDPluginBaseIO], writer: NDPluginBaseIO
) -> NDArrayInfo:
"""Scan the chain and return a description of the NDArray the writer will get."""
# Create the chain of driver and plugins and get driver params
chain, size_x, size_y, datatype = await asyncio.gather(
get_chain(driver, plugins, writer),
driver.array_size_x.get_value(),
driver.array_size_y.get_value(),
driver.data_type.get_value(),
)
if datatype is ADBaseDataType.UNDEFINED:
raise ValueError(f"{driver.data_type.source} is blank, this is not supported")
# Get the NDAttribute datatypes for anything in the chain
attribute_dtypes = await get_attribute_dtypes(chain)
# At the moment always generate shape based on the driver
return NDArrayInfo(
dtype_numpy=np.dtype(datatype.value.lower()).str,
shape=(size_y, size_x),
attribute_dtypes=attribute_dtypes,
)Acceptance Criteria
- Decision made and implemented if it is a good idea
Reactions are currently unavailable