Skip to content

Commit 5706cfe

Browse files
Add unit test for saving extra modules (#8290)
* add test * better test
1 parent bc61653 commit 5706cfe

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

docs/docs/tutorials/saving/index.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ You can pick the suitable saving approach based on your needs.
9898

9999
### Serializing Imported Modules
100100

101-
When saving a program with `save_program=True`, you might need to include custom modules that your program depends on.
101+
When saving a program with `save_program=True`, you might need to include custom modules that your program depends on. This is
102+
necessary if your program depends on these modules, but at loading time these modules are not imported before calling `dspy.load`.
102103

103104
You can specify which custom modules should be serialized with your program by passing them to the `modules_to_serialize`
104105
parameter when calling `save`. This ensures that any dependencies your program relies on are included during serialization and
105106
available when loading the program later.
106107

107-
This uses cloudpickle's `cloudpickle.register_pickle_by_value` function in order to register a module as picklable by value. When
108-
a module is registered this way, cloudpickle will serialize the module by value rather than by reference, ensuring that the
108+
Under the hood this uses cloudpickle's `cloudpickle.register_pickle_by_value` function to register a module as picklable by value.
109+
When a module is registered this way, cloudpickle will serialize the module by value rather than by reference, ensuring that the
109110
module contents are preserved with the saved program.
110111

111112
For example, if your program uses custom modules:

tests/primitives/test_module.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,65 @@ def dummy_metric(example, pred, trace=None):
115115
assert new_cot.predict.demos == compiled_cot.predict.demos
116116

117117

118+
def test_save_with_extra_modules(tmp_path):
119+
import sys
120+
121+
# Create a temporary Python file with our custom module
122+
custom_module_path = tmp_path / "custom_module.py"
123+
with open(custom_module_path, "w") as f:
124+
f.write("""
125+
import dspy
126+
127+
class MyModule(dspy.Module):
128+
def __init__(self):
129+
self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))
130+
131+
def forward(self, q):
132+
return self.cot(q=q)
133+
""")
134+
135+
# Add the tmp_path to Python path so we can import the module
136+
sys.path.insert(0, str(tmp_path))
137+
try:
138+
import custom_module
139+
140+
cot = custom_module.MyModule()
141+
142+
cot.save(tmp_path, save_program=True)
143+
# Remove the custom module from sys.modules to simulate it not being available
144+
sys.modules.pop("custom_module", None)
145+
# Also remove it from sys.path
146+
sys.path.remove(str(tmp_path))
147+
del custom_module
148+
149+
# Test the loading fails without using `modules_to_serialize`
150+
with pytest.raises(ModuleNotFoundError):
151+
dspy.load(tmp_path)
152+
153+
sys.path.insert(0, str(tmp_path))
154+
import custom_module
155+
156+
cot.save(
157+
tmp_path,
158+
modules_to_serialize=[custom_module],
159+
save_program=True,
160+
)
161+
162+
# Remove the custom module from sys.modules to simulate it not being available
163+
sys.modules.pop("custom_module", None)
164+
# Also remove it from sys.path
165+
sys.path.remove(str(tmp_path))
166+
del custom_module
167+
168+
loaded_module = dspy.load(tmp_path)
169+
assert loaded_module.cot.predict.signature == cot.cot.predict.signature
170+
171+
finally:
172+
# Only need to clean up sys.path
173+
if str(tmp_path) in sys.path:
174+
sys.path.remove(str(tmp_path))
175+
176+
118177
def test_load_with_version_mismatch(tmp_path):
119178
from dspy.primitives.module import logger
120179

0 commit comments

Comments
 (0)