Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions subiquity/common/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def __str__(self):
return f"processing {self.obj}: at {p}, {self.message}"


E = typing.TypeVar("E")


class NonExhaustive(typing.Generic[E]):
pass


@attr.s(auto_attribs=True)
class SerializationContext:
obj: typing.Any
Expand Down Expand Up @@ -71,6 +78,8 @@ def assert_type(self, typ):
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
# but that's not packaged and this is enough for our needs.

_enum_has_str_values = {}


class Serializer:
def __init__(
Expand All @@ -86,6 +95,7 @@ def __init__(
typing.List: self._walk_List,
dict: self._walk_Dict,
typing.Dict: self._walk_Dict,
NonExhaustive: self._walk_NonExhaustive,
}
self.type_serializers = {}
self.type_deserializers = {}
Expand All @@ -97,6 +107,24 @@ def __init__(
self.type_serializers[datetime.datetime] = self._serialize_datetime
self.type_deserializers[datetime.datetime] = self._deserialize_datetime

def _ann_ok_as_dict_key(self, annotation):
if annotation is str:
return True
origin = getattr(annotation, "__origin__", None)
if origin is NonExhaustive:
annotation = annotation.__args__[0]
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
if self.serialize_enums_by == "name":
return True
else:
if annotation in _enum_has_str_values:
return _enum_has_str_values[annotation]
ok = set(type(v.value) for v in annotation) == {str}
_enum_has_str_values[annotation] = ok
return ok
else:
return False

def _scalar(self, annotation, context):
context.assert_type(annotation)
return context.cur
Expand Down Expand Up @@ -139,20 +167,38 @@ def _walk_List(self, meth, args, context):

def _walk_Dict(self, meth, args, context):
k_ann, v_ann = args
if not context.serializing and k_ann is not str:
input_items = context.cur
else:
if self._ann_ok_as_dict_key(k_ann):
input_items = context.cur.items()
elif context.serializing:
input_items = context.cur.items()
else:
input_items = context.cur
output_items = [
[
meth(k_ann, context.child(f"/{k}", k)),
meth(v_ann, context.child(f"[{k}]", v)),
]
for k, v in input_items
]
if context.serializing and k_ann is not str:
if self._ann_ok_as_dict_key(k_ann):
return dict(output_items)
elif context.serializing:
return output_items
return dict(output_items)
else:
return dict(output_items)

def _walk_NonExhaustive(self, meth, args, context):
[enum_cls] = args
if context.serializing:
if isinstance(context.cur, enum_cls):
return meth(enum_cls, context)
else:
return context.cur
else:
if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls):
return meth(enum_cls, context)
else:
return context.cur

def _serialize_dict(self, annotation, context):
context.assert_type(annotation)
Expand Down
46 changes: 45 additions & 1 deletion subiquity/common/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

import attr

from subiquity.common.serialize import SerializationError, Serializer, named_field
from subiquity.common.serialize import (
NonExhaustive,
SerializationError,
Serializer,
named_field,
)


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -61,6 +66,10 @@ class MyEnum(enum.Enum):
name = "value"


class MyIntEnum(enum.Enum):
name = 1


class CommonSerializerTests:
simple_examples = [
(int, 1),
Expand Down Expand Up @@ -129,12 +138,24 @@ def test_rountrip_union(self):
def test_enums(self):
self.assertSerialization(MyEnum, MyEnum.name, "name")

def test_non_exhaustive_enums(self):
self.serializer = type(self.serializer)(compact=self.serializer.compact)
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name")
self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2")

def test_enums_by_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(MyEnum, MyEnum.name, "value")

def test_non_exhaustive_enums_by_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value")
self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2")

def test_serialize_any(self):
o = object()
self.assertSerialization(typing.Any, o, o)
Expand Down Expand Up @@ -259,6 +280,29 @@ class Type:
self.serializer.deserialize(Type, {"field-1": 1, "field2": 2})
self.assertEqual(catcher.exception.path, "['field-1']")

def test_serialize_dict_enumkeys_name(self):
self.assertSerialization(
typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"name": "b"}
)

def test_serialize_dict_enumkeys_str_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(
typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"value": "b"}
)

def test_serialize_dict_enumkeys_notstr_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(
typing.Dict[MyIntEnum, str],
{MyIntEnum.name: "b"},
[[1, "b"]],
)


class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer(compact=True)
Expand Down
6 changes: 3 additions & 3 deletions subiquity/server/controllers/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(self, app):
self._on_volume: Optional[snapdapi.OnVolume] = None
self._source_handler: Optional[AbstractSourceHandler] = None
self._system_mounter: Optional[Mounter] = None
self._role_to_device: Dict[str, _Device] = {}
self._role_to_device: Dict[Union[str, snapdapi.Role], _Device] = {}
self._device_to_structure: Dict[_Device, snapdapi.OnVolume] = {}
self._pyudev_context: Optional[pyudev.Context] = None
self.use_tpm: bool = False
Expand Down Expand Up @@ -949,9 +949,9 @@ async def setup_encryption(self, context):
step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION,
on_volumes=self._on_volumes(),
),
ann=snapdapi.SystemActionResponse,
)
role_to_encrypted_device = result["encrypted-devices"]
for role, enc_path in role_to_encrypted_device.items():
for role, enc_path in result.encrypted_devices.items():
arb_device = ArbitraryDevice(m=self.model, path=enc_path)
self.model._actions.append(arb_device)
part = self._role_to_device[role]
Expand Down
6 changes: 3 additions & 3 deletions subiquity/server/controllers/tests/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,11 +1537,11 @@ async def test_from_sample_data(self):
with mock.patch.object(
snapdapi, "post_and_wait", new_callable=mock.AsyncMock
) as mocked:
mocked.return_value = {
"encrypted-devices": {
mocked.return_value = snapdapi.SystemActionResponse(
encrypted_devices={
snapdapi.Role.SYSTEM_DATA: "enc-system-data",
},
}
)
await self.fsc.setup_encryption(context=self.fsc.context)

# setup_encryption mutates the filesystem model objects to
Expand Down
26 changes: 15 additions & 11 deletions subiquity/server/snapdapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from subiquity.common.api.client import make_client
from subiquity.common.api.defs import Payload, api, path_parameter
from subiquity.common.serialize import Serializer, named_field
from subiquity.common.serialize import NonExhaustive, Serializer, named_field
from subiquity.common.types import Change, TaskStatus

log = logging.getLogger("subiquity.server.snapdapi")
Expand Down Expand Up @@ -90,17 +90,11 @@ class Response:
status: str


class Role:
class Role(enum.Enum):
NONE = ""
MBR = "mbr"
SYSTEM_BOOT = "system-boot"
SYSTEM_BOOT_IMAGE = "system-boot-image"
SYSTEM_BOOT_SELECT = "system-boot-select"
SYSTEM_DATA = "system-data"
SYSTEM_RECOVERY_SELECT = "system-recovery-select"
SYSTEM_SAVE = "system-save"
SYSTEM_SEED = "system-seed"
SYSTEM_SEED_NULL = "system-seed-null"


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -134,7 +128,7 @@ class VolumeStructure:
offset_write: Optional[RelativeOffset] = named_field("offset-write", None)
size: int = 0
type: str = ""
role: str = Role.NONE
role: NonExhaustive[Role] = Role.NONE
id: Optional[str] = None
filesystem: str = ""
content: Optional[List[VolumeContent]] = None
Expand Down Expand Up @@ -232,6 +226,13 @@ class SystemActionRequest:
on_volumes: Dict[str, OnVolume] = named_field("on-volumes")


@attr.s(auto_attribs=True)
class SystemActionResponse:
encrypted_devices: Dict[NonExhaustive[Role], str] = named_field(
"encrypted-devices", default=attr.Factory(dict)
)


@api
class SnapdAPI:
serialize_query_args = False
Expand Down Expand Up @@ -313,14 +314,17 @@ async def make_request(method, path, *, params, json):
snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value")


async def post_and_wait(client, meth, *args, **kw):
async def post_and_wait(client, meth, *args, ann=None, **kw):
change_id = await meth(*args, **kw)
log.debug("post_and_wait %s", change_id)

while True:
result = await client.v2.changes[change_id].GET()
if result.status == TaskStatus.DONE:
return result.data
data = result.data
if ann is not None:
data = snapd_serializer.deserialize(ann, data)
return data
elif result.status == TaskStatus.ERROR:
raise aiohttp.ClientError(result.err)
await asyncio.sleep(0.1)