Skip to content

Commit 1f5c5d5

Browse files
authored
Enhance memory management features and IStorage/IStream tests. (#901)
* test: Validate `pwcsName` in `IStorage.Stat` for created files Ensures that the `pwcsName` field of the `tagSTATSTG` structure, returned by `IStorage.Stat`, correctly reflects the file path used during the creation of a compound file. This enhances the `test_Stat` in `test_storage.py` by verifying the consistency of the storage's reported name with its physical location. * test: Verify `tagSTATSTG.pwcsName` memory management in `IStorage` and `IStream` tests. * test: Validate `FILETIME` in `IStorage.Stat` test. Added validation for `ctime`, `atime`, and `mtime` `FILETIME` timestamps within the `tagSTATSTG` structure in `test_storage.py`. * test: Add `IStorage.SetElementTimes` functionality test. Introduced `test_SetElementTimes` in `test_storage.py` to verify the functionality of the `IStorage.SetElementTimes` method. * refactor: Rename `statstg` to `stat` in `test_stream.py`. * refactor: Improve `SIZE_T` import in `malloc.py`. Directly import `c_size_t` as `SIZE_T` from `ctypes` to remove an unnecessary alias assignment, streamlining type definition. * refactor: Centralize `IMalloc` retrieval with `CoGetMalloc` utility function. Introduced a `CoGetMalloc` utility function in `comtypes/malloc.py` to encapsulate the low-level `_CoGetMalloc` API call. This new function streamlines the acquisition of the OLE task memory allocator by standardizing the `dwMemContext` parameter to `1`, which is the only supported value for `CoGetMalloc`. * fix: Remove `DidAlloc` checks on freed memory in `STATSTG` tests. Removed assertions for `IMalloc.DidAlloc` after `del stat` in `test_storage.py` and `test_stream.py`.
1 parent 8c3bdec commit 1f5c5d5

File tree

5 files changed

+157
-32
lines changed

5 files changed

+157
-32
lines changed

comtypes/malloc.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from ctypes import HRESULT, POINTER, OleDLL, WinDLL, c_int, c_size_t, c_ulong, c_void_p
1+
from ctypes import HRESULT, POINTER, OleDLL, WinDLL, byref, c_int, c_ulong, c_void_p
2+
from ctypes import c_size_t as SIZE_T
23
from ctypes.wintypes import DWORD, LPVOID
34
from typing import TYPE_CHECKING, Any, Optional
45

@@ -34,7 +35,19 @@ def HeapMinimize(self) -> None: ...
3435

3536
_ole32_nohresult = WinDLL("ole32")
3637

37-
SIZE_T = c_size_t
3838
_CoTaskMemAlloc = _ole32_nohresult.CoTaskMemAlloc
3939
_CoTaskMemAlloc.argtypes = [SIZE_T]
4040
_CoTaskMemAlloc.restype = LPVOID
41+
42+
43+
def CoGetMalloc(dwMemContext: int = 1) -> IMalloc:
44+
"""Retrieves a pointer to the default OLE task memory allocator.
45+
46+
https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cogetmalloc
47+
"""
48+
malloc = POINTER(IMalloc)()
49+
_CoGetMalloc(
50+
dwMemContext, # This parameter must be 1.
51+
byref(malloc),
52+
)
53+
return malloc # type: ignore

comtypes/test/test_malloc.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55

66
from comtypes import GUID, hresult
7-
from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemFree
7+
from comtypes.malloc import CoGetMalloc, _CoTaskMemFree
88

99
# Constants
1010
# KNOWNFOLDERID
@@ -24,16 +24,9 @@
2424
_SHGetKnownFolderPath.restype = HRESULT
2525

2626

27-
def _get_malloc() -> IMalloc:
28-
malloc = POINTER(IMalloc)()
29-
_CoGetMalloc(1, byref(malloc))
30-
assert bool(malloc)
31-
return malloc # type: ignore
32-
33-
3427
class Test(ut.TestCase):
3528
def test_Realloc(self):
36-
malloc = _get_malloc()
29+
malloc = CoGetMalloc()
3730
size1 = 4
3831
ptr1 = malloc.Alloc(size1)
3932
self.assertEqual(malloc.DidAlloc(ptr1), 1)
@@ -59,7 +52,7 @@ def test_SHGetKnownFolderPath(self):
5952
self.assertEqual(hr, hresult.S_OK)
6053
self.assertIsInstance(ptr.value, str)
6154
self.assertTrue(Path(ptr.value).exists()) # type: ignore
62-
malloc = _get_malloc()
55+
malloc = CoGetMalloc()
6356
self.assertEqual(malloc.DidAlloc(ptr), 1)
6457
self.assertGreater(malloc.GetSize(ptr), 0)
6558
_CoTaskMemFree(ptr)

comtypes/test/test_outparam.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import logging
22
import unittest
3-
from ctypes import POINTER, byref, c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at
3+
from ctypes import c_wchar, c_wchar_p, cast, memmove, sizeof, wstring_at
44
from unittest.mock import patch
55

6-
from comtypes.malloc import IMalloc, _CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree
6+
from comtypes.malloc import CoGetMalloc, _CoTaskMemAlloc, _CoTaskMemFree
77

88
logger = logging.getLogger(__name__)
99

1010

11-
malloc = POINTER(IMalloc)()
12-
_CoGetMalloc(1, byref(malloc))
11+
malloc = CoGetMalloc()
1312
assert bool(malloc)
1413

1514

comtypes/test/test_storage.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,45 @@
1+
import ctypes
2+
import os
13
import tempfile
24
import unittest
35
from _ctypes import COMError
4-
from ctypes import HRESULT, POINTER, OleDLL, byref, c_ubyte
5-
from ctypes.wintypes import DWORD, PWCHAR
6+
from ctypes import HRESULT, POINTER, OleDLL, Structure, WinDLL, byref, c_ubyte
7+
from ctypes.wintypes import BOOL, DWORD, FILETIME, LONG, PWCHAR, WORD
68
from pathlib import Path
79
from typing import Optional
810

911
import comtypes
1012
import comtypes.client
13+
from comtypes.malloc import CoGetMalloc
1114

1215
comtypes.client.GetModule("portabledeviceapi.dll")
13-
from comtypes.gen.PortableDeviceApiLib import IStorage, tagSTATSTG
16+
from comtypes.gen.PortableDeviceApiLib import WSTRING, IStorage, tagSTATSTG
17+
18+
19+
class SYSTEMTIME(Structure):
20+
_fields_ = [
21+
("wYear", WORD),
22+
("wMonth", WORD),
23+
("wDayOfWeek", WORD),
24+
("wDay", WORD),
25+
("wHour", WORD),
26+
("wMinute", WORD),
27+
("wSecond", WORD),
28+
("wMilliseconds", WORD),
29+
]
30+
31+
32+
_kernel32 = WinDLL("kernel32")
33+
34+
# https://learn.microsoft.com/en-us/windows/win32/api/timezoneapi/nf-timezoneapi-systemtimetofiletime
35+
_SystemTimeToFileTime = _kernel32.SystemTimeToFileTime
36+
_SystemTimeToFileTime.argtypes = [POINTER(SYSTEMTIME), POINTER(FILETIME)]
37+
_SystemTimeToFileTime.restype = BOOL
38+
39+
# https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-comparefiletime
40+
_CompareFileTime = _kernel32.CompareFileTime
41+
_CompareFileTime.argtypes = [POINTER(FILETIME), POINTER(FILETIME)]
42+
_CompareFileTime.restype = LONG
1443

1544
STGTY_STORAGE = 1
1645

@@ -36,6 +65,20 @@
3665
_StgCreateDocfile.restype = HRESULT
3766

3867

68+
def _systemtime_to_filetime(st: SYSTEMTIME) -> FILETIME:
69+
ft = FILETIME()
70+
_SystemTimeToFileTime(byref(st), byref(ft))
71+
return ft
72+
73+
74+
def _compare_filetime(ft1: FILETIME, ft2: FILETIME) -> int:
75+
return _CompareFileTime(byref(ft1), byref(ft2))
76+
77+
78+
def _get_pwcsname(stat: tagSTATSTG) -> WSTRING:
79+
return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset)
80+
81+
3982
class Test_IStorage(unittest.TestCase):
4083
RW_EXCLUSIVE = STGM_READWRITE | STGM_SHARE_EXCLUSIVE
4184
RW_EXCLUSIVE_TX = RW_EXCLUSIVE | STGM_TRANSACTED
@@ -48,12 +91,17 @@ def _create_docfile(self, mode: int, name: Optional[str] = None) -> IStorage:
4891
_StgCreateDocfile(name, mode, 0, byref(stg))
4992
return stg # type: ignore
5093

94+
FIXED_TEST_FILETIME = _systemtime_to_filetime(
95+
SYSTEMTIME(wYear=2000, wMonth=1, wDay=1)
96+
)
97+
5198
def test_CreateStream(self):
5299
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
53100
# When created with `StgCreateDocfile(NULL, ...)`, `pwcsName` is a
54101
# temporary filename. The file really exists on disk because Windows
55102
# creates an actual temporary file for the compound storage.
56-
filepath = Path(storage.Stat(STATFLAG_DEFAULT).pwcsName)
103+
stat = storage.Stat(STATFLAG_DEFAULT)
104+
filepath = Path(stat.pwcsName)
57105
self.assertTrue(filepath.exists())
58106
stream = storage.CreateStream("example", self.RW_EXCLUSIVE_CREATE, 0, 0)
59107
test_data = b"Some data"
@@ -67,6 +115,12 @@ def test_CreateStream(self):
67115
self.assertTrue(filepath.exists())
68116
del storage
69117
self.assertFalse(filepath.exists())
118+
name_ptr = _get_pwcsname(stat)
119+
self.assertEqual(name_ptr.value, stat.pwcsName)
120+
malloc = CoGetMalloc()
121+
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
122+
del stat # `pwcsName` is expected to be freed here.
123+
# `DidAlloc` checks are skipped to avoid using a dangling pointer.
70124

71125
# TODO: Auto-generated methods based on type info are remote-side and hard
72126
# to call from the client.
@@ -148,6 +202,32 @@ def test_RenameElement(self):
148202
storage.OpenStorage("example", None, self.RW_EXCLUSIVE_TX, None, 0)
149203
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)
150204

205+
def test_SetElementTimes(self):
206+
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
207+
sub_name = "SubStorageElement"
208+
orig_stat = storage.CreateStorage(sub_name, self.CREATE_TESTDOC, 0, 0).Stat(
209+
STATFLAG_DEFAULT
210+
)
211+
storage.SetElementTimes(
212+
sub_name,
213+
None, # pctime (creation time)
214+
None, # patime (access time)
215+
self.FIXED_TEST_FILETIME, # pmtime (modification time)
216+
)
217+
storage.Commit(STGC_DEFAULT)
218+
modified_stat = storage.OpenStorage(
219+
sub_name, None, self.RW_EXCLUSIVE_TX, None, 0
220+
).Stat(STATFLAG_DEFAULT)
221+
self.assertEqual(_compare_filetime(orig_stat.ctime, modified_stat.ctime), 0)
222+
self.assertEqual(_compare_filetime(orig_stat.atime, modified_stat.atime), 0)
223+
self.assertNotEqual(_compare_filetime(orig_stat.mtime, modified_stat.mtime), 0)
224+
self.assertEqual(
225+
_compare_filetime(self.FIXED_TEST_FILETIME, modified_stat.mtime), 0
226+
)
227+
with self.assertRaises(COMError) as cm:
228+
storage.SetElementTimes("NonExistent", None, None, self.FIXED_TEST_FILETIME)
229+
self.assertEqual(cm.exception.hresult, STG_E_PATHNOTFOUND)
230+
151231
def test_SetClass(self):
152232
storage = self._create_docfile(mode=self.CREATE_TEMP_TESTDOC)
153233
# Initial value is CLSID_NULL.
@@ -176,7 +256,23 @@ def test_Stat(self):
176256
stat = storage.Stat(STATFLAG_DEFAULT)
177257
self.assertIsInstance(stat, tagSTATSTG)
178258
del storage # Release the storage to prevent 'cannot access the file ...'
259+
# Validate each field:
260+
self.assertEqual(
261+
os.path.normcase(os.path.normpath(Path(stat.pwcsName))),
262+
os.path.normcase(os.path.normpath(tmpfile)),
263+
)
179264
self.assertEqual(stat.type, STGTY_STORAGE)
265+
# Timestamps (`mtime`, `ctime`, `atime`) are set by the underlying
266+
# compound file implementation.
267+
# In many cases (especially on modern Windows with NTFS), all three
268+
# timestamps are set to the same value at creation time. However, this
269+
# is not guaranteed by the OLE32 specification.
270+
# Therefore, we only verify that each timestamp is a valid `FILETIME`
271+
# (non-zero is sufficient for a newly created file).
272+
zero_ft = FILETIME()
273+
self.assertNotEqual(_compare_filetime(stat.ctime, zero_ft), 0)
274+
self.assertNotEqual(_compare_filetime(stat.atime, zero_ft), 0)
275+
self.assertNotEqual(_compare_filetime(stat.mtime, zero_ft), 0)
180276
# Due to header overhead and file system allocation, the size may be
181277
# greater than 0 bytes.
182278
self.assertGreaterEqual(stat.cbSize, 0)
@@ -185,3 +281,9 @@ def test_Stat(self):
185281
self.assertEqual(stat.grfLocksSupported, 0)
186282
self.assertEqual(stat.clsid, comtypes.GUID()) # CLSID_NULL for new creation.
187283
self.assertEqual(stat.grfStateBits, 0)
284+
name_ptr = _get_pwcsname(stat)
285+
self.assertEqual(name_ptr.value, stat.pwcsName)
286+
malloc = CoGetMalloc()
287+
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
288+
del stat # `pwcsName` is expected to be freed here.
289+
# `DidAlloc` checks are skipped to avoid using a dangling pointer.

comtypes/test/test_stream.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040

4141
import comtypes.client
4242
from comtypes import hresult
43+
from comtypes.malloc import CoGetMalloc
4344

4445
comtypes.client.GetModule("portabledeviceapi.dll")
4546
# The stdole module is generated automatically during the portabledeviceapi
4647
# module generation.
4748
import comtypes.gen.stdole as stdole
48-
from comtypes.gen.PortableDeviceApiLib import IStream
49+
from comtypes.gen.PortableDeviceApiLib import WSTRING, IStream, tagSTATSTG
4950

5051
SIZE_T = c_size_t
5152

@@ -110,6 +111,10 @@ def _create_stream_on_file(
110111
return stream # type: ignore
111112

112113

114+
def _get_pwcsname(stat: tagSTATSTG) -> WSTRING:
115+
return WSTRING.from_address(ctypes.addressof(stat) + tagSTATSTG.pwcsName.offset)
116+
117+
113118
class Test_RemoteWrite(ut.TestCase):
114119
def test_RemoteWrite(self):
115120
stream = _create_stream_on_hglobal()
@@ -206,19 +211,25 @@ def test_RemoteCopyTo(self):
206211
class Test_Stat(ut.TestCase):
207212
# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-istream-stat
208213
# https://learn.microsoft.com/en-us/windows/win32/api/objidl/ns-objidl-statstg
209-
def test_returns_statstg_from_no_modified_stream(self):
214+
def test_returns_stat_from_no_modified_stream(self):
210215
stream = _create_stream_on_hglobal()
211-
statstg = stream.Stat(STATFLAG_DEFAULT)
212-
self.assertIsNone(statstg.pwcsName)
213-
self.assertEqual(statstg.type, STGTY_STREAM)
214-
self.assertEqual(statstg.cbSize, 0)
215-
mt, ct, at = statstg.mtime, statstg.ctime, statstg.atime
216+
stat = stream.Stat(STATFLAG_DEFAULT)
217+
self.assertIsNone(stat.pwcsName)
218+
self.assertEqual(stat.type, STGTY_STREAM)
219+
self.assertEqual(stat.cbSize, 0)
220+
mt, ct, at = stat.mtime, stat.ctime, stat.atime
216221
self.assertTrue(mt.dwLowDateTime == ct.dwLowDateTime == at.dwLowDateTime)
217222
self.assertTrue(mt.dwHighDateTime == ct.dwHighDateTime == at.dwHighDateTime)
218-
self.assertEqual(statstg.grfMode, 0)
219-
self.assertEqual(statstg.grfLocksSupported, 0)
220-
self.assertEqual(statstg.clsid, comtypes.GUID())
221-
self.assertEqual(statstg.grfStateBits, 0)
223+
self.assertEqual(stat.grfMode, 0)
224+
self.assertEqual(stat.grfLocksSupported, 0)
225+
self.assertEqual(stat.clsid, comtypes.GUID())
226+
self.assertEqual(stat.grfStateBits, 0)
227+
name_ptr = _get_pwcsname(stat)
228+
self.assertIsNone(name_ptr.value)
229+
malloc = CoGetMalloc()
230+
self.assertEqual(malloc.DidAlloc(name_ptr), -1)
231+
del stat # `pwcsName` is expected to be freed here.
232+
# `DidAlloc` checks are skipped to avoid using a dangling pointer.
222233

223234

224235
class Test_Clone(ut.TestCase):
@@ -274,11 +285,18 @@ def test_can_lock_file_based_stream(self):
274285
# Cleanup: Close descriptors and release the lock
275286
os.close(fd)
276287
stm.UnlockRegion(0, 5, LOCK_EXCLUSIVE)
277-
buf, read = stm.RemoteRead(stm.Stat(STATFLAG_DEFAULT).cbSize)
288+
stat = stm.Stat(STATFLAG_DEFAULT)
289+
buf, read = stm.RemoteRead(stat.cbSize)
278290
# Verify that COM stream content reflects the successful out-of-lock write
279291
self.assertEqual(bytearray(buf)[0:read], b"\x00\x00\x00\x00\x00ABCDE")
280292
# Verify that the actual file content on disk matches the expected data
281293
self.assertEqual(tmpfile.read_bytes(), b"\x00\x00\x00\x00\x00ABCDE")
294+
name_ptr = _get_pwcsname(stat)
295+
self.assertEqual(name_ptr.value, stat.pwcsName)
296+
malloc = CoGetMalloc()
297+
self.assertEqual(malloc.DidAlloc(name_ptr), 1)
298+
del stat # `pwcsName` is expected to be freed here.
299+
# `DidAlloc` checks are skipped to avoid using a dangling pointer.
282300

283301

284302
# TODO: If there is a standard Windows `IStream` implementation that supports

0 commit comments

Comments
 (0)