@@ -188,6 +188,64 @@ def get_function_name(function):
188188 return function .__name__
189189
190190
191+ def get_default_value_for_repr (value ):
192+ """Return a substitute for rendering the default value of a funciton arg.
193+
194+ Function and object instances are rendered as <Foo object at 0x00000000>
195+ which can't be parsed by black. We substitute functions with the function
196+ name and objects with a rendered version of the constructor like
197+ `Foo(a=2, b="bar")`.
198+
199+ Args:
200+ value: The value to find a better rendering of.
201+
202+ Returns:
203+ Another value or `None` if no substitution is needed.
204+ """
205+
206+ class ReprWrapper :
207+ def __init__ (self , representation ):
208+ self .representation = representation
209+
210+ def __repr__ (self ):
211+ return self .representation
212+
213+ if value is inspect ._empty :
214+ return None
215+
216+ if inspect .isfunction (value ):
217+ # Render the function name instead
218+ return ReprWrapper (value .__name__ )
219+
220+ if (
221+ repr (value ).startswith ("<" ) # <Foo object at 0x00000000>
222+ and hasattr (value , "__class__" ) # it is an object
223+ and hasattr (value , "get_config" ) # it is a Keras object
224+ ):
225+ config = value .get_config ()
226+ init_args = [] # The __init__ arguments to render
227+ for p in inspect .signature (value .__class__ .__init__ ).parameters .values ():
228+ if p .name == "self" :
229+ continue
230+ if p .kind == inspect .Parameter .POSITIONAL_ONLY :
231+ # Required positional, render without a name
232+ init_args .append (repr (config [p .name ]))
233+ elif p .default is inspect ._empty or p .default != config [p .name ]:
234+ # Keyword arg with non-default value, render
235+ init_args .append (p .name + "=" + repr (config [p .name ]))
236+ # else don't render that argument
237+ return ReprWrapper (
238+ value .__class__ .__module__
239+ + "."
240+ + value .__class__ .__name__
241+ + "("
242+ + ", " .join (init_args )
243+ + ")"
244+ )
245+
246+ return None
247+
248+
191249def get_signature_start (function ):
192250 """For the Dense layer, it should return the string 'keras.layers.Dense'"""
193251 if ismethod (function ):
@@ -209,9 +267,12 @@ def get_signature_end(function):
209267
210268 formatted_params = []
211269 for x in params :
270+ default = get_default_value_for_repr (x .default )
271+ if default :
272+ x = inspect .Parameter (
273+ x .name , x .kind , default = default , annotation = x .annotation
274+ )
212275 str_x = str (x )
213- if "<function" in str_x :
214- str_x = re .sub (r'<function (.*?) at 0x[0-9a-fA-F]+>' , r'\1' , str_x )
215276 formatted_params .append (str_x )
216277 signature_end = "(" + ", " .join (formatted_params ) + ")"
217278
@@ -382,10 +443,8 @@ def get_class_from_method(meth):
382443 return cls
383444 meth = meth .__func__ # fallback to __qualname__ parsing
384445 if inspect .isfunction (meth ):
385- cls = getattr (
386- inspect .getmodule (meth ),
387- meth .__qualname__ .split (".<locals>" , 1 )[0 ].rsplit ("." , 1 )[0 ],
388- )
446+ cls_name = meth .__qualname__ .split (".<locals>" , 1 )[0 ].rsplit ("." , 1 )[0 ]
447+ cls = getattr (inspect .getmodule (meth ), cls_name , None )
389448 if isinstance (cls , type ):
390449 return cls
391450 return getattr (meth , "__objclass__" , None ) # handle special descriptor objects
0 commit comments