@@ -74,25 +74,26 @@ def register_callback(self, callback: Callback):
7474 self .callbacks .append (callback )
7575
7676 def _register_callback_from_config (self ):
77- local_dir = self .config .local_dir
78- for _callback in self .config .callbacks :
79- if _callback .endswith ('.py' ):
80- if not self .trust_remote_code :
81- raise AssertionError (f'Your config file contains external code, '
82- f'instantiate the code may be UNSAFE, if you trust the code, '
83- f'please pass `trust_remote_code=True` or `--trust_remote_code true`' )
84- if sys .path [0 ] != local_dir :
85- assert local_dir is not None , 'Using external py files, but local_dir cannot be found.'
86- sys .path .insert (0 , local_dir )
87- callback_file = importlib .import_module (_callback [:- 3 ])
88- module_classes = {name : cls for name , cls in inspect .getmembers (callback_file , inspect .isclass )}
89- for name , cls in module_classes .items ():
90- # Find cls which base class is `Callback`
91- if cls .__base__ [0 ] is Callback :
92- self .callbacks .append (cls ())
93- else :
94- assert _callback in callbacks_mapping
95- self .callbacks .append (callbacks_mapping [_callback ]())
77+ local_dir = self .config .local_dir if hasattr (self .config , 'local_dir' ) else None
78+ if hasattr (self .config , 'callbacks' ):
79+ for _callback in self .config .callbacks :
80+ if _callback .endswith ('.py' ):
81+ if not self .trust_remote_code :
82+ raise AssertionError (f'Your config file contains external code, '
83+ f'instantiate the code may be UNSAFE, if you trust the code, '
84+ f'please pass `trust_remote_code=True` or `--trust_remote_code true`' )
85+ if sys .path [0 ] != local_dir :
86+ assert local_dir is not None , 'Using external py files, but local_dir cannot be found.'
87+ sys .path .insert (0 , local_dir )
88+ callback_file = importlib .import_module (_callback [:- 3 ])
89+ module_classes = {name : cls for name , cls in inspect .getmembers (callback_file , inspect .isclass )}
90+ for name , cls in module_classes .items ():
91+ # Find cls which base class is `Callback`
92+ if cls .__base__ [0 ] is Callback :
93+ self .callbacks .append (cls ())
94+ else :
95+ assert _callback in callbacks_mapping
96+ self .callbacks .append (callbacks_mapping [_callback ]())
9697
9798 def _loop_callback (self , point , messages : List [Message ]):
9899 for callback in self .callbacks :
@@ -120,18 +121,18 @@ def _prepare_messages(self, prompt):
120121 {'role' : 'system' , 'content' : self .config .prompt .system or self .DEFAULT_SYSTEM_EN },
121122 {'role' : 'user' , 'content' : prompt or self .config .prompt .query },
122123 ]
123- messages ['query ' ] = self ._query_documents (messages [1 ]['content' ])
124+ messages [1 ][ 'content ' ] = self ._query_documents (messages [1 ]['content' ])
124125 return messages
125126
126127 def _prepare_memory (self ):
127- if self .config .memory :
128+ if hasattr ( self . config , 'memory' ) and self .config .memory :
128129 for _memory in self .config .memory :
129130 assert _memory in memory_mapping , (f'{ _memory } not in memory_mapping, '
130131 f'which supports: { list (memory_mapping .keys ())} ' )
131132 self .memory_tools .append (memory_mapping [_memory ]())
132133
133134 def _prepare_rag (self ):
134- if self .config .rag :
135+ if hasattr ( self . config , 'rag' ) and self .config .rag :
135136 assert self .config .rag in rag_mapping
136137 self .rag : Rag = rag_mapping (self .config .rag )()
137138
0 commit comments