@@ -115,6 +115,65 @@ def dummy_metric(example, pred, trace=None):
115
115
assert new_cot .predict .demos == compiled_cot .predict .demos
116
116
117
117
118
+ def test_save_with_extra_modules (tmp_path ):
119
+ import sys
120
+
121
+ # Create a temporary Python file with our custom module
122
+ custom_module_path = tmp_path / "custom_module.py"
123
+ with open (custom_module_path , "w" ) as f :
124
+ f .write ("""
125
+ import dspy
126
+
127
+ class MyModule(dspy.Module):
128
+ def __init__(self):
129
+ self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))
130
+
131
+ def forward(self, q):
132
+ return self.cot(q=q)
133
+ """ )
134
+
135
+ # Add the tmp_path to Python path so we can import the module
136
+ sys .path .insert (0 , str (tmp_path ))
137
+ try :
138
+ import custom_module
139
+
140
+ cot = custom_module .MyModule ()
141
+
142
+ cot .save (tmp_path , save_program = True )
143
+ # Remove the custom module from sys.modules to simulate it not being available
144
+ sys .modules .pop ("custom_module" , None )
145
+ # Also remove it from sys.path
146
+ sys .path .remove (str (tmp_path ))
147
+ del custom_module
148
+
149
+ # Test the loading fails without using `modules_to_serialize`
150
+ with pytest .raises (ModuleNotFoundError ):
151
+ dspy .load (tmp_path )
152
+
153
+ sys .path .insert (0 , str (tmp_path ))
154
+ import custom_module
155
+
156
+ cot .save (
157
+ tmp_path ,
158
+ modules_to_serialize = [custom_module ],
159
+ save_program = True ,
160
+ )
161
+
162
+ # Remove the custom module from sys.modules to simulate it not being available
163
+ sys .modules .pop ("custom_module" , None )
164
+ # Also remove it from sys.path
165
+ sys .path .remove (str (tmp_path ))
166
+ del custom_module
167
+
168
+ loaded_module = dspy .load (tmp_path )
169
+ assert loaded_module .cot .predict .signature == cot .cot .predict .signature
170
+
171
+ finally :
172
+ # Only need to clean up sys.path
173
+ if str (tmp_path ) in sys .path :
174
+ sys .path .remove (str (tmp_path ))
175
+
176
+
118
177
def test_load_with_version_mismatch (tmp_path ):
119
178
from dspy .primitives .module import logger
120
179
0 commit comments