<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Implicit Distributions |</title><link>https://tiao.io/tags/implicit-distributions/</link><atom:link href="https://tiao.io/tags/implicit-distributions/index.xml" rel="self" type="application/rss+xml"/><description>Implicit Distributions</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Mon, 27 Aug 2018 00:00:00 +0000</lastBuildDate><image><url>https://tiao.io/media/icon_hu_9c2a75fde2335590.png</url><title>Implicit Distributions</title><link>https://tiao.io/tags/implicit-distributions/</link></image><item><title>Density Ratio Estimation for KL Divergence Minimization between Implicit Distributions</title><link>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</link><pubDate>Mon, 27 Aug 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</guid><description>&lt;!-- TODO: Clarify that optimal classifier refers to the classifier that minimizes the Bayes risk --&gt;
&lt;p&gt;The Kullback-Leibler (KL) divergence between distributions $p$ and $q$ is
defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;It can be expressed more succinctly as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] = \mathbb{E}_{p(x)} [ \log r^{*}(x) ],
$$&lt;p&gt;where $r^{*}(x)$ is defined to be the ratio of between the densities $p(x)$ and
$q(x)$,&lt;/p&gt;
$$
r^{*}(x) := \frac{p(x)}{q(x)}.
$$&lt;p&gt;This density ratio is crucial for computing not only the KL divergence but for
all $f$-divergences, defined as&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)] :=
\mathbb{E}_{q(x)} \left [ f \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;Rarely can this expectation (i.e. integral) can be calculated analytically&amp;mdash;in
most cases, we must resort to Monte Carlo approximation methods, which
explicitly requires the density ratio.
In the more severe case where this density ratio is unavailable, because either
or both $p(x)$ and $q(x)$ are not calculable, we must resort to methods for
&lt;em&gt;density ratio estimation&lt;/em&gt;.
In this post, we illustrate how to perform density ratio estimation by
exploiting its tight correspondence to &lt;em&gt;probabilistic classification&lt;/em&gt;.&lt;/p&gt;
&lt;h3 id="example-univariate-gaussians"&gt;Example: Univariate Gaussians&lt;/h3&gt;
&lt;p&gt;Let us consider the following univariate Gaussian distributions as the running
example for this post,&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid 1, 1^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid 0, 2^2).
$$&lt;p&gt;We will be using &lt;em&gt;TensorFlow&lt;/em&gt;, &lt;em&gt;TensorFlow Probability&lt;/em&gt;, and &lt;em&gt;Keras&lt;/em&gt; in the
code snippets throughout this post.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow_probability&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We first instantiate the distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Their densities are shown below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Univariate Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_densities.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;For any pair of distributions, we can implement their density ratio function $r$
as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let&amp;rsquo;s create the density ratio function for the Gaussian distributions we just
instantiated:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This density ratio function is plotted as the orange dotted line below,
alongside the individual densities shown in the previous plot:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Ratio of Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_density_ratios.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h2 id="analytical-form"&gt;Analytical Form&lt;/h2&gt;
&lt;p&gt;For our running example, we picked $p(x)$ and $q(x)$ to be Gaussians so that
it is possible to integrate out $x$ and compute the KL divergence &lt;em&gt;analytically&lt;/em&gt;.
When we introduce the approximate methods later, this will provide us a &amp;ldquo;gold
standard&amp;rdquo; to benchmark against.&lt;/p&gt;
&lt;p&gt;In general, for Gaussian distributions&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid \mu_p, \sigma_p^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid \mu_q, \sigma_q^2),
$$&lt;p&gt;
it is easy to verify that
&lt;/p&gt;
$$
\mathrm{KL}[ p(x) || q(x) ]
= \log \sigma_q - \log \sigma_p - \frac{1}{2}
\left [
1 - \left ( \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{\sigma_q^2} \right )
\right ].
$$&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can use this to compute the KL divergence between $p(x)$ and $q(x)$
&lt;em&gt;exactly&lt;/em&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Equivalently, we could also use &lt;code&gt;kl_divergence&lt;/code&gt; from &lt;em&gt;TensorFlow
Probability&amp;ndash;Distributions&lt;/em&gt; (&lt;code&gt;tfp.distributions&lt;/code&gt;), which implements the
analytical closed-form expression of the KL divergence between distributions
when such exists.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="monte-carlo-estimation--prescribed-distributions"&gt;Monte Carlo Estimation &amp;mdash; prescribed distributions&lt;/h2&gt;
&lt;p&gt;For distributions where their KL divergence is not analytically tractable, we
may appeal to Monte Carlo (MC) estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;p&gt;Clearly, this requires the density ratio $r^{*}(x)$ and, in turn, the densities
$p(x)$ and $q(x)$ to be analytically tractable. Distributions for which the
density function can be readily evaluated are sometimes referred to as
&lt;strong&gt;prescribed distributions&lt;/strong&gt;. As before, we &lt;em&gt;prescribed&lt;/em&gt; Gaussians distributions
in our running example so the Monte Carlo estimate can be later compared against.
We approximate their KL divergence using $M = 5000$ Monte Carlo samples as
follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;true_log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44670376&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Or equivalently, using the &lt;code&gt;expectation&lt;/code&gt; function from &lt;em&gt;TensorFlow
Probability&amp;ndash;Monte Carlo&lt;/em&gt; (&lt;code&gt;tfp.monte_carlo&lt;/code&gt;):&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4581419&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;More generally, we can approximate any $f$-divergence with MC estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathbb{E}_{q(x)} [ f(r^{*}(x)) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} f(r^{*}(x_q^{(i)})),
\quad x_q^{(i)} \sim q(x).
\end{align*}
$$&lt;p&gt;This can be done using the &lt;code&gt;monte_carlo_csiszar_f_divergence&lt;/code&gt; function from
&lt;em&gt;TensorFlow Probability&amp;ndash;Variational Inference&lt;/em&gt; (&lt;code&gt;tfp.vi&lt;/code&gt;).
One simply needs to specify the appropriate convex function $f$.
The convex function that instantiates the (forward) KL divergence is provided
in &lt;code&gt;tfp.vi&lt;/code&gt; as &lt;code&gt;kl_forward&lt;/code&gt;, alongside many other common $f$-divergences.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_forward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4430853&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="density-ratio-estimation--implicit-distributions"&gt;Density Ratio Estimation &amp;mdash; implicit distributions&lt;/h2&gt;
&lt;p&gt;When either density $p(x)$ or $q(x)$ is unavailable, things become more tricky.
Which brings us to the topic of this post. Suppose we only have samples from
$p(x)$ and $q(x)$&amp;mdash;these could be natural images, outputs from a neural
network with stochastic inputs, or in the case of our running example, i.i.d.
samples drawn from Gaussians, etc.
Distributions for which we are only able to observe their samples are known as
&lt;strong&gt;implicit distributions&lt;/strong&gt;, since their samples &lt;em&gt;imply&lt;/em&gt; some underlying true
density which we may not have direct access to.&lt;/p&gt;
&lt;p&gt;Density ratio estimation is concerned with estimating the ratio of densities
$r^{*}(x) = p(x) / q(x)$ given access only to samples from $p(x)$ and $q(x)$.
Moreover, density ratio estimation usually encompass methods that achieve this
without resorting to direct &lt;em&gt;density estimation&lt;/em&gt; of the individual densities
$p(x)$ or $q(x)$, since any error in the estimation of the denominator $q(x)$
is magnified exponentially.&lt;/p&gt;
&lt;p&gt;Of the many density ratio estimation methods that now
flourish&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, the classical approach of &lt;em&gt;probabilistic
classification&lt;/em&gt; remains dominant, due in no small part to its simplicity.&lt;/p&gt;
&lt;h3 id="reducing-density-ratio-estimation-to-probabilistic-classification"&gt;Reducing Density Ratio Estimation to Probabilistic Classification&lt;/h3&gt;
&lt;p&gt;We now demonstrate that density ratio estimation can be reduced to probabilistic
classification. We shall do this by highlighting the one-to-one correspondence
between the density ratio of $p(x)$ and $q(x)$ and the optimal probabilistic
classifier that discriminates between their samples.
Specifically, suppose we have a collection of samples from both $p(x)$ and $q(x)$,
where each sample is assigned a class label indicating which distribution it was
drawn from. Then, from an estimator of the class-membership probabilities, it is
straightforward to recover an estimator of the density ratio.&lt;/p&gt;
&lt;p&gt;Suppose we have $N_p$ and $N_q$ samples drawn from $p(x)$ and $q(x)$,
respectively,&lt;/p&gt;
$$
x_p^{(1)}, \dotsc, x_p^{(N_p)} \sim p(x),
\qquad \text{and} \qquad
x_q^{(1)}, \dotsc, x_q^{(N_q)} \sim q(x).
$$&lt;p&gt;Then, we form the dataset $\{ (x_n, y_n) \}_{n=1}^N$, where $N = N_p + N_q$
and&lt;/p&gt;
$$
\begin{align*}
(x_1, \dotsc, x_N) &amp; = (x_p^{(1)}, \dotsc, x_p^{(N_p)},
x_q^{(1)}, \dotsc, x_q^{(N_q)}), \newline
(y_1, \dotsc, y_N) &amp; = (\underbrace{1, \dotsc, 1}_{N_p},
\underbrace{0, \dotsc, 0}_{N_q}).
\end{align*}
$$&lt;p&gt;In other words, we label samples drawn from $p(x)$ as 1 and those drawn from
$q(x)$ as 0. In code, this looks like:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This dataset is visualized below. The blue squares in the top row are samples
$x_p^{(i)} \sim p(x)$ with label 1; red squares in the bottom row are samples
$x_q^{(j)} \sim q(x)$ with label 0.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Classification dataset"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/dataset.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Now, by construction, we have&lt;/p&gt;
$$
p(x) = \mathcal{P}(x \mid y = 1),
\qquad
\text{and}
\qquad
q(x) = \mathcal{P}(x \mid y = 0).
$$&lt;p&gt;Using Bayes&amp;rsquo; rule, we can write&lt;/p&gt;
$$
\mathcal{P}(x \mid y) =
\frac{\mathcal{P}(y \mid x) \mathcal{P}(x)}
{\mathcal{P}(y)}.
$$&lt;p&gt;Hence, we can express the density ratio $r^{*}(x)$ as&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) &amp; = \frac{p(x)}{q(x)}
= \frac{\mathcal{P}(x \mid y = 1)}
{\mathcal{P}(x \mid y = 0)} \newline
&amp; = \left ( \frac{\mathcal{P}(y = 1 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 1)} \right )
\left ( \frac{\mathcal{P}(y = 0 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 0)} \right ) ^ {-1} \newline
&amp; = \frac{\mathcal{P}(y = 0)}{\mathcal{P}(y = 1)}
\frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;p&gt;Let us approximate the ratio of marginal densities by the ratio of sample sizes,&lt;/p&gt;
$$
\frac{\mathcal{P}(y = 0)}
{\mathcal{P}(y = 1)}
\approx
\frac{N_q}{N_p + N_q}
\left ( \frac{N_p}{N_p + N_q} \right )^{-1}
= \frac{N_q}{N_p}.
$$&lt;p&gt;To avoid notational clutter, let us assume from now on that $N_q = N_p$.
We can then write $r^{*}(x)$ in terms of class-posterior probabilities,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;h4 id="recovering-the-density-ratio-from-the-class-probability"&gt;Recovering the Density Ratio from the Class Probability&lt;/h4&gt;
&lt;p&gt;This yields a one-to-one correspondence between the density ratio $r^{*}(x)$
and the class-posterior probability $\mathcal{P}(y = 1 \mid x)$.
Namely,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}
&amp; = \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \newline
&amp; = \exp
\left [
\log \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \right ] \newline
&amp; = \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ],
\end{align*}
$$&lt;p&gt;where $\sigma^{-1}$ is the &lt;em&gt;logit&lt;/em&gt; function, or inverse sigmoid function, given
by $\sigma^{-1}(\rho) = \log \left ( \frac{\rho}{1-\rho} \right )$&lt;/p&gt;
&lt;h4 id="recovering-the-class-probability-from-the-density-ratio"&gt;Recovering the Class Probability from the Density Ratio&lt;/h4&gt;
&lt;p&gt;By simultaneously manipulating both sides of this equation, we can also recover
the exact class-posterior probability as a function of the density ratio,&lt;/p&gt;
$$
\mathcal{P}(y=1 \mid x) = \sigma(\log r^{*}(x)) = \frac{p(x)}{p(x) + q(x)}.
$$
&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;optimal_classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truediv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In the figure below, The class-posterior probability $\mathcal{P}(y=1 \mid x)$
is plotted against the dataset visualized earlier.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Optimal classifier&amp;mdash;class-posterior probabilities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/optimal_classifier.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="probabilistic-classification-with-logistic-regression"&gt;Probabilistic Classification with Logistic Regression&lt;/h3&gt;
&lt;p&gt;The class-posterior probability $\mathcal{P}(y = 1 \mid x)$ can be approximated
using a parameterized function $D_{\theta}(x)$ with parameters $\theta$. This
functions takes as input samples from $p(x)$ and $q(x)$ and outputs a &lt;em&gt;score&lt;/em&gt;,
or probability, in the range $[0, 1]$ that it was drawn from $p(x)$.
Hence, we refer to $D_{\theta}(x)$ as the probabilistic classifier.&lt;/p&gt;
&lt;p&gt;From before, it is clear to see how an estimator of the density ratio
$r_{\theta}(x)$ might be constructed as a function of probabilistic classifier
$D_{\theta}(x)$. Namely,&lt;/p&gt;
$$
\begin{align*}
r_{\theta}(x) &amp; = \exp[ \sigma^{-1}(D_{\theta}(x)) ] \newline
&amp; \approx \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ] = r^{*}(x),
\end{align*}
$$&lt;p&gt;
and &lt;em&gt;vice versa&lt;/em&gt;,
&lt;/p&gt;
$$
\begin{align*}
D_{\theta}(x) &amp; = \sigma(\log r_{\theta}(x)) \newline
&amp; \approx \sigma(\log r^{*}(x)) = \mathcal{P}(y = 1 \mid x).
\end{align*}
$$&lt;p&gt;Instead of $D_{\theta}(x)$, we usually specify the parameterized function
$\log r_{\theta}(x)$. This is also referred to as the &lt;em&gt;log-odds&lt;/em&gt;, or &lt;em&gt;logits&lt;/em&gt;,
since it is equivalent to the unnormalized output of the classifier before being
fed through the logistic sigmoid function.&lt;/p&gt;
&lt;p&gt;We define a small fully-connected neural network with two hidden layers and ReLU
activations:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This simple architecture is visualized in the diagram below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Architecture"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_ratio_architecture.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We learn the optimal class probability estimator by optimizing it with respect
to a &lt;em&gt;proper scoring rule&lt;/em&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; that yields well-calibrated probabilistic predictions, such as the &lt;em&gt;binary cross-entropy loss&lt;/em&gt;,&lt;/p&gt;
$$
\begin{align*}
\mathcal{L}(\theta) &amp; :=
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
-\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ]
-\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ].
\end{align*}
$$&lt;p&gt;An implementation optimized for numerical stability is given below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now we can build a
, where the
&amp;mdash;samples from
$p(x)$ and $q(x)$, respectively.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The model can now be compiled and finalized. Since we&amp;rsquo;re using a custom loss
that take the two sets of log-ratios as input, we specify &lt;code&gt;loss=None&lt;/code&gt; and
define it instead through the &lt;code&gt;add_loss&lt;/code&gt; method.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As a sanity-check, the loss evaluated on a random batch can be obtained like so:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;1.3765026330947876&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can now fit our estimator, recording the loss at the end of each epoch:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;hist&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps_per_epoch&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The following animation shows how the predictions for the probabilistic
classifier, density ratio, log density ratio, evolve after every epoch:&lt;/p&gt;
&lt;p&gt;&lt;video controls autoplay src="https://giant.gfycat.com/FrighteningThunderousFlicker.webm"&gt;&lt;/video&gt;&lt;/p&gt;
&lt;p&gt;It is overlaid on top of their exact, analytical counterparts, which are only
available since we prescribed them to be Gaussian distribution.
For implicit distributions, these won&amp;rsquo;t be accessible at all.&lt;/p&gt;
&lt;p&gt;Below is the final plot of how the binary cross-entropy loss converges:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary Cross-entropy Loss"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the probabilistic classifier $D_{\theta}(x)$ (&lt;em&gt;dotted green&lt;/em&gt;),
plotted against the optimal classifier, which is the class-posterior probability
$\mathcal{P}(y=1 \mid x) = \frac{p(x)}{p(x) + q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Class Probability Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/class_probability_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the density ratio estimator $r_{\theta}(x)$
(&lt;em&gt;dotted green&lt;/em&gt;), plotted against the exact density ratio function
$r^{*}(x) = \frac{p(x)}{q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;And finally, the previous plot in logarithmic scale:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;While it may appear that we are simply performing regression on the latent
function $r^{*}(x)$ (which is not wrong&amp;mdash;we are), it is important to emphasize that
we do this without ever having observed values of $r^{*}(x)$.
Instead, we only ever observed samples from $p(x)$ and $q(x)$
This has profound implications and potential for a great number of applications
that we shall explore later on.&lt;/p&gt;
&lt;h3 id="back-to-monte-carlo-estimation"&gt;Back to Monte Carlo estimation&lt;/h3&gt;
&lt;p&gt;Having an obtained an estimate of the log density ratio, it is now feasible to
perform Monte Carlo estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x) \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r_{\theta}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4570999&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In other words, we draw MC samples from $p(x)$ as before. But instead of taking
the mean of the function $\log r^{*}(x)$ evaluated on these samples (which is
unavailable for implicit distributions), we do so on a proxy function
$\log r_{\theta}(x)$ that is estimated through probabilistic classification as
described above.&lt;/p&gt;
&lt;h2 id="learning-in-implicit-generative-models"&gt;Learning in Implicit Generative Models&lt;/h2&gt;
&lt;p&gt;Now let&amp;rsquo;s take a look at where these ideas are being used in practice.
Consider a collection of natural images, such as the MNIST handwritten
digits shown below, which are assumed to be samples drawn from some implicit
distribution $q(\mathbf{x})$:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/MnistExamples.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;MNIST hand-written digits&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Directly estimating the density of $q(\mathbf{x})$ may not always be feasible&amp;mdash;in
some cases, it may not even exist.
Instead, consider defining a parametric function $G_{\phi}: \mathbf{z} \mapsto
\mathbf{x}$ with parameters $\phi$, that takes as input $\mathbf{z}$ drawn from
some fixed distribution $p(\mathbf{z})$.
The outputs $\mathbf{x}$ of this generative process are assumed to be samples
following some implicit distribution $p_{\phi}(\mathbf{x})$. In other words,
we can write&lt;/p&gt;
$$
\mathbf{x} \sim p_{\phi}(\mathbf{x}) \quad
\Leftrightarrow \quad
\mathbf{x} = G_{\phi}(\mathbf{z}),
\quad \mathbf{z} \sim p(\mathbf{z}).
$$&lt;p&gt;By optimizing parameters $\phi$, we can make $p_{\phi}(\mathbf{x})$ close to
the real data distribution $q(\mathbf{x})$. This is a compelling alternative to
density estimation since there are many situations where being able to generate
samples is more important than being able to calculate the numerical value of
the density. Some examples of these include &lt;em&gt;image super-resolution&lt;/em&gt; and
&lt;em&gt;semantic segmentation&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;One approach might be to introduce a classifier $D_{\theta}$ that discriminates
between real and synthetic samples.
Then we optimize $G_{\phi}$ to synthesize samples that are indistinguishable,
to classifier $D_{\theta}$, from the real samples. This can be achieved by
simultaneously optimizing the binary cross-entropy loss, resulting in the
saddle-point objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p_{\phi}(\mathbf{x})} [ \log(1-D_{\theta} (\mathbf{x})) ] \newline =
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ].
\end{align*}
$$&lt;p&gt;This is, of course, none other than the groundbreaking &lt;em&gt;generative adversarial
network (GAN)&lt;/em&gt;&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.
You can read more about the density ratio estimation perspective of GANs in
the paper by Uehara et al. 2016&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;. For an even more general and complete treatment of learning in implicit models, I recommend the paper
from Mohamed and Lakshminarayanan, 2016&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;, which partially inspired this post.&lt;/p&gt;
&lt;p&gt;For the remainder of this section, I want to highlight a variant of this
approach that specifically aims to minimize the KL divergence w.r.t. parameters
$\phi$,&lt;/p&gt;
$$
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})].
$$&lt;p&gt;To overcome the fact that the densities of both $p_{\phi}(\mathbf{x})$ and
$q(\mathbf{x})$ are unknown, we can readily adopt the density ratio estimation
approach outlined in this post.
Namely, by maximizing the following objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ] \newline
= &amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log \sigma ( \log r_{\theta} (\mathbf{x}) ) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1 - \sigma ( \log r_{\theta} (G_{\phi}(\mathbf{z})) )) ],
\end{align*}
$$&lt;p&gt;which attains its maximum at&lt;/p&gt;
$$
r_{\theta}(\mathbf{x}) = \frac{q(\mathbf{x})}{p_{\phi}(\mathbf{x})}.
$$&lt;p&gt;Concurrently, we also minimize the current best estimate of the KL divergence,&lt;/p&gt;
$$
\begin{align*}
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})]
&amp; =
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} \left [ \log \frac{p_{\phi}(\mathbf{x})}{q(\mathbf{x})} \right ] \newline
&amp; \approx
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} [ - \log r_{\theta}(\mathbf{x}) ] \newline
&amp; =
\min_{\phi} \mathbb{E}_{p(\mathbf{z})} [ - \log r_{\theta}(G_{\phi}(\mathbf{z})) ].
\end{align*}
$$&lt;p&gt;In addition to being more stable than the vanilla GAN approach (alleviates
saturating gradients), this is especially important in contexts where there is
a specific need to minimize the KL divergence, such as in &lt;em&gt;variational inference
(VI)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;This was first used in &lt;em&gt;AffGAN&lt;/em&gt; by Sønderby et al. 2016&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;,
and has since been incorporated in many papers that deal with implicit
distributions in variational inference, such as
(Mescheder et al. 2017&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;,
Huszar 2017&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;,
Tran et al. 2017&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;,
Pu et al. 2017&lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;,
Chen et al. 2018&lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;,
Tiao et al. 2018&lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt;), and many others.&lt;/p&gt;
&lt;h2 id="bound-on-the-jensen-shannon-divergence"&gt;Bound on the Jensen-Shannon Divergence&lt;/h2&gt;
&lt;p&gt;Before we wrap things up, let us take another look at the plot of the
binary-cross entropy loss recorded at the end of each epoch.
We see that it converges quickly to some value.
It is natural to wonder: what is the significance, if any, of this value?&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary cross-entropy loss converges to Jensen Shannon divergence (up to constants)"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy_vs_jensen_shannon.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;It is in fact the (negative) Jensen-Shannon (JS) divergence, up to constants,&lt;/p&gt;
$$
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Recall the Jensen-Shannon divergence is defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]
= \frac{1}{2} \mathcal{D}_{\mathrm{KL}}[p(x) || m(x)] +
\frac{1}{2} \mathcal{D}_{\mathrm{KL}}[q(x) || m(x)],
$$&lt;p&gt;where $m$ is the mixture density&lt;/p&gt;
$$
m(x) = \frac{p(x) + q(x)}{2}.
$$&lt;p&gt;With our running example, this cannot be evaluated exactly since the KL
divergence between a Gaussian and a mixture of Gaussians is analytically
intractable.
However, like the KL, we can still estimate their JS divergence with Monte
Carlo estimation&lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;js&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jensen_shannon&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This value is shown in the horizontal black line in the plot above. Along the
right margin, we also plot the a histogram of the binary cross-entropy loss
values over epochs. We can see that this value indeed coincides with the mode of
this histogram.&lt;/p&gt;
&lt;p&gt;It is straightforward to show that we have the upper bound&lt;/p&gt;
$$
\inf_{\theta} \mathcal{L}(\theta) \geq - 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Firstly, we have&lt;/p&gt;
$$
\begin{align*}
\sup_{\theta} &amp;
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
\mathbb{E}_{p(x)} [ \log \mathcal{P}(y=1 \mid x) ] +
\mathbb{E}_{q(x)} [ \log \mathcal{P}(y=0 \mid x) ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{p(x) + q(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{p(x) + q(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{1}{2} \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{1}{2} \frac{q(x)}{m(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{m(x)} \right ] - 2 \log 2 \newline
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4.
\end{align*}
$$&lt;p&gt;Therefore,&lt;/p&gt;
$$
2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
$$&lt;p&gt;Negating both sides, we get&lt;/p&gt;
$$
\begin{align*}
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4
\leq &amp;
-\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta}
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta} \mathcal{L}(\theta),
\end{align*}
$$&lt;p&gt;as required.&lt;/p&gt;
&lt;p&gt;In short, this tells us that the binary cross-entropy loss is &lt;em&gt;itself&lt;/em&gt; an
approximation (up to constants) to the Jensen-Shannon divergence.
This begs the question: is it possible to construct a more general loss that bounds any given $f$-divergence?&lt;/p&gt;
&lt;h2 id="teaser-lower-bound-on-any--divergence"&gt;Teaser: Lower Bound on any $f$-divergence&lt;/h2&gt;
&lt;p&gt;Using convex analysis, one can actually show that for any $f$-divergence, we
have the lower bound&lt;sup id="fnref:15"&gt;&lt;a href="#fn:15" class="footnote-ref" role="doc-noteref"&gt;15&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)]
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ f'(r_{\theta}(x)) ] -
\mathbb{E}_{q(x)} [ f^{\star}(f'(r_{\theta}(x))) ],
$$&lt;p&gt;with equality exactly when $r_{\theta}(x) = r^{*}(x)$.
Importantly, this lower bound can be computed without requiring the densities of
$p(x)$ or $q(x)$&amp;mdash;only their samples are needed.&lt;/p&gt;
&lt;p&gt;In the special case of $f(u) = u \log u - (u + 1) \log (u + 1)$, we recover the
binary cross-entropy loss and the previous result, as expected,&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4 \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ] +
\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ] \newline
&amp; = \sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
\end{align*}
$$&lt;p&gt;Alternately, in the special case of $f(u) = u \log u$, we get&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log r_{\theta} (x) ] -
\mathbb{E}_{q(x)} [ r_{\theta} (x) - 1 ].
\end{align*}
$$&lt;p&gt;This gives us &lt;em&gt;yet&lt;/em&gt; another way to estimate the KL divergence between
implicit distributions, in the form of a direct lower bound on the KL divergence
itself.
As it turns out, this lower bound is closely-related to the objective of the
&lt;em&gt;KL Importance Estimation Procedure (KLIEP)&lt;/em&gt;&lt;sup id="fnref:16"&gt;&lt;a href="#fn:16" class="footnote-ref" role="doc-noteref"&gt;16&lt;/a&gt;&lt;/sup&gt;, and will be
the topic of our next post in this series.&lt;/p&gt;
&lt;h1 id="summary"&gt;Summary&lt;/h1&gt;
&lt;p&gt;This post covered how to evaluate the KL divergence, or any $f$-divergence,
between implicit distributions&amp;mdash;distributions which we can only sample from.
First, we underscored the crucial role of the density ratio in the estimation of
$f$-divergences.
Next, we showed the correspondence between the density ratio and the optimal
classifier.
By exploiting this link, we demonstrated how one can use a trained probabilistic classifier to construct a proxy for the exact density ratio, and use this to
enable estimation of any $f$-divergence.
Finally, we provided some context on where this method is used, touching upon
some recent advances in implicit generative models and variational inference.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2018dre,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{D}ensity {R}atio {E}stimation for {KL} {D}ivergence {M}inimization between {I}mplicit {D}istributions&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2018&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h1 id="acknowledgements"&gt;Acknowledgements&lt;/h1&gt;
&lt;p&gt;I am grateful to
for providing
extensive feedback and insightful discussions. I would also like to thank
Alistair Reid and
for their comments and suggestions.&lt;/p&gt;
&lt;h1 id="links-and-resources"&gt;Links and Resources&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;The
used to generate the figures in this post, which you can
.&lt;/li&gt;
&lt;li&gt;The very readable textbook on
&lt;sup id="fnref1:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, which I highly recommend. (Note: the Gaussian distributions example was borrowed from this book.) &lt;a target="_blank" href="https://www.amazon.com/gp/product/0521190177/ref=as_li_tl?ie=UTF8&amp;camp=1789&amp;creative=9325&amp;creativeASIN=0521190177&amp;linkCode=as2&amp;tag=tiao03-20&amp;linkId=0907c42c1a834ffa68ca2f27c2bdb92f"&gt;&lt;img border="0" src="//ws-na.amazon-adsystem.com/widgets/q?_encoding=UTF8&amp;MarketPlace=US&amp;ASIN=0521190177&amp;ServiceVersion=20070822&amp;ID=AsinImage&amp;WS=1&amp;Format=_SL250_&amp;tag=tiao03-20" &gt;&lt;/a&gt;&lt;img src="//ir-na.amazon-adsystem.com/e/ir?t=tiao03-20&amp;l=am2&amp;o=1&amp;a=0521190177" width="1" height="1" border="0" alt="" style="border:none !important; margin:0px !important;" /&gt;&lt;/li&gt;
&lt;li&gt;Shakir Mohamed&amp;rsquo;s blog post
.&lt;/li&gt;
&lt;li&gt;The paper by Menon and Ong, 2016&lt;sup id="fnref:17"&gt;&lt;a href="#fn:17" class="footnote-ref" role="doc-noteref"&gt;17&lt;/a&gt;&lt;/sup&gt;, which gives a generalized treatment of the theoretical link between density ratio estimation and probabilistic classification.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;The (forward) KL divergence can be recovered with
&lt;/p&gt;
$$
f_{\mathrm{KL}}(u) := u \log u.
$$&lt;p&gt;
This is easy to verify,
&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] &amp; :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ \frac{p(x)}{q(x)} \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ f_{\mathrm{KL}} \left ( \frac{p(x)}{q(x)} \right ) \right ].
\end{align*}
$$&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Sugiyama, M., Suzuki, T., &amp;amp; Kanamori, T. (2012). &lt;em&gt;Density Ratio Estimation in Machine Learning&lt;/em&gt;. Cambridge University Press.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Gneiting, T., &amp;amp; Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. &lt;em&gt;Journal of the American Statistical Association&lt;/em&gt;, 102(477), (pp. 359-378).&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., &amp;hellip; &amp;amp; Bengio, Y. (2014). Generative Adversarial Nets. In Advances in &lt;em&gt;Neural Information Processing Systems&lt;/em&gt; (pp. 2672-2680).&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Uehara, M., Sato, I., Suzuki, M., Nakayama, K., &amp;amp; Matsuo, Y. (2016). Generative Adversarial Nets from a Density Ratio Estimation Perspective. &lt;em&gt;arXiv preprint arXiv:1610.02920&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;Mohamed, S., &amp;amp; Lakshminarayanan, B. (2016). Learning in Implicit Generative Models. &lt;em&gt;arXiv preprint arXiv:1610.03483&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;Sønderby, C. K., Caballero, J., Theis, L., Shi, W., &amp;amp; Huszár, F. (2016). Amortised map inference for image super-resolution. &lt;em&gt;arXiv preprint arXiv:1610.04490&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Mescheder, L., Nowozin, S., &amp;amp; Geiger, A. (2017). Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks. In &lt;em&gt;International Conference on Machine learning (ICML)&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;Huszár, F. (2017). Variational inference using implicit distributions. &lt;em&gt;arXiv preprint arXiv:1702.08235&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;Tran, D., Ranganath, R., &amp;amp; Blei, D. (2017). Hierarchical implicit models and likelihood-free variational inference. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 5523-5533).&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;Pu, Y., Wang, W., Henao, R., Chen, L., Gan, Z., Li, C., &amp;amp; Carin, L. (2017). Adversarial symmetric variational autoencoder. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 4330-4339).&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;Chen, L., Dai, S., Pu, Y., Zhou, E., Li, C., Su, Q., &amp;hellip; &amp;amp; Carin, L. (2018, March). Symmetric variational autoencoder and connections to adversarial learning. In &lt;em&gt;International Conference on Artificial Intelligence and Statistics&lt;/em&gt; (pp. 661-669).&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Tiao, L. C., Bonilla, E. V., &amp;amp; Ramos, F. (2018). Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference. &lt;em&gt;arXiv preprint arXiv:1806.01771&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;Note that &lt;code&gt;jensen_shannon&lt;/code&gt; with &lt;code&gt;self_normalized=False&lt;/code&gt; (default), corresponds to $2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4$, while &lt;code&gt;self_normalized=True&lt;/code&gt; corresponds to $\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]$.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:15"&gt;
&lt;p&gt;Nguyen, X., Wainwright, M. J., &amp;amp; Jordan, M. I. (2010). Estimating divergence functionals and the likelihood ratio by convex risk minimization. &lt;em&gt;IEEE Transactions on Information Theory&lt;/em&gt;, 56(11), 5847-5861.&amp;#160;&lt;a href="#fnref:15" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:16"&gt;
&lt;p&gt;Sugiyama, M., Nakajima, S., Kashima, H., Buenau, P. V., &amp;amp; Kawanabe, M. (2008). Direct importance estimation with model selection and its application to covariate shift adaptation. In Advances in neural information processing systems (pp. 1433-1440).&amp;#160;&lt;a href="#fnref:16" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:17"&gt;
&lt;p&gt;Menon, A., &amp;amp; Ong, C. S. (2016, June). Linking Losses for Density Ratio and Class-Probability Estimation. In &lt;em&gt;International Conference on Machine Learning&lt;/em&gt; (pp. 304-313).&amp;#160;&lt;a href="#fnref:17" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>