Skip to content

Commit 96e4c0e

Browse files
authored
Merge pull request #1 from dreoporto/load-or-fit-model
Load or fit model
2 parents 456f062 + ed7e417 commit 96e4c0e

5 files changed

Lines changed: 214 additions & 4 deletions

File tree

README.md

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
## Summary
44

5-
**PTMLib** is a set of utilities that I have used with Machine Learning frameworks such as Scikit-Learn and TensorFlow. The purpose is to eliminate code that I kept repeating in multiple projects.
5+
**PTMLib** is a set of utilities for use with Machine Learning frameworks such as Scikit-Learn and TensorFlow.
66

77
- **ptmlib.time.Stopwatch** - measure the time it takes to complete a long-running task, with an audio alert for task completion
88
- **ptmlib.cpu.CpuCount** - get info on CPUs available, with options to adjust/exclude based on a specific number/percentage. Useful for setting `n_jobs` in Scikit-Learn tools that support multiple CPUs, such as `RandomForestClassifier`
9-
- **ptmlib.charts** - render separate line charts for TensorFlow accuracy and loss, with corresponding validation data if available
10-
11-
*This code and documentation were created by Andre Oporto at [Pendragon AI](https://www.pendragonai.com)*
9+
- **ptmlib.charts** - render separate line charts for TensorFlow metrics such as accuracy and loss, with corresponding validation data if available
10+
- **ptmlib.model_tools.load_or_fit_model()** - train, save, and reload Tensorflow models and metric charts automatically, making it easier to pick up where you left off
1211

1312
## ptmlib.time.Stopwatch
1413

@@ -133,6 +132,44 @@ TensorFlow History Loss Chart: *loss-20210201-111545.png*
133132

134133
![TF History Accuracy Chart](ptmlib/examples/loss-20210201-111545.png)
135134

135+
The default file name format for these images is *searchstring-timestamp.png*. The `file_name_suffix` parameter lets you replace the timestamp with another value, for more predictable filenames to simplify reuse of images in your code.
136+
137+
## ptmlib.model_tools.load_or_fit_model()
138+
139+
The `ptmlib.model_tools.load_or_fit_model()` function makes it easy to train and save a model for later use, in cases where you may need to stop and restart work in Jupyter or your IDE *after* model training has completed. This can be very helpful when working through a long and detailed notebook with multiple example models, where some models take significant time to train. You can avoid repeatedly training models you are satisfied with and have completed, and still close and reopen your notebook as needed.
140+
141+
### Example Usage:
142+
143+
```python
144+
# from examples/computer_vision_caching.py
145+
146+
import ptmlib.model_tools as modt
147+
148+
...
149+
150+
model_file_name = "computer_vision_1"
151+
152+
...
153+
154+
fit_model_function_with_callback = lambda my_model, x, y, validation_data, epochs: my_model.fit(
155+
x, y, validation_data, epochs=epochs, callbacks=[early_callback],
156+
validation_split=hp_validation_split)
157+
158+
# if this has previously been executed, we will load the trained/saved model
159+
model, history = modt.load_or_fit_model(model, model_file_name, x=training_images, y=training_labels,
160+
epochs=hp_epochs, fit_model_function=fit_model_function_with_callback, metrics=["accuracy"])
161+
162+
model.evaluate(test_images, test_labels)
163+
```
164+
165+
### Example Output:
166+
167+
You will see output similar to the following if you re-run a previously saved notebook where `load_or_fit_model` was used.
168+
169+
![Sample load_or_fit_model Screenshot](ptmlib/media/load_or_fit_model_screenshot.png)
170+
171+
If you wish to retrain a model that has previously been saved, simply delete the model file and related images, which are stored as `h5` and `png` files respectively.
172+
136173
## Installation
137174

138175
To install `ptmlib` in a virtualenv or conda environment:

ptmlib/charts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def show_history_chart(history: Any, search_string: str, fig_size: (int, int) =
2828

2929
filtered_hist = {k: v for (k, v) in history.history.items() if search_string in k}
3030

31+
if len(filtered_hist.keys()) == 0:
32+
print('No data to plot for search_string:', search_string)
33+
return
34+
3135
pd.DataFrame(filtered_hist).plot(figsize=fig_size)
3236
plt.grid(True, which='major')
3337
plt.grid(True, which='minor', alpha=0.3, linestyle='--')
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# THIS CODE IS A MODULARIZED AND UPDATED VERSION OF CODE FROM THE "DEEPLEARNING.AI TENSORFLOW DEVELOPER" COURSE
2+
# SOURCE:
3+
# https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%204%20-%20Lesson%202%20-%20Notebook.ipynb
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import tensorflow as tf
8+
from tensorflow import keras
9+
from tensorflow.keras import layers
10+
11+
import ptmlib.model_tools as modt
12+
13+
14+
class MyCallback(keras.callbacks.Callback):
15+
16+
def __init__(self, target):
17+
super().__init__()
18+
self.target = target
19+
20+
def on_epoch_end(self, _, logs=None):
21+
22+
if logs is None:
23+
logs = {}
24+
if logs.get("accuracy") > self.target:
25+
print(f"\nReached {self.target * 100}% accuracy so cancelling training!")
26+
self.model.stop_training = True
27+
28+
29+
def print_diagnostics() -> None:
30+
print('TF VERSION:', tf.__version__)
31+
print('KERAS VERSION:', keras.__version__)
32+
33+
34+
def get_data():
35+
mnist = keras.datasets.fashion_mnist
36+
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
37+
38+
np.set_printoptions(linewidth=200)
39+
plt.imshow(training_images[0])
40+
plt.show()
41+
42+
print(training_labels[0])
43+
44+
# normalize image data to values between 0 and 1
45+
training_images = training_images / 255.0
46+
test_images = test_images / 255.0
47+
48+
return (training_images, training_labels), (test_images, test_labels)
49+
50+
51+
def get_model() -> keras.models.Sequential:
52+
model = keras.models.Sequential([
53+
layers.Flatten(input_shape=(28, 28)),
54+
layers.Dropout(0.2),
55+
layers.Dense(512, activation=tf.nn.relu),
56+
layers.Dense(10, activation=tf.nn.softmax)
57+
])
58+
59+
model.summary()
60+
61+
model.compile(
62+
optimizer=tf.optimizers.Adam(),
63+
loss="sparse_categorical_crossentropy",
64+
metrics=["accuracy"]
65+
)
66+
67+
return model
68+
69+
70+
def main():
71+
72+
# HYPER PARAMS, CONSTANTS, ETC
73+
hp_epochs = 50
74+
hp_target = 0.91
75+
hp_validation_split = 0.2
76+
model_file_name = "computer_vision_1"
77+
78+
print_diagnostics()
79+
80+
(training_images, training_labels), (test_images, test_labels) = get_data()
81+
82+
model = get_model()
83+
84+
early_callback = MyCallback(target=hp_target)
85+
86+
fit_model_function_with_callback = lambda my_model, x, y, validation_data, epochs: my_model.fit(
87+
x, y, validation_data, epochs=epochs, callbacks=[early_callback], validation_split=hp_validation_split)
88+
89+
model, history = modt.load_or_fit_model(model, model_file_name, x=training_images, y=training_labels,
90+
epochs=hp_epochs, fit_model_function=fit_model_function_with_callback,
91+
metrics=["accuracy"])
92+
93+
model.evaluate(test_images, test_labels)
94+
95+
classifications = model.predict(test_images)
96+
print(classifications[0])
97+
print(test_labels[0])
98+
print(max(classifications[0]))
99+
100+
101+
if __name__ == '__main__':
102+
main()
42.1 KB
Loading

ptmlib/model_tools.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from typing import Any, List
3+
4+
import matplotlib.image as mpimg
5+
import matplotlib.pyplot as plt
6+
from tensorflow import keras
7+
8+
import ptmlib.charts as pch
9+
from ptmlib.time import Stopwatch
10+
11+
12+
def default_load_model_function(model_file_name: str):
13+
return keras.models.load_model(f'{model_file_name}.h5')
14+
15+
16+
def _default_fit_model_function(model: Any, x: Any, y: Any = None, validation_data: Any = None, epochs: int = 1):
17+
return model.fit(x, y, validation_data=validation_data, epochs=epochs)
18+
19+
20+
def load_or_fit_model(model: Any, model_file_name: str, x: Any, y: Any = None, validation_data: Any = None,
21+
epochs: int = 1, metrics: List[str] = None, images_enabled=True, fig_size: (int, int) = (10, 6),
22+
load_model_function=default_load_model_function,
23+
fit_model_function=_default_fit_model_function):
24+
history = None
25+
26+
if os.path.exists(f'{model_file_name}.h5'):
27+
print(f'Loading existing model file: {model_file_name}.h5')
28+
model = load_model_function(model_file_name)
29+
if images_enabled:
30+
_show_saved_images(metrics, model_file_name, fig_size)
31+
else:
32+
stopwatch = Stopwatch()
33+
stopwatch.start()
34+
history = fit_model_function(model, x, y, validation_data, epochs)
35+
stopwatch.stop()
36+
print(f'Saving new model file: {model_file_name}.h5')
37+
model.save(f'{model_file_name}.h5')
38+
if images_enabled:
39+
_show_new_images(history, model_file_name, metrics)
40+
41+
return model, history
42+
43+
44+
def _show_new_images(history: Any, model_file_name: str, metrics: List[str]):
45+
if metrics is not None:
46+
for metric in metrics:
47+
pch.show_history_chart(history, metric, save_fig_enabled=True, file_name_suffix=model_file_name)
48+
pch.show_history_chart(history, "loss", save_fig_enabled=True, file_name_suffix=model_file_name)
49+
50+
51+
def _show_saved_images(metrics: List[str], model_file_name: str, fig_size: (int, int) = (10, 6)):
52+
if metrics is not None:
53+
for metric in metrics:
54+
if os.path.exists(f'{metric}-{model_file_name}.png'):
55+
_show_saved_image(f'{metric}-{model_file_name}.png', fig_size)
56+
if os.path.exists(f'loss-{model_file_name}.png'):
57+
_show_saved_image(f'loss-{model_file_name}.png', fig_size)
58+
59+
60+
def _show_saved_image(filename: str, fig_size: (int, int) = (10, 6)):
61+
image_data = mpimg.imread(filename)
62+
fig = plt.figure(figsize=fig_size)
63+
ax = plt.Axes(fig, [0., 0., 1., 1.])
64+
fig.add_axes(ax)
65+
plt.axis('off')
66+
plt.imshow(image_data)
67+
plt.show()

0 commit comments

Comments
 (0)