Skip to content

Commit b26a22d

Browse files
authored
DOC/ENH: Better AutoEncoder examples & support (#200)
1 parent 302dd7c commit b26a22d

File tree

7 files changed

+170
-86
lines changed

7 files changed

+170
-86
lines changed

docs/source/notebooks/AutoEncoders.md

Lines changed: 153 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python
@@ -19,7 +19,9 @@ jupyter:
1919

2020
# Autoencoders in SciKeras
2121

22-
Autencoders are an approach to use nearual networks to distill data into it's most important features, thereby compressing the data. We will be following the [Keras tutorial](https://blog.keras.io/building-autoencoders-in-keras.html) on the topic, which goes much more in depth and breadth than we will here. You are highly encouraged to check out that tutorial if you want to learn about autoencoders in the general sense.
22+
Autencoders are an approach to use nearual networks to distill data into it's most important features, thereby compressing the data.
23+
We will be following the [Keras tutorial](https://blog.keras.io/building-autoencoders-in-keras.html) on the topic, which goes much more in depth and breadth than we will here.
24+
You are highly encouraged to check out that tutorial if you want to learn about autoencoders in the general sense.
2325

2426
## Table of contents
2527

@@ -28,6 +30,7 @@ Autencoders are an approach to use nearual networks to distill data into it's mo
2830
* [3. Define Keras Model](#3.-Define-Keras-Model)
2931
* [4. Training](#4.-Training)
3032
* [5. Explore Results](#5.-Explore-Results)
33+
* [6. Deep AutoEncoder](#6.-Deep-AutoEncoder)
3134

3235
## 1. Setup
3336

@@ -73,99 +76,110 @@ print(x_test.shape)
7376

7477
## 3. Define Keras Model
7578

76-
We will be defining a very simple autencoder. We define _three_ model building methods:
79+
We will be defining a very simple autencoder. We define _three_ model architectures:
7780

78-
1. One to build a full end-to-end autoencoder.
79-
2. One to create a model that includes only the encoder portion.
80-
3. One that creates a model that includes only the decoder portion.
81+
1. An encoder: a series of densly connected layers culminating in an "output" layer that determines the encoding dimensions.
82+
2. A decoder: takes the output of the encoder as it's input and reconstructs the original data.
83+
3. An autoencoder: a chain of the encoder and decoder that directly connects them for training purposes.
8184

8285
The only variable we give our model is the encoding dimensions, which will be a hyperparemter of our final transformer.
8386

84-
```python
85-
from tensorflow import keras
87+
The encoder and decoder are views to the first/last layers of the autoencoder model.
88+
They'll be directly used in `transform` and `inverse_transform`, so we'll create some SciKeras models with those layers
89+
and save them as in `encoder_model_` and `decoder_model_`. All three models are created within `_keras_build_fn`.
8690

91+
For a background on chaining Functional Models like this, see [All models are callable](https://keras.io/guides/functional_api/#all-models-are-callable-just-like-layers) in the Keras docs.
8792

88-
def get_fit_model(encoding_dim: int) -> keras.Model:
89-
"""Get an autoencoder.
93+
```python
94+
from typing import Dict, Any
9095

91-
This autoencoder compresses a 28x28 image (784 pixels) down to a feature of length
92-
`encoding_dim`, and tries to reconstruct the input image from that vector.
93-
"""
94-
input_img = keras.Input(shape=(784,), name="input")
95-
encoded = keras.layers.Dense(encoding_dim, activation='relu', name="encoded")(input_img)
96-
decoded = keras.layers.Dense(784, activation='sigmoid', name="output")(encoded)
97-
autoencoder_model = keras.Model(input_img, decoded)
98-
return autoencoder_model
96+
from sklearn.base import TransformerMixin
97+
from sklearn.metrics import mean_squared_error
98+
from scikeras.wrappers import BaseWrapper
9999

100-
def get_tf_model(fit_model: keras.Model) -> keras.Model:
101-
"""Get an encoder model.
102100

103-
We do this by extracting the encoding layer from the fitted autoencoder model.
101+
class AutoEncoder(BaseWrapper, TransformerMixin):
102+
"""A class that enables transform and fit_transform.
104103
"""
105-
return keras.Model(fit_model.get_layer("input").input, fit_model.get_layer("encoded").output)
106104

107-
def get_inverse_tf_model(fit_model: keras.Model, encoding_dim: int) -> keras.Model:
108-
"""Get an deencoder model.
105+
encoder_model_: BaseWrapper
106+
decoder_model_: BaseWrapper
107+
108+
def _keras_build_fn(self, encoding_dim: int, meta: Dict[str, Any]):
109+
n_features_in = meta["n_features_in_"]
109110

110-
We do this by extracting the deencoding layer from the fitted autoencoder model
111-
and adding a new Keras input layer.
112-
"""
113-
encoded_input = keras.Input(shape=(encoding_dim,))
114-
output = fit_model.get_layer("output")(encoded_input)
115-
return keras.Model(encoded_input, output)
116-
```
111+
encoder_input = keras.Input(shape=(n_features_in,))
112+
encoder_output = keras.layers.Dense(encoding_dim, activation='relu')(encoder_input)
113+
encoder_model = keras.Model(encoder_input, encoder_output)
117114

118-
Next we create a class that that will enable the `transform` and `fit_transform` methods, as well as integrating all three of our models into a single estimator.
115+
decoder_input = keras.Input(shape=(encoding_dim,))
116+
decoder_output = keras.layers.Dense(n_features_in, activation='sigmoid', name="decoder")(decoder_input)
117+
decoder_model = keras.Model(decoder_input, decoder_output)
118+
119+
autoencoder_input = keras.Input(shape=(n_features_in,))
120+
encoded_img = encoder_model(autoencoder_input)
121+
reconstructed_img = decoder_model(encoded_img)
119122

120-
```python
121-
from sklearn.base import TransformerMixin, clone
122-
from scikeras.wrappers import BaseWrapper
123+
autoencoder_model = keras.Model(autoencoder_input, reconstructed_img)
123124

125+
self.encoder_model_ = BaseWrapper(encoder_model, verbose=self.verbose)
126+
self.decoder_model_ = BaseWrapper(decoder_model, verbose=self.verbose)
124127

125-
class KerasTransformer(BaseWrapper, TransformerMixin):
126-
"""A class that enables transform and fit_transform.
127-
"""
128-
129-
def __init__(self, *args, tf_est: BaseWrapper = None, inv_tf_est: BaseWrapper = None, **kwargs) -> None:
130-
super().__init__(*args, **kwargs)
131-
self.tf_est = tf_est
132-
self.inv_tf_est = inv_tf_est
133-
128+
return autoencoder_model
129+
130+
def _initialize(self, X, y=None):
131+
X, _ = super()._initialize(X=X, y=y)
132+
# since encoder_model_ and decoder_model_ share layers (and their weights)
133+
# X_tf here come from random weights, but we only use it to initialize our models
134+
X_tf = self.encoder_model_.initialize(X).predict(X)
135+
self.decoder_model_.initialize(X_tf)
136+
return X, X
137+
138+
def initialize(self, X):
139+
self._initialize(X=X, y=X)
140+
return self
134141

135-
def fit(self, X, sample_weight=None):
142+
def fit(self, X, *, sample_weight=None) -> "AutoEncoder":
136143
super().fit(X=X, y=X, sample_weight=sample_weight)
137-
self.tf_est_ = clone(self.tf_est)
138-
self.inv_tf_est_ = clone(self.inv_tf_est)
139-
self.tf_est_.set_params(fit_model=self.model_)
140-
self.inv_tf_est_.set_params(fit_model=self.model_, encoding_dim=self.encoding_dim)
141-
X = self.feature_encoder_.transform(X)
142-
self.tf_est_.initialize(X=X)
143-
X_tf = self.tf_est_.predict(X=X)
144-
self.inv_tf_est_.initialize(X_tf)
144+
# at this point, encoder_model_ and decoder_model_
145+
# are both "fitted" because they share layers w/ model_
146+
# which is fit in the above call
145147
return self
146148

147-
def transform(self, X):
148-
X = self.feature_encoder_.transform(X)
149-
X_tf = self.tf_est_.predict(X)
150-
return X_tf
151-
152-
def inverse_transform(self, X_tf):
153-
X = self.inv_tf_est_.predict(X_tf)
154-
X = self.feature_encoder_.inverse_transform(X)
155-
return X
149+
def score(self, X) -> float:
150+
# Note: we use 1-MSE as the score
151+
# With MSE, "larger is better", but Scikit-Learn
152+
# always maximizes the score (e.g. in GridSearch)
153+
return 1 - mean_squared_error(self.predict(X), X)
154+
155+
def transform(self, X) -> np.ndarray:
156+
X: np.ndarray = self.feature_encoder_.transform(X)
157+
return self.encoder_model_.predict(X)
158+
159+
def inverse_transform(self, X_tf: np.ndarray):
160+
X: np.ndarray = self.decoder_model_.predict(X_tf)
161+
return self.feature_encoder_.inverse_transform(X)
156162
```
157163

158-
Next, we wrap the Keras Model with Scikeras. Note that for our encoder/decoder estimators, we do not need to provide a loss function since no training will be done. We do however need to have the `fit_model` and `encoding_dim` so that these will be settable by `BaseWrapper.set_params`.
164+
Next, we wrap the Keras Model with Scikeras. Note that for our encoder/decoder estimators, we do not need to provide a loss function since no training will be done.
165+
We do however need to have the `fit_model` and `encoding_dim` so that these will be settable by `BaseWrapper.set_params`.
159166

160167
```python
161-
tf_est = BaseWrapper(model=get_tf_model, fit_model=None, verbose=0)
162-
inv_tf_est = BaseWrapper(model=get_inverse_tf_model, fit_model=None, encoding_dim=None, verbose=0)
163-
autoencoder = KerasTransformer(model=get_fit_model, tf_est=tf_est, inv_tf_est=inv_tf_est, loss="binary_crossentropy", encoding_dim=32, epochs=5)
168+
autoencoder = AutoEncoder(
169+
loss="binary_crossentropy",
170+
encoding_dim=32,
171+
random_state=0,
172+
epochs=5,
173+
verbose=False,
174+
optimizer="adam",
175+
)
164176
```
165177

166178
## 4. Training
167179

168-
To train the model, we pass the input images as both the features and the target. This will train the layers to compress the data as accurately as possible between the encoder and decoder. Note that we only pass the `X` parameter, since we defined the mapping `y=X` in `KerasTransformer.fit` above.
180+
To train the model, we pass the input images as both the features and the target.
181+
This will train the layers to compress the data as accurately as possible between the encoder and decoder.
182+
Note that we only pass the `X` parameter, since we defined the mapping `y=X` in `KerasTransformer.fit` above.
169183

170184
```python
171185
_ = autoencoder.fit(X=x_train)
@@ -208,8 +222,77 @@ What about the compression? Let's check the sizes of the arrays.
208222

209223
```python
210224
encoded_imgs = autoencoder.transform(x_test)
211-
print(f"x_test.shape[1]: {x_test.shape[1]}")
212-
print(f"encoded_imgs.shape[1]: {encoded_imgs.shape[1]}")
225+
print(f"x_test size (in MB): {x_test.nbytes/1024**2:.2f}")
226+
print(f"encoded_imgs size (in MB): {encoded_imgs.nbytes/1024**2:.2f}")
213227
cr = round((encoded_imgs.nbytes/x_test.nbytes), 2)
214228
print(f"Compression ratio: 1/{1/cr:.0f}")
215229
```
230+
231+
## 6. Deep AutoEncoder
232+
233+
234+
We can easily expand our model to be a deep autoencoder by adding some hidden layers. All we have to do is add a parameter `hidden_layer_sizes` and use it in `_keras_build_fn` to build hidden layers.
235+
For simplicity, we use a single `hidden_layer_sizes` parameter and mirror it across the encoding layers and decoding layers, but there is nothing forcing us to build symetrical models.
236+
237+
```python
238+
from typing import List
239+
240+
241+
class DeepAutoEncoder(AutoEncoder):
242+
"""A class that enables transform and fit_transform.
243+
"""
244+
245+
def _keras_build_fn(self, encoding_dim: int, hidden_layer_sizes: List[str], meta: Dict[str, Any]):
246+
n_features_in = meta["n_features_in_"]
247+
248+
encoder_input = keras.Input(shape=(n_features_in,))
249+
x = encoder_input
250+
for layer_size in hidden_layer_sizes:
251+
x = keras.layers.Dense(layer_size, activation='relu')(x)
252+
encoder_output = keras.layers.Dense(encoding_dim, activation='relu')(x)
253+
encoder_model = keras.Model(encoder_input, encoder_output)
254+
255+
decoder_input = keras.Input(shape=(encoding_dim,))
256+
x = decoder_input
257+
for layer_size in reversed(hidden_layer_sizes):
258+
x = keras.layers.Dense(layer_size, activation='relu')(x)
259+
decoder_output = keras.layers.Dense(n_features_in, activation='sigmoid', name="decoder")(x)
260+
decoder_model = keras.Model(decoder_input, decoder_output)
261+
262+
autoencoder_input = keras.Input(shape=(n_features_in,))
263+
encoded_img = encoder_model(autoencoder_input)
264+
reconstructed_img = decoder_model(encoded_img)
265+
266+
autoencoder_model = keras.Model(autoencoder_input, reconstructed_img)
267+
268+
self.encoder_model_ = BaseWrapper(encoder_model, verbose=self.verbose)
269+
self.decoder_model_ = BaseWrapper(decoder_model, verbose=self.verbose)
270+
271+
return autoencoder_model
272+
```
273+
274+
```python
275+
deep = DeepAutoEncoder(
276+
loss="binary_crossentropy",
277+
encoding_dim=32,
278+
hidden_layer_sizes=[128],
279+
random_state=0,
280+
epochs=5,
281+
verbose=False,
282+
optimizer="adam",
283+
)
284+
_ = deep.fit(X=x_train)
285+
```
286+
287+
```python
288+
print("1-MSE for training set (higher is better)\n")
289+
score = autoencoder.score(X=x_test)
290+
print(f"AutoEncoder: {score:.4f}")
291+
292+
score = deep.score(X=x_test)
293+
print(f"Deep AutoEncoder: {score:.4f}")
294+
```
295+
296+
Suprisingly, our score got worse. It's possible that that because of the extra trainable variables, our deep model trains slower than our simple model.
297+
298+
Check out the [Keras tutorial](https://blog.keras.io/building-autoencoders-in-keras.html) to see the difference after 100 epochs of training, as well as more architectures and applications for AutoEncoders!

docs/source/notebooks/Basic_Usage.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python

docs/source/notebooks/Benchmarks.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python

docs/source/notebooks/DataTransformers.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python

docs/source/notebooks/MLPClassifier_MLPRegressor.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python

docs/source/notebooks/Meta_Estimators.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ jupyter:
55
text_representation:
66
extension: .md
77
format_name: markdown
8-
format_version: '1.2'
9-
jupytext_version: 1.9.1
8+
format_version: '1.3'
9+
jupytext_version: 1.10.2
1010
kernelspec:
1111
display_name: Python 3
1212
language: python

scikeras/wrappers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,13 @@ def _build_keras_model(self):
403403
else:
404404
model = final_build_fn(**build_params)
405405

406-
# compile model if user gave us an un-compiled model
407-
if not (hasattr(model, "loss") and hasattr(model, "optimizer")):
408-
if compile_kwargs is None:
409-
compile_kwargs = self._get_compile_kwargs()
410-
model.compile(**compile_kwargs)
411-
412406
return model
413407

408+
def _ensure_compiled_model(self) -> None:
409+
# compile model if user gave us an un-compiled model
410+
if not (hasattr(self.model_, "loss") and hasattr(self.model_, "optimizer")):
411+
self.model_.compile(**self._get_compile_kwargs())
412+
414413
def _fit_keras_model(
415414
self,
416415
X: Union[np.ndarray, List[np.ndarray], Dict[str, np.ndarray]],
@@ -447,6 +446,7 @@ def _fit_keras_model(
447446
A reference to the instance that can be chain called
448447
(ex: instance.fit(X,y).transform(X) )
449448
"""
449+
450450
# Make sure model has a loss function
451451
loss = self.model_.loss
452452
no_loss = False
@@ -828,6 +828,7 @@ def _fit(
828828
X, y = self._initialize(X, y)
829829
else:
830830
X, y = self._validate_data(X, y)
831+
self._ensure_compiled_model()
831832

832833
if sample_weight is not None:
833834
X, sample_weight = self._validate_sample_weight(X, sample_weight)

0 commit comments

Comments
 (0)