Skip to content

Commit 593f062

Browse files
authored
PdoBase: Forbid getting item with zero index (fixes #581) (#609)
* Raise KeyError for zero index. * Properly test different access type return values. Augment test_pdo_getitem() to not only check the mapped object values. What's more important is the type of object returned, and whether it is the correct object (identical to other access method results).
1 parent 54ac5c2 commit 593f062

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

canopen/pdo/base.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,22 @@ def __init__(self, node: Union[LocalNode, RemoteNode]):
3939
def __iter__(self):
4040
return iter(self.map)
4141

42-
def __getitem__(self, key):
43-
if isinstance(key, int) and (0x1A00 <= key <= 0x1BFF or # By TPDO ID (512)
44-
0x1600 <= key <= 0x17FF or # By RPDO ID (512)
45-
0 < key <= 512): # By PDO Index
46-
return self.map[key]
47-
else:
48-
for pdo_map in self.map.values():
49-
try:
50-
return pdo_map[key]
51-
except KeyError:
52-
# ignore if one specific PDO does not have the key and try the next one
53-
continue
42+
def __getitem__(self, key: Union[int, str]):
43+
if isinstance(key, int):
44+
if key == 0:
45+
raise KeyError("PDO index zero requested for 1-based sequence")
46+
if (
47+
0 < key <= 512 # By PDO Index
48+
or 0x1600 <= key <= 0x17FF # By RPDO ID (512)
49+
or 0x1A00 <= key <= 0x1BFF # By TPDO ID (512)
50+
):
51+
return self.map[key]
52+
for pdo_map in self.map.values():
53+
try:
54+
return pdo_map[key]
55+
except KeyError:
56+
# ignore if one specific PDO does not have the key and try the next one
57+
continue
5458
raise KeyError(f"PDO: {key} was not found in any map")
5559

5660
def __len__(self):

test/test_pdo.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,41 @@ def test_pdo_getitem(self):
5050
self.assertEqual(node.tpdo[1]['BOOLEAN value 2'].raw, True)
5151

5252
# Test different types of access
53-
self.assertEqual(node.pdo[0x1600]['INTEGER16 value'].raw, -3)
54-
self.assertEqual(node.pdo['INTEGER16 value'].raw, -3)
55-
self.assertEqual(node.pdo.tx[1]['INTEGER16 value'].raw, -3)
56-
self.assertEqual(node.pdo[0x2001].raw, -3)
57-
self.assertEqual(node.tpdo[0x2001].raw, -3)
58-
self.assertEqual(node.pdo[0x2002].raw, 0xf)
59-
self.assertEqual(node.pdo['0x2002'].raw, 0xf)
60-
self.assertEqual(node.tpdo[0x2002].raw, 0xf)
61-
self.assertEqual(node.pdo[0x1600][0x2002].raw, 0xf)
53+
by_mapping_record = node.pdo[0x1600]
54+
self.assertIsInstance(by_mapping_record, canopen.pdo.PdoMap)
55+
self.assertEqual(by_mapping_record['INTEGER16 value'].raw, -3)
56+
by_object_name = node.pdo['INTEGER16 value']
57+
self.assertIsInstance(by_object_name, canopen.pdo.PdoVariable)
58+
self.assertIs(by_object_name.od, node.object_dictionary['INTEGER16 value'])
59+
self.assertEqual(by_object_name.raw, -3)
60+
by_pdo_index = node.pdo.tx[1]
61+
self.assertIs(by_pdo_index, by_mapping_record)
62+
by_object_index = node.pdo[0x2001]
63+
self.assertIsInstance(by_object_index, canopen.pdo.PdoVariable)
64+
self.assertIs(by_object_index, by_object_name)
65+
by_object_index_tpdo = node.tpdo[0x2001]
66+
self.assertIs(by_object_index_tpdo, by_object_name)
67+
by_object_index = node.pdo[0x2002]
68+
self.assertEqual(by_object_index.raw, 0xf)
69+
self.assertIs(node.pdo['0x2002'], by_object_index)
70+
self.assertIs(node.tpdo[0x2002], by_object_index)
71+
self.assertIs(node.pdo[0x1600][0x2002], by_object_index)
72+
73+
self.assertRaises(KeyError, lambda: node.pdo[0])
74+
self.assertRaises(KeyError, lambda: node.tpdo[0])
75+
self.assertRaises(KeyError, lambda: node.pdo['DOES NOT EXIST'])
76+
self.assertRaises(KeyError, lambda: node.pdo[0x1BFF])
77+
self.assertRaises(KeyError, lambda: node.tpdo[0x1BFF])
78+
79+
def test_pdo_maps_iterate(self):
80+
node = self.node
81+
self.assertEqual(len(node.pdo), sum(1 for _ in node.pdo))
82+
self.assertEqual(len(node.tpdo), sum(1 for _ in node.tpdo))
83+
self.assertEqual(len(node.rpdo), sum(1 for _ in node.rpdo))
84+
self.assertEqual(len(node.rpdo) + len(node.tpdo), len(node.pdo))
85+
86+
pdo = node.tpdo[1]
87+
self.assertEqual(len(pdo), sum(1 for _ in pdo))
6288

6389
def test_pdo_save(self):
6490
self.node.tpdo.save()

0 commit comments

Comments
 (0)