Skip to content

checkerboard artifacts by using tfpl.IndependentNormal in decoder #593

Open
@mastaer

Description

@mastaer

TF: 2.0.0
TFP version: 0.8.0

Hi,
I want to write a VAE with Tensorflow-Probability. If I use tfpl.IndependentNormal at the end of the decoder, I get checkerboard artifacts. If I use instead tfd.Independent(tfd.Normal(...)) it works fine.

To show you what I mean, you can find the code here:

import tensorflow as tf
from tensorflow.keras import layers as tfl
import numpy as np
from tensorflow_probability import layers as tfpl
from tensorflow_probability import distributions as tfd
import matplotlib.pyplot as plt

# basic model
decoder = tf.keras.models.Sequential()
decoder.add(tfl.InputLayer(input_shape=[10]))
decoder.add(tfl.Reshape([1, 1, 10]))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))
decoder.add(tfl.Conv2D(100,(3,3),activation='selu',padding='same'))
decoder.add(tfl.UpSampling2D((2,2)))

plt.figure(figsize=(17,17))

# test input
input_values = np.array(np.random.random((1,10)),dtype=np.float32)

# 1. version: Pure TF.Conv-Layer
decoder1 = tf.keras.models.Sequential(decoder)
decoder1.add(tfl.Conv2D(1,(3,3),activation='selu',padding='same'))
plt.subplot(1,4,1)
plt.imshow(decoder1(input_values)[0,:,:,0])
plt.title('Pure TF.Conv-Layer')

# 2. version: Using tfpl.IndependentNormal
decoder2 = tf.keras.models.Sequential(decoder)
decoder2.add(tfl.Conv2D(2,(3,3),padding='same'))
decoder2.add(tfl.Flatten())
decoder2.add(tfpl.IndependentNormal((32,32,1)))
plt.subplot(1,4,2)
plt.imshow(decoder2(input_values).mean()[0,:,:,0])
plt.title('tfpl.IndependentNormal')

# 3. version: Using tfd.Independent(tfd.Normal(...))
plt.subplot(1,4,3)
plt.imshow(tfd.Independent(tfd.Normal(decoder1(input_values),decoder1(input_values)), 2).mean()[0,:,:,0])
plt.title('tfd.Independent(tfd.Normal(...))')

# 4. version: Using tfd.Independent(tfd.Normal(...)) in tfpl.DistributionLambda
def IndependentConvNormal():
    return tfpl.DistributionLambda(
            make_distribution_fn=lambda t:
                tfd.Independent(
                    tfd.Normal(
                            loc=t[...,:1],
                            scale=tf.exp(t[...,1:]))))
decoder3 = tf.keras.models.Sequential(decoder)
decoder3.add(tfl.Conv2D(2,(3,3),padding='same'))
decoder3.add(IndependentConvNormal())
plt.subplot(1,4,4)
plt.imshow(decoder3(input_values).mean()[0,:,:,0])
plt.title('tfd.Independent(tfd.Normal(...))\nin tfpl.DistributionLambda')

plt.show()

Screenshot from 2019-10-09 14-42-50

Thanks for your help! :)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions