<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>TensorFlow Probability |</title><link>https://tiao.io/tags/tensorflow-probability/</link><atom:link href="https://tiao.io/tags/tensorflow-probability/index.xml" rel="self" type="application/rss+xml"/><description>TensorFlow Probability</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Sun, 16 Apr 2023 11:16:03 +0000</lastBuildDate><image><url>https://tiao.io/media/icon_hu_9c2a75fde2335590.png</url><title>TensorFlow Probability</title><link>https://tiao.io/tags/tensorflow-probability/</link></image><item><title>Efficient Cholesky decomposition of low-rank updates</title><link>https://tiao.io/posts/efficient-cholesky-decomposition-of-low-rank-updates/</link><pubDate>Sun, 16 Apr 2023 11:16:03 +0000</pubDate><guid>https://tiao.io/posts/efficient-cholesky-decomposition-of-low-rank-updates/</guid><description>&lt;p&gt;Suppose we&amp;rsquo;re given a positive semidefinite (PSD)
matrix $\mathbf{A} \in \mathbb{R}^{N \times N}$
to
which we wish to update by some low-rank
matrix $\mathbf{U} \mathbf{U}^\top \in \mathbb{R}^{N \times N}$
,
$$\mathbf{B} \triangleq \mathbf{A} + \mathbf{U} \mathbf{U}^\top,$$
where the update factor matrix $\mathbf{U} \in \mathbb{R}^{N \times M}$
.
To be more precise, the low-rank update is rank-$M$ for some $M \ll N$.&lt;/p&gt;
&lt;p&gt;&lt;em&gt;What is the best way to calculate the Cholesky decomposition of $\mathbf{B}$
?&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;Given no additional information the obvious way is to calculate it directly,
which incurs a cost of $\mathcal{O}(N^3)$
.
But suppose we&amp;rsquo;ve already calculated the lower-triangular Cholesky factor
$\mathbf{L} \in \mathbb{R}^{N \times N}$
of $\mathbf{A}$
(i.e., $\mathbf{LL}^\top = \mathbf{A}$
).
Then, we can use it to calculate the Cholesky decomposition
of $\mathbf{B}$
at a reduced cost
of $\mathcal{O}(N^2M)$
.
Here&amp;rsquo;s how.&lt;/p&gt;
&lt;h2 id="rank-1-updates"&gt;Rank-1 Updates&lt;/h2&gt;
&lt;p&gt;First, let&amp;rsquo;s consider the simpler case involving just &lt;em&gt;rank-1 updates&lt;/em&gt;
$$\mathbf{B} \triangleq \mathbf{A} + \mathbf{u} \mathbf{u}^\top,$$
where update factor vector $\mathbf{u} \in \mathbb{R}^{N}$
.
With some clever manipulations&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;, the details of which we won&amp;rsquo;t
get into in this post, we can leverage $\mathbf{L}$
to
calculate the Cholesky decomposition of $\mathbf{B}$
at a reduced cost of $\mathcal{O}(N^2)$
.
Such a procedure for rank-1 updates is implemented in the old-school Fortran
linear algebra software library
(but unfortunately not in its successor
),
and also in modern libraries like
(TFP).&lt;/p&gt;
&lt;p&gt;In TFP, this is implemented in the function named
.
For example,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow_probability&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update_factor_vector&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3); suppose this is pre-computed and stored&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3), ignores `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^2), uses `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_factor_1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Here &lt;code&gt;cholesky_update&lt;/code&gt; takes as arguments &lt;code&gt;chol&lt;/code&gt; with shape &lt;code&gt;[B1, ..., Bn, N, N]&lt;/code&gt;
and &lt;code&gt;u&lt;/code&gt; with shape &lt;code&gt;[B1, ..., Bn, N]&lt;/code&gt;, and returns a lower triangular Cholesky
factor of the rank-1 updated matrix &lt;code&gt;chol @ chol.T + u @ u.T&lt;/code&gt; in $\mathcal{O}(N^2)$
time.&lt;/p&gt;
&lt;h2 id="low-rank-updates"&gt;Low-Rank Updates&lt;/h2&gt;
&lt;p&gt;Now let&amp;rsquo;s return to rank-$M$ updates.
First let&amp;rsquo;s write the update factor matrix $\mathbf{U}$ in terms of column
vectors $\mathbf{u}_m \in \mathbb{R}^{N}$,
$$
\mathbf{U} \triangleq
\begin{bmatrix}
\mathbf{u}_1 &amp; \cdots &amp; \mathbf{u}_M
\end{bmatrix}.
$$
&lt;/p&gt;
&lt;p&gt;Now we can write the rank-$M$ update matrix as a sum of $M$ rank-1 matrices,
$$
\mathbf{U} \mathbf{U}^\top =
\begin{bmatrix} \mathbf{u}_1 &amp; \cdots &amp; \mathbf{u}_M \end{bmatrix}
\begin{bmatrix} \mathbf{u}_1^\top \\ \vdots \\ \mathbf{u}_M^\top \end{bmatrix} =
\sum_{m=1}^{M} \mathbf{u}_m \mathbf{u}_m^\top.
$$
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, M]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, 1, M] [..., 1, N, M] -&amp;gt; [..., N, N, M] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:,&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, M] [..., M, N] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# not exactly equal due to finite precision, but still equal up to high precision&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decimal&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Thus seen, a low-rank update is nothing more than a repeated application of
rank-1 updates,
$$
\begin{align}
\mathbf{B} &amp; = \mathbf{A} + \mathbf{U} \mathbf{U}^\top \\ &amp; =
\mathbf{A} + \sum_{m=1}^{M} \mathbf{u}_m \mathbf{u}_m^\top \\ &amp; =
((\mathbf{A} + \mathbf{u}_1 \mathbf{u}_1^\top) + \cdots ) + \mathbf{u}_M \mathbf{u}_M^{\top}.
\end{align}
$$
&lt;/p&gt;
&lt;p&gt;Therefore, we can simply leverage the $O(N^2)$ procedure for Cholesky
decompositions of rank-1 updates and apply it recursively $M$ times to obtain
a $O(N^2M)$ procedure for rank-$M$ updates.&lt;/p&gt;
&lt;p&gt;Hence, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, M] [..., M, N] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3), ignores `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^2M), uses `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_factor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where function &lt;code&gt;cholesky_update_iterated&lt;/code&gt; is implemented as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# base case&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;prev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can also implement this iteratively.
First we&amp;rsquo;d use &lt;code&gt;tf.unstack&lt;/code&gt; to turn the update factor matrix $\mathbf{U}$
into a list of update factor vectors $\mathbf{u}_m$:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;update_factor_vectors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;isinstance&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# `update_factor_vectors` is a list&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;M&lt;/span&gt; &lt;span class="c1"&gt;# ... the list contains M vectors&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;Bs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# ... and each vector has shape [B1, ..., Bn, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Then, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The astute reader will recognize that this is simply an special case of
the
or
patterns, where
the &lt;em&gt;binary operator&lt;/em&gt; is &lt;code&gt;tfp.math.cholesky_update&lt;/code&gt;,
the &lt;em&gt;iterable&lt;/em&gt; is &lt;code&gt;tf.unstack(update_factor, axis=-1)&lt;/code&gt; and
the &lt;em&gt;initial value&lt;/em&gt; is &lt;code&gt;chol&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;Therefore, we can also implement it neatly using the one-liner:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;reduce&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;reduce&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="summary"&gt;Summary&lt;/h2&gt;
&lt;p&gt;In summary, we showed that to efficiently calculate the Cholesky decomposition
of a matrix perturbed by a low-rank update, one just needs to iteratively
calculate that of the same matrix perturbed by a series of rank-1 updates.
Better yet, all of this can be done with a simple one-liner!&lt;/p&gt;
&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Seeger, M. (2004). Low rank updates for the Cholesky decomposition.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>An Illustrated Guide to the Knowledge Gradient Acquisition Function</title><link>https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/</link><pubDate>Thu, 18 Feb 2021 19:13:23 +0100</pubDate><guid>https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/</guid><description>
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;Draft &amp;ndash; work in progress.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;We provide a short guide to the knowledge-gradient (KG) acquisition
function (Frazier et al., 2009)&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; for Bayesian
optimization (BO).
Rather than being a self-contained tutorial, this posts is intended to serve as
an illustrated compendium to the paper of Frazier et al., 2009&lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;
and the subsequent tutorial by Frazier, 2018&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, authored
nearly a decade later.&lt;/p&gt;
&lt;p&gt;This post assumes a basic level of familiarity with BO and Gaussian processes (GPs),
to the extent provided by the literature survey of Shahriari et al.,
2015&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;, and the acclaimed textbook of Rasmussen and Williams, 2006,
respectively.&lt;/p&gt;
&lt;h2 id="knowledge-gradient"&gt;Knowledge-gradient&lt;/h2&gt;
&lt;p&gt;First, we set-up the notation and terminology.
Let $f: \mathcal{X} \to \mathbb{R}$ be the blackbox function we wish to
minimize.
We denote the GP posterior predictive distribution, or &lt;em&gt;predictive&lt;/em&gt; for short,
by $p(y | \mathbf{x}, \mathcal{D})$.
The mean of the predictive, or the &lt;em&gt;predictive mean&lt;/em&gt; for short, is denoted by
&lt;/p&gt;
$$
\mu(\mathbf{x}; \mathcal{D}) = \mathbb{E}[y | \mathbf{x}, \mathcal{D}]
$$&lt;p&gt;
Let $\mathcal{D}_n$ be the set of $n$ input-output
observations $\mathcal{D}_n = \{ (\mathbf{x}_i, y_i) \}_{i=1}^n$, where
output $y_i = f(\mathbf{x}_i) + \epsilon$ is assumed to be observed with noise
$\epsilon \sim \mathcal{N}(0, \sigma^2)$.
We make the following abbreviation
&lt;/p&gt;
$$
\mu_n(\mathbf{x}) = \mu(\mathbf{x}; \mathcal{D}_n)
$$&lt;p&gt;
Next, we define the minimum of the predictive mean, or &lt;em&gt;predictive minimum&lt;/em&gt; for short,
as
&lt;/p&gt;
$$
\tau(\mathcal{D}) = \min_{\mathbf{x}' \in \mathcal{X}} \mu(\mathbf{x}'; \mathcal{D})
$$&lt;p&gt;
If we view $\mu(\mathbf{x}; \mathcal{D})$ as our fit to the underlying
function $f(\mathbf{x})$ from which the observations $\mathcal{D}$ were
generated, then $\tau(\mathcal{D})$ is our estimate of the minimum of $f(\mathbf{x})$,
given observations $\mathcal{D}$.&lt;/p&gt;
&lt;p&gt;Further, we make the following abbreviations
&lt;/p&gt;
$$
\tau_n = \tau(\mathcal{D}_n),
\qquad
\text{and}
\qquad
\tau_{n+1} = \tau(\mathcal{D}_{n+1}),
$$&lt;p&gt;
where $\mathcal{D}_{n+1} = \mathcal{D}_n \cup \{ (\mathbf{x}, y) \}$ is the
set of existing observations, augmented by some input-output pair $(\mathbf{x}, y)$.
Then, the knowledge-gradient is defined as
&lt;/p&gt;
$$
\alpha(\mathbf{x}; \mathcal{D}_n) =
\mathbb{E}_{p(y | \mathbf{x}, \mathcal{D}_n)} [ \tau_n - \tau_{n+1} ]
$$&lt;p&gt;
Crucially, note that $\tau_{n+1}$ is implicitly a function of $(\mathbf{x}, y)$,
and that this expression integrates over all possible input-output observation
pairs $(\mathbf{x}, y)$ for the given $\mathbf{x}$ under the
predictive $p(y | \mathbf{x}, \mathcal{D}_n)$.&lt;/p&gt;
&lt;h3 id="monte-carlo-estimation"&gt;Monte Carlo estimation&lt;/h3&gt;
&lt;p&gt;Not surprisingly, the knowledge-gradient function is analytically intractable.
Therefore, in practice, we compute it using Monte Carlo estimation,
&lt;/p&gt;
$$
\alpha(\mathbf{x}; \mathcal{D}_n) \approx
\frac{1}{M} \left ( \sum_{m=1}^M \tau_n - \tau_{n+1}^{(m)} \right ),
\qquad
y^{(m)} \sim p(y | \mathbf{x}, \mathcal{D}_n),
$$&lt;p&gt;
where $\tau_{n+1}^{(m)} = \tau(\mathcal{D}_{n+1}^{(m)})$
and $\mathcal{D}_{n+1}^{(m)} = \mathcal{D}_n \cup \{ (\mathbf{x}, y^{(m)}) \}$.&lt;/p&gt;
&lt;p&gt;We refer to $y^{(m)}$ as the $m$th simulated outcome, or the $m$th &lt;em&gt;simulation&lt;/em&gt;
for short.
Then, $\mathcal{D}_{n+1}^{(m)}$ is the $m$th simulation-augmented dataset and,
accordingly, $\tau_{n+1}^{(m)}$ is the $m$th simulation-augmented predictive minimum.&lt;/p&gt;
&lt;p&gt;We see that this approximation to the knowledge-gradient is simply the average
difference between the predictive minimum values &lt;em&gt;based on simulation-augmented
data&lt;/em&gt; $\tau_{n+1}^{(m)}$, and that &lt;em&gt;based on observed data&lt;/em&gt; $\tau_n$,
across $M$ simulations.&lt;/p&gt;
&lt;p&gt;This might take a moment to digest, as there are quite a number of moving parts
to keep track of. To help visualize these parts, we provide an illustration of
each of the steps required to compute KG on a simple one-dimensional synthetic
problem.&lt;/p&gt;
&lt;h2 id="one-dimensional-example"&gt;One-dimensional example&lt;/h2&gt;
&lt;p&gt;As the running example throughout this post, we use a synthetic function
defined as
&lt;/p&gt;
$$
f(x) = \sin(3x) + x^2 - 0.7 x.
$$&lt;p&gt;
We generate $n=10$ observations at locations sampled uniformly at random.
The true function, and the set of noisy observations $\mathcal{D}_n$ are
visualized in the figure below:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/observations_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Latent blackbox function and $n=10$ observations.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Using the observations $\mathcal{D}_n$ we have collected so far, we wish to
use KG to score a candidate location $x_c$ at which to evaluate next.&lt;/p&gt;
&lt;h2 id="posterior-predictive-distribution"&gt;Posterior predictive distribution&lt;/h2&gt;
&lt;p&gt;The posterior predictive $p(y | \mathbf{x}, \mathcal{D}_n)$ is visualized in
the figure below. In particular, the predictive mean $\mu_n(\mathbf{x})$ is
represented by the solid orange curve.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_mean_before_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Posterior predictive distribution (*before* hyperparameter estimation).&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Clearly, this is a poor fit to the data and a uncalibrated estimation of the
predictive uncertainly.&lt;/p&gt;
&lt;h3 id="step-1-hyperparameter-estimation"&gt;Step 1: Hyperparameter estimation&lt;/h3&gt;
&lt;p&gt;Therefore, first step is to optimize the hyperparameters of the GP regression
model, i.e. the kernel lengthscale, amplitude, and the observation noise variance.
We do this using type-II maximum likelihood estimation (MLE), or &lt;em&gt;empirical Bayes&lt;/em&gt;.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_mean_after_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Posterior predictive distribution (*after* hyperparameter estimation).&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-2-determine-the-predictive-minimum"&gt;Step 2: Determine the predictive minimum&lt;/h3&gt;
&lt;p&gt;Next, we compute the predictive minimum $\tau_n = \min_{\mathbf{x}' \in \mathcal{X}} \mu_n(\mathbf{x}')$.
Since $\mu_n$ is end-to-end differentiable wrt to input $\mathbf{x}$, we can
simply use a multi-started quasi-Newton hill-climber such as L-BFGS.
We visualize this in the figure below, where the value of the predictive
minimum is represented by the orange horizontal dashed line, and its location is
denoted by the orange star and triangle.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_minimum_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Predictive minimum $\tau_n$.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-3-compute-simulation-augmented-predictive-means"&gt;Step 3: Compute simulation-augmented predictive means&lt;/h3&gt;
&lt;p&gt;Suppose we are scoring the candidate location $x_c = 0.1$.
For illustrative purposes, let us draw just $M=1$ sample $y_c^{(1)} \sim p(y | x_c, \mathcal{D}_n)$.
In the figure below, the candidate location $x_c$ is represented by the
vertical solid gray line, and the single simulated outcome $y_c^{(1)}$ is
represented by the filled blue dot.&lt;/p&gt;
&lt;p&gt;In general, we denote the simulation-augmented predictive mean as
&lt;/p&gt;
$$
\mu_{n+1}^{(m)}(\mathbf{x}) = \mu(\mathbf{x}; \mathcal{D}_{n+1}^{(m)}),
$$&lt;p&gt;
where
$\mathcal{D}_{n+1}^{(m)} = \mathcal{D}_n \cup \{ (\mathbf{x}, y^{(m)}) \}$
as defined earlier.&lt;/p&gt;
&lt;p&gt;Here, the simulation-augmented dataset $\mathcal{D}_{n+1}^{(1)}$ is the set
of existing observations $\mathcal{D}_n$, augmented by the simulated
input-output pair $(x_c, y_c^{(1)})$,
&lt;/p&gt;
$$
\mathcal{D}_{n+1}^{(1)} = \mathcal{D}_n \cup \{ (x_c, y_c^{(1)}) \},
$$&lt;p&gt;
and the corresponding simulation-augmented predictive mean $\mu_{n+1}^{(1)}(x)$
is represented in the figure below by the solid blue curve.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/simulated_predictive_mean_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive mean $\mu_{n&amp;#43;1}^{(1)}(x)$ at location $x_c = 0.1$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-4-compute-simulation-augmented-predictive-minimums"&gt;Step 4: Compute simulation-augmented predictive minimums&lt;/h3&gt;
&lt;p&gt;Next, we compute the simulation-augmented predictive minimum
&lt;/p&gt;
$$
\tau_{n+1}^{(1)} = \min_{\mathbf{x}' \in \mathcal{X}} \mu_{n+1}^{(1)}(\mathbf{x}')
$$&lt;p&gt;
It may not be immediately obvious, but $\mu_{n+1}^{(1)}$ is in fact also
end-to-end differentiable wrt to input $\mathbf{x}$. Therefore, we can again
appeal to an method such as L-BFGS.
We visualize this in the figure below, where the value of the simulation-augmented
predictive minimum is represented by the blue horizontal dashed line, and its
location is denoted by the blue star and triangle.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/simulated_predictive_minimum_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive minimum $\tau_{n&amp;#43;1}^{(1)}$ at location $x_c = 0.1$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Taking the difference between the orange and blue horizontal dashed line will
give us an unbiased estimate of the knowledge-gradient.
However, this is likely to be a crude one, since it is based on just a single
MC sample.
To obtain a more accurate estimate, one needs to increase $M$, the number of
MC samples.&lt;/p&gt;
&lt;h4 id="samples"&gt;Samples $M &gt; 1$&lt;/h4&gt;
&lt;p&gt;Let us now consider $M=5$ samples. We draw $y_c^{(m)} \sim p(y | x_c, \mathcal{D}_n)$,
for $m = 1, \dotsc, 5$.
As before, the input location $x_c$ is represented by the vertical solid
gray line, and the corresponding simulated outcomes are represented by the
filled dots below, with varying hues from a perceptually uniform color palette
to distinguish between samples.&lt;/p&gt;
&lt;p&gt;Accordingly, the simulation-augmented predictive means
$\mu_{n+1}^{(m)}(x)$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$ are
represented by the colored curves, with hues set to that of the simulated
outcome on which the predictive distribution is based.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/bar_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive mean $\mu_{n&amp;#43;1}^{(m)}(x)$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Next we compute the simulation-augmented predictive
minimum $\tau_{n+1}^{(m)}$, which requires minimizing
$\mu_{n+1}^{(m)}(x)$ for $m = 1, \dotsc, 5$.
These values are represented below by the horizontal dashed lines, and their
location is denoted by the stars and triangles.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/baz_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive minimum $\tau_{n&amp;#43;1}^{(1)}$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Finally, taking the average difference between the orange dashed line and every
other dashed line gives us the estimate of the knowledge gradient at
input $x_c$.&lt;/p&gt;
&lt;h2 id="links-and-further-readings"&gt;Links and Further Readings&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;In this post, we only showed a (naïve) approach to calculating the KG at a
given location.
Suffice it to say, there is still quite a gap between this and being able to
efficiently minimize KG within a sequential decision-making algorithm.
For a guide on incorporating KG in a modular and fully-fledged framework for
BO (namely
) see
&lt;/li&gt;
&lt;li&gt;Another introduction to KG:
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2021knowledge,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{A}n {I}llustrated {G}uide to the {K}nowledge {G}radient {A}cquisition {F}unction&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2021&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Frazier, P., Powell, W., &amp;amp; Dayanik, S. (2009).
. INFORMS Journal on Computing, 21(4), 599-613.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Frazier, P. I. (2018).
. arXiv preprint arXiv:1807.02811.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., &amp;amp; De Freitas, N. (2015).
. Proceedings of the IEEE, 104(1), 148-175.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Building Probability Distributions with the TensorFlow Probability Bijector API</title><link>https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/</link><pubDate>Mon, 30 Jul 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/</guid><description>&lt;p&gt;TensorFlow Distributions, now under the broader umbrella of
, is a fantastic TensorFlow library for efficient and
composable manipulation of probability distributions&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Among the many features it has to offer, one of the most powerful in my opinion
is the &lt;code&gt;Bijector&lt;/code&gt; API, which provide the modular building blocks necessary to
construct a broad class of probability distributions.
Instead of describing it any further in the abstract, let&amp;rsquo;s dive right in with
a simple example.&lt;/p&gt;
&lt;h2 id="example-banana-shaped-distribution"&gt;Example: Banana-shaped distribution&lt;/h2&gt;
&lt;p&gt;Consider the &lt;em&gt;banana-shaped distribution&lt;/em&gt;, a commonly-used testbed for adaptive
MCMC methods&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;.
Denote the density of this distribution as $p_{Y}(\mathbf{y})$.
To illustrate, 1k samples randomly drawn from this distribution are shown below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana distribution samples"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_samples.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The underlying process that generates samples
$\tilde{\mathbf{y}} \sim p_{Y}(\mathbf{y})$ is simple to describe,
and is of the general form,&lt;/p&gt;
$$
\tilde{\mathbf{y}} \sim p_{Y}(\mathbf{y}) \quad
\Leftrightarrow \quad
\tilde{\mathbf{y}} = G(\tilde{\mathbf{x}}),
\quad \tilde{\mathbf{x}} \sim p_{X}(\mathbf{x}).
$$&lt;p&gt;In other words, a sample $\tilde{\mathbf{y}}$ is the output of a transformation
$G$, given a sample $\tilde{\mathbf{x}}$ drawn from some underlying
base distribution $p_{X}(\mathbf{x})$.&lt;/p&gt;
&lt;p&gt;However, it is not as straightforward to compute an analytical expression for
density $p_{Y}(\mathbf{y})$.
In fact, this is only possible if $G$ is a &lt;em&gt;differentiable&lt;/em&gt; and &lt;em&gt;invertible&lt;/em&gt;
transformation (a &lt;em&gt;diffeomorphism&lt;/em&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;), and if there is an analytical
expression for $p_{X}(\mathbf{x})$.&lt;/p&gt;
&lt;p&gt;Transformations that fail to satisfy these conditions (which includes something
as simple as a multi-layer perceptron with non-linear activations) give rise to
&lt;em&gt;implicit distributions&lt;/em&gt;, and will be the subject of many posts to come.
But for now, we will restrict our attention to diffeomorphisms.&lt;/p&gt;
&lt;h3 id="base-distribution"&gt;Base distribution&lt;/h3&gt;
&lt;p&gt;Following on with our example, the base distribution $p_{X}(\mathbf{x})$ is
given by a two-dimensional Gaussian with unit variances and covariance
$\rho = 0.95$:&lt;/p&gt;
$$
p_{X}(\mathbf{x}) = \mathcal{N}(\mathbf{x} | \mathbf{0}, \mathbf{\Sigma}),
\qquad
\mathbf{\Sigma} =
\begin{bmatrix}
1 &amp; 0.95 \newline
0.95 &amp; 1
\end{bmatrix}
$$&lt;p&gt;This can be encapsulated by an instance of
,
which is parameterized by a lower-triangular matrix.
First let&amp;rsquo;s import TensorFlow Distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow.contrib.distributions&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfd&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Then we create the lower-triangular matrix and the instantiate the distribution:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)[::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([[&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="p"&gt;]],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MultivariateNormalTriL&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scale_tril&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As with all subclasses of &lt;code&gt;tfd.Distribution&lt;/code&gt;, we can evaluated the probability
density function of this distribution by calling the &lt;code&gt;p_x.prob&lt;/code&gt; method.
Evaluating this on an uniformly-spaced grid yields the equiprobability contour
plot below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Base density"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_base_density.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="forward-transformation"&gt;Forward Transformation&lt;/h3&gt;
&lt;p&gt;The required transformation $G$ is defined as:&lt;/p&gt;
$$
G(\mathbf{x}) =
\begin{bmatrix}
x_1 \newline
x_2 - x_1^2 - 1 \newline
\end{bmatrix}
$$&lt;p&gt;We implement this in the &lt;code&gt;_forward&lt;/code&gt; function below&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can now use this to generate samples from $p_{Y}(\mathbf{y})$.
To do this we first sample from the base distribution $p_{X}(\mathbf{x})$ by
calling &lt;code&gt;p_x.sample&lt;/code&gt;. For this illustration, we generate 1k samples, which is
specified through the &lt;code&gt;sample_shape&lt;/code&gt; argument. We then transform these samples
through $G$ by calling &lt;code&gt;_forward&lt;/code&gt; on them.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The figure below contains scatterplots of the 1k samples &lt;code&gt;x_samples&lt;/code&gt; (left)
and the transformed &lt;code&gt;y_samples&lt;/code&gt; (right):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana and base samples"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_base_samples.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="instantiating-a-transformeddistribution-with-a-bijector"&gt;Instantiating a &lt;code&gt;TransformedDistribution&lt;/code&gt; with a &lt;code&gt;Bijector&lt;/code&gt;&lt;/h3&gt;
&lt;p&gt;Having specified the forward transformation and the underlying distribution, we
have now fully described the sample generation process, which is the bare
minimum necessary to define a probability distribution.&lt;/p&gt;
&lt;p&gt;The forward transformation is also the &lt;em&gt;first&lt;/em&gt; of &lt;strong&gt;three&lt;/strong&gt; operations needed to
fully specify a &lt;code&gt;Bijector&lt;/code&gt;, which can be used to instantiate a
&lt;code&gt;TransformedDistribution&lt;/code&gt; that encapsulates the banana-shaped distribution.&lt;/p&gt;
&lt;h4 id="creating-a-bijector"&gt;Creating a &lt;code&gt;Bijector&lt;/code&gt;&lt;/h4&gt;
&lt;p&gt;First, let&amp;rsquo;s subclass &lt;code&gt;Bijector&lt;/code&gt; to define the &lt;code&gt;Banana&lt;/code&gt; bijector and implement
the forward transformation as an instance method:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bijector&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;banana&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Note that we need to specify either &lt;code&gt;forward_min_event_ndims&lt;/code&gt; or
&lt;code&gt;inverse_min_event_ndims&lt;/code&gt;, the number of dimensions the forward or inverse
transformation operate on (which can sometimes differ).
In our example, both the inverse and forward transformation operate on vectors
(rank 1 tensors), so we set &lt;code&gt;inverse_min_event_ndims=1&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;With an instance of the &lt;code&gt;Banana&lt;/code&gt; bijector, we can call the &lt;code&gt;forward&lt;/code&gt; method on
&lt;code&gt;x_samples&lt;/code&gt; to produce &lt;code&gt;y_samples&lt;/code&gt; as before:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h4 id="instantiating-a-transformeddistribution"&gt;Instantiating a &lt;code&gt;TransformedDistribution&lt;/code&gt;&lt;/h4&gt;
&lt;p&gt;More importantly, we can now create a &lt;code&gt;TransformedDistribution&lt;/code&gt; with the base
distribution &lt;code&gt;p_x&lt;/code&gt; and an instance of the &lt;code&gt;Banana&lt;/code&gt; bijector:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TransformedDistribution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;distribution&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bijector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This now allows us to directly sample from &lt;code&gt;p_y&lt;/code&gt; just as we could with &lt;code&gt;p_x&lt;/code&gt;,
and any other TensorFlow Probability &lt;code&gt;Distribution&lt;/code&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p_y&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Neat!&lt;/p&gt;
&lt;h3 id="probability-density-function"&gt;Probability Density Function&lt;/h3&gt;
&lt;p&gt;Although we can now sample from this distribution, we have yet to define the
operations necessary to evaluate its probability density function&amp;mdash;the
remaining &lt;em&gt;two&lt;/em&gt; of &lt;strong&gt;three&lt;/strong&gt; operations needed to fully specify a &lt;code&gt;Bijector&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;Indeed, calling &lt;code&gt;p_y.prob&lt;/code&gt; at this stage would simply raise a
&lt;code&gt;NotImplementedError&lt;/code&gt; exception. So what else do we need to define?&lt;/p&gt;
&lt;p&gt;Recall the probability density of $p_{Y}(\mathbf{y})$ is given by:&lt;/p&gt;
$$
p_{Y}(\mathbf{y}) = p_{X}(G^{-1}(\mathbf{y})) \mathrm{det}
\left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
$$&lt;p&gt;Hence we need to specify the inverse transformation $G^{-1}(\mathbf{y})$ and its
Jacobian determinant
$\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )$.&lt;/p&gt;
&lt;p&gt;For numerical stability, the &lt;code&gt;Bijector&lt;/code&gt; API requires that this be defined in
log-space. Hence, it is useful to recall that the forward and inverse log
determinant Jacobians differ only in their signs&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;,&lt;/p&gt;
$$
\begin{align}
\log \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = - \log \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right ),
\end{align}
$$&lt;p&gt;which gives us the option of implementing either (or both).
However, do note the following from the official
API docs:&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;Generally its preferable to directly implement the inverse Jacobian
determinant. This should have superior numerical stability and will often share
subgraphs with the &lt;code&gt;_inverse&lt;/code&gt; implementation.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="inverse-transformation"&gt;Inverse Transformation&lt;/h3&gt;
&lt;p&gt;So let&amp;rsquo;s implement the inverse transform $G^{-1}$, which is given by:&lt;/p&gt;
$$
G^{-1}(\mathbf{y}) =
\begin{bmatrix}
y_1 \newline
y_2 + y_1^2 + 1 \newline
\end{bmatrix}
$$&lt;p&gt;We define this in the &lt;code&gt;_inverse&lt;/code&gt; function below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="jacobian-determinant"&gt;Jacobian determinant&lt;/h3&gt;
&lt;p&gt;Now we compute the log determinant of the Jacobian of the &lt;em&gt;inverse&lt;/em&gt;
transformation.
In this simple example, the transformation is &lt;em&gt;volume-preserving&lt;/em&gt;, meaning its
Jacobian determinant is equal to 1.&lt;/p&gt;
&lt;p&gt;This is easy to verify:&lt;/p&gt;
$$
\begin{align}
\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = \mathrm{det}
\begin{pmatrix}
\frac{\partial}{\partial y_1} y_1 &amp; \frac{\partial}{\partial y_2} y_1 \newline
\frac{\partial}{\partial y_1} y_2 + y_1^2 + 1 &amp; \frac{\partial}{\partial y_2} y_2 + y_1^2 + 1 \newline
\end{pmatrix} \newline
&amp; = \mathrm{det}
\begin{pmatrix}
1 &amp; 0 \newline
2 y_1 &amp; 1 \newline
\end{pmatrix}
= 1
\end{align}
$$&lt;p&gt;Hence, the log determinant Jacobian is given by zeros shaped like input &lt;code&gt;y&lt;/code&gt;, up
to the last &lt;code&gt;inverse_min_event_ndims=1&lt;/code&gt; dimensions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Since the log determinant Jacobian is constant, i.e. independent of the input,
we can just specify it for one input by setting the flag &lt;code&gt;is_constant_jacobian=True&lt;/code&gt;&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;,
and the &lt;code&gt;Bijector&lt;/code&gt; class will handle the necessary shape inference for us.&lt;/p&gt;
&lt;p&gt;Putting it all together in the &lt;code&gt;Banana&lt;/code&gt; bijector subclass, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bijector&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;banana&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;is_constant_jacobian&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Finally, we can instantiate distribution &lt;code&gt;p_y&lt;/code&gt; by calling
&lt;code&gt;tfd.TransformedDistribution&lt;/code&gt; as we did before &lt;em&gt;et voilà&lt;/em&gt;,
we can now simply call &lt;code&gt;p_y.prob&lt;/code&gt; to evaluate the probability density function.&lt;/p&gt;
&lt;p&gt;Evaluating this on the same uniformly-spaced grid as before yields the following
equiprobability contour plot:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana density"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_density.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="inline-bijector"&gt;Inline Bijector&lt;/h4&gt;
&lt;p&gt;Before we conclude, we note that instead of creating a subclass, one can also
opt for a more lightweight and functional approach by creating an
bijector:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;banana&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Inline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;forward_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_log_det_jacobian_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;is_constant_jacobian&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TransformedDistribution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;distribution&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bijector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;banana&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;!-- ### Swiss roll distribution
$$
\begin{align}
y_1 &amp; = r \cos x_1 \newline
y_2 &amp; = r \sin x_1
\end{align}
$$
where
$$
r = a x_1 + b x_2
$$
for $a = \frac{2}{5}$ and $b = 1$
for $x_1$ in range 5 to 10 and $x_2 = 0$
### Pinwheel distribution --&gt;
&lt;h1 id="summary"&gt;Summary&lt;/h1&gt;
&lt;p&gt;In this post, we showed that using diffeomorphisms&amp;mdash;mappings that are
differentiable and invertible, it is possible transform standard distributions
into interesting and complicated distributions, while still being able to
compute their densities analytically.&lt;/p&gt;
&lt;p&gt;The &lt;code&gt;Bijector&lt;/code&gt; API provides an interface that encapsulates the basic properties
of a diffeomorphism needed to transform a distribution. These are: the
forward transform itself, its inverse and the determinant of their Jacobians.&lt;/p&gt;
&lt;p&gt;Using this, &lt;code&gt;TransformedDistribution&lt;/code&gt; &lt;em&gt;automatically&lt;/em&gt; implements perhaps the two
most important methods of a probability distribution: sampling (&lt;code&gt;sample&lt;/code&gt;), and
density evaluation (&lt;code&gt;prob&lt;/code&gt;).&lt;/p&gt;
&lt;p&gt;Needless to say, this is a very powerful combination.
Through the &lt;code&gt;Bijector&lt;/code&gt; API, the number of possible distributions that can be
implemented and used directly with other functionalities in the TensorFlow
Probability ecosystem effectively becomes &lt;em&gt;endless&lt;/em&gt;.&lt;/p&gt;
&lt;!-- And I haven't even mentioned the fact that you can easily *parameterize* and
*compose* `Bijector`s to implement *normalizing flows* such as the
*autoregressive flows*!
--&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2018bijector,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{B}uilding {P}robability {D}istributions with the {T}ensor{F}low {P}robability {B}ijector {API}&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2018&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/building-probability-distributions-with-tensorflow-probability-bijector-api/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="links--resources"&gt;Links &amp;amp; Resources&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;Try this out yourself in a
.&lt;/li&gt;
&lt;li&gt;Paper: see footnote&lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;Blog Post:
&lt;/li&gt;
&lt;li&gt;API Documentation:
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Dillon, J.V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., Patton, B., Alemi, A., Hoffman, M. and Saurous, R.A., 2017. &lt;em&gt;TensorFlow Distributions.&lt;/em&gt;
.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Haario, H., Saksman, E., &amp;amp; Tamminen, J. (1999).
. &lt;em&gt;Computational Statistics&lt;/em&gt;, 14(3), 375-396.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;for the transformation to be a diffeomorphism, it also needs to be &lt;em&gt;smooth&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;we implement this for the general case of $K \geq 2$ dimensional inputs since this actually turns out to be easier and cleaner (a phenomenon known as
).&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;this is a straightforward consequence of the
which says the matrix inverse of the Jacobian of $G$ is the Jacobian of
its inverse $G^{-1}$,
&lt;/p&gt;
$$
\frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) =
\left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1}
$$&lt;p&gt;
Taking the determinant of both sides, we get:
&lt;/p&gt;
$$
\begin{align}
\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = \mathrm{det} \left ( \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1} \right ) \newline
&amp; = \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1}
\end{align}
$$&lt;p&gt;
as required.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;See description of
argument for further details.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>