Skip to content

Commit 78df6f5

Browse files
authored
New tools for model preset admin (#2025)
A tool to update all json files for a preset (by running them through Keras' serialize and deserialize routines). A tool to update all preset version in the library to the latest version on kaggle.
1 parent 9b024bd commit 78df6f5

File tree

3 files changed

+187
-1
lines changed

3 files changed

+187
-1
lines changed

tools/admin/mirror_weights_on_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
kagglehub = None
1717

1818
HF_BASE_URI = "hf://keras"
19-
JSON_FILE_PATH = "tools/hf_uploaded_presets.json"
19+
JSON_FILE_PATH = "tools/admin/hf_uploaded_presets.json"
2020
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
2121

2222

tools/admin/update_all_json.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Update all json files for all models on Kaggle.
2+
3+
Run tools/admin/update_all_versions.py before running this tool to make sure
4+
all our kaggle links point to the latest version!
5+
6+
This script downloads all models from KaggleHub, loads and re-serializes all
7+
json files, and reuploads them. This can be useful when changing our metadata or
8+
updating our saved configs.
9+
10+
This script relies on private imports from preset_utils and may need updates
11+
when it is re-run.
12+
13+
Usage:
14+
```
15+
# Preview changes.
16+
python tools/admin/update_all_json.py
17+
# Upload changes.
18+
python tools/admin/update_all_json.py --upload
19+
# Resume after a failure.
20+
python tools/admin/update_all_json.py --upload --start_at=gemma_2b_en
21+
```
22+
"""
23+
24+
import difflib
25+
import os
26+
import pathlib
27+
import shutil
28+
29+
import kagglehub
30+
import torch
31+
from absl import app
32+
from absl import flags
33+
34+
os.environ["KERAS_BACKEND"] = "torch"
35+
36+
import keras_hub
37+
from keras_hub.src.utils import preset_utils
38+
39+
FLAGS = flags.FLAGS
40+
flags.DEFINE_boolean("upload", False, "Upload updated models.")
41+
flags.DEFINE_string("start_at", "", "Resume at given preset.")
42+
43+
44+
BOLD = "\033[1m"
45+
GREEN = "\033[92m"
46+
RED = "\033[91m"
47+
RESET = "\033[0m"
48+
49+
50+
def diff(in_path, out_path):
51+
with open(in_path) as in_file, open(out_path) as out_file:
52+
in_lines = in_file.readlines()
53+
out_lines = out_file.readlines()
54+
# Ignore updates to upload_date.
55+
if "metadata.json" in in_path.name:
56+
in_lines = [line for line in in_lines if "date" not in line]
57+
out_lines = [line for line in out_lines if "date" not in line]
58+
diff = difflib.unified_diff(
59+
in_lines,
60+
out_lines,
61+
)
62+
diff = list(diff)
63+
if not diff:
64+
return False
65+
for line in diff:
66+
if line.startswith("+"):
67+
print(" " + GREEN + line + RESET, end="")
68+
elif line.startswith("-"):
69+
print(" " + RED + line + RESET, end="")
70+
else:
71+
print(" " + line, end="")
72+
print()
73+
return True
74+
75+
76+
def main(argv):
77+
presets = keras_hub.models.Backbone.presets
78+
output_parent = pathlib.Path("updates")
79+
output_parent.mkdir(parents=True, exist_ok=True)
80+
81+
remaining = sorted(presets.keys())
82+
if FLAGS.start_at:
83+
remaining = remaining[remaining.index(FLAGS.start_at) :]
84+
85+
for preset in remaining:
86+
handle = presets[preset]["kaggle_handle"].removeprefix("kaggle://")
87+
handle_no_version = os.path.dirname(handle)
88+
builtin_name = os.path.basename(handle_no_version)
89+
90+
# Download the full model with KaggleHub.
91+
input_dir = kagglehub.model_download(handle)
92+
input_dir = pathlib.Path(input_dir)
93+
output_dir = output_parent / builtin_name
94+
if os.path.exists(output_dir):
95+
shutil.rmtree(output_dir)
96+
shutil.copytree(input_dir, output_dir)
97+
98+
# Manually create saver/loader objects.
99+
config = preset_utils.load_json(preset, preset_utils.CONFIG_FILE)
100+
loader = preset_utils.KerasPresetLoader(preset, config)
101+
saver = preset_utils.KerasPresetSaver(output_dir)
102+
103+
# Update all json files.
104+
print(BOLD + handle + RESET)
105+
updated = False
106+
for file in input_dir.glob("*.json"):
107+
if file.name == preset_utils.METADATA_FILE:
108+
# metadata.json is handled concurrently with config.json.
109+
continue
110+
config = preset_utils.load_json(preset, file.name)
111+
layer = loader._load_serialized_object(config)
112+
saver._save_serialized_object(layer, file.name)
113+
if file.name == preset_utils.CONFIG_FILE:
114+
# Handle metadata.json with config.json.
115+
print(" " + BOLD + preset_utils.METADATA_FILE + RESET)
116+
saver._save_metadata(layer)
117+
name = preset_utils.METADATA_FILE
118+
if diff(input_dir / name, output_dir / name):
119+
updated = True
120+
print(" " + BOLD + file.name + RESET)
121+
if diff(input_dir / file.name, output_dir / file.name):
122+
updated = True
123+
del layer
124+
125+
if not updated:
126+
continue
127+
128+
# Reupload the model if any json files were updated.
129+
if FLAGS.upload:
130+
print(BOLD + "Uploading " + handle_no_version + RESET)
131+
kagglehub.model_upload(
132+
handle_no_version,
133+
output_dir,
134+
version_notes="updated json files",
135+
)
136+
else:
137+
print(BOLD + "Preview. Not uploading " + handle_no_version + RESET)
138+
139+
if FLAGS.upload:
140+
print(BOLD + "Wait a few hours (for kaggle to process models)." + RESET)
141+
print(BOLD + "Then run tasks/admin/update_all_versions.py" + RESET)
142+
143+
144+
if __name__ == "__main__":
145+
with torch.device("meta"):
146+
app.run(main)

tools/admin/update_all_versions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Update all preset files to use the latest version on kaggle.
2+
3+
Run from the base of the repo.
4+
5+
Usage:
6+
```
7+
python tools/admin/update_all_versions.py
8+
```
9+
"""
10+
11+
import os
12+
import pathlib
13+
14+
import kagglehub
15+
16+
import keras_hub
17+
18+
19+
def update():
20+
presets = keras_hub.models.Backbone.presets
21+
for preset in sorted(presets.keys()):
22+
uri = presets[preset]["kaggle_handle"]
23+
kaggle_handle = uri.removeprefix("kaggle://")
24+
old_version = os.path.basename(kaggle_handle)
25+
kaggle_handle = os.path.dirname(kaggle_handle)
26+
hub_dir = kagglehub.model_download(kaggle_handle, path="metadata.json")
27+
new_version = os.path.basename(os.path.dirname(hub_dir))
28+
if old_version != new_version:
29+
print(f"Updating {preset} from {old_version} to {new_version}")
30+
for path in pathlib.Path(".").glob("keras_hub/**/*_presets.py"):
31+
with open(path, "r") as file:
32+
contents = file.read()
33+
new_uri = os.path.dirname(uri) + f"/{new_version}"
34+
contents = contents.replace(f'"{uri}"', f'"{new_uri}"')
35+
with open(path, "w") as file:
36+
file.write(contents)
37+
38+
39+
if __name__ == "__main__":
40+
update()

0 commit comments

Comments
 (0)