<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Generative Models |</title><link>https://tiao.io/tags/generative-models/</link><atom:link href="https://tiao.io/tags/generative-models/index.xml" rel="self" type="application/rss+xml"/><description>Generative Models</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Sat, 15 Jun 2019 13:00:00 +0000</lastBuildDate><image><url>https://tiao.io/media/icon_hu_9c2a75fde2335590.png</url><title>Generative Models</title><link>https://tiao.io/tags/generative-models/</link></image><item><title>Tech Talk: Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/events/amazon-ml-tech-talk-2019/</link><pubDate>Sat, 15 Jun 2019 13:00:00 +0000</pubDate><guid>https://tiao.io/events/amazon-ml-tech-talk-2019/</guid><description/></item><item><title>Density Ratio Estimation for KL Divergence Minimization between Implicit Distributions</title><link>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</link><pubDate>Mon, 27 Aug 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</guid><description>&lt;!-- TODO: Clarify that optimal classifier refers to the classifier that minimizes the Bayes risk --&gt;
&lt;p&gt;The Kullback-Leibler (KL) divergence between distributions $p$ and $q$ is
defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;It can be expressed more succinctly as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] = \mathbb{E}_{p(x)} [ \log r^{*}(x) ],
$$&lt;p&gt;where $r^{*}(x)$ is defined to be the ratio of between the densities $p(x)$ and
$q(x)$,&lt;/p&gt;
$$
r^{*}(x) := \frac{p(x)}{q(x)}.
$$&lt;p&gt;This density ratio is crucial for computing not only the KL divergence but for
all $f$-divergences, defined as&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)] :=
\mathbb{E}_{q(x)} \left [ f \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;Rarely can this expectation (i.e. integral) can be calculated analytically&amp;mdash;in
most cases, we must resort to Monte Carlo approximation methods, which
explicitly requires the density ratio.
In the more severe case where this density ratio is unavailable, because either
or both $p(x)$ and $q(x)$ are not calculable, we must resort to methods for
&lt;em&gt;density ratio estimation&lt;/em&gt;.
In this post, we illustrate how to perform density ratio estimation by
exploiting its tight correspondence to &lt;em&gt;probabilistic classification&lt;/em&gt;.&lt;/p&gt;
&lt;h3 id="example-univariate-gaussians"&gt;Example: Univariate Gaussians&lt;/h3&gt;
&lt;p&gt;Let us consider the following univariate Gaussian distributions as the running
example for this post,&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid 1, 1^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid 0, 2^2).
$$&lt;p&gt;We will be using &lt;em&gt;TensorFlow&lt;/em&gt;, &lt;em&gt;TensorFlow Probability&lt;/em&gt;, and &lt;em&gt;Keras&lt;/em&gt; in the
code snippets throughout this post.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow_probability&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We first instantiate the distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Their densities are shown below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Univariate Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_densities.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;For any pair of distributions, we can implement their density ratio function $r$
as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let&amp;rsquo;s create the density ratio function for the Gaussian distributions we just
instantiated:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This density ratio function is plotted as the orange dotted line below,
alongside the individual densities shown in the previous plot:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Ratio of Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_density_ratios.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h2 id="analytical-form"&gt;Analytical Form&lt;/h2&gt;
&lt;p&gt;For our running example, we picked $p(x)$ and $q(x)$ to be Gaussians so that
it is possible to integrate out $x$ and compute the KL divergence &lt;em&gt;analytically&lt;/em&gt;.
When we introduce the approximate methods later, this will provide us a &amp;ldquo;gold
standard&amp;rdquo; to benchmark against.&lt;/p&gt;
&lt;p&gt;In general, for Gaussian distributions&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid \mu_p, \sigma_p^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid \mu_q, \sigma_q^2),
$$&lt;p&gt;
it is easy to verify that
&lt;/p&gt;
$$
\mathrm{KL}[ p(x) || q(x) ]
= \log \sigma_q - \log \sigma_p - \frac{1}{2}
\left [
1 - \left ( \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{\sigma_q^2} \right )
\right ].
$$&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can use this to compute the KL divergence between $p(x)$ and $q(x)$
&lt;em&gt;exactly&lt;/em&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Equivalently, we could also use &lt;code&gt;kl_divergence&lt;/code&gt; from &lt;em&gt;TensorFlow
Probability&amp;ndash;Distributions&lt;/em&gt; (&lt;code&gt;tfp.distributions&lt;/code&gt;), which implements the
analytical closed-form expression of the KL divergence between distributions
when such exists.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="monte-carlo-estimation--prescribed-distributions"&gt;Monte Carlo Estimation &amp;mdash; prescribed distributions&lt;/h2&gt;
&lt;p&gt;For distributions where their KL divergence is not analytically tractable, we
may appeal to Monte Carlo (MC) estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;p&gt;Clearly, this requires the density ratio $r^{*}(x)$ and, in turn, the densities
$p(x)$ and $q(x)$ to be analytically tractable. Distributions for which the
density function can be readily evaluated are sometimes referred to as
&lt;strong&gt;prescribed distributions&lt;/strong&gt;. As before, we &lt;em&gt;prescribed&lt;/em&gt; Gaussians distributions
in our running example so the Monte Carlo estimate can be later compared against.
We approximate their KL divergence using $M = 5000$ Monte Carlo samples as
follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;true_log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44670376&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Or equivalently, using the &lt;code&gt;expectation&lt;/code&gt; function from &lt;em&gt;TensorFlow
Probability&amp;ndash;Monte Carlo&lt;/em&gt; (&lt;code&gt;tfp.monte_carlo&lt;/code&gt;):&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4581419&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;More generally, we can approximate any $f$-divergence with MC estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathbb{E}_{q(x)} [ f(r^{*}(x)) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} f(r^{*}(x_q^{(i)})),
\quad x_q^{(i)} \sim q(x).
\end{align*}
$$&lt;p&gt;This can be done using the &lt;code&gt;monte_carlo_csiszar_f_divergence&lt;/code&gt; function from
&lt;em&gt;TensorFlow Probability&amp;ndash;Variational Inference&lt;/em&gt; (&lt;code&gt;tfp.vi&lt;/code&gt;).
One simply needs to specify the appropriate convex function $f$.
The convex function that instantiates the (forward) KL divergence is provided
in &lt;code&gt;tfp.vi&lt;/code&gt; as &lt;code&gt;kl_forward&lt;/code&gt;, alongside many other common $f$-divergences.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_forward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4430853&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="density-ratio-estimation--implicit-distributions"&gt;Density Ratio Estimation &amp;mdash; implicit distributions&lt;/h2&gt;
&lt;p&gt;When either density $p(x)$ or $q(x)$ is unavailable, things become more tricky.
Which brings us to the topic of this post. Suppose we only have samples from
$p(x)$ and $q(x)$&amp;mdash;these could be natural images, outputs from a neural
network with stochastic inputs, or in the case of our running example, i.i.d.
samples drawn from Gaussians, etc.
Distributions for which we are only able to observe their samples are known as
&lt;strong&gt;implicit distributions&lt;/strong&gt;, since their samples &lt;em&gt;imply&lt;/em&gt; some underlying true
density which we may not have direct access to.&lt;/p&gt;
&lt;p&gt;Density ratio estimation is concerned with estimating the ratio of densities
$r^{*}(x) = p(x) / q(x)$ given access only to samples from $p(x)$ and $q(x)$.
Moreover, density ratio estimation usually encompass methods that achieve this
without resorting to direct &lt;em&gt;density estimation&lt;/em&gt; of the individual densities
$p(x)$ or $q(x)$, since any error in the estimation of the denominator $q(x)$
is magnified exponentially.&lt;/p&gt;
&lt;p&gt;Of the many density ratio estimation methods that now
flourish&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, the classical approach of &lt;em&gt;probabilistic
classification&lt;/em&gt; remains dominant, due in no small part to its simplicity.&lt;/p&gt;
&lt;h3 id="reducing-density-ratio-estimation-to-probabilistic-classification"&gt;Reducing Density Ratio Estimation to Probabilistic Classification&lt;/h3&gt;
&lt;p&gt;We now demonstrate that density ratio estimation can be reduced to probabilistic
classification. We shall do this by highlighting the one-to-one correspondence
between the density ratio of $p(x)$ and $q(x)$ and the optimal probabilistic
classifier that discriminates between their samples.
Specifically, suppose we have a collection of samples from both $p(x)$ and $q(x)$,
where each sample is assigned a class label indicating which distribution it was
drawn from. Then, from an estimator of the class-membership probabilities, it is
straightforward to recover an estimator of the density ratio.&lt;/p&gt;
&lt;p&gt;Suppose we have $N_p$ and $N_q$ samples drawn from $p(x)$ and $q(x)$,
respectively,&lt;/p&gt;
$$
x_p^{(1)}, \dotsc, x_p^{(N_p)} \sim p(x),
\qquad \text{and} \qquad
x_q^{(1)}, \dotsc, x_q^{(N_q)} \sim q(x).
$$&lt;p&gt;Then, we form the dataset $\{ (x_n, y_n) \}_{n=1}^N$, where $N = N_p + N_q$
and&lt;/p&gt;
$$
\begin{align*}
(x_1, \dotsc, x_N) &amp; = (x_p^{(1)}, \dotsc, x_p^{(N_p)},
x_q^{(1)}, \dotsc, x_q^{(N_q)}), \newline
(y_1, \dotsc, y_N) &amp; = (\underbrace{1, \dotsc, 1}_{N_p},
\underbrace{0, \dotsc, 0}_{N_q}).
\end{align*}
$$&lt;p&gt;In other words, we label samples drawn from $p(x)$ as 1 and those drawn from
$q(x)$ as 0. In code, this looks like:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This dataset is visualized below. The blue squares in the top row are samples
$x_p^{(i)} \sim p(x)$ with label 1; red squares in the bottom row are samples
$x_q^{(j)} \sim q(x)$ with label 0.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Classification dataset"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/dataset.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Now, by construction, we have&lt;/p&gt;
$$
p(x) = \mathcal{P}(x \mid y = 1),
\qquad
\text{and}
\qquad
q(x) = \mathcal{P}(x \mid y = 0).
$$&lt;p&gt;Using Bayes&amp;rsquo; rule, we can write&lt;/p&gt;
$$
\mathcal{P}(x \mid y) =
\frac{\mathcal{P}(y \mid x) \mathcal{P}(x)}
{\mathcal{P}(y)}.
$$&lt;p&gt;Hence, we can express the density ratio $r^{*}(x)$ as&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) &amp; = \frac{p(x)}{q(x)}
= \frac{\mathcal{P}(x \mid y = 1)}
{\mathcal{P}(x \mid y = 0)} \newline
&amp; = \left ( \frac{\mathcal{P}(y = 1 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 1)} \right )
\left ( \frac{\mathcal{P}(y = 0 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 0)} \right ) ^ {-1} \newline
&amp; = \frac{\mathcal{P}(y = 0)}{\mathcal{P}(y = 1)}
\frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;p&gt;Let us approximate the ratio of marginal densities by the ratio of sample sizes,&lt;/p&gt;
$$
\frac{\mathcal{P}(y = 0)}
{\mathcal{P}(y = 1)}
\approx
\frac{N_q}{N_p + N_q}
\left ( \frac{N_p}{N_p + N_q} \right )^{-1}
= \frac{N_q}{N_p}.
$$&lt;p&gt;To avoid notational clutter, let us assume from now on that $N_q = N_p$.
We can then write $r^{*}(x)$ in terms of class-posterior probabilities,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;h4 id="recovering-the-density-ratio-from-the-class-probability"&gt;Recovering the Density Ratio from the Class Probability&lt;/h4&gt;
&lt;p&gt;This yields a one-to-one correspondence between the density ratio $r^{*}(x)$
and the class-posterior probability $\mathcal{P}(y = 1 \mid x)$.
Namely,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}
&amp; = \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \newline
&amp; = \exp
\left [
\log \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \right ] \newline
&amp; = \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ],
\end{align*}
$$&lt;p&gt;where $\sigma^{-1}$ is the &lt;em&gt;logit&lt;/em&gt; function, or inverse sigmoid function, given
by $\sigma^{-1}(\rho) = \log \left ( \frac{\rho}{1-\rho} \right )$&lt;/p&gt;
&lt;h4 id="recovering-the-class-probability-from-the-density-ratio"&gt;Recovering the Class Probability from the Density Ratio&lt;/h4&gt;
&lt;p&gt;By simultaneously manipulating both sides of this equation, we can also recover
the exact class-posterior probability as a function of the density ratio,&lt;/p&gt;
$$
\mathcal{P}(y=1 \mid x) = \sigma(\log r^{*}(x)) = \frac{p(x)}{p(x) + q(x)}.
$$
&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;optimal_classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truediv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In the figure below, The class-posterior probability $\mathcal{P}(y=1 \mid x)$
is plotted against the dataset visualized earlier.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Optimal classifier&amp;mdash;class-posterior probabilities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/optimal_classifier.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="probabilistic-classification-with-logistic-regression"&gt;Probabilistic Classification with Logistic Regression&lt;/h3&gt;
&lt;p&gt;The class-posterior probability $\mathcal{P}(y = 1 \mid x)$ can be approximated
using a parameterized function $D_{\theta}(x)$ with parameters $\theta$. This
functions takes as input samples from $p(x)$ and $q(x)$ and outputs a &lt;em&gt;score&lt;/em&gt;,
or probability, in the range $[0, 1]$ that it was drawn from $p(x)$.
Hence, we refer to $D_{\theta}(x)$ as the probabilistic classifier.&lt;/p&gt;
&lt;p&gt;From before, it is clear to see how an estimator of the density ratio
$r_{\theta}(x)$ might be constructed as a function of probabilistic classifier
$D_{\theta}(x)$. Namely,&lt;/p&gt;
$$
\begin{align*}
r_{\theta}(x) &amp; = \exp[ \sigma^{-1}(D_{\theta}(x)) ] \newline
&amp; \approx \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ] = r^{*}(x),
\end{align*}
$$&lt;p&gt;
and &lt;em&gt;vice versa&lt;/em&gt;,
&lt;/p&gt;
$$
\begin{align*}
D_{\theta}(x) &amp; = \sigma(\log r_{\theta}(x)) \newline
&amp; \approx \sigma(\log r^{*}(x)) = \mathcal{P}(y = 1 \mid x).
\end{align*}
$$&lt;p&gt;Instead of $D_{\theta}(x)$, we usually specify the parameterized function
$\log r_{\theta}(x)$. This is also referred to as the &lt;em&gt;log-odds&lt;/em&gt;, or &lt;em&gt;logits&lt;/em&gt;,
since it is equivalent to the unnormalized output of the classifier before being
fed through the logistic sigmoid function.&lt;/p&gt;
&lt;p&gt;We define a small fully-connected neural network with two hidden layers and ReLU
activations:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This simple architecture is visualized in the diagram below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Architecture"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_ratio_architecture.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We learn the optimal class probability estimator by optimizing it with respect
to a &lt;em&gt;proper scoring rule&lt;/em&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; that yields well-calibrated probabilistic predictions, such as the &lt;em&gt;binary cross-entropy loss&lt;/em&gt;,&lt;/p&gt;
$$
\begin{align*}
\mathcal{L}(\theta) &amp; :=
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
-\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ]
-\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ].
\end{align*}
$$&lt;p&gt;An implementation optimized for numerical stability is given below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now we can build a
, where the
&amp;mdash;samples from
$p(x)$ and $q(x)$, respectively.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The model can now be compiled and finalized. Since we&amp;rsquo;re using a custom loss
that take the two sets of log-ratios as input, we specify &lt;code&gt;loss=None&lt;/code&gt; and
define it instead through the &lt;code&gt;add_loss&lt;/code&gt; method.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As a sanity-check, the loss evaluated on a random batch can be obtained like so:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;1.3765026330947876&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can now fit our estimator, recording the loss at the end of each epoch:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;hist&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps_per_epoch&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The following animation shows how the predictions for the probabilistic
classifier, density ratio, log density ratio, evolve after every epoch:&lt;/p&gt;
&lt;p&gt;&lt;video controls autoplay src="https://giant.gfycat.com/FrighteningThunderousFlicker.webm"&gt;&lt;/video&gt;&lt;/p&gt;
&lt;p&gt;It is overlaid on top of their exact, analytical counterparts, which are only
available since we prescribed them to be Gaussian distribution.
For implicit distributions, these won&amp;rsquo;t be accessible at all.&lt;/p&gt;
&lt;p&gt;Below is the final plot of how the binary cross-entropy loss converges:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary Cross-entropy Loss"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the probabilistic classifier $D_{\theta}(x)$ (&lt;em&gt;dotted green&lt;/em&gt;),
plotted against the optimal classifier, which is the class-posterior probability
$\mathcal{P}(y=1 \mid x) = \frac{p(x)}{p(x) + q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Class Probability Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/class_probability_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the density ratio estimator $r_{\theta}(x)$
(&lt;em&gt;dotted green&lt;/em&gt;), plotted against the exact density ratio function
$r^{*}(x) = \frac{p(x)}{q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;And finally, the previous plot in logarithmic scale:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;While it may appear that we are simply performing regression on the latent
function $r^{*}(x)$ (which is not wrong&amp;mdash;we are), it is important to emphasize that
we do this without ever having observed values of $r^{*}(x)$.
Instead, we only ever observed samples from $p(x)$ and $q(x)$
This has profound implications and potential for a great number of applications
that we shall explore later on.&lt;/p&gt;
&lt;h3 id="back-to-monte-carlo-estimation"&gt;Back to Monte Carlo estimation&lt;/h3&gt;
&lt;p&gt;Having an obtained an estimate of the log density ratio, it is now feasible to
perform Monte Carlo estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x) \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r_{\theta}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4570999&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In other words, we draw MC samples from $p(x)$ as before. But instead of taking
the mean of the function $\log r^{*}(x)$ evaluated on these samples (which is
unavailable for implicit distributions), we do so on a proxy function
$\log r_{\theta}(x)$ that is estimated through probabilistic classification as
described above.&lt;/p&gt;
&lt;h2 id="learning-in-implicit-generative-models"&gt;Learning in Implicit Generative Models&lt;/h2&gt;
&lt;p&gt;Now let&amp;rsquo;s take a look at where these ideas are being used in practice.
Consider a collection of natural images, such as the MNIST handwritten
digits shown below, which are assumed to be samples drawn from some implicit
distribution $q(\mathbf{x})$:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/MnistExamples.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;MNIST hand-written digits&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Directly estimating the density of $q(\mathbf{x})$ may not always be feasible&amp;mdash;in
some cases, it may not even exist.
Instead, consider defining a parametric function $G_{\phi}: \mathbf{z} \mapsto
\mathbf{x}$ with parameters $\phi$, that takes as input $\mathbf{z}$ drawn from
some fixed distribution $p(\mathbf{z})$.
The outputs $\mathbf{x}$ of this generative process are assumed to be samples
following some implicit distribution $p_{\phi}(\mathbf{x})$. In other words,
we can write&lt;/p&gt;
$$
\mathbf{x} \sim p_{\phi}(\mathbf{x}) \quad
\Leftrightarrow \quad
\mathbf{x} = G_{\phi}(\mathbf{z}),
\quad \mathbf{z} \sim p(\mathbf{z}).
$$&lt;p&gt;By optimizing parameters $\phi$, we can make $p_{\phi}(\mathbf{x})$ close to
the real data distribution $q(\mathbf{x})$. This is a compelling alternative to
density estimation since there are many situations where being able to generate
samples is more important than being able to calculate the numerical value of
the density. Some examples of these include &lt;em&gt;image super-resolution&lt;/em&gt; and
&lt;em&gt;semantic segmentation&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;One approach might be to introduce a classifier $D_{\theta}$ that discriminates
between real and synthetic samples.
Then we optimize $G_{\phi}$ to synthesize samples that are indistinguishable,
to classifier $D_{\theta}$, from the real samples. This can be achieved by
simultaneously optimizing the binary cross-entropy loss, resulting in the
saddle-point objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p_{\phi}(\mathbf{x})} [ \log(1-D_{\theta} (\mathbf{x})) ] \newline =
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ].
\end{align*}
$$&lt;p&gt;This is, of course, none other than the groundbreaking &lt;em&gt;generative adversarial
network (GAN)&lt;/em&gt;&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.
You can read more about the density ratio estimation perspective of GANs in
the paper by Uehara et al. 2016&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;. For an even more general and complete treatment of learning in implicit models, I recommend the paper
from Mohamed and Lakshminarayanan, 2016&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;, which partially inspired this post.&lt;/p&gt;
&lt;p&gt;For the remainder of this section, I want to highlight a variant of this
approach that specifically aims to minimize the KL divergence w.r.t. parameters
$\phi$,&lt;/p&gt;
$$
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})].
$$&lt;p&gt;To overcome the fact that the densities of both $p_{\phi}(\mathbf{x})$ and
$q(\mathbf{x})$ are unknown, we can readily adopt the density ratio estimation
approach outlined in this post.
Namely, by maximizing the following objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ] \newline
= &amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log \sigma ( \log r_{\theta} (\mathbf{x}) ) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1 - \sigma ( \log r_{\theta} (G_{\phi}(\mathbf{z})) )) ],
\end{align*}
$$&lt;p&gt;which attains its maximum at&lt;/p&gt;
$$
r_{\theta}(\mathbf{x}) = \frac{q(\mathbf{x})}{p_{\phi}(\mathbf{x})}.
$$&lt;p&gt;Concurrently, we also minimize the current best estimate of the KL divergence,&lt;/p&gt;
$$
\begin{align*}
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})]
&amp; =
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} \left [ \log \frac{p_{\phi}(\mathbf{x})}{q(\mathbf{x})} \right ] \newline
&amp; \approx
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} [ - \log r_{\theta}(\mathbf{x}) ] \newline
&amp; =
\min_{\phi} \mathbb{E}_{p(\mathbf{z})} [ - \log r_{\theta}(G_{\phi}(\mathbf{z})) ].
\end{align*}
$$&lt;p&gt;In addition to being more stable than the vanilla GAN approach (alleviates
saturating gradients), this is especially important in contexts where there is
a specific need to minimize the KL divergence, such as in &lt;em&gt;variational inference
(VI)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;This was first used in &lt;em&gt;AffGAN&lt;/em&gt; by Sønderby et al. 2016&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;,
and has since been incorporated in many papers that deal with implicit
distributions in variational inference, such as
(Mescheder et al. 2017&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;,
Huszar 2017&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;,
Tran et al. 2017&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;,
Pu et al. 2017&lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;,
Chen et al. 2018&lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;,
Tiao et al. 2018&lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt;), and many others.&lt;/p&gt;
&lt;h2 id="bound-on-the-jensen-shannon-divergence"&gt;Bound on the Jensen-Shannon Divergence&lt;/h2&gt;
&lt;p&gt;Before we wrap things up, let us take another look at the plot of the
binary-cross entropy loss recorded at the end of each epoch.
We see that it converges quickly to some value.
It is natural to wonder: what is the significance, if any, of this value?&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary cross-entropy loss converges to Jensen Shannon divergence (up to constants)"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy_vs_jensen_shannon.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;It is in fact the (negative) Jensen-Shannon (JS) divergence, up to constants,&lt;/p&gt;
$$
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Recall the Jensen-Shannon divergence is defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]
= \frac{1}{2} \mathcal{D}_{\mathrm{KL}}[p(x) || m(x)] +
\frac{1}{2} \mathcal{D}_{\mathrm{KL}}[q(x) || m(x)],
$$&lt;p&gt;where $m$ is the mixture density&lt;/p&gt;
$$
m(x) = \frac{p(x) + q(x)}{2}.
$$&lt;p&gt;With our running example, this cannot be evaluated exactly since the KL
divergence between a Gaussian and a mixture of Gaussians is analytically
intractable.
However, like the KL, we can still estimate their JS divergence with Monte
Carlo estimation&lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;js&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jensen_shannon&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This value is shown in the horizontal black line in the plot above. Along the
right margin, we also plot the a histogram of the binary cross-entropy loss
values over epochs. We can see that this value indeed coincides with the mode of
this histogram.&lt;/p&gt;
&lt;p&gt;It is straightforward to show that we have the upper bound&lt;/p&gt;
$$
\inf_{\theta} \mathcal{L}(\theta) \geq - 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Firstly, we have&lt;/p&gt;
$$
\begin{align*}
\sup_{\theta} &amp;
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
\mathbb{E}_{p(x)} [ \log \mathcal{P}(y=1 \mid x) ] +
\mathbb{E}_{q(x)} [ \log \mathcal{P}(y=0 \mid x) ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{p(x) + q(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{p(x) + q(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{1}{2} \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{1}{2} \frac{q(x)}{m(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{m(x)} \right ] - 2 \log 2 \newline
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4.
\end{align*}
$$&lt;p&gt;Therefore,&lt;/p&gt;
$$
2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
$$&lt;p&gt;Negating both sides, we get&lt;/p&gt;
$$
\begin{align*}
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4
\leq &amp;
-\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta}
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta} \mathcal{L}(\theta),
\end{align*}
$$&lt;p&gt;as required.&lt;/p&gt;
&lt;p&gt;In short, this tells us that the binary cross-entropy loss is &lt;em&gt;itself&lt;/em&gt; an
approximation (up to constants) to the Jensen-Shannon divergence.
This begs the question: is it possible to construct a more general loss that bounds any given $f$-divergence?&lt;/p&gt;
&lt;h2 id="teaser-lower-bound-on-any--divergence"&gt;Teaser: Lower Bound on any $f$-divergence&lt;/h2&gt;
&lt;p&gt;Using convex analysis, one can actually show that for any $f$-divergence, we
have the lower bound&lt;sup id="fnref:15"&gt;&lt;a href="#fn:15" class="footnote-ref" role="doc-noteref"&gt;15&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)]
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ f'(r_{\theta}(x)) ] -
\mathbb{E}_{q(x)} [ f^{\star}(f'(r_{\theta}(x))) ],
$$&lt;p&gt;with equality exactly when $r_{\theta}(x) = r^{*}(x)$.
Importantly, this lower bound can be computed without requiring the densities of
$p(x)$ or $q(x)$&amp;mdash;only their samples are needed.&lt;/p&gt;
&lt;p&gt;In the special case of $f(u) = u \log u - (u + 1) \log (u + 1)$, we recover the
binary cross-entropy loss and the previous result, as expected,&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4 \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ] +
\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ] \newline
&amp; = \sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
\end{align*}
$$&lt;p&gt;Alternately, in the special case of $f(u) = u \log u$, we get&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log r_{\theta} (x) ] -
\mathbb{E}_{q(x)} [ r_{\theta} (x) - 1 ].
\end{align*}
$$&lt;p&gt;This gives us &lt;em&gt;yet&lt;/em&gt; another way to estimate the KL divergence between
implicit distributions, in the form of a direct lower bound on the KL divergence
itself.
As it turns out, this lower bound is closely-related to the objective of the
&lt;em&gt;KL Importance Estimation Procedure (KLIEP)&lt;/em&gt;&lt;sup id="fnref:16"&gt;&lt;a href="#fn:16" class="footnote-ref" role="doc-noteref"&gt;16&lt;/a&gt;&lt;/sup&gt;, and will be
the topic of our next post in this series.&lt;/p&gt;
&lt;h1 id="summary"&gt;Summary&lt;/h1&gt;
&lt;p&gt;This post covered how to evaluate the KL divergence, or any $f$-divergence,
between implicit distributions&amp;mdash;distributions which we can only sample from.
First, we underscored the crucial role of the density ratio in the estimation of
$f$-divergences.
Next, we showed the correspondence between the density ratio and the optimal
classifier.
By exploiting this link, we demonstrated how one can use a trained probabilistic classifier to construct a proxy for the exact density ratio, and use this to
enable estimation of any $f$-divergence.
Finally, we provided some context on where this method is used, touching upon
some recent advances in implicit generative models and variational inference.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2018dre,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{D}ensity {R}atio {E}stimation for {KL} {D}ivergence {M}inimization between {I}mplicit {D}istributions&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2018&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h1 id="acknowledgements"&gt;Acknowledgements&lt;/h1&gt;
&lt;p&gt;I am grateful to
for providing
extensive feedback and insightful discussions. I would also like to thank
Alistair Reid and
for their comments and suggestions.&lt;/p&gt;
&lt;h1 id="links-and-resources"&gt;Links and Resources&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;The
used to generate the figures in this post, which you can
.&lt;/li&gt;
&lt;li&gt;The very readable textbook on
&lt;sup id="fnref1:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, which I highly recommend. (Note: the Gaussian distributions example was borrowed from this book.) &lt;a target="_blank" href="https://www.amazon.com/gp/product/0521190177/ref=as_li_tl?ie=UTF8&amp;camp=1789&amp;creative=9325&amp;creativeASIN=0521190177&amp;linkCode=as2&amp;tag=tiao03-20&amp;linkId=0907c42c1a834ffa68ca2f27c2bdb92f"&gt;&lt;img border="0" src="//ws-na.amazon-adsystem.com/widgets/q?_encoding=UTF8&amp;MarketPlace=US&amp;ASIN=0521190177&amp;ServiceVersion=20070822&amp;ID=AsinImage&amp;WS=1&amp;Format=_SL250_&amp;tag=tiao03-20" &gt;&lt;/a&gt;&lt;img src="//ir-na.amazon-adsystem.com/e/ir?t=tiao03-20&amp;l=am2&amp;o=1&amp;a=0521190177" width="1" height="1" border="0" alt="" style="border:none !important; margin:0px !important;" /&gt;&lt;/li&gt;
&lt;li&gt;Shakir Mohamed&amp;rsquo;s blog post
.&lt;/li&gt;
&lt;li&gt;The paper by Menon and Ong, 2016&lt;sup id="fnref:17"&gt;&lt;a href="#fn:17" class="footnote-ref" role="doc-noteref"&gt;17&lt;/a&gt;&lt;/sup&gt;, which gives a generalized treatment of the theoretical link between density ratio estimation and probabilistic classification.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;The (forward) KL divergence can be recovered with
&lt;/p&gt;
$$
f_{\mathrm{KL}}(u) := u \log u.
$$&lt;p&gt;
This is easy to verify,
&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] &amp; :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ \frac{p(x)}{q(x)} \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ f_{\mathrm{KL}} \left ( \frac{p(x)}{q(x)} \right ) \right ].
\end{align*}
$$&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Sugiyama, M., Suzuki, T., &amp;amp; Kanamori, T. (2012). &lt;em&gt;Density Ratio Estimation in Machine Learning&lt;/em&gt;. Cambridge University Press.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Gneiting, T., &amp;amp; Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. &lt;em&gt;Journal of the American Statistical Association&lt;/em&gt;, 102(477), (pp. 359-378).&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., &amp;hellip; &amp;amp; Bengio, Y. (2014). Generative Adversarial Nets. In Advances in &lt;em&gt;Neural Information Processing Systems&lt;/em&gt; (pp. 2672-2680).&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Uehara, M., Sato, I., Suzuki, M., Nakayama, K., &amp;amp; Matsuo, Y. (2016). Generative Adversarial Nets from a Density Ratio Estimation Perspective. &lt;em&gt;arXiv preprint arXiv:1610.02920&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;Mohamed, S., &amp;amp; Lakshminarayanan, B. (2016). Learning in Implicit Generative Models. &lt;em&gt;arXiv preprint arXiv:1610.03483&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;Sønderby, C. K., Caballero, J., Theis, L., Shi, W., &amp;amp; Huszár, F. (2016). Amortised map inference for image super-resolution. &lt;em&gt;arXiv preprint arXiv:1610.04490&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Mescheder, L., Nowozin, S., &amp;amp; Geiger, A. (2017). Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks. In &lt;em&gt;International Conference on Machine learning (ICML)&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;Huszár, F. (2017). Variational inference using implicit distributions. &lt;em&gt;arXiv preprint arXiv:1702.08235&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;Tran, D., Ranganath, R., &amp;amp; Blei, D. (2017). Hierarchical implicit models and likelihood-free variational inference. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 5523-5533).&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;Pu, Y., Wang, W., Henao, R., Chen, L., Gan, Z., Li, C., &amp;amp; Carin, L. (2017). Adversarial symmetric variational autoencoder. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 4330-4339).&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;Chen, L., Dai, S., Pu, Y., Zhou, E., Li, C., Su, Q., &amp;hellip; &amp;amp; Carin, L. (2018, March). Symmetric variational autoencoder and connections to adversarial learning. In &lt;em&gt;International Conference on Artificial Intelligence and Statistics&lt;/em&gt; (pp. 661-669).&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Tiao, L. C., Bonilla, E. V., &amp;amp; Ramos, F. (2018). Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference. &lt;em&gt;arXiv preprint arXiv:1806.01771&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;Note that &lt;code&gt;jensen_shannon&lt;/code&gt; with &lt;code&gt;self_normalized=False&lt;/code&gt; (default), corresponds to $2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4$, while &lt;code&gt;self_normalized=True&lt;/code&gt; corresponds to $\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]$.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:15"&gt;
&lt;p&gt;Nguyen, X., Wainwright, M. J., &amp;amp; Jordan, M. I. (2010). Estimating divergence functionals and the likelihood ratio by convex risk minimization. &lt;em&gt;IEEE Transactions on Information Theory&lt;/em&gt;, 56(11), 5847-5861.&amp;#160;&lt;a href="#fnref:15" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:16"&gt;
&lt;p&gt;Sugiyama, M., Nakajima, S., Kashima, H., Buenau, P. V., &amp;amp; Kawanabe, M. (2008). Direct importance estimation with model selection and its application to covariate shift adaptation. In Advances in neural information processing systems (pp. 1433-1440).&amp;#160;&lt;a href="#fnref:16" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:17"&gt;
&lt;p&gt;Menon, A., &amp;amp; Ong, C. S. (2016, June). Linking Losses for Density Ratio and Class-Probability Estimation. In &lt;em&gt;International Conference on Machine Learning&lt;/em&gt; (pp. 304-313).&amp;#160;&lt;a href="#fnref:17" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Contributed Talk: Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/events/icml2018-tagdm/</link><pubDate>Sat, 14 Jul 2018 15:20:00 +0000</pubDate><guid>https://tiao.io/events/icml2018-tagdm/</guid><description/></item><item><title>Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/publications/cycle-bayes/</link><pubDate>Sun, 01 Jul 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/cycle-bayes/</guid><description/></item><item><title>A Tutorial on Variational Autoencoders with a Concise Keras Implementation</title><link>https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/</link><pubDate>Wed, 20 Apr 2016 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/</guid><description>&lt;p&gt;
is awesome. It is a very well-designed library that clearly abides by
its
of modularity and extensibility, enabling us to
easily assemble powerful, complex models from primitive building blocks.
This has been demonstrated in numerous blog posts and tutorials, in particular,
the excellent tutorial on
.
As the name suggests, that tutorial provides examples of how to implement
various kinds of autoencoders in Keras, including the variational autoencoder
(VAE)&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Like all autoencoders, the variational autoencoder is primarily used for
unsupervised learning of hidden representations.
However, they are fundamentally different to your usual neural network-based
autoencoder in that they approach the problem from a probabilistic perspective.
They specify a joint distribution over the observed and latent variables, and
approximate the intractable posterior conditional density over latent
variables with variational inference, using an &lt;em&gt;inference network&lt;/em&gt;
&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; (or more classically, a &lt;em&gt;recognition model&lt;/em&gt;
&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;) to amortize the cost of inference.&lt;/p&gt;
&lt;p&gt;While the examples in the aforementioned tutorial do well to showcase the
versatility of Keras on a wide range of autoencoder model architectures,
doesn&amp;rsquo;t properly take
advantage of Keras&amp;rsquo; modular design, making it difficult to generalize and
extend in important ways. As we will see, it relies on implementing custom
layers and constructs that are restricted to a specific instance of
variational autoencoders. This is a shame because when combined, Keras&amp;rsquo;
building blocks are powerful enough to encapsulate most variants of the
variational autoencoder and more generally, recognition-generative model
combinations for which the generative model belongs to a large family of
&lt;em&gt;deep latent Gaussian models&lt;/em&gt; (DLGMs)&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The goal of this post is to propose a clean and elegant alternative
implementation that takes better advantage of Keras&amp;rsquo; modular design.
It is not intended as tutorial on variational autoencoders &lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;.
Rather, we study variational autoencoders as a special case of variational
inference in deep latent Gaussian models using inference networks, and
demonstrate how we can use Keras to implement them in a modular fashion such
that they can be easily adapted to approximate inference in tasks beyond
unsupervised learning, and with complicated (non-Gaussian) likelihoods.&lt;/p&gt;
&lt;p&gt;This first post will lay the groundwork for a series of future posts that
explore ways to extend this basic modular framework to implement the
cutting-edge methods proposed in the latest research, such as the normalizing
flows for building richer posterior approximations &lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;, importance
weighted autoencoders &lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;, the Gumbel-softmax trick for inference in
discrete latent variables &lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;, and even the most recent GAN-based
density-ratio estimation techniques for likelihood-free inference
&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h1 id="model-specification"&gt;Model specification&lt;/h1&gt;
&lt;p&gt;First, it is important to understand that the variational autoencoder
.
Rather, the generative model is a component of the variational autoencoder and
is, in general, a deep latent Gaussian model.
In particular, let $\mathbf{x}$ be a local observed variable and
$\mathbf{z}$ its corresponding local latent variable, with joint
distribution&lt;/p&gt;
$$
p_{\theta}(\mathbf{x}, \mathbf{z})
= p_{\theta}(\mathbf{x} | \mathbf{z}) p(\mathbf{z}).
$$&lt;p&gt;In Bayesian modelling, we assume the distribution of observed variables to be
governed by the latent variables. Latent variables are drawn from a prior
density $p(\mathbf{z})$ and related to the observations through the
likelihood $p_{\theta}(\mathbf{x} | \mathbf{z})$.
Deep latent Gaussian models (DLGMs) are a general class of models where the
observed variable is governed by a &lt;em&gt;hierarchy&lt;/em&gt; of latent variables, and the
latent variables at each level of the hierarchy are Gaussian &lt;em&gt;a priori&lt;/em&gt;
&lt;sup id="fnref1:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;In a typical instance of the variational autoencoder, we have only a single
layer of latent variables with a Normal prior distribution,&lt;/p&gt;
$$
p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I}).
$$&lt;p&gt;Now, each local latent variable is related to its corresponding observation
through the likelihood $p_{\theta}(\mathbf{x} | \mathbf{z})$, which can
be viewed as a &lt;em&gt;probabilistic&lt;/em&gt; decoder. Given a hidden lower-dimensional
representation (or &amp;ldquo;code&amp;rdquo;) $\mathbf{z}$, it &amp;ldquo;decodes&amp;rdquo; it into a
&lt;em&gt;distribution&lt;/em&gt; over the observation $\mathbf{x}$.&lt;/p&gt;
&lt;h2 id="decoder"&gt;Decoder&lt;/h2&gt;
&lt;p&gt;In this example, we define $p_{\theta}(\mathbf{x} | \mathbf{z})$ to be a
multivariate Bernoulli whose probabilities are computed from $\mathbf{z}$ using
a fully-connected neural network with a single hidden layer,&lt;/p&gt;
$$
\begin{align*}
p_{\theta}(\mathbf{x} | \mathbf{z})
&amp; = \mathrm{Bern}( \sigma( \mathbf{W}_2 \mathbf{h} + \mathbf{b}_2 ) ), \newline
\mathbf{h}
&amp; = h(\mathbf{W}_1 \mathbf{z} + \mathbf{b}_1),
\end{align*}
$$&lt;p&gt;where $\sigma$ is the logistic sigmoid function, $h$ is some non-linearity, and
the model parameters
$\theta = \{ \mathbf{W}_1, \mathbf{W}_2, \mathbf{b}_1, \mathbf{b}_2 \}$
consist of the weights and biases of this neural network.&lt;/p&gt;
&lt;p&gt;It is straightforward to implement this in Keras with the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;You can view a summary of the model parameters $\theta$ by calling
&lt;code&gt;decoder.summary()&lt;/code&gt;. Additionally, you can produce a high-level diagram of
the network architecture, and optionally the input and output shapes of each
layer using
from the
&lt;code&gt;keras.utils.vis_utils&lt;/code&gt; module. Although our architecture is about as
simple as it gets, it is included in the figure below as an example of what
the diagrams look like.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Decoder architecture"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/decoder.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Note that by fixing $\mathbf{W}_1$, $\mathbf{b}_1$ and $h$ to be the identity
matrix, the zero vector, and the identity function, respectively (or
equivalently dropping the first &lt;code&gt;Dense&lt;/code&gt; layer in the snippet above
altogether), we recover &lt;em&gt;logistic factor analysis&lt;/em&gt;.
With similarly minor modifications, we can recover other members from the
family of DLGMs, which include &lt;em&gt;non-linear factor analysis&lt;/em&gt;,
&lt;em&gt;non-linear Gaussian belief networks&lt;/em&gt;, &lt;em&gt;sigmoid belief networks&lt;/em&gt;, and many
others &lt;sup id="fnref2:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Having specified how the probabilities are computed, we can now define the
negative log likelihood of a Bernoulli $- \log p_{\theta}(\mathbf{x}|\mathbf{z})$, which is in fact equivalent to the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keras.losses.binary_crossentropy gives the mean&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# over the last axis. we require the sum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As we discuss later, this will not be the loss we ultimately minimize, but will
constitute the data-fitting term of our final loss.&lt;/p&gt;
&lt;p&gt;Note this is a valid definition of a
,
which is required to compile and optimize a model. It is a symbolic function
that returns a scalar for each data-point in &lt;code&gt;y_true&lt;/code&gt; and &lt;code&gt;y_pred&lt;/code&gt;.
In our example, &lt;code&gt;y_pred&lt;/code&gt; will be the output of our &lt;code&gt;decoder&lt;/code&gt; network, which
are the predicted probabilities, and &lt;code&gt;y_true&lt;/code&gt; will be the true probabilities.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-using-tensorflow-distributions-in-loss"&gt;Side note: Using TensorFlow Distributions in loss&lt;/h4&gt;
&lt;p&gt;If you are using the TensorFlow backend, you can directly use the (negative)
log probability of &lt;code&gt;Bernoulli&lt;/code&gt; from TensorFlow Distributions as a Keras
loss, as I demonstrate in my post on
.&lt;/p&gt;
&lt;p&gt;Specifically we can define the loss as,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bernoulli&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lh&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This is exactly equivalent to the previous definition, but does not call
&lt;code&gt;K.binary_crossentropy&lt;/code&gt; directly.&lt;/p&gt;
&lt;hr&gt;
&lt;h1 id="inference"&gt;Inference&lt;/h1&gt;
&lt;p&gt;Having specified the generative process, we would now like to perform inference
on the latent variables and model parameters $\mathbf{z}$ and $\theta$,
respectively.
In particular, our goal is to compute the posterior
$p_{\theta}(\mathbf{z} | \mathbf{x})$, the conditional density of the latent
variable $\mathbf{z}$ given observed variable $\mathbf{x}$.
Additionally, we wish to optimize the model parameters $\theta$ with respect to
the marginal likelihood $p_{\theta}(\mathbf{x})$.
Both depend on the marginal likelihood, whose calculation requires marginalizing
out the latent variables $\mathbf{z}$. In general, this is computational
intractable, requiring exponential time to compute, or it is analytically
intractable and cannot be evaluated in closed-form. In our case, we suffer from
the latter intractability, since our prior is Gaussian non-conjugate to the
Bernoulli likelihood.&lt;/p&gt;
&lt;p&gt;To circumvent this intractability we turn to &lt;em&gt;variational inference&lt;/em&gt;, which
formulates inference as an optimization problem. It seeks an approximate
posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$ closest in Kullback-Leibler
(KL) divergence to the true posterior. More precisely, the approximate posterior
is parameterized by &lt;em&gt;variational parameters&lt;/em&gt; $\phi$, and we seek a setting
of these parameters that minimizes the aforementioned KL divergence,&lt;/p&gt;
$$
\phi^* = \mathrm{argmin}_{\phi}
\mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z} | \mathbf{x}) ]
$$&lt;p&gt;With the luck we&amp;rsquo;ve had so far, it shouldn&amp;rsquo;t come as a surprise anymore that
&lt;em&gt;this too&lt;/em&gt; is intractable. It also depends on the log marginal likelihood,
whose intractability is the reason we appealed to approximate inference in the
first place. Instead, we &lt;em&gt;maximize&lt;/em&gt; an alternative objective function, the
&lt;em&gt;evidence lower bound&lt;/em&gt; (ELBO), which is expressed as&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELBO}(q)
&amp; =
\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [
\log p_{\theta}(\mathbf{x} | \mathbf{z}) +
\log p(\mathbf{z}) -
\log q_{\phi}(\mathbf{z} | \mathbf{x})
] \newline
&amp; =
\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}
[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) ]
-\mathrm{KL} [ q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ].
\end{align*}
$$&lt;p&gt;Importantly, the ELBO is a lower bound to the log marginal likelihood.
Therefore, maximizing it with respect to the model parameters $\theta$
approximately maximizes the log marginal likelihood.
Additionally, maximizing it with respect to variational parameters $\phi$ can
be shown to minimize
$\mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z} | \mathbf{x}) ]$.
Also, it turns out that the KL divergence determines the tightness of the lower
bound, where we have equality iff the KL divergence is zero, which happens iff
$q_{\phi}(\mathbf{z} | \mathbf{x}) = p_{\theta}(\mathbf{z} | \mathbf{x})$.
Hence, simultaneously maximizing it with respect to $\theta$ and $\phi$ gets us
two birds with one stone.&lt;/p&gt;
&lt;p&gt;Next we discuss the form of the approximate posterior
$q_{\phi}(\mathbf{z} | \mathbf{x})$, which can be viewed as a
&lt;em&gt;probabilistic&lt;/em&gt; encoder. Its role is opposite to that of the decoder.
Given an observation $\mathbf{x}$, it &amp;ldquo;encodes&amp;rdquo; it into a &lt;em&gt;distribution&lt;/em&gt;
over its hidden lower-dimensional representations.&lt;/p&gt;
&lt;h2 id="encoder"&gt;Encoder&lt;/h2&gt;
&lt;p&gt;For each local observed variable $\mathbf{x}_n$, we wish to approximate
the true posterior distribution $p(\mathbf{z}_n|\mathbf{x}_n)$ over its
corresponding local latent variables $\mathbf{z}_n$. A common approach is to
approximate it using a &lt;em&gt;variational distribution&lt;/em&gt;
$q_{\lambda_n}(\mathbf{z}_n)$, specified as a diagonal
Gaussian, where the &lt;em&gt;local&lt;/em&gt; variational parameters
$\lambda_n = \{ \boldsymbol{\mu}_n, \boldsymbol{\sigma}_n \}$ are the mean and
standard deviation of this approximating distribution,
&lt;/p&gt;
$$
q_{\lambda_n}(\mathbf{z}_n) =
\mathcal{N}(
\mathbf{z}_n |
\boldsymbol{\mu}_n,
\mathrm{diag}(\boldsymbol{\sigma}_n^2)
).
$$&lt;p&gt;
This approach has a number of shortcomings. First, the number of local
variational parameters we need to optimize grows with the size of the dataset.
Second, a new set of local variational parameters need to be optimized for new
unseen test points. This is not to mention the strong factorization assumption
we make by specifying diagonal Gaussian distributions as the family of
approximations. The last is still an active area of research, and the first
two can be addressed by introducing a further approximation using an inference
network.&lt;/p&gt;
&lt;h3 id="inference-network"&gt;Inference network&lt;/h3&gt;
&lt;h1 id="q_phimathbfz_n--mathbfx_n"&gt;We &lt;em&gt;amortize&lt;/em&gt; the cost of inference by introducing an &lt;em&gt;inference network&lt;/em&gt; which
approximates the local variational parameters $\lambda_n$ for a given local
observed variable $\textbf{x}_n$.
For our approximating distribution in particular, given $\textbf{x}_n$ the
inference network yields two vector-valued outputs $\boldsymbol{\mu}_{\phi}(\textbf{x}_n)$ and
$\boldsymbol{\sigma}_{\phi}(\textbf{x}_n)$, which we use to approximate its local
variational parameters $\boldsymbol{\mu}_n$ and $\boldsymbol{\sigma}_n$, respectively.
Our approximate posterior distribution now becomes
$$
q_{\phi}(\mathbf{z}_n | \mathbf{x}_n)&lt;/h1&gt;
&lt;p&gt;\mathcal{N}(\mathbf{z}&lt;em&gt;n
| \boldsymbol{\mu}&lt;/em&gt;{\phi}(\mathbf{x}&lt;em&gt;n),
\mathrm{diag}(\boldsymbol{\sigma}&lt;/em&gt;{\phi}^2(\mathbf{x}_n))
).
$$
Instead of learning &lt;em&gt;local&lt;/em&gt; variational parameters $\lambda_n$ for each data-point,
we now learn a fixed number of &lt;em&gt;global&lt;/em&gt; variational parameters $\phi$ which
constitute the parameters (i.e. weights) of the inference network.
Moreover, this approximation allows statistical strength to be shared across
observed data-points and also generalize to unseen test points.&lt;/p&gt;
&lt;p&gt;We specify the mean $\boldsymbol{\mu}_{\phi}(\mathbf{x})$ and log variance
$\log \boldsymbol{\sigma}_{\phi}^2(\mathbf{x})$ of this distribution as the output of
an inference network. For this post, we keep the architecture of the network
simple, with only a single hidden layer and two fully-connected output layers.
Again, this is simple to define in Keras:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# input layer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# hidden layer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# output layer for mean and log variance&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Since this network has multiple outputs, we couldn&amp;rsquo;t use the Sequential model
API as we did for the decoder. Instead, we will resort to the more powerful
,
which allows us to implement complex models with shared layers, multiple
inputs, multiple outputs, and so on.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Inference network"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/inference_network.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Note that we output the log variance instead of the standard deviation because
this is not only more convenient to work with, but also helps with numerical
stability. However, we still require the standard deviation later. To recover
it, we simply implement the appropriate transformation and encapsulate it in a
.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# normalize log variance to std dev&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Before moving on, we give a few words on nomenclature and context.
In the prelude and title of this section, we characterized the approximate
posterior distribution with an inference network as a probabilistic encoder
(analogously to its counterpart, the probabilistic decoder).
Although this is an accurate interpretation, it is a limited one.
Classically, inference networks are known as &lt;em&gt;recognition models&lt;/em&gt;, and have now
been used for decades in a wide variety of probabilistic methods.
When composed end-to-end, the recognition-generative model combination can be
seen as having an autoencoder structure. Indeed, this structure contains the
variational autoencoder as a special case, and also the now less fashionable
&lt;em&gt;Helmholtz machine&lt;/em&gt; &lt;sup id="fnref1:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.
Even more generally, this recognition-generative model combination constitutes
a widely-applicable approach currently known as &lt;em&gt;amortized variational inference&lt;/em&gt;,
which can be used to perform approximate inference in models that lie beyond
even the large class of deep latent Gaussian models.&lt;/p&gt;
&lt;p&gt;Having specified all the ingredients necessary to carry out variational
inference (namely, the prior, likelihood and approximate posterior), we next
focus on finalizing the definition of the (negative) ELBO as our loss function
in Keras. As written earlier, the ELBO can be decomposed into two terms,
$\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [ \log p_{\theta}(\mathbf{x} | \mathbf{z}) ]$
the expected log likelihood (ELL) over $q_{\phi}(\mathbf{z} | \mathbf{x})$,
and $- \mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ]$
the negative KL divergence between prior $p(\mathbf{z})$ and approximate
posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$. We first turn our attention
to the KL divergence term.&lt;/p&gt;
&lt;h3 id="kl-divergence"&gt;KL Divergence&lt;/h3&gt;
&lt;p&gt;Intuitively, maximizing the negative KL divergence term encourages approximate
posterior densities that place its mass on configurations of the latent
variables which are closest to the prior. Effectively, this regularizes the
complexity of latent space. Now, since both the prior $p(\mathbf{z})$ and
approximate posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$ are Gaussian,
the KL divergence can actually be calculated with the closed-form expression,&lt;/p&gt;
$$
\mathrm{KL} [ q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ]
= - \frac{1}{2} \sum_{k=1}^K \{ 1 + \log \sigma_k^2 - \mu_k^2 - \sigma_k^2 \}
$$&lt;p&gt;where $\mu_k$ and $\sigma_k$ are the $k$-th components of output vectors
$\mu_{\phi}(\mathbf{x})$ and $\sigma_{\phi}(\mathbf{x})$, respectively.
This is not too difficult to derive, and I would recommend verifying this as an
exercise. You can also find a derivation in the appendix of Kingma and Welling&amp;rsquo;s
(2014) paper &lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Recall that earlier, we defined the expected log likelihood term of the ELBO as
a Keras loss. We were able to do this since the log likelihood is a function of
the network&amp;rsquo;s final output (the predicted probabilities), so it maps nicely to a
Keras loss. Unfortunately, the same does not apply for the KL divergence term,
which is a function of the network&amp;rsquo;s intermediate layer outputs, the mean &lt;code&gt;mu&lt;/code&gt;
and log variance &lt;code&gt;log_var&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;We define an auxiliary
which takes &lt;code&gt;mu&lt;/code&gt; and &lt;code&gt;log_var&lt;/code&gt; as input and simply returns them as output
without modification. We do however explicitly introduce the
of
calculating the KL divergence and adding it to a collection of losses, by
calling the method &lt;code&gt;add_loss&lt;/code&gt; &lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Identity transform layer that adds KL divergence
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; to the final model loss.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_placeholder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kl_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_var&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kl_batch&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Next we feed &lt;code&gt;z_mu&lt;/code&gt; and &lt;code&gt;z_log_var&lt;/code&gt; through this layer (this needs to take
place before feeding &lt;code&gt;z_log_var&lt;/code&gt; through the Lambda layer to recover &lt;code&gt;z_sigma&lt;/code&gt;).&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now when the Keras model is finally compiled, the collection of losses will be
aggregated and added to the specified Keras loss function to form the loss we
ultimately minimize. If we specify the loss as the negative log-likelihood we
defined earlier (&lt;code&gt;nll&lt;/code&gt;), we recover the negative ELBO as the final loss we
minimize, as intended.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-alternative-divergences"&gt;Side note: Alternative divergences&lt;/h4&gt;
&lt;p&gt;A key benefit of encapsulating the divergence in an auxiliary layer is that we
can easily implement and swap in other divergences, such as the
$\chi$-divergence or the $\alpha$-divergence.
Using alternative divergences for variational inference is an active research
topic &lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-implicit-models-and-adversarial-learning"&gt;Side note: Implicit models and adversarial learning&lt;/h4&gt;
&lt;p&gt;Additionally, we could also extend the divergence layer to use an auxiliary
density ratio estimator function, instead of evaluating the KL divergence in
the analytical form above.
This relaxes the requirement on approximate posterior
$q_{\phi}(\mathbf{z}|\mathbf{x})$ (and incidentally also prior $p(\mathbf{z})$)
to yield tractable densities, at the cost of maximizing a cruder estimate of the
ELBO.
This is known as Adversarial Variational Bayes&lt;sup id="fnref1:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;, and is an
important line of recent research that, when taken to its logcal conclusion,
can extend the applicability of variational inference to arbitrarily expressive
implicit probabilistic models with intractable likelihoods&lt;sup id="fnref1:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;h3 id="reparameterization-using-merge-layers"&gt;Reparameterization using Merge Layers&lt;/h3&gt;
&lt;p&gt;To perform gradient-based optimization of ELBO with respect to model parameters
$\theta$ and variational parameters $\phi$, we require its gradients with
respect to these parameters, which is generally intractable.
Currently, the dominant approach for circumventing this is by Monte Carlo (MC)
estimation of the gradients. The basic idea is to write the gradient of the
ELBO as an expectation of the gradient, approximate it with MC estimates, then
perform stochastic gradient descent with the repeated MC gradient estimates.&lt;/p&gt;
&lt;p&gt;There exist a number of estimators based on different variance reduction
techniques. However, MC gradient estimates based on the reparameterization trick,
known as the &lt;em&gt;reparameterization gradients&lt;/em&gt;, have be shown to have the lowest
variance among competing estimators for continuous latent variables&lt;sup id="fnref3:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.
The reparameterization trick is a straightforward change of variables that
expresses the random variable $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$
as a deterministic transformation $g_{\phi}$ of another random variable
$\boldsymbol{\epsilon}$ and input $\mathbf{x}$, with parameters $\phi$,&lt;/p&gt;
$$
z = g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon}), \quad
\boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}).
$$&lt;p&gt;Note that $p(\boldsymbol{\epsilon})$ is simpler base distribution which is
parameter-free and independent of $\mathbf{x}$ or $\phi$.
To prevent clutter, we write the ELBO as an expectation of the function
$f(\mathbf{x}, \mathbf{z}) = \log p_{\theta}(\mathbf{x} , \mathbf{z}) -
\log q_{\phi}(\mathbf{z} | \mathbf{x})$ over distribution
$q_{\phi}(\mathbf{z} | \mathbf{x})$.
Now, for any function $f(\mathbf{x}, \mathbf{z})$, taking the gradient of the
expectation with respect to $\phi$, and substituting all occurrences of
$\mathbf{z}$ with $g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})$, we have&lt;/p&gt;
$$
\begin{align*}
\nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}
[ f(\mathbf{x}, \mathbf{z}) ]
&amp; = \nabla_{\phi} \mathbb{E}_{p(\boldsymbol{\epsilon})}
[ f(\mathbf{x}, g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})) ] \newline
&amp; = \mathbb{E}_{p(\mathbf{\epsilon})}
[ \nabla_{\phi} f(\mathbf{x}, g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})) ].
\end{align*}
$$&lt;p&gt;In other words, this simple reparameterization allows the gradient and the
expectation to commute, thereby allowing us to compute unbiased stochastic
estimates of the ELBO gradients by drawing noise samples $\boldsymbol{\epsilon}$
from $p(\boldsymbol{\epsilon})$.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;To recover the diagonal Gaussian approximation we specified earlier
$q_{\phi}(\mathbf{z}_n | \mathbf{x}_n) = \mathcal{N}(\mathbf{z}_n |
\boldsymbol{\mu}_{\phi}(\mathbf{x}_n), \mathrm{diag}(\boldsymbol{\sigma}_{\phi}^2(\mathbf{x}_n)))$,
we draw noise from the Normal base distribution, and specify a simple
location-scale transformation&lt;/p&gt;
$$
\mathbf{z}
= g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})
= \mu_{\phi}(\mathbf{x}) +
\sigma_{\phi}(\mathbf{x}) \odot
\boldsymbol{\epsilon}, \quad
\boldsymbol{\epsilon}
\sim \mathcal{N}(\mathbf{0}, \mathbf{I}),
$$&lt;p&gt;where $\mu_{\phi}(\mathbf{x})$ and $\sigma_{\phi}(\mathbf{x})$ are the outputs
of the inference network defined earlier with parameters $\phi$, and $\odot$
denotes the elementwise product. In Keras, we explicitly make the noise vector
an input to the model by defining an Input layer for it. We then implement the
above location-scale transformation using
, namely &lt;code&gt;Add&lt;/code&gt; and &lt;code&gt;Multiply&lt;/code&gt;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Reparameterization with simple location-scale transformation using Keras merge layers.
"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/reparameterization.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-monte-carlo-sample-size"&gt;Side note: Monte Carlo sample size&lt;/h4&gt;
&lt;p&gt;Note both the inputs for observed variables and noise (&lt;code&gt;x&lt;/code&gt; and &lt;code&gt;eps&lt;/code&gt;) need to be
specified explicitly as inputs to our final model.
Furthermore, the size of their first dimension (i.e. batch size) are required
to be the same.
This corresponds to using a exactly one Monte Carlo sample to approximate the
expected log likelihood, drawing a single sample $\mathbf{z}_n$ from
$q_{\phi}(\mathbf{z}_n | \mathbf{x}_n)$ for each data-point $\mathbf{x}_n$ in
the batch. Although you might find an MC sample size of 1 surprisingly small,
it is actually adequate for a sufficiently large batch size (~100) &lt;sup id="fnref2:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.
In a
,
I demonstrate how to extend our approach to support larger MC sample sizes using
just a few minor tweaks. This extension is crucial for implementing the
&lt;em&gt;importance weighted autoencoder&lt;/em&gt; &lt;sup id="fnref1:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Now, since the noise input is drawn from the Normal distribution, we can save
from having to feed in values for this input from outside the computation graph
by binding a tensor to this Input layer. Specifically, we bind a tensor created
using &lt;code&gt;K.random_normal&lt;/code&gt; with the required shape,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;While &lt;code&gt;eps&lt;/code&gt; still needs to be explicitly specified as an input to compile the
model, values for this input will no longer be expected by methods such as
&lt;code&gt;fit&lt;/code&gt;, &lt;code&gt;predict&lt;/code&gt;. Instead, samples from this distribution will be lazily
generated inside the computation graph when required. See my notes on
for more
details.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Encoder architecture."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/encoder.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;In the
, all of this logic is encapsulated in a single
&lt;code&gt;Lambda&lt;/code&gt; layer, which simultaneously draws samples from a hard-coded base
distribution and also performs the location-scale transformation.
In contrast, this approach achieves a good level of
and
.
By decoupling the random noise vector from the layer&amp;rsquo;s internal logic and
explicitly making it a model input, we emphasize the fact that all sources of
stochasticity emanate from this input. It thereby becomes clear that a random
sample drawn from a particular approximating distribution is obtained by feeding
this source of stochasticity through a number of successive deterministic
transformations.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-notes-gumbel-softmax-trick-for-discrete-latent-variables"&gt;Side notes: Gumbel-softmax trick for discrete latent variables&lt;/h4&gt;
&lt;p&gt;As an example, we could provide samples drawn from the Uniform distribution
as noise input. By applying a number of deterministic transformations that
constitute the &lt;em&gt;Gumbel-softmax reparameterization trick&lt;/em&gt; &lt;sup id="fnref1:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;, we
are able to obtain samples from the Categorical distribution. This allows us
to perform approximate inference on &lt;em&gt;discrete&lt;/em&gt; latent variables, and can be
implemented in this framework by adding a dozen or so lines of code!&lt;/p&gt;
&lt;h1 id="putting-it-all-together"&gt;Putting it all together&lt;/h1&gt;
&lt;p&gt;So far, we&amp;rsquo;ve dissected the variational autoencoder into modular components and
discussed the role and implementation of each one at some length.
Now let&amp;rsquo;s compose these components together end-to-end to form the final
autoencoder architecture.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;It&amp;rsquo;s surprisingly concise, taking up around 20 lines of code.
The diagram of the full model architecture is visualized below.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Variational autoencoder architecture."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/vae_full.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Finally, we specify and compile the model, using the negative log likelihood
&lt;code&gt;nll&lt;/code&gt; defined earlier as the loss.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="model-fitting"&gt;Model fitting&lt;/h1&gt;
&lt;h2 id="dataset-mnist-digits"&gt;Dataset: MNIST digits&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_data&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Variational autoencoder architecture for the MNIST digits dataset."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/vae_full_shapes.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;validation_data&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="loss-nelbo-convergence"&gt;Loss (NELBO) convergence&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt=""
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/nelbo.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h1 id="model-evaluation"&gt;Model evaluation&lt;/h1&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D plot of the digit classes in the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;viridis&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt=""
srcset="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_8bb4eb676623e380.webp 320w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_64a98c5233df2a9d.webp 480w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_f97f8af14c434d9d.webp 600w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_8bb4eb676623e380.webp"
width="600"
height="500"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D manifold of the digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt; &lt;span class="c1"&gt;# figure with 15x15 digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;digit_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# linearly spaced coordinates on the unit square were transformed&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# through the inverse CDF (ppf) of the Gaussian to produce values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# of the latent variables z, since the prior of the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# is Gaussian&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; \
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_pred_grid&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt=""
srcset="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e23f379b58eda1c7.webp 320w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e0d7ef0dff27fb2e.webp 480w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_8a0316a94df89cca.webp 500w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e23f379b58eda1c7.webp"
width="500"
height="500"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h1 id="recap"&gt;Recap&lt;/h1&gt;
&lt;p&gt;In this post, we covered the basics of amortized variational inference, looking
at variational autoencoders as a specific example. In particular, we&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Implemented the decoder and encoder using the
and
respectively.&lt;/li&gt;
&lt;li&gt;Augmented the final loss with the KL divergence term by writing an auxiliary
.&lt;/li&gt;
&lt;li&gt;Worked with the log variance for numerical stability, and used a
to transform it to the
standard deviation when necessary.&lt;/li&gt;
&lt;li&gt;Explicitly made the noise an Input layer, and implemented the
reparameterization trick using
.&lt;/li&gt;
&lt;li&gt;
,
so random samples are generated &lt;em&gt;within&lt;/em&gt; the computation graph.&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="whats-next"&gt;What&amp;rsquo;s next&lt;/h1&gt;
&lt;p&gt;Next, we will extend the divergence layer to use an auxiliary density ratio
estimator function, instead of evaluating the KL divergence in the analytical
form above.
This relaxes the requirement on approximate posterior
$q_{\phi}(\mathbf{z}|\mathbf{x})$ (and incidentally also prior $p(\mathbf{z})$)
to yield tractable densities, at the cost of maximizing a cruder estimate of the
ELBO.
This is known as Adversarial Variational Bayes&lt;sup id="fnref2:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;, and is an
important line of recent research that, when taken to its logcal conclusion,
can extend the applicability of variational inference to arbitrarily expressive
implicit probabilistic models with intractable likelihoods&lt;sup id="fnref2:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-gdscript3" data-lang="gdscript3"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="err"&gt;@&lt;/span&gt;&lt;span class="n"&gt;article&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tiao2017vae&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;title&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;{A} {T}utorial on {V}ariational {A}utoencoders with a {C}oncise {K}eras {I}mplementation&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;author&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;Tiao, Louis C&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;journal&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;tiao.io&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;year&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;2017&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;url&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;https://tiao.io/post/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="links--resources"&gt;Links &amp;amp; Resources&lt;/h2&gt;
&lt;p&gt;Below, you can find:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;The
used to generate the diagrams and plots in this post.&lt;/li&gt;
&lt;li&gt;The above snippets combined in a single executable Python file:&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;scipy.stats&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;backend&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.layers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.models&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;784&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;256&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;epsilon_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keras.losses.binary_crossentropy gives the mean&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# over the last axis. we require the sum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Identity transform layer that adds KL divergence
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; to the final model loss.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_placeholder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kl_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_var&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kl_batch&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stddev&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epsilon_std&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# train the VAE on MNIST digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_data&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;validation_data&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D plot of the digit classes in the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;viridis&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D manifold of the digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt; &lt;span class="c1"&gt;# figure with 15x15 digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;digit_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# linearly spaced coordinates on the unit square were transformed&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# through the inverse CDF (ppf) of the Gaussian to produce values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# of the latent variables z, since the prior of the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# is Gaussian&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;u_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u_grid&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_decoded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_decoded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_decoded&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_decoded&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;D. P. Kingma and M. Welling, &amp;ldquo;Auto-Encoding Variational Bayes,&amp;rdquo; in Proceedings of the 2nd International Conference on Learning Representations (ICLR), 2014.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;
&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Section &amp;ldquo;Recognition models and amortised inference&amp;rdquo; in
&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Dayan, P., Hinton, G. E., Neal, R. M., &amp;amp; Zemel, R. S. (1995). The Helmholtz machine. Neural Computation, 7(5), 889–904.
&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Rezende, D. J., Mohamed, S., &amp;amp; Wierstra, D. (2014). &amp;ldquo;Stochastic backpropagation and approximate inference in deep generative models,&amp;rdquo; in Proceedings of The 31st International Conference on Machine Learning, 2014, (Vol. 32, pp. 1278–1286). Bejing, China: PMLR.
&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref3:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;For a complete treatment of variational autoencoders, and variational
inference in general, I highly recommend:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Jaan Altosaar&amp;rsquo;s blog post,
&lt;/li&gt;
&lt;li&gt;Diederik P. Kingma&amp;rsquo;s PhD Thesis,
.&lt;/li&gt;
&lt;/ul&gt;
&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;D. Rezende and S. Mohamed, &amp;ldquo;Variational Inference with Normalizing Flows,&amp;rdquo; in Proceedings of the 32nd International Conference on Machine Learning, 2015, vol. 37, pp. 1530–1538.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Y. Burda, R. Grosse, and R. Salakhutdinov, &amp;ldquo;Importance Weighted Autoencoders,&amp;rdquo; in Proceedings of the 3rd International Conference on Learning Representations (ICLR), 2015.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;E. Jang, S. Gu, and B. Poole, &amp;ldquo;Categorical Reparameterization with Gumbel-Softmax,&amp;rdquo; Nov. 2016. in Proceedings of the 5th International Conference on Learning Representations (ICLR), 2017.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;L. Mescheder, S. Nowozin, and A. Geiger, &amp;ldquo;Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks,&amp;rdquo; in Proceedings of the 34th International Conference on Machine Learning, 2017, vol. 70, pp. 2391–2400.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;D. Tran, R. Ranganath, and D. Blei, &amp;ldquo;Hierarchical Implicit Models and Likelihood-Free Variational Inference,&amp;rdquo; in Advances in Neural Information Processing Systems 30, 2017.&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;To support sample weighting (fined-tuning how much each data-point
contributes to the loss), Keras losses are expected returns a scalar for each
data-point in the batch. In contrast, losses appended with the &lt;code&gt;add_loss&lt;/code&gt;
method don&amp;rsquo;t support this, and are expected to be a single scalar.
Hence, we calculate the KL divergence for all data-points in the batch and
take the mean before passing it to &lt;code&gt;add_loss&lt;/code&gt;.&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Y. Li and R. E. Turner, &amp;ldquo;Rényi Divergence Variational Inference,&amp;rdquo; in Advances in Neural Information Processing Systems 29, 2016.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;A. B. Dieng, D. Tran, R. Ranganath, J. Paisley, and D. Blei, &amp;ldquo;Variational Inference via chi Upper Bound Minimization,&amp;rdquo; in Advances in Neural Information Processing Systems 30, 2017.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>