Skip to content

Commit 323f6fd

Browse files
committed
MutableMapping -> Mapping as modifications are handled by merge().
1 parent 1eb859a commit 323f6fd

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

lbry/schema/attrs.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os.path
44
import hashlib
5-
from collections.abc import MutableMapping, Iterable
5+
from collections.abc import Mapping, Iterable
66
from typing import Tuple, List
77
from string import ascii_letters
88
from decimal import Decimal, ROUND_UP
@@ -24,7 +24,6 @@
2424
Location as LocationMessage,
2525
Language as LanguageMessage,
2626
)
27-
from google.protobuf.struct_pb2 import Struct as StructMessage
2827
from lbry.schema.types.v2.extension_pb2 import Extension as ExtensionMessage
2928

3029
log = logging.getLogger(__name__)
@@ -714,7 +713,7 @@ def merge(self, ext: 'StreamExtension', delete: bool = False) -> 'StreamExtensio
714713
self.unpacked.merge(ext.unpacked, delete=delete)
715714
return self
716715

717-
class Struct(Metadata, MutableMapping, Iterable):
716+
class Struct(Metadata, Mapping, Iterable):
718717
__slots__ = ()
719718

720719
def to_dict(self) -> dict:
@@ -756,23 +755,28 @@ def merge(self, other: 'Struct', delete: bool = False) -> 'Struct':
756755
return self
757756

758757
def __getitem__(self, key):
758+
def extract(val):
759+
if not isinstance(val, ProtobufMessage):
760+
return val
761+
kind = val.WhichOneof('kind')
762+
if kind == 'struct_value':
763+
return dict(Struct(val.struct_value))
764+
elif kind == 'list_value':
765+
return list(map(extract, val.list_value.values))
766+
else:
767+
return getattr(val, kind)
759768
if key in self.message.fields:
760-
return self.message.fields[key]
769+
val = self.message.fields[key]
770+
return extract(val)
761771
raise KeyError(key)
762772

763-
def __setitem__(self, key, value):
764-
self.message.fields[key].CopyFrom(value.message)
765-
766-
def __delitem__(self, key):
767-
del self.message.fields[key]
768-
769773
def __iter__(self):
770774
return iter(self.message.fields)
771775

772776
def __len__(self):
773777
return len(self.message.fields)
774778

775-
class StreamExtensionMap(Metadata, MutableMapping, Iterable):
779+
class StreamExtensionMap(Metadata, Mapping, Iterable):
776780
__slots__ = ()
777781
item_class = StreamExtension
778782

@@ -791,7 +795,7 @@ def merge(self, exts, delete: bool = False) -> 'StreamExtensionMap':
791795
else:
792796
obj.from_value({schema: ext})
793797
if delete and not len(obj.unpacked):
794-
del self[schema]
798+
del self.message[schema]
795799
continue
796800
existing = StreamExtension(schema, self.message[schema])
797801
existing.merge(obj, delete=delete)
@@ -802,15 +806,8 @@ def __getitem__(self, key):
802806
return StreamExtension(key, self.message[key])
803807
raise KeyError(key)
804808

805-
def __setitem__(self, key, value):
806-
del self.message[key]
807-
self.message[key].CopyFrom(value.message)
808-
809-
def __delitem__(self, key):
810-
del self.message[key]
811-
812809
def __iter__(self):
813-
return self.message.__iter__()
810+
return iter(self.message)
814811

815812
def __len__(self):
816813
return len(self.message)

tests/unit/schema/test_models.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44

55
from lbry.schema.claim import Claim, Stream, Collection
6-
from lbry.schema.attrs import StreamExtension
6+
from lbry.schema.attrs import StreamExtension, Struct
77
from google.protobuf.struct_pb2 import Struct as StructMessage
88
from lbry.schema.types.v2.extension_pb2 import Extension as ExtensionMessage
99
from lbry.error import InputValueIsNoneError
@@ -253,14 +253,17 @@ def setUp(self):
253253

254254
def test_extension_properties(self):
255255
self.maxDiff = None
256+
256257
# Verify schema.
257258
self.assertEqual(self.ext1.schema, 'cad')
258259
self.assertEqual(self.ext2.schema, 'music')
259260
self.assertEqual(self.ext3.schema, 'lit')
261+
260262
# Verify to_dict().
261263
self.assertEqual(self.ext1.to_dict(), self.ext1_dict)
262264
self.assertEqual(self.ext2.to_dict(), self.ext2_dict)
263265
self.assertEqual(self.ext3.to_dict(), self.ext3_dict)
266+
264267
# Decode from dict.
265268
parsed1 = StreamExtension(None, ExtensionMessage())
266269
parsed1.from_value(self.ext1_dict)
@@ -271,6 +274,7 @@ def test_extension_properties(self):
271274
parsed3 = StreamExtension(None, ExtensionMessage())
272275
parsed3.from_value(self.ext3_dict)
273276
self.assertEqual(parsed3.to_dict(), self.ext3_dict)
277+
274278
# Decode from str (JSON).
275279
parsed1 = StreamExtension(None, ExtensionMessage())
276280
parsed1.from_value(self.ext1_json)
@@ -282,6 +286,35 @@ def test_extension_properties(self):
282286
parsed3.from_value(self.ext3_json)
283287
self.assertEqual(parsed3.to_dict(), self.ext3_dict)
284288

289+
# Verify Mapping functionality.
290+
self.assertEqual(self.ext1.unpacked['material'], ['PLA1', 'PLA2'])
291+
self.assertEqual(self.ext1.unpacked['cubic_cm'], 5)
292+
self.assertEqual(self.ext2.unpacked['venue'], 'studio')
293+
self.assertEqual(self.ext2.unpacked['genre'], ['metal'])
294+
self.assertEqual(self.ext2.unpacked['instrument'], ['drum', 'cymbal', 'guitar'])
295+
self.assertEqual(self.ext3.unpacked['pages'], 185)
296+
self.assertEqual(self.ext3.unpacked['genre'], ['fiction', 'mystery'])
297+
self.assertEqual(self.ext3.unpacked['format'], 'epub')
298+
299+
# Verify Iterable functionality.
300+
self.assertEqual(len(self.ext1.unpacked), 2)
301+
for k, v in self.ext1.unpacked.items():
302+
self.assertIn(k, self.ext1.unpacked)
303+
self.assertTrue(isinstance(v, (str, list, float)), type(v))
304+
self.assertEqual(v, self.ext1.unpacked[k])
305+
self.assertEqual(len(self.ext2.unpacked), 3)
306+
for k, v in self.ext2.unpacked.items():
307+
self.assertIn(k, self.ext2.unpacked)
308+
self.assertTrue(isinstance(v, (str, list, float)), type(v))
309+
self.assertEqual(v, self.ext2.unpacked[k])
310+
self.assertEqual(len(self.ext3.unpacked), 3)
311+
for k, v in self.ext3.unpacked.items():
312+
self.assertIn(k, self.ext3.unpacked)
313+
self.assertTrue(isinstance(v, (str, list, float)), type(v))
314+
self.assertEqual(v, self.ext3.unpacked[k])
315+
316+
317+
285318
def test_extension_clear_field(self):
286319
self.maxDiff = None
287320
ext = StreamExtension(None, ExtensionMessage())

0 commit comments

Comments
 (0)