Variational Auto-Encoder with Discrete Latent Variables

In this tutorial, I’ll share my top 10 tips for getting started with Academic:

Tip 1

Tip 2

Data Distribution

$$ q(\mathbf{x}) $$

def normalize(image):
    
    image = tf.expand_dims(image, axis=-1)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    return image
>>> (x_train, y_train), (x_test, y_test) = mnist.load_data()
>>> x_train_dataset = tf.data.Dataset.from_tensor_slices(x_train) \
...                                  .map(normalize) \
...                                  .shuffle(buffer_size=buffer_size) \
...                                  .batch(batch_size) \
...                                  .repeat(num_epochs)
>>> x_train_iterator = x_train_dataset.make_one_shot_iterator()
>>> x_data_sample = x_train_iterator.get_next()

Prior Distribution

$$ p(\mathbf{y}) = \mathrm{BinConcrete} \left (\frac{1}{2}, \tau \right ) $$

$$ s(\mathbf{z}) = \mathrm{Logistic} \left (0, \frac{1}{\tau} \right ) $$

prior = tfp.distributions.Independent(
    tfp.distributions.Logistic(loc=tf.zeros(latent_dim), scale=1.0/temperature),
    reinterpreted_batch_ndims=1
)

Conditional Distributions

Likelihood

$$ p_{\theta}(\mathbf{x} \mid \mathbf{z}) = \mathrm{Bern}(\mathbf{x} \mid \mathcal{F}_{\theta}(\mathbf{z})) $$

def make_bernoulli(fn):

    def bernoulli(u):

        logits = fn(u)

        dist = tfp.distributions.Bernoulli(logits=logits)
        return tfp.distributions.Independent(dist, reinterpreted_batch_ndims=3)

    return bernoulli
>>> generative_network = Sequential([
...     Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
...     Dense(np.prod(observed_shape)),
...     Reshape(observed_shape),
... ])
>>> likelihood = make_bernoulli(generative_network)

Variational Posterior

$$ q_{\phi}(\mathbf{z} \mid \mathbf{x}) = \mathrm{BernRelaxed}(\mathbf{z} \mid \mathcal{G}_{\phi}(\mathbf{x}), \tau) $$

def make_logistic(fn, temperature=0.5):

    def logistic(u):

        logits = fn(u)

        dist = tfp.distributions.Logistic(loc=logits/temperature,
                                          scale=1.0/temperature)
        return tfp.distributions.Independent(dist, reinterpreted_batch_ndims=1)

    return logistic
>>> inference_network = Sequential([
...     Flatten(input_shape=observed_shape),
...     Dense(intermediate_dim, activation='relu'),
...     Dense(latent_dim)
... ])
>>> posterior = make_logistic(inference_network)

$$ \mathbf{z}^{(1)}, \dotsc, \mathbf{z}^{(M)} \sim q_{\phi}(\mathbf{z} \mid \mathbf{x}) $$

z_posterior_sample = posterior(x_data_sample).sample()
y_posterior_sample = tf.sigmoid(z_posterior_sample)

$$ \log p_{\theta}(\mathbf{x}^{(j)} \mid \mathbf{z}^{(i)}) $$

ell_local = likelihood(y_posterior_sample).log_prob(x_data_sample)

$$\log q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)}) - \log p(\mathbf{z}^{(i)})$$

kl_local = (posterior(x_data_sample).log_prob(z_posterior_sample) -
            prior.log_prob(z_posterior_sample))

$$ \log \sum_{i=1}^M e^{\log p_{\theta}(\mathbf{x}^{(j)} \mid \mathbf{z}^{(i)}) + \log p(\mathbf{z}^{(i)}) - \log q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})} - \log M = \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}^{(j)}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})} $$

elbo_local = ell_local - kl_local
elbo = tf.reduce_mean(elbo_local)
loss = - elbo

$$ \begin{align} \mathcal{L}_M(\theta, \phi) & := \frac{1}{N} \sum_{j=1}^N \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}^{(j)}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})} \newline & \approx \mathbb{E}_{q(\mathbf{x})} \left [ \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x})} \right ] \end{align} $$

prior = tfp.distributions.Bernoulli(logits=tf.zeros(latent_dim))
Univariate Gaussian densities.

Univariate Gaussian densities.

Univariate Gaussian densities.

Univariate Gaussian densities.