1111}
1212"""
1313
14+ from hub_master import MODELS_MASTER
15+
1416try :
1517 import keras_hub
1618except Exception as e :
@@ -46,10 +48,18 @@ def format_param_count(metadata):
4648
4749def format_path (metadata ):
4850 """Returns Path for the given preset"""
49- try :
50- return f"[{ metadata ['official_name' ]} ]({ metadata ['path' ]} )"
51- except KeyError :
52- return "Unknown"
51+ for child in MODELS_MASTER ["children" ]:
52+ path = child ["path" ].strip ("/" )
53+ if metadata ["path" ] == path :
54+ text = child ["title" ]
55+ link = f"/keras_hub/api/models/{ path } "
56+ return f"[{ text } ]({ link } )"
57+ return "-"
58+
59+
60+ def format_preset_link (preset , handle ):
61+ url = handle .replace ("kaggle://" , "https://www.kaggle.com/models/" )
62+ return f"[{ preset } ]({ url } )"
5363
5464
5565def is_base_class (symbol ):
@@ -61,35 +71,38 @@ def is_base_class(symbol):
6171 )
6272
6373
64- def render_all_presets (symbols ):
65- """Renders the markdown table for backbone presets as a string."""
74+ def sort_presets (presets ):
75+ # Sort by path and then by parameter count.
76+ return sorted (
77+ presets .keys (),
78+ key = lambda x : (
79+ presets [x ]["metadata" ]["path" ],
80+ presets [x ]["metadata" ]["params" ],
81+ )
82+ )
83+
84+
85+ def render_row (preset , data , add_doc_link = False ):
86+ """Renders a row for a preset in a markdown table."""
87+ metadata = data ["metadata" ]
88+ url = data ["kaggle_handle" ]
89+ url = url .replace ("kaggle://" , "https://www.kaggle.com/models/" )
90+ cols = []
91+ cols .append (format_preset_link (preset , data ["kaggle_handle" ]))
92+ if add_doc_link :
93+ cols .append (format_path (metadata ))
94+ cols .append (format_param_count (metadata ))
95+ cols .append (metadata ["description" ])
96+ return " | " .join (cols ) + "\n "
6697
67- table = TABLE_HEADER
6898
69- # Backbones has alias, which duplicates some presets.
70- # Use a set to keep them unique.
71- added_presets = set ()
72- # Bakcbone presets
73- for name , symbol in symbols :
74- if is_base_class (symbol ) or "Backbone" not in name :
75- continue
76- presets = symbol .presets
77- # Only keep the ones with pretrained weights for KerasCV Backbones.
78- for preset in presets :
79- if preset in added_presets :
80- continue
81- else :
82- added_presets .add (preset )
83- metadata = presets [preset ]["metadata" ]
84- url = presets [preset ]["kaggle_handle" ]
85- url = url .replace ("kaggle://" , "https://www.kaggle.com/models/" )
86- table += (
87- f"[{ preset } ]({ url } ) | "
88- f"{ format_path (metadata )} | "
89- f"{ format_param_count (metadata )} | "
90- f"{ metadata ['description' ]} "
91- )
92- table += "\n "
99+ def render_all_presets ():
100+ """Renders the markdown table for backbone presets as a string."""
101+ table = TABLE_HEADER
102+ symbol = keras_hub .models .Backbone
103+ for preset in sort_presets (symbol .presets ):
104+ data = symbol .presets [preset ]
105+ table += render_row (preset , data , add_doc_link = True )
93106 return table
94107
95108
@@ -100,15 +113,9 @@ def render_table(symbol):
100113 table = TABLE_HEADER_PER_MODEL
101114 if is_base_class (symbol ) or len (symbol .presets ) == 0 :
102115 return None
103- for preset in symbol .presets :
104- metadata = symbol .presets [preset ]["metadata" ]
105- url = symbol .presets [preset ]["kaggle_handle" ]
106- url = url .replace ("kaggle://" , "https://www.kaggle.com/models/" )
107- table += (
108- f"[{ preset } ]({ url } ) | "
109- f"{ format_param_count (metadata )} | "
110- f"{ metadata ['description' ]} \n "
111- )
116+ for preset in sort_presets (symbol .presets ):
117+ data = symbol .presets [preset ]
118+ table += render_row (preset , data )
112119 return table
113120
114121
@@ -117,9 +124,6 @@ def render_tags(template):
117124 if keras_hub is None :
118125 return template
119126
120- symbols = keras_hub .models .__dict__ .items ()
121127 if "{{presets_table}}" in template :
122- template = template .replace (
123- "{{presets_table}}" , render_all_presets (symbols )
124- )
128+ template = template .replace ("{{presets_table}}" , render_all_presets ())
125129 return template
0 commit comments