Skip to content

Commit 5e241d8

Browse files
authored
Revive util and add comprehensive tests. (#865)
* Add initial test for `util.py`. Adds a basic import test for the `comtypes.util` module to ensure it can be imported without errors. This serves as a foundational test for future enhancements. * Correct `PyCArgObject.tag` comparison to use bytes. Fixes a `RuntimeError` in `comtypes.util._calc_offset` that occurred due to a type mismatch when validating `PyCArgObject`. The `argobj.tag` field, which is a `c_char`, is represented as a `bytes` object in Python (e.g., `b'P'`). However, the validation logic was comparing it to a `str` literal (`"P"`). This comparison always failed, leading to an incorrect `RuntimeError` being raised. This fix changes the literal to `b"P"`, ensuring the comparison is performed correctly between two `bytes` objects. This resolves the underlying issue and allows `_calc_offset` to function as intended. * Add a simple test for `byref_at`. * Add comprehensive tests for `byref_at`. * Add comprehensive tests for `cast_field`.
1 parent 61c7cd2 commit 5e241d8

File tree

2 files changed

+228
-1
lines changed

2 files changed

+228
-1
lines changed

comtypes/test/test_util.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import unittest
2+
from ctypes import (
3+
POINTER,
4+
Structure,
5+
Union,
6+
addressof,
7+
c_byte,
8+
c_char,
9+
c_double,
10+
c_int,
11+
c_void_p,
12+
cast,
13+
sizeof,
14+
)
15+
16+
import comtypes.util
17+
from comtypes import GUID, CoCreateInstance, IUnknown, shelllink
18+
19+
20+
class ByrefAtTest(unittest.TestCase):
21+
def test_ctypes(self):
22+
for ctype, value in [
23+
(c_int, 42),
24+
(c_double, 3.14),
25+
(c_char, b"A"),
26+
]:
27+
with self.subTest(ctype=ctype, value=value):
28+
obj = ctype(value)
29+
# Test with zero offset - should point to the same location
30+
ref = comtypes.util.byref_at(obj, 0)
31+
ptr = cast(ref, POINTER(ctype))
32+
# byref objects don't have contents, but we can cast them to pointers
33+
self.assertEqual(addressof(ptr.contents), addressof(obj))
34+
self.assertEqual(ptr.contents.value, value)
35+
36+
def test_array_offsets(self):
37+
elms = [10, 20, 30, 40]
38+
arr = (c_int * 4)(*elms) # Create an array
39+
# Test accessing different elements via offset
40+
for i, expected in enumerate(elms):
41+
with self.subTest(index=i, expected=expected):
42+
ref = comtypes.util.byref_at(arr, offset=sizeof(c_int) * i)
43+
ptr = cast(ref, POINTER(c_int))
44+
self.assertEqual(ptr.contents.value, expected)
45+
46+
def test_pointer_arithmetic(self):
47+
# Test that byref_at behaves like C pointer arithmetic
48+
49+
class TestStruct(Structure):
50+
_fields_ = [
51+
("field1", c_int),
52+
("field2", c_double),
53+
("field3", c_char),
54+
]
55+
56+
struct = TestStruct(123, 3.14, b"X")
57+
for fname, ftype, expected in [
58+
("field1", c_int, 123),
59+
("field2", c_double, 3.14),
60+
("field3", c_char, b"X"),
61+
]:
62+
with self.subTest(field=fname, type=ftype, expected=expected):
63+
offset = getattr(TestStruct, fname).offset
64+
ref = comtypes.util.byref_at(struct, offset)
65+
ptr = cast(ref, POINTER(ftype))
66+
self.assertEqual(ptr.contents.value, expected)
67+
68+
def test_com_interface(self):
69+
CLSID_ShellLink = GUID("{00021401-0000-0000-C000-000000000046}")
70+
sc = CoCreateInstance(CLSID_ShellLink, interface=shelllink.IShellLinkA)
71+
ref = comtypes.util.byref_at(sc, 0)
72+
ptr = cast(ref, POINTER(POINTER(IUnknown)))
73+
self.assertEqual(addressof(ptr.contents), addressof(sc))
74+
75+
def test_large_offset(self):
76+
# Create a large array to test with large offsets
77+
arr = (c_int * 100)(*range(100))
78+
# Test accessing element at index 50 (offset = 50 * sizeof(c_int))
79+
offset = 50 * sizeof(c_int)
80+
ref = comtypes.util.byref_at(arr, offset)
81+
ptr = cast(ref, POINTER(c_int))
82+
self.assertEqual(ptr.contents.value, 50)
83+
84+
def test_memory_safety(self):
85+
for initial in [111, 222, 333, 444]:
86+
with self.subTest(initial=initial):
87+
obj = c_int(initial)
88+
ref = comtypes.util.byref_at(obj, 0)
89+
ptr = cast(ref, POINTER(c_int))
90+
# Verify initial value
91+
self.assertEqual(ptr.contents.value, initial)
92+
# Modify original objects and verify references still work
93+
obj.value = 333
94+
# Verify reference still works after modification
95+
self.assertEqual(ptr.contents.value, 333)
96+
97+
98+
class CastFieldTest(unittest.TestCase):
99+
def test_ctypes(self):
100+
class TestStruct(Structure):
101+
_fields_ = [
102+
("int_field", c_int),
103+
("double_field", c_double),
104+
("char_field", c_char),
105+
]
106+
107+
struct = TestStruct(42, 3.14, b"X")
108+
for fname, ftype, expected in [
109+
("int_field", c_int, 42),
110+
("double_field", c_double, 3.14),
111+
("char_field", c_char, b"X"),
112+
]:
113+
with self.subTest(fname=fname, ftype=ftype):
114+
actual = comtypes.util.cast_field(struct, fname, ftype)
115+
self.assertEqual(actual, expected)
116+
117+
def test_type_reinterpretation(self):
118+
class TestStruct(Structure):
119+
_fields_ = [
120+
("data", c_int),
121+
]
122+
123+
# Create struct with known bit pattern
124+
struct = TestStruct(0x41424344) # ASCII "ABCD" in little-endian
125+
# Cast the int field as a char array to see individual bytes
126+
char_value = comtypes.util.cast_field(struct, "data", c_char)
127+
# This should give us the first byte of the int
128+
self.assertIsInstance(char_value, bytes)
129+
130+
def test_pointers(self):
131+
class TestStruct(Structure):
132+
_fields_ = [
133+
("ptr_field", c_void_p),
134+
("int_field", c_int),
135+
]
136+
137+
target_int = c_int(99)
138+
struct = TestStruct(addressof(target_int), 123)
139+
for fname, ftype, expected in [
140+
("ptr_field", c_void_p, addressof(target_int)),
141+
("int_field", c_int, 123),
142+
]:
143+
with self.subTest(fname=fname, ftype=ftype, expected=expected):
144+
actual_value = comtypes.util.cast_field(struct, fname, ftype)
145+
self.assertEqual(actual_value, expected)
146+
147+
def test_nested_structures(self):
148+
class InnerStruct(Structure):
149+
_fields_ = [
150+
("inner_int", c_int),
151+
("inner_char", c_char),
152+
]
153+
154+
class OuterStruct(Structure):
155+
_fields_ = [
156+
("outer_int", c_int),
157+
("inner", InnerStruct),
158+
]
159+
160+
inner = InnerStruct(456, b"Y")
161+
outer = OuterStruct(789, inner)
162+
# Cast the nested structure field
163+
inner_value = comtypes.util.cast_field(outer, "inner", InnerStruct)
164+
self.assertEqual(inner_value.inner_int, 456)
165+
self.assertEqual(inner_value.inner_char, b"Y")
166+
# Cast outer int field
167+
outer_int = comtypes.util.cast_field(outer, "outer_int", c_int)
168+
self.assertEqual(outer_int, 789)
169+
170+
def test_arrays(self):
171+
class TestStruct(Structure):
172+
_fields_ = [
173+
("int_array", c_int * 3),
174+
("single_int", c_int),
175+
]
176+
177+
arr = (c_int * 3)(10, 20, 30)
178+
struct = TestStruct(arr, 40)
179+
# Cast array field as array type
180+
array_value = comtypes.util.cast_field(struct, "int_array", c_int * 3)
181+
self.assertEqual(list(array_value), [10, 20, 30])
182+
# Cast single int
183+
int_value = comtypes.util.cast_field(struct, "single_int", c_int)
184+
self.assertEqual(int_value, 40)
185+
186+
def test_union(self):
187+
class TestUnion(Union):
188+
_fields_ = [
189+
("as_int", c_int),
190+
("as_bytes", c_byte * 4),
191+
]
192+
193+
class TestStruct(Structure):
194+
_fields_ = [
195+
("union_field", TestUnion),
196+
("regular_field", c_int),
197+
]
198+
199+
union_val = TestUnion()
200+
union_val.as_int = 0x41424344 # "ABCD" in ASCII
201+
struct = TestStruct(union_val, 999)
202+
union_result = comtypes.util.cast_field(struct, "union_field", TestUnion)
203+
self.assertEqual(union_result.as_int, 0x41424344)
204+
int_result = comtypes.util.cast_field(struct, "regular_field", c_int)
205+
self.assertEqual(int_result, 999)
206+
207+
def test_void_p(self):
208+
class VTableLikeStruct(Structure):
209+
_fields_ = [
210+
("QueryInterface", c_void_p),
211+
("AddRef", c_void_p),
212+
("Release", c_void_p),
213+
("custom_method", c_void_p),
214+
]
215+
216+
# Initialize with some dummy pointers
217+
struct = VTableLikeStruct(0x1000, 0x2000, 0x3000, 0x4000)
218+
# Test accessing different entries
219+
for fname, expected in [
220+
("QueryInterface", 0x1000),
221+
("AddRef", 0x2000),
222+
("Release", 0x3000),
223+
("custom_method", 0x4000),
224+
]:
225+
with self.subTest(fname=fname, expected=expected):
226+
ptr_value = comtypes.util.cast_field(struct, fname, c_void_p)
227+
self.assertEqual(ptr_value, expected)

comtypes/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class value(Union):
4949

5050
argobj = PyCArgObject.from_address(id(ref))
5151

52-
if argobj.obj != id(obj) or argobj.p != addressof(obj) or argobj.tag != "P":
52+
if argobj.obj != id(obj) or argobj.p != addressof(obj) or argobj.tag != b"P":
5353
raise RuntimeError("PyCArgObject field definitions incorrect")
5454

5555
return PyCArgObject.p.offset # offset of the pointer field

0 commit comments

Comments
 (0)