|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | from enum import StrEnum |
5 | | -from typing import ClassVar, Literal |
| 5 | +from typing import ClassVar, Literal, get_args, get_origin |
6 | 6 |
|
| 7 | +import pytest |
7 | 8 | from pydantic import Field |
8 | 9 |
|
9 | 10 | from swc.aeon.io.api import load |
|
18 | 19 | class DummyHarpDevice(HarpDevice): |
19 | 20 | """A dummy Harp device.""" |
20 | 21 |
|
21 | | - device_type: Literal["DummyHarpDevice"] = "DummyHarpDevice" |
22 | 22 | who_am_i: ClassVar[int] = 0000 |
23 | 23 |
|
24 | 24 | @data_reader |
@@ -106,3 +106,19 @@ def test_dataset_read_metadata(test_data_dir): |
106 | 106 | """Test that dataset Metadata is loaded successfully.""" |
107 | 107 | metadata = load(test_data_dir, Metadata(DummyDataset)) |
108 | 108 | assert len(metadata) > 0 |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.parametrize( |
| 112 | + ("device_instance", "expected_device_type"), |
| 113 | + [ |
| 114 | + (DummyHarpDevice(port_name="COM3"), "DummyHarpDevice"), |
| 115 | + (DummyCamera(serial_number="12345"), "DummyCamera"), |
| 116 | + ], |
| 117 | +) |
| 118 | +def test_device_type_mixin(device_instance, expected_device_type): |
| 119 | + """Test that DeviceTypeMixin correctly sets device_type to the subclass name.""" |
| 120 | + assert device_instance.device_type == expected_device_type |
| 121 | + # Check that the device_type annotation is Literal[expected_device_type] |
| 122 | + annotation = type(device_instance).__annotations__["device_type"] |
| 123 | + assert get_origin(annotation) == Literal |
| 124 | + assert get_args(annotation) == (expected_device_type,) |
0 commit comments