-
Notifications
You must be signed in to change notification settings - Fork 1
Basic Usage
This project has been designed to fit directly into a basic tensorflow setup. Its really easy to get started, just create a functional or sequential model and call create_model on it as demonstrated at bottom of README.
Some miscellaneous points
-
You can compile and train a mlvtk model as you would a standard model. This includes using callbacks, and any other built in Tensorflow features. if you find something that does not work please open an issue
-
Currently, subclassed models are not supported, however this is something I want to support in the future.
-
There is an argument for
mlag.surface_plot()calledreturn_traces. If you wish to customize your plots but dont want to work directly with the pandas data frame (mlag.loss_df), set this argument toTrue. It will return plotly traces of the surface plot and scatter3d used for the optimizer path. -
The
extargument formlag.surface_plot()is used to extend or calculate the bounds of alphas/betas. For example,mlag.surface_plot(ext=1)will create alphas/betas like so
self.alphas = np.linspace(
self.xdir.min() - 1, self.xdir.max() + 1, num=50, dtype=np.float32
)
self.betas = np.linspace(
self.ydir.min() - 1, self.ydir.max() + 1, num=50, dtype=np.float32
) ext can take a numeric value, the string 'std' to indicate standard deviation, or an arbitrary expression
- Currently, the surface plot is created based on the final location of the model ie the weight values after training is complete. If you wish to visualize the surface at any other point, here is an example. This will be made easier in the future:
m = tf.keras.models.load_model('vwd/model_10.h5') # load checkpoint
new_m = create_model(m) # create mlvtk model
new_m.compile(optimizer=tf.keras.optimizers.SGD(),
loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy']) # compile mlvtk model
new_m.testdat=mnist_test_data # SET TEST DATA ~~~ This is needed because we are not training this model (ie calling new_m.fit(...)
new_m.gen_path() # manually call gen_path to calculate xdir and ydir
new_m.xdir=new_m.xdir[:10] # trim xdir and ydir. This is necessary if you want to show ONLY up to the loaded epoch (in this case # 10)
new_m.ydir=new_m.ydir[:10] # otherwise will plot points for ALL epochs in 'vwd/'
new_m.surface_plot() # call surface_plot