@@ -99,16 +99,22 @@ def get_class_name(wrapper_name):
9999 kwargs = wrapper_dict [wrapper_name ]
100100 else :
101101 kwargs = {}
102- wrapper_module = importlib .import_module (get_module_name (wrapper_name ))
103- wrapper_class = getattr (wrapper_module , get_class_name (wrapper_name ))
102+
103+ if isinstance (wrapper_name , str ):
104+ wrapper_module = importlib .import_module (get_module_name (wrapper_name ))
105+ wrapper_class = getattr (wrapper_module , get_class_name (wrapper_name ))
106+ elif isinstance (wrapper_name , type ):
107+ # No conversion needed
108+ wrapper_class = wrapper_name
109+ else :
110+ raise ValueError (
111+ f"Unexpected value { wrapper_name } for a { key } , must a str and a class, not { type (wrapper_name )} "
112+ )
113+
104114 wrapper_classes .append (wrapper_class )
105115 wrapper_kwargs .append (kwargs )
106116
107117 def wrap_env (env : gym .Env ) -> gym .Env :
108- """
109- :param env:
110- :return:
111- """
112118 for wrapper_class , kwargs in zip (wrapper_classes , wrapper_kwargs ):
113119 env = wrapper_class (env , ** kwargs )
114120 return env
@@ -183,8 +189,12 @@ def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
183189 else :
184190 kwargs = {}
185191
186- callback_class = get_class_by_name (callback_name )
187- callbacks .append (callback_class (** kwargs ))
192+ if isinstance (callback_name , BaseCallback ):
193+ # No conversion needed
194+ callbacks .append (callback_name )
195+ else :
196+ callback_class = get_class_by_name (callback_name )
197+ callbacks .append (callback_class (** kwargs ))
188198
189199 return callbacks
190200
0 commit comments