# Variational Auto-Encoder with Discrete Latent Variables

In this tutorial, I’ll share my top 10 tips for getting started with Academic:

## Data Distribution

$$q(\mathbf{x})$$

def normalize(image):

image = tf.expand_dims(image, axis=-1)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)

return image

>>> (x_train, y_train), (x_test, y_test) = mnist.load_data()
>>> x_train_dataset = tf.data.Dataset.from_tensor_slices(x_train) \
...                                  .map(normalize) \
...                                  .shuffle(buffer_size=buffer_size) \
...                                  .batch(batch_size) \
...                                  .repeat(num_epochs)
>>> x_train_iterator = x_train_dataset.make_one_shot_iterator()
>>> x_data_sample = x_train_iterator.get_next()


## Prior Distribution

$$p(\mathbf{y}) = \mathrm{BinConcrete} \left (\frac{1}{2}, \tau \right )$$

$$s(\mathbf{z}) = \mathrm{Logistic} \left (0, \frac{1}{\tau} \right )$$

prior = tfp.distributions.Independent(
tfp.distributions.Logistic(loc=tf.zeros(latent_dim), scale=1.0/temperature),
reinterpreted_batch_ndims=1
)


## Conditional Distributions

### Likelihood

$$p_{\theta}(\mathbf{x} \mid \mathbf{z}) = \mathrm{Bern}(\mathbf{x} \mid \mathcal{F}_{\theta}(\mathbf{z}))$$

def make_bernoulli(fn):

def bernoulli(u):

logits = fn(u)

dist = tfp.distributions.Bernoulli(logits=logits)
return tfp.distributions.Independent(dist, reinterpreted_batch_ndims=3)

return bernoulli

>>> generative_network = Sequential([
...     Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
...     Dense(np.prod(observed_shape)),
...     Reshape(observed_shape),
... ])

>>> likelihood = make_bernoulli(generative_network)


### Variational Posterior

$$q_{\phi}(\mathbf{z} \mid \mathbf{x}) = \mathrm{BernRelaxed}(\mathbf{z} \mid \mathcal{G}_{\phi}(\mathbf{x}), \tau)$$

def make_logistic(fn, temperature=0.5):

def logistic(u):

logits = fn(u)

dist = tfp.distributions.Logistic(loc=logits/temperature,
scale=1.0/temperature)
return tfp.distributions.Independent(dist, reinterpreted_batch_ndims=1)

return logistic

>>> inference_network = Sequential([
...     Flatten(input_shape=observed_shape),
...     Dense(intermediate_dim, activation='relu'),
...     Dense(latent_dim)
... ])

>>> posterior = make_logistic(inference_network)


$$\mathbf{z}^{(1)}, \dotsc, \mathbf{z}^{(M)} \sim q_{\phi}(\mathbf{z} \mid \mathbf{x})$$

z_posterior_sample = posterior(x_data_sample).sample()

y_posterior_sample = tf.sigmoid(z_posterior_sample)


$$\log p_{\theta}(\mathbf{x}^{(j)} \mid \mathbf{z}^{(i)})$$

ell_local = likelihood(y_posterior_sample).log_prob(x_data_sample)


$$\log q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)}) - \log p(\mathbf{z}^{(i)})$$

kl_local = (posterior(x_data_sample).log_prob(z_posterior_sample) -
prior.log_prob(z_posterior_sample))


$$\log \sum_{i=1}^M e^{\log p_{\theta}(\mathbf{x}^{(j)} \mid \mathbf{z}^{(i)}) + \log p(\mathbf{z}^{(i)}) - \log q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})} - \log M = \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}^{(j)}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})}$$

elbo_local = ell_local - kl_local
elbo = tf.reduce_mean(elbo_local)
loss = - elbo


\begin{align} \mathcal{L}_M(\theta, \phi) & := \frac{1}{N} \sum_{j=1}^N \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}^{(j)}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x}^{(j)})} \newline & \approx \mathbb{E}_{q(\mathbf{x})} \left [ \log \frac{1}{M} \sum_{i=1}^M \frac{p_{\theta}(\mathbf{x}, \mathbf{z}^{(i)})}{q_{\phi}(\mathbf{z}^{(i)} \mid \mathbf{x})} \right ] \end{align}

prior = tfp.distributions.Bernoulli(logits=tf.zeros(latent_dim))