<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Machine Learning |</title><link>https://tiao.io/tags/machine-learning/</link><atom:link href="https://tiao.io/tags/machine-learning/index.xml" rel="self" type="application/rss+xml"/><description>Machine Learning</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Sun, 01 Feb 2026 00:00:00 +0000</lastBuildDate><image><url>https://tiao.io/media/icon_hu_9c2a75fde2335590.png</url><title>Machine Learning</title><link>https://tiao.io/tags/machine-learning/</link></image><item><title>Empirical Gaussian Processes</title><link>https://tiao.io/publications/empirical-gaussian-processes/</link><pubDate>Sun, 01 Feb 2026 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/empirical-gaussian-processes/</guid><description/></item><item><title>Ax: A Platform for Adaptive Experimentation</title><link>https://tiao.io/publications/ax-platform/</link><pubDate>Mon, 01 Sep 2025 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/ax-platform/</guid><description/></item><item><title>Probabilistic Machine Learning in the Age of Deep Learning: New Perspectives for Gaussian Processes, Bayesian Optimization and Beyond (PhD Thesis)</title><link>https://tiao.io/publications/phd-thesis/</link><pubDate>Fri, 01 Sep 2023 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/phd-thesis/</guid><description>&lt;p&gt;The full text is available as a single PDF file &lt;a href="phd-thesis-louis-tiao.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;You can also find a list of contents and PDFs corresponding to each individual chapter below:&lt;/p&gt;
&lt;h3 id="table-of-contents"&gt;Table of Contents&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;Chapter 1: Introduction &lt;a href="contents/1 Introduction.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Chapter 2: Background &lt;a href="contents/2 Background.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Chapter 3: Orthogonally-Decoupled Sparse Gaussian Processes with Spherical Neural Network Activation Features &lt;a href="contents/3 Orthogonally-Decoupled Sparse Gaussian Processes with Spherical Neural Network Activation Features.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Chapter 4: Cycle-Consistent Generative Adversarial Networks as a Bayesian Approximation &lt;a href="contents/4 Cycle-Consistent Generative Adversarial Networks as a Bayesian Approximation.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Chapter 5: Bayesian Optimisation by Classification with Deep Learning and Beyond &lt;a href="contents/5 Bayesian Optimisation by Classification with Deep Learning and Beyond.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Chapter 6: Conclusion &lt;a href="contents/6 Conclusion.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Appendix A: Numerical Methods for Improved Decoupled Sampling of Gaussian Processes &lt;a href="contents/A Numerical Methods for Improved Decoupled Sampling of Gaussian Processes.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Bibliography &lt;a href="contents/Bibliography.pdf" target="_blank" rel="noopener"&gt;
&lt;span class="inline-block pr-1"&gt;
&lt;svg style="height: 1em; transform: translateY(0.1em);" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3"/&gt;&lt;/svg&gt;
&lt;/span&gt;&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Please find &lt;em&gt;Chapter 1: Introduction&lt;/em&gt; reproduced in full below:&lt;/p&gt;
&lt;h3 id="introduction"&gt;Introduction&lt;/h3&gt;
&lt;p&gt;Artificial intelligence (AI) stands poised to be among the most disruptive technologies of our era. The breakneck pace of recent AI advancements has been spearheaded by machine learning (ML), particularly the resurgence of &lt;em&gt;deep learning&lt;/em&gt;. Deep learning is as old as the first general-purpose electronic computer; with roots tracing back to the 1940s and ’50s (
;
), the revival of deep learning, beginning in the early 2010s, was catalysed by a series of breakthroughs that shattered previously perceived limitations and captivated the collective imagination. These breakthroughs span various domains, including computer vision (
;
;
;
), speech recognition (
;
), natural language processing (
;
), protein folding (
), generative art and artificial creativity  (
;
;
;
), as well as reinforcement learning for robotics control (
;
) and achieving superhuman-level gameplay (
;
).&lt;/p&gt;
&lt;p&gt;Nevertheless, it is crucial to view these developments as means to an ultimate end rather than an end in themselves. Arguably, the true pinnacle of AI’s capabilities lies in optimal &lt;em&gt;decision-making&lt;/em&gt;, whether that entails offering analyses and insights to aid humans in making better decisions or completely automating the decision-making process altogether. Practically any task directed towards a well-defined objective can be boiled down to a cascade of decisions. At a fundamental level, operating a vehicle involves a continuous stream of decisions involving accelerating, braking, and turning. Financial trading revolves around decisions to buy, sell, or hold various assets. Even complex engineering tasks, such as designing an aerofoil, involve a sequence of decisions about adjusting design variables to achieve desirable aerodynamic characteristics.&lt;/p&gt;
&lt;p&gt;Yet, the intricacies of decision-making surpass what any single advancement in deep learning can address. While convolutional neural networks (CNNs) can facilitate object detection tasks in autonomous vehicles, recurrent neural networks (RNNs) can aid in forecasting market dynamics for systematic trading, and physics-informed NNs can assist in predicting aerodynamic effects, it remains the case that no target or quantity of interest can be entirely known or predictable (indeed, if they were, the pursuit of predictive modelling and ML would be superfluous). Instead, predictions often prove unreliable, or at best, &lt;em&gt;uncertain&lt;/em&gt;, due to the limitations of our knowledge and the complexity and variability inherent in the underlying real-world processes. The impressive power of deep learning models often overshadows their ignorance of the limits of their own knowledge and the extent of uncertainty in their predictions. When these predictions are integrated into a sequential decision-making framework, such uncertainty can amplify, compound, and lead to catastrophic consequences. In the context of aeronautical engineering, this could result in inefficient designs; in quantitative finance, it can lead to devastating capital losses; and in autonomous driving, it can even cost lives.&lt;/p&gt;
&lt;h4 id="probabilistic-machine-learning"&gt;Probabilistic Machine Learning&lt;/h4&gt;
&lt;p&gt;Grounded in the laws of probability and Bayesian statistics (
;
), &lt;em&gt;probabilistic&lt;/em&gt; ML provides a consistent framework for systematically reasoning about the unknown. The probabilistic approach to ML acknowledges that the real world is fraught with uncertainty and embraces this uncertainty as an inherent part of decision-making. Unlike traditional methods, including those of deep learning, it recognises model predictions not as absolute truths that can be represented as single &lt;em&gt;point estimates&lt;/em&gt; produced from a deterministic mapping, but as full &lt;em&gt;probability distributions&lt;/em&gt; that capture the potential outcomes of a random variable as it propagates through some underlying data-generating process. In a &lt;em&gt;probabilistic model&lt;/em&gt;, all quantities are treated as random variables governed by probability distributions – the data are treated as observed variables, which are influenced by some underlying hidden variables, e.g., the model parameters. A prior distribution is used to express reasonable values for these hidden variables and to eliminate implausible ones. The relationship between observed and hidden variables is described using the likelihood, and the process of Bayesian inference amounts to calculating, using basic laws of probability, a posterior distribution over the hidden factors conditioned on the observed data, which can be seen as a refinement of the prior beliefs in light of new evidence. While the posterior distribution can be useful in and of itself, its primary role lies in facilitating subsequent prediction and decision-making by providing full probability distributions over predicted outcomes. This capability allows the decision-maker to assess the range of possible scenarios and their associated probabilities, enabling a more nuanced understanding of uncertainty and risk, which is indispensable in complex, dynamic environments where the repercussions of incorrect decisions can be severe. In essence, probabilistic ML equips autonomous decision-making systems with a probabilistic worldview, enabling them to navigate ambiguity and make sound decisions in the face of imperfect information.&lt;/p&gt;
&lt;h4 id="probabilistic-ml-vs-deep-learning"&gt;Probabilistic ML vs. Deep Learning&lt;/h4&gt;
&lt;p&gt;While deep learning has dominated recent AI advances, probabilistic ML remains as important as ever and continues to offer valuable tools for addressing AI challenges that can not be fully resolved by deep learning alone. Although both approaches can be combined to create hybrid methods that leverage their respective strengths, some defining characteristics have traditionally set deep learning apart from probabilistic ML. Perhaps most notably, probabilistic ML approaches can achieve remarkable predictive performance even when data is scarce. In contrast, deep learning models tend to be data-intensive by nature, often demanding datasets of a scale proportional to their size (i.e., their parameter count) (
), which has seen explosive growth in recent years (
;
;
;
;
). With that being said, inference in many probabilistic models poses computational problems that are difficult to scale. On the other hand, deep learning approaches have excelled in scalability, a key factor contributing to their widespread success. This scalability is bolstered by their compatibility with various speed-enhancing mechanisms such as stochastic optimisation, specialised hardware accelerators (GPUs and TPUs), as well as distributed and/or cloud-based computing infrastructure. To bridge this gap, substantial research effort has been devoted to enabling probabilistic ML to benefit from these advantages through optimisation-based approximations to Bayesian inference (
).&lt;/p&gt;
&lt;p&gt;Moreover, as mentioned earlier, these paradigms are by no means mutually exclusive. Indeed, it is often possible to directly extend existing models with a Bayesian treatment of their parameters, adding a layer of probabilistic reasoning to the model, and allowing it to not only make predictions but also estimate the uncertainty associated with those predictions. An excellent example is the BNN, which treats the weights as hidden variables and leverages posterior inference to provide predictions while estimating associated uncertainties, delivering a more robust and principled approach to deep learning (
;
;
).&lt;/p&gt;
&lt;p&gt;The Bayesian formalism naturally gives rise to many popular methods and paradigms, often in the form of point estimates or other kinds of approximations. The quintessential example of this is found in linear regression, in particular, in ridge and lasso regression (
), which correspond variously to maximum &lt;em&gt;a posteriori&lt;/em&gt; (MAP) estimates in Bayesian linear regression (BLR) models with prior distributions possessing different sparsity-inducing characteristics (
) – more broadly, mitigations against over-fitting tend to arise organically in Bayesian methods, which is why they are frequently characterised as being fundamentally more robust against over-fitting (
). Likewise, the once &lt;em&gt;à la mode&lt;/em&gt; support vector machines (SVMs) can be seen as MAP estimates for a class of nonparametric Bayesian models (
), dropout (
) in NNs can be seen as a variational approximation to exact inference in BNNs (
), and unsupervised learning methods such as factor analysis (FA) (
) and principal component analysis (PCA) (
) are instances of a class of LVMs (
;
) known as linear-Gaussian factor models (
), to name just a few examples. Time and again, classical approaches have not only benefitted from being viewed through the Bayesian perspective but have also been enriched and redefined by the depth of insights this framework provides.&lt;/p&gt;
&lt;h3 id="thesis-goals"&gt;Thesis Goals&lt;/h3&gt;
&lt;p&gt;The over-arching goal of this thesis is to continue advancing the integration and cross-pollination between deep learning and probabilistic ML. We aim to further the interplay between these two fields, both by incorporating probabilistic interpretations and uncertainty quantification into popular deep learning frameworks, and by leveraging the representational power of deep NNs to improve established Bayesian methods. This dual-pronged approach provides fresh perspectives and taps the complementary strengths of both paradigms, advancing the foundations of AI and facilitating the development of more capable and dependable decision support frameworks. Ultimately, we strive to unlock the potential of deep learning within high-impact probabilistic ML methodologies, and to lend useful Bayesian perspectives on current deep learning techniques.&lt;/p&gt;
&lt;h4 id="gaussian-process-models"&gt;Gaussian Process Models&lt;/h4&gt;
&lt;p&gt;Arguably, no family of probabilistic models embodies the ethos of probabilistic ML and illustrates its nuances and parallels with deep learning quite like the GP. Accordingly, they shall occupy a prominent place in our thesis. In particular, GPs stand out as the ideal choice when dealing with limited data, offer the flexibility to encode prior beliefs through the covariance function, and provide predictive uncertainty estimates with a fine calibration that is second to none. Conversely, they are challenging to scale to large datasets, a limitation that has spurred extensive research and development efforts. Furthermore, in contrast to deep learning models, which are often lauded for their ability to automatically uncover valuable patterns and features in data, GPs have at times been dismissed as unsophisticated smoothing mechanisms (
). Despite these apparent disparities, GPs are intricately connected to NNs in numerous ways. Among these, one of the most classical and well-known relationships is the convergence of single-layer NNs with randomly initialised weights toward GPs in the infinite-width limit (
). Similar links have also been identified between GPs and infinitely wide &lt;em&gt;deep&lt;/em&gt; NNs (
;
).&lt;/p&gt;
&lt;p&gt;In an effort to elevate the representational capabilities of GPs to a level comparable with deep NNs, DGPs (
) stack together multiple layers of GPs. Additional efforts to construct efficient sparse GP approximations have leveraged the advantageous properties of computations on the hypersphere (
), which has led to deep GP (DGP) models in which the propagation of posterior predictive means is equivalent to a forward pass through a deep neural network (NN) (
;
). Notably, as a side effect, this model effectively provides uncertainty estimates for deep NN through its predictive variance. Among the contributions of our thesis is the further development of this framework, integrating cutting-edge techniques (
;
) to address some of its practical limitations, thereby narrowing the performance gap between GPs and deep NNs.&lt;/p&gt;
&lt;p&gt;Probabilistic models, serving a crucial role as decision support tools, routinely aid scientific discovery in fields such as physics and astronomy, guiding advancements in areas of medicine and healthcare encompassing bioinformatics, epidemiology, and medical diagnosis. Beyond that, these models have wide-ranging applications in economics, econometrics, and the social sciences. Moreover, they are indispensable in various engineering disciplines, such as robotics and environmental engineering. Among the many probabilistic models, GPs stand out as a powerful driving force behind a number of important sequential decision-making frameworks, including active learning (
) and reinforcement learning (
), and the broader area of probabilistic numerics at large (
). Notably, Bayesian optimisation (BO) (
;
;
) is one major area that relies heavily on GPs and will feature extensively in our thesis.&lt;/p&gt;
&lt;h4 id="bayesian-optimisation"&gt;Bayesian optimisation&lt;/h4&gt;
&lt;p&gt;BO is a powerful methodology dedicated to the global optimisation of complex and resource-intensive objective functions. In contrast to classical optimisation methods, BO excels even when dealing with functions that lack strong assumptions or guarantees. These functions may not be convex, possess no gradients, lack a well-defined mathematical form, and observable only indirectly through noisy measurements.&lt;/p&gt;
&lt;p&gt;At its core, BO is a sequential decision-making algorithm.&lt;/p&gt;
&lt;p&gt;It relies on observations from past function evaluations to determine the next candidate location for evaluation in pursuit of optimal solutions. BO leverages a probabilistic model, often a GP, to represent its knowledge and beliefs about the unknown function. This model is continuously updated with the acquisition of each new observation, enabling the algorithm to adapt its behaviour and make sound decisions based on the evolving information.&lt;/p&gt;
&lt;p&gt;BO effectively manages uncertainty inherent in such sequential decision-making processes by making use of the probabilistic model to the fullest, harnessing the entire predictive distribution, particularly, the predictive uncertainty, to select promising candidate solutions that bring the most value to the optimisation process. This generally consists not merely of those most likely to optimise the objective function (i.e., &lt;em&gt;exploiting&lt;/em&gt; that which is known), but also those likely to reveal the most knowledge and information about the function itself (i.e., &lt;em&gt;exploring&lt;/em&gt; that which remains unknown).&lt;/p&gt;
&lt;p&gt;This pronounced emphasis on well-calibrated uncertainty distinguishes BO as one of the standout “killer apps” for GPs and a jewel in the crown of probabilistic ML applications. In practice, BO has proven instrumental across science, engineering, and industry, where efficiency and cost-effectiveness are paramount. Its applications include protein engineering (
;
), material discovery (
), experimental physics (e.g., experiments involving ultra-cold atoms (
) and free-electron lasers (
)), environmental monitoring (sensor placement) (
;
), and the design of aerodynamic aerofoils (
;
), integrated circuits (
;
), broadband high-efficiency power amplifiers (
), and fast-charging protocols for lithium-ion batteries (
). Notably, it has played a crucial role in automating the hyperparameter tuning of various ML models (
;
), especially deep learning models, thus representing yet another way in which probabilistic ML has contributed to the advancement of deep learning.&lt;/p&gt;
&lt;p&gt;However, GPs are not universally suitable for all BO problem scenarios. They are most effective when dealing with smooth, stationary functions with homoscedastic noise and a relatively modest input dimensionality. Additionally, GPs are easiest to work with for functions with a single output and purely continuous inputs. While a surprisingly wide array of real-world challenges satisfy these conditions, many high-impact problems, such as gene and protein design, which involves sequential inputs (
;
;
;
;
); NAS, which involves structured inputs with intricate conditional dependencies; and automotive safety engineering, which involve numerous constraints and multiple objectives, clearly fall outside of this scope. This is not to say that GPs cannot be extended to such challenging scenarios. However, such extensions almost always come at a cost. Consequently, it makes sense to appeal to alternative modelling paradigms more naturally suited to specific tasks, e.g., employing random forests (RFs) to handle discrete and structured inputs, or deep NNs for capturing nonstationary behaviour and dealing with multiple objectives. A major contribution of this thesis is the introduction of a new formulation of BO that seamlessly accommodates virtually any modelling paradigm, including deep learning, without any compromise.&lt;/p&gt;
&lt;h3 id="thesis-overview"&gt;Thesis Overview&lt;/h3&gt;
&lt;p&gt;The core contributions of our thesis are summarised as follows:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;span id="item:contrib-orthogonal-sparse-spherical-gp" label="item:contrib-orthogonal-sparse-spherical-gp"&gt;&lt;/span&gt; We improve upon the framework for sparse hyperspherical GP approximations that employ nonlinear activations as inter-domain inducing features. This framework serves as a bridge between GPs and NNs, with posterior predictive mean taking the form of single-layer feedforward NNs. Our thesis examines some practical issues associated with this approach and proposes an extension that takes advantage of the orthogonal decoupling of GPs to mitigate these limitations. In particular, we introduce spherical inter-domain features to construct more flexible data-dependent basis functions for both the principal and orthogonal components of the GP approximation. We demonstrate that incorporating orthogonal inducing variables under this framework not only alleviates these shortcomings but also offers superior scalability compared to alternative strategies.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;span id="item:contrib-cycle-bayes" label="item:contrib-cycle-bayes"&gt;&lt;/span&gt; We provide a probabilistic perspective on cycle-consistent adversarial networks (CYCLEGANs), a cutting-edge deep generative model for style transfer and image-to-image translation. Specifically, we frame the problem of learning cross-domain correspondences without paired data as Bayesian inference in a latent variable model (LVM), in which the goal is to uncover the hidden representations of entities from one domain as entities in another. First, we introduce implicit LVMs, which allow flexible prior specification over latent representations as implicit distributions. Next, we develop a new variational inference (VI) framework that minimises a symmetrised statistical divergence between the variational and true joint distributions. Finally, we show that CYCLEGANs emerge as a closely-related variant of our framework, providing a useful interpretation as a Bayesian approximation.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;span id="item:contrib-bore" label="item:contrib-bore"&gt;&lt;/span&gt; We introduce a model-agnostic formulation of BO based on classification. Building on the established links between class-probability estimation (CPE), density-ratio estimation (DRE), and the improvement-based acquisition functions, we reformulate the acquisition function as a binary classifier over candidate solutions. This approach eliminates the need for an explicit probabilistic model of the objective function and casts aside the limitations of tractability constraints. As a result, our model-agnostic BO approach substantially broadens its applicability across diverse problem scenarios, accommodating flexible and scalable modelling paradigms such as deep learning without necessitating approximations or sacrificing expressive and representational capacity.&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;Accordingly, our thesis is organised as follows:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;Chapter 2 (Background) lays the necessary groundwork for our thesis. We begin by outlining the fundamental principles of probability and Bayesian statistics, which form the basis of probabilistic ML. Additionally, we introduce the widely-adopted method of approximate Bayesian inference known as VI. Our discussion underscores the central role played by statistical divergences, prompting us to delve into a larger family of divergences and motivating our discussion of DRE. With a solid foundation in place, we shift our focus to GPs, providing an introductory overview and highlighting the most commonly-used sparse approximations. Finally, we conclude this background chapter by introducing the basic concepts behind BO.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Chapter 3 (Orthogonally-Decoupled Sparse GPs with Spherical Inducing Features) examines orthogonally-decoupled sparse GPs with spherical NN activation features, as summarised in the corresponding item above.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Chapter 4 (Cycle-Consistent Adversarial Learning as Bayesian Inference) examines from the perspective of approximate Bayesian inference, as summarised in the corresponding item above.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Chapter 5 (Bayesian Optimization by Density-Ratio Estimation) examines our model-agnostic approach to BO based on binary classification and DRE, as summarised in the corresponding item above.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Chapter 6 (Conclusion) brings this thesis to a close by reflecting on our main contributions and situating them in the broader landscape of probabilistic methods in ML. Finally, we conclude by presenting our outlook on the avenues for future research and development in this rapidly evolving field.&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="references"&gt;References&lt;/h3&gt;
&lt;div id="refs" class="references csl-bib-body hanging-indent" entry-spacing="0" line-spacing="2"&gt;
&lt;div id="ref-anil2023palm" class="csl-entry"&gt;
&lt;p&gt;Anil, R., Dai, A. M., Firat, O., Johnson, M., Lepikhin, D., Passos, A., Shakeri, S., Taropa, E., Bailey, P., Chen, Z., et al. (2023). Palm 2 technical report. &lt;em&gt;arXiv Preprint arXiv:2305.10403&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-attia2020closed" class="csl-entry"&gt;
&lt;p&gt;Attia, P. M., Grover, A., Jin, N., Severson, K. A., Markov, T. M., Liao, Y.-H., Chen, M. H., Cheong, B., Perkins, N., Yang, Z., et al. (2020). Closed-loop optimization of fast-charging protocols for batteries with machine learning. &lt;em&gt;Nature&lt;/em&gt;, &lt;em&gt;578&lt;/em&gt;(7795), 397–402.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-bartholomew2011latent" class="csl-entry"&gt;
&lt;p&gt;Bartholomew, D. J., Knott, M., &amp;amp; Moustaki, I. (2011). &lt;em&gt;Latent variable models and factor analysis: A unified approach&lt;/em&gt;. John Wiley &amp;amp; Sons.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-bayes1763lii" class="csl-entry"&gt;
&lt;p&gt;Bayes, T. (1763). LII. An essay towards solving a problem in the doctrine of chances. By the late rev. Mr. Bayes, FRS communicated by mr. Price, in a letter to john canton, AMFR s. &lt;em&gt;Philosophical Transactions of the Royal Society of London&lt;/em&gt;, &lt;em&gt;53&lt;/em&gt;, 370–418.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-blundell2015weight" class="csl-entry"&gt;
&lt;p&gt;Blundell, C., Cornebise, J., Kavukcuoglu, K., &amp;amp; Wierstra, D. (2015). Weight uncertainty in neural network. &lt;em&gt;International Conference on Machine Learning&lt;/em&gt;, 1613–1622.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-brochu2010tutorial" class="csl-entry"&gt;
&lt;p&gt;Brochu, E., Cora, V. M., &amp;amp; De Freitas, N. (2010). A tutorial on bayesian optimization of expensive cost functions, with application to active user modeling and hierarchical reinforcement learning. &lt;em&gt;arXiv Preprint arXiv:1012.2599&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-brown2020language" class="csl-entry"&gt;
&lt;p&gt;Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. (2020). Language models are few-shot learners. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;33&lt;/em&gt;, 1877–1901.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-chen2015bayesian" class="csl-entry"&gt;
&lt;p&gt;Chen, P., Merrick, B. M., &amp;amp; Brazil, T. J. (2015). Bayesian optimization for broadband high-efficiency power amplifier designs. &lt;em&gt;IEEE Transactions on Microwave Theory and Techniques&lt;/em&gt;, &lt;em&gt;63&lt;/em&gt;(12), 4263–4272.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-damianou2013deep" class="csl-entry"&gt;
&lt;p&gt;Damianou, A., &amp;amp; Lawrence, N. D. (2013). Deep gaussian processes. &lt;em&gt;Artificial Intelligence and Statistics&lt;/em&gt;, 207–215.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-deisenroth2011pilco" class="csl-entry"&gt;
&lt;p&gt;Deisenroth, M., &amp;amp; Rasmussen, C. E. (2011). PILCO: A model-based and data-efficient approach to policy search. &lt;em&gt;Proceedings of the 28th International Conference on Machine Learning (ICML-11)&lt;/em&gt;, 465–472.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-duris2020bayesian" class="csl-entry"&gt;
&lt;p&gt;Duris, J., Kennedy, D., Hanuka, A., Shtalenkova, J., Edelen, A., Baxevanis, P., Egger, A., Cope, T., McIntire, M., Ermon, S., et al. (2020). Bayesian optimization of a free-electron laser. &lt;em&gt;Physical Review Letters&lt;/em&gt;, &lt;em&gt;124&lt;/em&gt;(12), 124801.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-dutordoir2020sparse" class="csl-entry"&gt;
&lt;p&gt;Dutordoir, V., Durrande, N., &amp;amp; Hensman, J. (2020). Sparse Gaussian processes with spherical harmonic features. &lt;em&gt;International Conference on Machine Learning&lt;/em&gt;, 2793–2802.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-dutordoir2021deep" class="csl-entry"&gt;
&lt;p&gt;Dutordoir, V., Hensman, J., Wilk, M. van der, Ek, C. H., Ghahramani, Z., &amp;amp; Durrande, N. (2021). Deep neural networks as point estimates for deep Gaussian processes. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;34&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-forrester2009recent" class="csl-entry"&gt;
&lt;p&gt;Forrester, A. I., &amp;amp; Keane, A. J. (2009). Recent advances in surrogate-based optimization. &lt;em&gt;Progress in Aerospace Sciences&lt;/em&gt;, &lt;em&gt;45&lt;/em&gt;(1-3), 50–79.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-gal2016dropout" class="csl-entry"&gt;
&lt;p&gt;Gal, Y., &amp;amp; Ghahramani, Z. (2016). Dropout as a bayesian approximation: Representing model uncertainty in deep learning. &lt;em&gt;International Conference on Machine Learning&lt;/em&gt;, 1050–1059.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-garnett_bayesoptbook_2023" class="csl-entry"&gt;
&lt;p&gt;Garnett, R. (2023). &lt;em&gt;&lt;span class="nocase"&gt;Bayesian Optimization&lt;/span&gt;&lt;/em&gt;. Cambridge University Press.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-garnett2010bayesian" class="csl-entry"&gt;
&lt;p&gt;Garnett, R., Osborne, M. A., &amp;amp; Roberts, S. J. (2010). Bayesian optimization for sensor set selection. &lt;em&gt;Proceedings of the 9th ACM/IEEE International Conference on Information Processing in Sensor Networks&lt;/em&gt;, 209–219.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-gelman2013bayesian" class="csl-entry"&gt;
&lt;p&gt;Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., &amp;amp; Rubin, D. B. (2013). &lt;em&gt;Bayesian data analysis&lt;/em&gt;. CRC press.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-girshick2014rich" class="csl-entry"&gt;
&lt;p&gt;Girshick, R., Donahue, J., Darrell, T., &amp;amp; Malik, J. (2014). Rich feature hierarchies for accurate object detection and semantic segmentation. &lt;em&gt;Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition&lt;/em&gt;, 580–587.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-gonzalez2015bayesian" class="csl-entry"&gt;
&lt;p&gt;Gonzalez, J., Longworth, J., James, D. C., &amp;amp; Lawrence, N. D. (2015). Bayesian optimization for synthetic gene design. &lt;em&gt;arXiv Preprint arXiv:1505.01627&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-goodfellow2014generative" class="csl-entry"&gt;
&lt;p&gt;Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., &amp;amp; Bengio, Y. (2014). Generative adversarial networks. &lt;em&gt;arXiv Preprint arXiv:1406.2661&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-graves2013speech" class="csl-entry"&gt;
&lt;p&gt;Graves, A., Mohamed, A., &amp;amp; Hinton, G. (2013). Speech recognition with deep recurrent neural networks. &lt;em&gt;2013 IEEE International Conference on Acoustics, Speech and Signal Processing&lt;/em&gt;, 6645–6649.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-hennig2022probabilistic" class="csl-entry"&gt;
&lt;p&gt;Hennig, P., Osborne, M. A., &amp;amp; Kersting, H. P. (2022). &lt;em&gt;Probabilistic numerics&lt;/em&gt;. Cambridge University Press.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-hie2022adaptive" class="csl-entry"&gt;
&lt;p&gt;Hie, B. L., &amp;amp; Yang, K. K. (2022). Adaptive machine learning for protein engineering. &lt;em&gt;Current Opinion in Structural Biology&lt;/em&gt;, &lt;em&gt;72&lt;/em&gt;, 145–152.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-hinton2012deep" class="csl-entry"&gt;
&lt;p&gt;Hinton, G., Deng, L., Yu, D., Dahl, G. E., Mohamed, A., Jaitly, N., Senior, A., Vanhoucke, V., Nguyen, P., Sainath, T. N., et al. (2012). Deep neural networks for acoustic modeling in speech recognition: The shared views of four research groups. &lt;em&gt;IEEE Signal Processing Magazine&lt;/em&gt;, &lt;em&gt;29&lt;/em&gt;(6), 82–97.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-ho2020denoising" class="csl-entry"&gt;
&lt;p&gt;Ho, J., Jain, A., &amp;amp; Abbeel, P. (2020). Denoising diffusion probabilistic models. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;33&lt;/em&gt;, 6840–6851.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-hoffmann2022training" class="csl-entry"&gt;
&lt;p&gt;Hoffmann, J., Borgeaud, S., Mensch, A., Buchatskaya, E., Cai, T., Rutherford, E., Casas, D. de L., Hendricks, L. A., Welbl, J., Clark, A., et al. (2022). Training compute-optimal large language models. &lt;em&gt;arXiv Preprint arXiv:2203.15556&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-houlsby2011bayesian" class="csl-entry"&gt;
&lt;p&gt;Houlsby, N., Huszár, F., Ghahramani, Z., &amp;amp; Lengyel, M. (2011). Bayesian active learning for classification and preference learning. &lt;em&gt;arXiv Preprint arXiv:1112.5745&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-jordan1998introduction" class="csl-entry"&gt;
&lt;p&gt;Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., &amp;amp; Saul, L. K. (1998). An introduction to variational methods for graphical models. &lt;em&gt;Learning in Graphical Models&lt;/em&gt;, 105–161.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-jumper2021highly" class="csl-entry"&gt;
&lt;p&gt;Jumper, J., Evans, R., Pritzel, A., Green, T., Figurnov, M., Ronneberger, O., Tunyasuvunakool, K., Bates, R., Žídek, A., Potapenko, A., et al. (2021). Highly accurate protein structure prediction with AlphaFold. &lt;em&gt;Nature&lt;/em&gt;, &lt;em&gt;596&lt;/em&gt;(7873), 583–589.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-krizhevsky2012imagenet" class="csl-entry"&gt;
&lt;p&gt;Krizhevsky, A., Sutskever, I., &amp;amp; Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;25&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-lam2018advances" class="csl-entry"&gt;
&lt;p&gt;Lam, R., Poloczek, M., Frazier, P., &amp;amp; Willcox, K. E. (2018). Advances in bayesian optimization with applications in aerospace engineering. &lt;em&gt;2018 AIAA Non-Deterministic Approaches Conference&lt;/em&gt;, 1656.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-laplace1814theorie" class="csl-entry"&gt;
&lt;p&gt;Laplace, P. S. (1814). &lt;em&gt;Théorie analytique des probabilités&lt;/em&gt;. Courcier.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-lee2017deep" class="csl-entry"&gt;
&lt;p&gt;Lee, J., Bahri, Y., Novak, R., Schoenholz, S. S., Pennington, J., &amp;amp; Sohl-Dickstein, J. (2017). Deep neural networks as gaussian processes. &lt;em&gt;arXiv Preprint arXiv:1711.00165&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-lillicrap2015continuous" class="csl-entry"&gt;
&lt;p&gt;Lillicrap, T. P., Hunt, J. J., Pritzel, A., Heess, N., Erez, T., Tassa, Y., Silver, D., &amp;amp; Wierstra, D. (2015). Continuous control with deep reinforcement learning. &lt;em&gt;arXiv Preprint arXiv:1509.02971&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-lyu2017efficient" class="csl-entry"&gt;
&lt;p&gt;Lyu, W., Xue, P., Yang, F., Yan, C., Hong, Z., Zeng, X., &amp;amp; Zhou, D. (2017). An efficient bayesian optimization approach for automated optimization of analog circuits. &lt;em&gt;IEEE Transactions on Circuits and Systems I: Regular Papers&lt;/em&gt;, &lt;em&gt;65&lt;/em&gt;(6), 1954–1967.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-mackay1992practical" class="csl-entry"&gt;
&lt;p&gt;MacKay, D. J. (1992). A practical bayesian framework for backpropagation networks. &lt;em&gt;Neural Computation&lt;/em&gt;, &lt;em&gt;4&lt;/em&gt;(3), 448–472.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-mackay2003information" class="csl-entry"&gt;
&lt;p&gt;MacKay, D. J. (2003). &lt;em&gt;Information theory, inference and learning algorithms&lt;/em&gt;. Cambridge university press.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-marchant2012bayesian" class="csl-entry"&gt;
&lt;p&gt;Marchant, R., &amp;amp; Ramos, F. (2012). Bayesian optimisation for intelligent environmental monitoring. &lt;em&gt;2012 IEEE/RSJ International Conference on Intelligent Robots and Systems&lt;/em&gt;, 2242–2249.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-matthews2018gaussian" class="csl-entry"&gt;
&lt;p&gt;Matthews, A. G. de G., Rowland, M., Hron, J., Turner, R. E., &amp;amp; Ghahramani, Z. (2018). Gaussian process behaviour in wide deep neural networks. &lt;em&gt;arXiv Preprint arXiv:1804.11271&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-mcculloch1943logical" class="csl-entry"&gt;
&lt;p&gt;McCulloch, W. S., &amp;amp; Pitts, W. (1943). A logical calculus of the ideas immanent in nervous activity. &lt;em&gt;The Bulletin of Mathematical Biophysics&lt;/em&gt;, &lt;em&gt;5&lt;/em&gt;, 115–133.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-mnih2013playing" class="csl-entry"&gt;
&lt;p&gt;Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., &amp;amp; Riedmiller, M. (2013). Playing atari with deep reinforcement learning. &lt;em&gt;arXiv Preprint arXiv:1312.5602&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-mnih2015human" class="csl-entry"&gt;
&lt;p&gt;Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., Graves, A., Riedmiller, M., Fidjeland, A. K., Ostrovski, G., et al. (2015). Human-level control through deep reinforcement learning. &lt;em&gt;Nature&lt;/em&gt;, &lt;em&gt;518&lt;/em&gt;(7540), 529–533.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-moss2020boss" class="csl-entry"&gt;
&lt;p&gt;Moss, H. B., Beck, D., González, J., Leslie, D. S., &amp;amp; Rayson, P. (2020). BOSS: Bayesian optimization over string spaces. &lt;em&gt;arXiv Preprint arXiv:2010.00979&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-neal1995bayesian" class="csl-entry"&gt;
&lt;p&gt;Neal, R. M. (1995). &lt;em&gt;BAYESIAN LEARNING FOR NEURAL NETWORKS&lt;/em&gt; &lt;/p&gt;
\[PhD thesis\]&lt;p&gt;. University of Toronto.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-openai2023gpt" class="csl-entry"&gt;
&lt;p&gt;OpenAI, R. (2023). GPT-4 technical report. &lt;em&gt;arXiv&lt;/em&gt;, 2303–08774.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-opper2000gaussian" class="csl-entry"&gt;
&lt;p&gt;Opper, M., &amp;amp; Winther, O. (2000). &lt;em&gt;Gaussian processes and SVM: Mean field results and leave-one-out&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-pearson1901liii" class="csl-entry"&gt;
&lt;p&gt;Pearson, K. (1901). LIII. On lines and planes of closest fit to systems of points in space. &lt;em&gt;The London, Edinburgh, and Dublin Philosophical Magazine and Journal of Science&lt;/em&gt;, &lt;em&gt;2&lt;/em&gt;(11), 559–572.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-rae2021scaling" class="csl-entry"&gt;
&lt;p&gt;Rae, J. W., Borgeaud, S., Cai, T., Millican, K., Hoffmann, J., Song, F., Aslanides, J., Henderson, S., Ring, R., Young, S., et al. (2021). Scaling language models: Methods, analysis &amp;amp; insights from training gopher. &lt;em&gt;arXiv Preprint arXiv:2112.11446&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-ramesh2022hierarchical" class="csl-entry"&gt;
&lt;p&gt;Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., &amp;amp; Chen, M. (2022). Hierarchical text-conditional image generation with clip latents. &lt;em&gt;arXiv Preprint arXiv:2204.06125&lt;/em&gt;, &lt;em&gt;1&lt;/em&gt;(2), 3.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-10.7551/mitpress/3206.001.0001" class="csl-entry"&gt;
&lt;p&gt;Rasmussen, C. E., &amp;amp; Williams, C. K. I. (2005). &lt;em&gt;&lt;span class="nocase"&gt;Gaussian Processes for Machine Learning&lt;/span&gt;&lt;/em&gt;. The MIT Press.
&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-redmon2016you" class="csl-entry"&gt;
&lt;p&gt;Redmon, J., Divvala, S., Girshick, R., &amp;amp; Farhadi, A. (2016). You only look once: Unified, real-time object detection. &lt;em&gt;Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition&lt;/em&gt;, 779–788.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-rombach2022high" class="csl-entry"&gt;
&lt;p&gt;Rombach, R., Blattmann, A., Lorenz, D., Esser, P., &amp;amp; Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. &lt;em&gt;Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition&lt;/em&gt;, 10684–10695.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-romero2013navigating" class="csl-entry"&gt;
&lt;p&gt;Romero, P. A., Krause, A., &amp;amp; Arnold, F. H. (2013). Navigating the protein fitness landscape with gaussian processes. &lt;em&gt;Proceedings of the National Academy of Sciences&lt;/em&gt;, &lt;em&gt;110&lt;/em&gt;(3), E193–E201.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-ronneberger2015u" class="csl-entry"&gt;
&lt;p&gt;Ronneberger, O., Fischer, P., &amp;amp; Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. &lt;em&gt;Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18&lt;/em&gt;, 234–241.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-rosenblatt1958perceptron" class="csl-entry"&gt;
&lt;p&gt;Rosenblatt, F. (1958). The perceptron: A probabilistic model for information storage and organization in the brain. &lt;em&gt;Psychological Review&lt;/em&gt;, &lt;em&gt;65&lt;/em&gt;(6), 386.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-roweis1999unifying" class="csl-entry"&gt;
&lt;p&gt;Roweis, S., &amp;amp; Ghahramani, Z. (1999). A unifying review of linear gaussian models. &lt;em&gt;Neural Computation&lt;/em&gt;, &lt;em&gt;11&lt;/em&gt;(2), 305–345.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-salimbeni2018orthogonally" class="csl-entry"&gt;
&lt;p&gt;Salimbeni, H., Cheng, C.-A., Boots, B., &amp;amp; Deisenroth, M. (2018). Orthogonally decoupled variational Gaussian processes. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;31&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-seko2015prediction" class="csl-entry"&gt;
&lt;p&gt;Seko, A., Togo, A., Hayashi, H., Tsuda, K., Chaput, L., &amp;amp; Tanaka, I. (2015). Prediction of low-thermal-conductivity compounds with first-principles anharmonic lattice-dynamics calculations and bayesian optimization. &lt;em&gt;Physical Review Letters&lt;/em&gt;, &lt;em&gt;115&lt;/em&gt;(20), 205901.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-shahriari2015taking" class="csl-entry"&gt;
&lt;p&gt;Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., &amp;amp; De Freitas, N. (2015). Taking the human out of the loop: A review of bayesian optimization. &lt;em&gt;Proceedings of the IEEE&lt;/em&gt;, &lt;em&gt;104&lt;/em&gt;(1), 148–175.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-shi2020sparse" class="csl-entry"&gt;
&lt;p&gt;Shi, J., Titsias, M., &amp;amp; Mnih, A. (2020). Sparse orthogonal variational inference for Gaussian processes. &lt;em&gt;International Conference on Artificial Intelligence and Statistics&lt;/em&gt;, 1932–1942.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-shoeybi2019megatron" class="csl-entry"&gt;
&lt;p&gt;Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., &amp;amp; Catanzaro, B. (2019). Megatron-lm: Training multi-billion parameter language models using model parallelism. &lt;em&gt;arXiv Preprint arXiv:1909.08053&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-silver2016mastering" class="csl-entry"&gt;
&lt;p&gt;Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., Van Den Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., et al. (2016). Mastering the game of go with deep neural networks and tree search. &lt;em&gt;Nature&lt;/em&gt;, &lt;em&gt;529&lt;/em&gt;(7587), 484–489.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-snoek2012practical" class="csl-entry"&gt;
&lt;p&gt;Snoek, J., Larochelle, H., &amp;amp; Adams, R. P. (2012). Practical Bayesian optimization of machine learning algorithms. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;25&lt;/em&gt;, 2951–2959.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-spearman1904general" class="csl-entry"&gt;
&lt;p&gt;Spearman, C. (1904). &amp;quot; general intelligence,&amp;quot; objectively determined and measured. &lt;em&gt;The American Journal of Psychology&lt;/em&gt;, &lt;em&gt;15&lt;/em&gt;(2), 201–292.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-srivastava2014dropout" class="csl-entry"&gt;
&lt;p&gt;Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., &amp;amp; Salakhutdinov, R. (2014). Dropout: A simple way to prevent neural networks from overfitting. &lt;em&gt;The Journal of Machine Learning Research&lt;/em&gt;, &lt;em&gt;15&lt;/em&gt;(1), 1929–1958.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-sun2020neural" class="csl-entry"&gt;
&lt;p&gt;Sun, S., Shi, J., &amp;amp; Grosse, R. B. (2020). Neural networks as inter-domain inducing points. &lt;em&gt;Third Symposium on Advances in Approximate Bayesian Inference&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-tibshirani1996regression" class="csl-entry"&gt;
&lt;p&gt;Tibshirani, R. (1996). Regression shrinkage and selection via the lasso. &lt;em&gt;Journal of the Royal Statistical Society Series B: Statistical Methodology&lt;/em&gt;, &lt;em&gt;58&lt;/em&gt;(1), 267–288.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-tipping1999probabilistic" class="csl-entry"&gt;
&lt;p&gt;Tipping, M. E., &amp;amp; Bishop, C. M. (1999). Probabilistic principal component analysis. &lt;em&gt;Journal of the Royal Statistical Society: Series B (Statistical Methodology)&lt;/em&gt;, &lt;em&gt;61&lt;/em&gt;(3), 611–622.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-torun2018global" class="csl-entry"&gt;
&lt;p&gt;Torun, H. M., Swaminathan, M., Davis, A. K., &amp;amp; Bellaredj, M. L. F. (2018). A global bayesian optimization algorithm and its application to integrated system design. &lt;em&gt;IEEE Transactions on Very Large Scale Integration (VLSI) Systems&lt;/em&gt;, &lt;em&gt;26&lt;/em&gt;(4), 792–802.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-touvron2023llama" class="csl-entry"&gt;
&lt;p&gt;Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., et al. (2023). Llama 2: Open foundation and fine-tuned chat models. &lt;em&gt;arXiv Preprint arXiv:2307.09288&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-turner2021bayesian" class="csl-entry"&gt;
&lt;p&gt;Turner, R., Eriksson, D., McCourt, M., Kiili, J., Laaksonen, E., Xu, Z., &amp;amp; Guyon, I. (2021). Bayesian optimization is superior to random search for machine learning hyperparameter tuning: Analysis of the black-box optimization challenge 2020. &lt;em&gt;NeurIPS 2020 Competition and Demonstration Track&lt;/em&gt;, 3–26.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-vaswani2017attention" class="csl-entry"&gt;
&lt;p&gt;Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., &amp;amp; Polosukhin, I. (2017). Attention is all you need. &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, &lt;em&gt;30&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-wigley2016fast" class="csl-entry"&gt;
&lt;p&gt;Wigley, P. B., Everitt, P. J., Hengel, A. van den, Bastian, J. W., Sooriyabandara, M. A., McDonald, G. D., Hardman, K. S., Quinlivan, C. D., Manju, P., Kuhn, C. C., et al. (2016). Fast machine-learning online optimization of ultra-cold-atom experiments. &lt;em&gt;Scientific Reports&lt;/em&gt;, &lt;em&gt;6&lt;/em&gt;(1), 25890.&lt;/p&gt;
&lt;/div&gt;
&lt;div id="ref-yang2019machine" class="csl-entry"&gt;
&lt;p&gt;Yang, K. K., Wu, Z., &amp;amp; Arnold, F. H. (2019). Machine-learning-guided directed evolution for protein engineering. &lt;em&gt;Nature Methods&lt;/em&gt;, &lt;em&gt;16&lt;/em&gt;(8), 687–694.&lt;/p&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>📄 One paper accepted to ICML 2023</title><link>https://tiao.io/posts/one-paper-accepted-to-icml2023/</link><pubDate>Tue, 25 Apr 2023 20:37:43 +0000</pubDate><guid>https://tiao.io/posts/one-paper-accepted-to-icml2023/</guid><description>&lt;p&gt;Our paper
was accepted to ICML2023 as an Oral Presentation!
This work was largely done during my time at Secondmind Labs as a
Student Researcher, in collaboration with Vincent Dutordoir and Victor Picheny.&lt;/p&gt;</description></item><item><title>Spherical Inducing Features for Orthogonally-Decoupled Gaussian Processes</title><link>https://tiao.io/publications/spherical-features-gaussian-process/</link><pubDate>Tue, 25 Apr 2023 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/spherical-features-gaussian-process/</guid><description/></item><item><title>Efficient Cholesky decomposition of low-rank updates</title><link>https://tiao.io/posts/efficient-cholesky-decomposition-of-low-rank-updates/</link><pubDate>Sun, 16 Apr 2023 11:16:03 +0000</pubDate><guid>https://tiao.io/posts/efficient-cholesky-decomposition-of-low-rank-updates/</guid><description>&lt;p&gt;Suppose we&amp;rsquo;re given a positive semidefinite (PSD)
matrix $\mathbf{A} \in \mathbb{R}^{N \times N}$
to
which we wish to update by some low-rank
matrix $\mathbf{U} \mathbf{U}^\top \in \mathbb{R}^{N \times N}$
,
$$\mathbf{B} \triangleq \mathbf{A} + \mathbf{U} \mathbf{U}^\top,$$
where the update factor matrix $\mathbf{U} \in \mathbb{R}^{N \times M}$
.
To be more precise, the low-rank update is rank-$M$ for some $M \ll N$.&lt;/p&gt;
&lt;p&gt;&lt;em&gt;What is the best way to calculate the Cholesky decomposition of $\mathbf{B}$
?&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;Given no additional information the obvious way is to calculate it directly,
which incurs a cost of $\mathcal{O}(N^3)$
.
But suppose we&amp;rsquo;ve already calculated the lower-triangular Cholesky factor
$\mathbf{L} \in \mathbb{R}^{N \times N}$
of $\mathbf{A}$
(i.e., $\mathbf{LL}^\top = \mathbf{A}$
).
Then, we can use it to calculate the Cholesky decomposition
of $\mathbf{B}$
at a reduced cost
of $\mathcal{O}(N^2M)$
.
Here&amp;rsquo;s how.&lt;/p&gt;
&lt;h2 id="rank-1-updates"&gt;Rank-1 Updates&lt;/h2&gt;
&lt;p&gt;First, let&amp;rsquo;s consider the simpler case involving just &lt;em&gt;rank-1 updates&lt;/em&gt;
$$\mathbf{B} \triangleq \mathbf{A} + \mathbf{u} \mathbf{u}^\top,$$
where update factor vector $\mathbf{u} \in \mathbb{R}^{N}$
.
With some clever manipulations&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;, the details of which we won&amp;rsquo;t
get into in this post, we can leverage $\mathbf{L}$
to
calculate the Cholesky decomposition of $\mathbf{B}$
at a reduced cost of $\mathcal{O}(N^2)$
.
Such a procedure for rank-1 updates is implemented in the old-school Fortran
linear algebra software library
(but unfortunately not in its successor
),
and also in modern libraries like
(TFP).&lt;/p&gt;
&lt;p&gt;In TFP, this is implemented in the function named
.
For example,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow_probability&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update_factor_vector&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3); suppose this is pre-computed and stored&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3), ignores `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^2), uses `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_factor_1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Here &lt;code&gt;cholesky_update&lt;/code&gt; takes as arguments &lt;code&gt;chol&lt;/code&gt; with shape &lt;code&gt;[B1, ..., Bn, N, N]&lt;/code&gt;
and &lt;code&gt;u&lt;/code&gt; with shape &lt;code&gt;[B1, ..., Bn, N]&lt;/code&gt;, and returns a lower triangular Cholesky
factor of the rank-1 updated matrix &lt;code&gt;chol @ chol.T + u @ u.T&lt;/code&gt; in $\mathcal{O}(N^2)$
time.&lt;/p&gt;
&lt;h2 id="low-rank-updates"&gt;Low-Rank Updates&lt;/h2&gt;
&lt;p&gt;Now let&amp;rsquo;s return to rank-$M$ updates.
First let&amp;rsquo;s write the update factor matrix $\mathbf{U}$ in terms of column
vectors $\mathbf{u}_m \in \mathbb{R}^{N}$,
$$
\mathbf{U} \triangleq
\begin{bmatrix}
\mathbf{u}_1 &amp; \cdots &amp; \mathbf{u}_M
\end{bmatrix}.
$$
&lt;/p&gt;
&lt;p&gt;Now we can write the rank-$M$ update matrix as a sum of $M$ rank-1 matrices,
$$
\mathbf{U} \mathbf{U}^\top =
\begin{bmatrix} \mathbf{u}_1 &amp; \cdots &amp; \mathbf{u}_M \end{bmatrix}
\begin{bmatrix} \mathbf{u}_1^\top \\ \vdots \\ \mathbf{u}_M^\top \end{bmatrix} =
\sum_{m=1}^{M} \mathbf{u}_m \mathbf{u}_m^\top.
$$
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, M]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, 1, M] [..., 1, N, M] -&amp;gt; [..., N, N, M] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;newaxis&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:,&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, M] [..., M, N] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# not exactly equal due to finite precision, but still equal up to high precision&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decimal&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Thus seen, a low-rank update is nothing more than a repeated application of
rank-1 updates,
$$
\begin{align}
\mathbf{B} &amp; = \mathbf{A} + \mathbf{U} \mathbf{U}^\top \\ &amp; =
\mathbf{A} + \sum_{m=1}^{M} \mathbf{u}_m \mathbf{u}_m^\top \\ &amp; =
((\mathbf{A} + \mathbf{u}_1 \mathbf{u}_1^\top) + \cdots ) + \mathbf{u}_M \mathbf{u}_M^{\top}.
\end{align}
$$
&lt;/p&gt;
&lt;p&gt;Therefore, we can simply leverage the $O(N^2)$ procedure for Cholesky
decompositions of rank-1 updates and apply it recursively $M$ times to obtain
a $O(N^2M)$ procedure for rank-$M$ updates.&lt;/p&gt;
&lt;p&gt;Hence, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# [..., N, M] [..., M, N] -&amp;gt; [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transpose_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;update&lt;/span&gt; &lt;span class="c1"&gt;# Tensor; shape [..., N, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^3), ignores `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# O(N^2M), uses `a_factor`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;testing&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;assert_array_almost_equal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b_factor_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_factor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where function &lt;code&gt;cholesky_update_iterated&lt;/code&gt; is implemented as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# base case&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;prev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can also implement this iteratively.
First we&amp;rsquo;d use &lt;code&gt;tf.unstack&lt;/code&gt; to turn the update factor matrix $\mathbf{U}$
into a list of update factor vectors $\mathbf{u}_m$:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;update_factor_vectors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;isinstance&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# `update_factor_vectors` is a list&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;M&lt;/span&gt; &lt;span class="c1"&gt;# ... the list contains M vectors&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;update_factor_vectors&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;Bs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# ... and each vector has shape [B1, ..., Bn, N]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Then, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_vector&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;new_chol&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The astute reader will recognize that this is simply an special case of
the
or
patterns, where
the &lt;em&gt;binary operator&lt;/em&gt; is &lt;code&gt;tfp.math.cholesky_update&lt;/code&gt;,
the &lt;em&gt;iterable&lt;/em&gt; is &lt;code&gt;tf.unstack(update_factor, axis=-1)&lt;/code&gt; and
the &lt;em&gt;initial value&lt;/em&gt; is &lt;code&gt;chol&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;Therefore, we can also implement it neatly using the one-liner:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;functools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;reduce&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cholesky_update_iterated&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;reduce&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky_update&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;update_factor_matrix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;chol&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="summary"&gt;Summary&lt;/h2&gt;
&lt;p&gt;In summary, we showed that to efficiently calculate the Cholesky decomposition
of a matrix perturbed by a low-rank update, one just needs to iteratively
calculate that of the same matrix perturbed by a series of rank-1 updates.
Better yet, all of this can be done with a simple one-liner!&lt;/p&gt;
&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Seeger, M. (2004). Low rank updates for the Cholesky decomposition.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Batch Bayesian Optimisation via Density-ratio Estimation with Guarantees</title><link>https://tiao.io/publications/batch-bore-guarantees/</link><pubDate>Thu, 01 Dec 2022 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/batch-bore-guarantees/</guid><description/></item><item><title>Long Talk: BORE — Bayesian Optimization by Density-Ratio Estimation</title><link>https://tiao.io/events/icml2021-bore/</link><pubDate>Wed, 21 Jul 2021 14:00:00 +0000</pubDate><guid>https://tiao.io/events/icml2021-bore/</guid><description/></item><item><title>Invited Talk: BORE — Bayesian Optimization by Density-Ratio Estimation</title><link>https://tiao.io/events/ellis-automl-seminars-2021/</link><pubDate>Wed, 12 May 2021 16:00:00 +0000</pubDate><guid>https://tiao.io/events/ellis-automl-seminars-2021/</guid><description/></item><item><title>BORE: Bayesian Optimization by Density-Ratio Estimation</title><link>https://tiao.io/publications/bore-2/</link><pubDate>Sat, 08 May 2021 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/bore-2/</guid><description>&lt;p&gt;&lt;strong&gt;B&lt;/strong&gt;ayesian &lt;strong&gt;O&lt;/strong&gt;ptimization (BO) by Density-&lt;strong&gt;R&lt;/strong&gt;atio &lt;strong&gt;E&lt;/strong&gt;stimation (DRE),
or &lt;strong&gt;BORE&lt;/strong&gt;, is a simple, yet effective framework for the optimization of
blackbox functions.
BORE is built upon the correspondence between &lt;em&gt;expected improvement (EI)&lt;/em&gt;&amp;mdash;arguably
the predominant &lt;em&gt;acquisition functions&lt;/em&gt; used in BO&amp;mdash;and the &lt;em&gt;density-ratio&lt;/em&gt;
between two unknown distributions.&lt;/p&gt;
&lt;p&gt;One of the far-reaching consequences of this correspondence is that we can
reduce the computation of EI to a &lt;em&gt;probabilistic classification&lt;/em&gt; problem&amp;mdash;a
problem we are well-equipped to tackle, as evidenced by the broad range of
streamlined, easy-to-use and, perhaps most importantly, battle-tested
tools and frameworks available at our disposal for applying a variety of approaches.
Notable among these are
/
and
/
for Deep Learning,
for Gradient Tree Boosting,
not to mention
for just about
everything else.
The BORE framework lets us take direct advantage of these tools.&lt;/p&gt;
&lt;h2 id="code-example"&gt;Code Example&lt;/h2&gt;
&lt;p&gt;We provide an simple example with Keras to give you a taste of how BORE can
be implemented using a feed-forward &lt;em&gt;neural network (NN)&lt;/em&gt; classifier.
A useful class that the
package provides is
,
a subclass of
from
Keras that inherits all of its existing functionalities, and provides just
one additional method.
We can build and compile a feed-forward NN classifier as usual:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;bore.models&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MaximizableSequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;tensorflow.keras.layers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# build model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MaximizableSequential&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;sigmoid&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# compile model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;adam&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;binary_crossentropy&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;See
from the
if this seems unfamiliar to
you.&lt;/p&gt;
&lt;p&gt;The additional method provided is &lt;code&gt;argmax&lt;/code&gt;, which returns the &lt;em&gt;maximizer&lt;/em&gt; of
the network, i.e. the input $\mathbf{x}$ that maximizes the final output of
the network:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_argmax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;bounds&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;bounds&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;method&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;L-BFGS-B&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_start_points&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Since the network is differentiable end-to-end wrt to input $\mathbf{x}$, this
method can be implemented efficiently using a &lt;em&gt;multi-started quasi-Newton
hill-climber&lt;/em&gt; such as
.
We will see the pivotal role this method plays in the next section.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Using this classifier, the BO loop in BORE looks as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# initialize design&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;extend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;features_initial_design&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;extend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;targets_initial_design&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_iterations&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# construct classification problem&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tau&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;quantile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.25&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;less&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tau&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# update classifier&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# suggest new candidate&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_next&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;bounds&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;bounds&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;method&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;L-BFGS-B&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_start_points&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# evaluate blackbox&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_next&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;blackbox&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_next&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# update dataset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_next&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_next&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;p&gt;Let&amp;rsquo;s break this down a bit:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;At the start of the loop, we construct the classification problem&amp;mdash;by labeling
instances $\mathbf{x}$ whose corresponding target value $y$ is in the top
&lt;code&gt;q=0.25&lt;/code&gt; quantile of all target values as &lt;em&gt;positive&lt;/em&gt;, and the rest as &lt;em&gt;negative&lt;/em&gt;.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Next, we train the classifier to discriminate between these instances. This
classifier should converge towards
&lt;/p&gt;
$$
\pi^{*}(\mathbf{x}) = \frac{\gamma \ell(\mathbf{x})}{\gamma \ell(\mathbf{x}) + (1-\gamma) g(\mathbf{x})},
$$&lt;p&gt;
where $\ell(\mathbf{x})$ and $g(\mathbf{x})$ are the unknown distributions of
instances belonging to the positive and negative classes, respectively, and
$\gamma$ is the class balance-rate and, by construction, simply the quantile
we specified (i.e. $\gamma=0.25$).&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Once the classifier is a decent approximation to $\pi^{*}(\mathbf{x})$, we
propose the maximizer of this classifier as the next input to evaluate.
In other words, we are now using the classifier &lt;em&gt;itself&lt;/em&gt; as the acquisition
function.&lt;/p&gt;
&lt;p&gt;How is it justifiable to use this in lieu of EI, or some other acquisition
function we&amp;rsquo;re used to?
And what is so special about $\pi^{*}(\mathbf{x})$?&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Well, as it turns out, $\pi^{*}(\mathbf{x})$ is equivalent to EI, up to some
constant factors.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;The remainder of the loop should now be self-explanatory. Namely, we&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;evaluate the blackbox function at the suggested point, and&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;update the dataset.&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h3 id="step-by-step-illustration"&gt;Step-by-step Illustration&lt;/h3&gt;
&lt;p&gt;Here is a step-by-step animation of six iterations of this loop in action,
using the &lt;em&gt;Forrester&lt;/em&gt; synthetic function as an example.
The noise-free function is shown as the solid gray curve in the main pane.
This procedure is warm-started with four random initial designs.&lt;/p&gt;
&lt;p&gt;The right pane shows the empirical CDF (ECDF) of the observed $y$ values.
The vertical dashed black line in this pane is located at $\Phi(y) = \gamma$,
where $\gamma = 0.25$.
The horizontal dashed black line is located at $\tau$, the value of $y$ such
that $\Phi(y) = 0.25$, i.e. $\tau = \Phi^{-1}(0.25)$.&lt;/p&gt;
&lt;p&gt;The instances below this horizontal line are assigned binary label $z=1$, while
those above are assigned $z=0$. This is visualized in the bottom pane,
alongside the probabilistic classifier $\pi_{\boldsymbol{\theta}}(\mathbf{x})$
represented by the solid gray curve, which is trained to discriminate between
these instances.&lt;/p&gt;
&lt;p&gt;Finally, the maximizer of the classifier is represented by the vertical solid
green line.
This is the location at which the BO procedure suggests be evaluated next.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Animation"
srcset="https://tiao.io/publications/bore-2/paper_1500x5562_hu_bf54a19b8bc6fbf5.webp 205w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://tiao.io/publications/bore-2/paper_1500x5562_hu_bf54a19b8bc6fbf5.webp"
width="205"
height="760"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We see that the procedure converges toward to global minimum of the blackbox
function after half a dozen iterations.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;To understand how and why this works in more detail, please read our paper!
If you only have 15 minutes to spare, please watch the video recording of our
talk!&lt;/p&gt;
&lt;h2 id="video"&gt;Video&lt;/h2&gt;
&lt;div id="presentation-embed-38942425"&gt;&lt;/div&gt;
&lt;script src='https://slideslive.com/embed_presentation.js'&gt;&lt;/script&gt;
&lt;script&gt;
embed = new SlidesLiveEmbed('presentation-embed-38942425', {
presentationId: '38942425',
autoPlay: false, // change to true to autoplay the embedded presentation
verticalEnabled: true
});
&lt;/script&gt;</description></item><item><title>Simulation-based Scoring for Model-based Asynchronous Hyperparameter and Neural Architecture Search</title><link>https://tiao.io/publications/simulation-based-scoring/</link><pubDate>Sat, 01 May 2021 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/simulation-based-scoring/</guid><description/></item><item><title>A Primer on Pólya-gamma Random Variables - Part II: Bayesian Logistic Regression</title><link>https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/</link><pubDate>Tue, 20 Apr 2021 17:20:53 +0100</pubDate><guid>https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/</guid><description>
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;This is &lt;strong&gt;Part II&lt;/strong&gt; of a three-part series on Pólya-Gamma random variables.
Part I (Basic Relationships) and Part III (Local Variational Methods) are
in preparation.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;details class="print:hidden xl:hidden" &gt;
&lt;summary&gt;Table of Contents&lt;/summary&gt;
&lt;div class="text-sm"&gt;
&lt;nav id="TableOfContents"&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#binary-classification"&gt;Binary Classification&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#model--bayesian-logistic-regression"&gt;Model &amp;ndash; Bayesian Logistic Regression&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#likelihood"&gt;Likelihood&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#prior"&gt;Prior&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#inference-and-prediction"&gt;Inference and Prediction&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#augmented-model"&gt;Augmented Model&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#likelihood-conditioned-on-auxiliary-variables"&gt;Likelihood conditioned on auxiliary variables&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#prior-over-auxiliary-variables"&gt;Prior over auxiliary variables&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#inference-gibbs-sampling"&gt;Inference (Gibbs sampling)&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#implementation-weight-space-view"&gt;Implementation (Weight-space view)&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#synthetic-one-dimensional-classification-problem"&gt;Synthetic one-dimensional classification problem&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#prior-1"&gt;Prior&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#conditional-likelihood"&gt;Conditional likelihood&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#inference-and-prediction-1"&gt;Inference and Prediction&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#code"&gt;Code&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#bonus-gibbs-sampling-with-mutual-recursion-and-generator-delegation"&gt;Bonus: Gibbs sampling with mutual recursion and generator delegation&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#links-and-further-readings"&gt;Links and Further Readings&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#appendix"&gt;Appendix&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#i"&gt;I&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#ii"&gt;II&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#iii"&gt;III&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/nav&gt;
&lt;/div&gt;
&lt;/details&gt;
&lt;h2 id="binary-classification"&gt;Binary Classification&lt;/h2&gt;
&lt;p&gt;Consider the usual set-up for a binary classification problem:
for some input $\mathbf{x} \in \mathbb{R}^{D}$,
predict its binary label $y \in \{ 0, 1 \}$ given observations consisting of a
feature matrix $\mathbf{X} = [ \mathbf{x}_1 \cdots \mathbf{x}_N ]^{\top} \in \mathbb{R}^{N \times D}$
and a target vector $\mathbf{y} = [ y_1 \cdots y_N ]^{\top} \in \{ 0, 1 \}^N$.&lt;/p&gt;
&lt;h2 id="model--bayesian-logistic-regression"&gt;Model &amp;ndash; Bayesian Logistic Regression&lt;/h2&gt;
&lt;p&gt;Recall the standard &lt;em&gt;Bayesian logistic regression&lt;/em&gt; model:&lt;/p&gt;
&lt;h3 id="likelihood"&gt;Likelihood&lt;/h3&gt;
&lt;p&gt;Let $f: \mathbb{R}^{D} \to \mathbb{R}$ denote the real-valued latent function,
sometimes referred to as the &lt;em&gt;nuisance function&lt;/em&gt;, and let $f_n = f(\mathbf{x}_n)$
be the function value corresponding to observed input $\mathbf{x}_n$.
The distribution over the observed variable $y_n$ is assumed to be governed
by the latent variable $f_n$.
In particular, the observed target vectors $\mathbf{y}$
are related to $\mathbf{f}$, the column vector of latent
variables $\mathbf{f} = [f_1, \dotsc, f_N]^{\top}$,
through the likelihood, or observation model, defined as
&lt;/p&gt;
$$
p(\mathbf{y} | \mathbf{f}) \doteq \prod_{n=1}^N p(y_n | f_n),
$$&lt;p&gt;
where
&lt;/p&gt;
$$
p(y_n | f_n) = \mathrm{Bern}(y_n | \sigma(f_n)) =
\sigma(f_n)^{y_n} \left (1 - \sigma(f_n) \right )^{1 - y_n},
$$&lt;p&gt;
and $\sigma(u) = \left ( 1 + \exp(-u) \right )^{-1}$ is the logistic sigmoid
function.&lt;/p&gt;
&lt;h3 id="prior"&gt;Prior&lt;/h3&gt;
&lt;p&gt;For the sake of generality we discuss both the &lt;em&gt;weight-space&lt;/em&gt;
and &lt;em&gt;function-space&lt;/em&gt; views of Bayesian logistic regression.
In both cases, we consider a prior distribution in the form of a
multivariate Gaussian $\mathcal{N}(\mathbf{m}, \mathbf{S}^{-1})$,
whether it be over the weights or the function values themselves.&lt;/p&gt;
&lt;h4 id="weight-space"&gt;Weight-space&lt;/h4&gt;
&lt;p&gt;In the weight-space view, sometimes referred to as &lt;em&gt;linear&lt;/em&gt; logistic
regression, we assume &lt;em&gt;a priori&lt;/em&gt; that the latent function takes the form&lt;br&gt;
&lt;/p&gt;
$$
f(\mathbf{x}) = \boldsymbol{\beta}^{\top} \mathbf{x},
\qquad
\boldsymbol{\beta} \sim \mathcal{N}(\mathbf{m}, \mathbf{S}^{-1}).
$$&lt;p&gt;
In this case, we express vector of latent function
values as $\mathbf{f} = \mathbf{X} \boldsymbol{\beta}$
and the prior over the weights
as $p(\boldsymbol{\beta}) = \mathcal{N}(\mathbf{m}, \mathbf{S}^{-1})$.&lt;/p&gt;
&lt;h4 id="function-space"&gt;Function-space&lt;/h4&gt;
&lt;p&gt;In the function-space view, we assume the function is distributed according
to a Gaussian process (GP) with mean function $m(\mathbf{x})$ and covariance
function $k(\mathbf{x}, \mathbf{x}')$
&lt;/p&gt;
$$
f(\mathbf{x}) \sim \mathcal{GP}\left(m(\mathbf{x}), k(\mathbf{x}, \mathbf{x}')\right)
$$&lt;p&gt;
In this case, we express the prior over latent function values
as $p(\mathbf{f} | \mathbf{X}) = \mathcal{N}(\mathbf{m}, \mathbf{K}_X)$,
where $\mathbf{m} = m(\mathbf{X})$ and $\mathbf{K}_X = k(\mathbf{X}, \mathbf{X})$.&lt;/p&gt;
&lt;h3 id="inference-and-prediction"&gt;Inference and Prediction&lt;/h3&gt;
&lt;p&gt;Given some test input $\mathbf{x}_*$, we are interested in producing a probability
distribution over predictions $p(y_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*)$.
As we shall see, the procedure for computing this distribition is rife with
intractabilities.&lt;/p&gt;
&lt;p&gt;Specifically, we first marginalize out the uncertainty about the associated
latent function value $f_*$,
&lt;/p&gt;
$$
p(y_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*) =
\int \sigma(f_*) p(f_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*) \mathrm{d}f_*
$$&lt;p&gt;
where $p(f_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*)$ is the posterior
predictive distribution.
Solving this integral is intractable, but since it is one-dimensional, it can
be approximated efficiently using
assuming $p(f_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*)$
is Gaussian.&lt;/p&gt;
&lt;p&gt;But herein lies the real difficulty: the predictive $p(f_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*)$ is
computed as
&lt;/p&gt;
$$
p(f_* | \mathbf{X}, \mathbf{y}, \mathbf{x}_*) =
\int p(f_* | \mathbf{X}, \mathbf{x}_*, \mathbf{f}) p(\mathbf{f} | \mathbf{X}, \mathbf{y}) \mathrm{d}\mathbf{f},
$$&lt;p&gt;
where $p(\mathbf{f} | \mathbf{X}, \mathbf{y}) = \frac{p(\mathbf{y} | \mathbf{f}) p(\mathbf{f} | \mathbf{X})}{p(\mathbf{y} | \mathbf{X})} \propto p(\mathbf{y} | \mathbf{f}) p(\mathbf{f} | \mathbf{X})$ is the posterior over latent function values at the
observed points, which is analytically intractable because a Gaussian prior is
not
to the Bernoulli likelihood.&lt;/p&gt;
&lt;p&gt;To overcome this intractability, one must typically resort to approximate
inference methods such as the
Laplace approximation&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;,
variational inference (VI)&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;,
expectation propagation (EP)&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;
and sampling-based approximations such as Markov Chain Monte Carlo (MCMC).&lt;/p&gt;
&lt;h2 id="augmented-model"&gt;Augmented Model&lt;/h2&gt;
&lt;p&gt;Instead of appealing to approximate inference methods, let us consider
an augmentation strategy that works by introducing &lt;em&gt;auxiliary variables&lt;/em&gt; to
the model&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;In particular, we introduce auxiliary variables $\boldsymbol{\omega}$ and
define the &lt;em&gt;augmented&lt;/em&gt; or &lt;em&gt;joint likelihood&lt;/em&gt; that factorizes as
&lt;/p&gt;
$$
p(\mathbf{y}, \boldsymbol{\omega} | \mathbf{f}) = p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) p(\boldsymbol{\omega}),
$$&lt;p&gt;
where $p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})$ is the &lt;em&gt;conditional likelihood&lt;/em&gt;,
a likelihood that is conditioned on the auxiliary variables $\boldsymbol{\omega}$,
and $p(\boldsymbol{\omega})$ is the prior.
Specifically, we wish to define $p(\boldsymbol{\omega})$
and $p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})$ for which the following
two properties hold:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;Marginalizing out $\boldsymbol{\omega}$ recovers the original observation model
$$
\int \underbrace{p(\mathbf{y}, \boldsymbol{\omega} | \mathbf{f})}_\text{joint likelihood} d\boldsymbol{\omega} =
\int \underbrace{p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})}_\text{conditional likelihood} p(\boldsymbol{\omega}) d\boldsymbol{\omega} =
\underbrace{p(\mathbf{y} | \mathbf{f})}_\text{original likelihood}
$$&lt;/li&gt;
&lt;li&gt;A Gaussian prior $p(\mathbf{f})$ is conjugate to the conditional likelihood $p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})$.&lt;/li&gt;
&lt;/ol&gt;
&lt;h3 id="likelihood-conditioned-on-auxiliary-variables"&gt;Likelihood conditioned on auxiliary variables&lt;/h3&gt;
&lt;p&gt;First, let us define a conditional likelihood that factorize as
&lt;/p&gt;
$$
p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) =
\prod_{n=1}^n p(y_n | f_n, \omega_n),
$$&lt;p&gt;
where each factor is defined as
&lt;/p&gt;
$$
p(y_n | f_n, \omega_n) \doteq
\frac{1}{2} \exp{\left \{ - \frac{\omega_n}{2} \left ( f_n^2 -
2 f_n \frac{\kappa_n}{\omega_n} \right ) \right \}}
$$&lt;p&gt;
for $\kappa_n = y_n - \frac{1}{2}$.&lt;/p&gt;
&lt;h3 id="prior-over-auxiliary-variables"&gt;Prior over auxiliary variables&lt;/h3&gt;
&lt;p&gt;Second, let us define a prior over auxiliary variables $\boldsymbol{\omega}$ that
factorize as
&lt;/p&gt;
$$
p(\boldsymbol{\omega}) = \prod_{n=1}^N p(\omega_n)
$$&lt;p&gt;
where each factor $p(\omega_n)$ is a Pólya-gamma density
&lt;/p&gt;
$$
p(\omega_n) = \mathrm{PG}(\omega_n | 1, 0),
$$&lt;p&gt;
defined as an infinite
of gamma distributions :&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="pólya-gamma-density-polson-et-al-2013"&gt;Pólya-gamma density (Polson et al. 2013)&lt;/h4&gt;
&lt;p&gt;A random variable $\omega$ has a Pólya-gamma distribution with parameters $b &gt; 0$
and $c \in \mathbb{R}$, denoted $\omega \sim \mathrm{PG}(b, c)$, if
&lt;/p&gt;
$$
\omega \overset{D}{=} \frac{1}{2 \pi^2} \sum_{k=1}^{\infty}
\frac{g_k}{\left (k - \frac{1}{2} \right )^2 + \left ( \frac{c}{2\pi} \right )^2}
$$&lt;p&gt;
where the $g_k \sim \mathrm{Ga}(b, 1)$ are independent gamma random variables
(and where $\overset{D}{=}$ denotes equality in distribution).&lt;/p&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h4 id="property-i-recovering-the-original-model"&gt;Property I: Recovering the original model&lt;/h4&gt;
&lt;p&gt;First we show that we can recover the original likelihood $p(y_n | f_n)$
by integrating out $\boldsymbol{\omega}$.
Before we proceed, note that the $p(y_n | f_n)$ can be expressed more
succinctly as
&lt;/p&gt;
$$
p(y_n | f_n) = \frac{e^{y_n f_n}}{1 + e^{f_n}}.
$$&lt;p&gt;
Refer to
for derivations.
Next, note the following property of Pólya-gamma variables:&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="laplace-transform-of-the-pólya-gamma-density-polson-et-al-2013"&gt;Laplace transform of the Pólya-gamma density (Polson et al. 2013)&lt;/h4&gt;
&lt;p&gt;Based on the
of the Pólya-gamma density function, we can derive the following relationship:
&lt;/p&gt;
$$
\frac{\left (e^{u} \right )^a}{\left (1 + e^{u} \right )^b} =
\frac{1}{2^b} \exp{(\kappa u)} \
\int_0^\infty \exp{\left ( - \frac{u^2}{2} \omega \right )}
p(\omega) d\omega,
$$&lt;p&gt;
where $\kappa = a - \frac{b}{2}$ and $p(\omega) = \mathrm{PG}(\omega | b, 0)$.&lt;/p&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Therefore, by substituting $\kappa = \kappa_n, a = y_n, b = 1$ and $u = f_n$
we get
&lt;/p&gt;
$$
\begin{align*}
\int p(y_n, \omega_n | f_n) d\omega_n &amp;=
\int p(y_n | f_n, \omega_n) p(\omega_n) d\omega_n \newline &amp;=
\frac{1}{2} \int \exp{\left \{ - \frac{\omega_n}{2} \left (f_n^2 -
2 f_n \frac{\kappa_n}{\omega_n} \right ) \right \}} p(\omega_n) d\omega_n \newline &amp;=
\frac{1}{2} \exp{(\kappa_n f_n)}
\int \exp{\left ( - \frac{f_n^2}{2} \omega_n \right )} p(\omega_n) d\omega_n \newline &amp;=
\frac{\left (e^{f_n} \right )^{y_n}}{1 + e^{f_n}} = p(y_n | f_n)
\end{align*}
$$&lt;p&gt;
as required.&lt;/p&gt;
&lt;h4 id="property-ii-gaussian-gaussian-conjugacy"&gt;Property II: Gaussian-Gaussian conjugacy&lt;/h4&gt;
&lt;p&gt;Let us define the diagonal matrix $\boldsymbol{\Omega} = \mathrm{diag}(\omega_1 \cdots \omega_n)$ and vector $\mathbf{z} = \boldsymbol{\Omega}^{-1} \boldsymbol{\kappa}$.
More simply, $\mathbf{z}$ is the vector with $n$th element $z_n = {\kappa_n} / {\omega_n}$.
Hence, by
,
the per-datapoint conditional likelihood $p(y_n | f_n, \omega_n)$ above can be written as
&lt;/p&gt;
$$
\begin{align*}
p(y_n | f_n, \omega_n) &amp; \propto
\exp{\left \{ - \frac{\omega_n}{2} \left (f_n - \frac{\kappa_n}{\omega_n} \right )^2 \right \}} \newline &amp; = \exp{\left \{ - \frac{\omega_n}{2} \left (f_n - z_n \right )^2 \right \}}
\end{align*}
$$&lt;p&gt;
Importantly, this implies that the conditional likelihood over all
variables $p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})$ is simply a
multivariate Gaussian distribution up to a constant factor
&lt;/p&gt;
$$
p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) \propto \mathcal{N}\left (\boldsymbol{\Omega}^{-1} \boldsymbol{\kappa} | \mathbf{f}, \boldsymbol{\Omega}^{-1} \right ).
$$&lt;p&gt;
Refer to
for derivations.
Therefore, a Gaussian prior $p(\mathbf{f})$ is conjugate to the
conditional likelihood $p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega})$, which
leads to $p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega})$, the posterior
over $\mathbf{f}$ conditioned on the auxiliary latent
variables $\boldsymbol{\omega}$, also being a Gaussian&amp;mdash;a property that will
prove crucial to us in the next section.&lt;/p&gt;
&lt;h3 id="inference-gibbs-sampling"&gt;Inference (Gibbs sampling)&lt;/h3&gt;
&lt;p&gt;We wish to compute the posterior
distribution $p(\mathbf{f}, \boldsymbol{\omega} | \mathbf{y})$, the
distribution over the hidden variables $(\mathbf{f}, \boldsymbol{\omega})$
conditioned on the observed variables $\mathbf{y}$.
To produce samples from this distribution
&lt;/p&gt;
$$
(\mathbf{f}^{(t)}, \boldsymbol{\omega}^{(t)}) \sim p(\mathbf{f}, \boldsymbol{\omega} | \mathbf{y}),
$$&lt;p&gt;
we can readily apply Gibbs sampling&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;, an MCMC
algorithm that can be seen as a special case of the Metropolis-Hastings algorithm.&lt;/p&gt;
&lt;p&gt;Each step of the Gibbs sampling procedure involves replacing the value of one
of the variables by a value drawn from the distribution of that variable
conditioned on the values of the remaining variables.
Specifically, we proceed as follows.
At step $t$, we have values $\mathbf{f}^{(t-1)}, \boldsymbol{\omega}^{(t-1)}$
sampled from the previous step.&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;We first replace $\mathbf{f}^{(t-1)}$ by a new
value $\mathbf{f}^{(t)}$ by sampling from the conditional distribution $p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega}^{(t-1)})$,
$$
\mathbf{f}^{(t)} \sim p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega}^{(t-1)}).
$$&lt;/li&gt;
&lt;li&gt;Then we replace $\boldsymbol{\omega}^{(t-1)}$ by $\boldsymbol{\omega}^{(t)}$ by sampling
from the conditional distribution $p(\boldsymbol{\omega}| \mathbf{f}^{(t)})$,
$$
\boldsymbol{\omega}^{(t)} \sim p(\boldsymbol{\omega}| \mathbf{f}^{(t)}),
$$
where we&amp;rsquo;ve used $\mathbf{f}^{(t)}$, the new value for $\mathbf{f}$ from step 1,
straight away in the current step. Note that we&amp;rsquo;ve dropped the conditioning
on $\mathbf{y}$, since $\boldsymbol{\omega}$ does not &lt;em&gt;a posteriori&lt;/em&gt; depend
on this variable.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;We then proceed in like manner, cycling between the two variables in turn until
some convergence criterion is met.&lt;/p&gt;
&lt;p&gt;Suffice it to say, this requires us to first compute the conditional
posteriors $p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega})$
and $p(\boldsymbol{\omega}| \mathbf{f})$, the calculation of which will be the
subject of the next two subsections.&lt;/p&gt;
&lt;h4 id="posterior-over-latent-function-values"&gt;Posterior over latent function values&lt;/h4&gt;
&lt;p&gt;The posterior over the latent function values $\mathbf{f}$ conditioned on the
auxiliary latent variables $\boldsymbol{\omega}$ is
&lt;/p&gt;
$$
p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega}) = \mathcal{N}(\mathbf{f} | \boldsymbol{\mu}, \boldsymbol{\Sigma}),
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\boldsymbol{\mu} = \boldsymbol{\Sigma} \left ( \mathbf{S} \mathbf{m} + \boldsymbol{\kappa} \right )
\quad
\text{and}
\quad
\boldsymbol{\Sigma} = \left (\mathbf{S} + \boldsymbol{\Omega} \right )^{-1}.
$$&lt;p&gt;We readily obtain $\boldsymbol{\mu}$ and $\boldsymbol{\Sigma}$ by noting,
as alluded to earlier, that
&lt;/p&gt;
$$
p(\mathbf{f}) = \mathcal{N}(\mathbf{m}, \mathbf{S}^{-1}),
\qquad
\text{and}
\qquad
p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) \propto \mathcal{N}\left (\boldsymbol{\Omega}^{-1} \boldsymbol{\kappa} | \mathbf{f}, \boldsymbol{\Omega}^{-1} \right ).
$$&lt;p&gt;
Thereafter, we can appeal to the following elementary properties of Gaussian
conditioning and perform some pattern-matching substitutions:&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="marginal-and-conditional-gaussians-bishop-section-233-pg-93"&gt;Marginal and Conditional Gaussians (Bishop, Section 2.3.3, pg. 93)&lt;/h4&gt;
&lt;p&gt;Given a marginal Gaussian distribution for $\mathbf{b}$ and a conditional Gaussian
distribution for $\mathbf{a}$ given $\mathbf{b}$ in the form&lt;/p&gt;
$$
\begin{align*}
p(\mathbf{b}) &amp; =
\mathcal{N}(\mathbf{b} | \mathbf{m}, \mathbf{S}^{-1}) \newline
p(\mathbf{a} | \mathbf{b}) &amp; =
\mathcal{N}(\mathbf{a} | \mathbf{W} \mathbf{b}, \boldsymbol{\Psi}^{-1})
\end{align*}
$$&lt;p&gt;
the marginal distribution of $\mathbf{a}$ and the conditional distribution
of $\mathbf{b}$ given $\mathbf{a}$ are given by
\begin{align*}
p(\mathbf{a}) &amp;amp; =
\mathcal{N}(\mathbf{a} | \mathbf{W} \mathbf{m}, \boldsymbol{\Psi}&lt;sup&gt;{-1} + \mathbf{W} \mathbf{S}&lt;/sup&gt;{-1} \mathbf{W}^{\top}) \newline
p(\mathbf{b} | \mathbf{a}) &amp;amp; =
\mathcal{N}(\mathbf{b} | \boldsymbol{\mu}, \boldsymbol{\Sigma})
\end{align*}
where
&lt;/p&gt;
$$
\boldsymbol{\mu} = \boldsymbol{\Sigma} \left ( \mathbf{W}^{\top} \boldsymbol{\Psi} \mathbf{a} + \mathbf{S} \mathbf{m} \right ),
\quad
\text{and}
\quad
\boldsymbol{\Sigma} = \left (\mathbf{S} + \mathbf{W}^{\top} \boldsymbol{\Psi} \mathbf{W}\right )^{-1}.
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Note that we also could have derived this directly without resorting to
the formulae above by reducing the product of two exponential-quadratic
functions in $p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega}) \propto p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) p(\mathbf{f})$ into a single exponential-quadratic function
up to a constant factor.
It would, however, have been rather tedious and mundane.&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="example-gaussian-process-prior"&gt;Example: Gaussian process prior&lt;/h4&gt;
&lt;p&gt;To make this more concrete, let us revisit the Gaussian process prior we
discussed earlier, namely,
&lt;/p&gt;
$$
p(\mathbf{f} | \mathbf{X}) = \mathcal{N}(\mathbf{m}, \mathbf{K}_X).
$$&lt;p&gt;
By substituting $\mathbf{S}^{-1} = \mathbf{K}_X$ from before, we obtain
&lt;/p&gt;
$$
p(\mathbf{f} | \mathbf{y}, \boldsymbol{\omega}) =
\mathcal{N}(\mathbf{f} | \boldsymbol{\Sigma} \left ( \mathbf{K}_X^{-1} \mathbf{m} + \boldsymbol{\kappa} \right ), \boldsymbol{\Sigma}),
$$&lt;p&gt;
where $\boldsymbol{\Sigma} = \left (\mathbf{K}_X^{-1} + \boldsymbol{\Omega} \right )^{-1}.$&lt;/p&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h4 id="posterior-over-auxiliary-variables"&gt;Posterior over auxiliary variables&lt;/h4&gt;
&lt;p&gt;The posterior over the auxiliary latent variables $\boldsymbol{\omega}$
conditioned on the latent function values $\mathbf{f}$ factorizes as
&lt;/p&gt;
$$
p(\boldsymbol{\omega}| \mathbf{f}) = \prod_{n=1}^{N} p(\omega_n | f_n),
$$&lt;p&gt;
where each factor
&lt;/p&gt;
$$
p(\omega_n | f_n) =
\frac{p(f_n, \omega_n)}{\int p(f_n, \omega_n) d\omega_n}.
$$&lt;p&gt;
Now, the joint factorizes as $p(f_n, \omega_n) = p(f_n | \omega_n) p(\omega_n)$ where
&lt;/p&gt;
$$
p(f_n | \omega_n) = \exp{\left (-\frac{f_n^2}{2}\omega_n \right )},
\quad
\text{and}
\quad
p(\omega_n) = \mathrm{PG}(\omega_n | 1, 0).
$$&lt;p&gt;
Hence, by the
of the Pólya-gamma distribution, we have
&lt;/p&gt;
$$
p(\omega_n | f_n) = \mathrm{PG}(\omega_n | 1, f_n) \propto
\mathrm{PG}(\omega_n | 1, 0) \times
\exp{\left (-\frac{f_n^2}{2}\omega_n \right )} =
p(f_n, \omega_n).
$$&lt;p&gt;
We have omitted the normalizing constant $\int p(f_n, \omega_n) d\omega_n$
from our discussion for the sake of brevity.
If you&amp;rsquo;re interested in calculating it, refer to
.&lt;/p&gt;
&lt;h2 id="implementation-weight-space-view"&gt;Implementation (Weight-space view)&lt;/h2&gt;
&lt;p&gt;Having presented the general form of an augmented model for Bayesian logistic
regression, we now derive a simple instance of this model to tackle a synthetic
one-dimensional classification problem.
In this particular implementation, we make the following choices:
(a) we incorporate a basis function to project inputs into a higher-dimensional feature space, and
(b) we consider an isotropic Gaussian prior on the weights.&lt;/p&gt;
&lt;h3 id="synthetic-one-dimensional-classification-problem"&gt;Synthetic one-dimensional classification problem&lt;/h3&gt;
&lt;p&gt;First we synthesize a one-dimensional classification problem for which
the &lt;em&gt;true&lt;/em&gt; class-membership probability $p(y = 1 | x)$ is both known and easy
to compute.
To this end, let us introduce the following one-dimensional Gaussians,
&lt;/p&gt;
$$
p(x) = \mathcal{N}(1, 1^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(0, 2^2).
$$&lt;p&gt;In code we can specify these as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;scipy.stats&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We evenly draw a total of $N$ samples from both distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;X_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;draw_samples&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where the function &lt;code&gt;draw_samples&lt;/code&gt; is defined as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;draw_samples&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_top&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_samples&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_bot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_samples&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;num_top&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X_top&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rvs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_top&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X_bot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rvs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_bot&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;X_top&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_bot&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The densities of both distributions and their and samples are shown in the
figure below.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/figures/density_paper_1500x927.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Densities of two Gaussians and samples drawn from each.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;From these samples, let us now construct a classification dataset by assigning
label $y = 1$ to inputs $x \sim p(x)$, and $y = 0$ to inputs $x \sim q(x)$.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;make_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where the function &lt;code&gt;make_dataset&lt;/code&gt; is defined as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;make_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expand_dim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Crucially, the true class-membership probability is given exactly by
&lt;/p&gt;
$$
p(y = 1 | x) = \frac{p(x)}{p(x) + q(x)},
$$&lt;p&gt;
thus providing a ground-truth yardstick by which to measure the quality of our
resulting predictions.&lt;/p&gt;
&lt;p&gt;The class-membership probability $p(y = 1 | x)$ is shown in the figure below as
the black curve, along with the dataset $\mathcal{D}_N = \{(\mathbf{x}_n, y_n)\}_{n=1}^N$
where positive instances are colored red and negative instances are colored blue.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/figures/class_prob_true_paper_1500x927.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Classification dataset $\mathcal{D}_N = \{(\mathbf{x}_n, y_n)\}_{n=1}^N$ and the true class-posterior probability.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="prior-1"&gt;Prior&lt;/h3&gt;
&lt;p&gt;To increase the flexibility of our model, we introduce a basis
function $\phi: \mathbb{R}^{D} \to \mathbb{R}^{K}$ that projects
$D$-dimensional input vectors into a $K$-dimensional vector space.
Accordingly, we introduce matrix $\boldsymbol{\Phi} \in \mathbb{R}^{N \times K}$
such that the $n$th column of $\boldsymbol{\Phi}^{\top}$ consists of the
vector $\phi(\mathbf{x}_n)$.
Hence, we assume &lt;em&gt;a priori&lt;/em&gt; that the latent function is of the form&lt;br&gt;
&lt;/p&gt;
$$
f(\mathbf{x}) = \boldsymbol{\beta}^{\top} \phi(\mathbf{x}),
$$&lt;p&gt;
and express vector of latent function values
as $\mathbf{f} = \boldsymbol{\Phi} \boldsymbol{\beta}$.
In this example, we consider a simply polynomial basis function,
&lt;/p&gt;
$$
\phi(x) = \begin{bmatrix} 1 &amp; x &amp; x^2 &amp; \cdots &amp; x^{K-1} \end{bmatrix}^{\top}.
$$&lt;p&gt;Therefore, we call:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where the function &lt;code&gt;basis_function&lt;/code&gt; is defined as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;power&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let us define
the prior over weights as a simple isotropic Gaussian with
precision $\alpha &gt; 0$,
&lt;/p&gt;
$$
p(\boldsymbol{\beta}) = \mathcal{N}(\mathbf{0}, \alpha^{-1} \mathbf{I}),
$$&lt;p&gt;
and the prior over each local auxiliary latent variable as before,
&lt;/p&gt;
$$
p(\omega_n) = \mathrm{PG}(\omega_n | 1, 0).
$$&lt;p&gt;
Since we have analytic forms for the conditional posteriors, we don&amp;rsquo;t need to
implement the priors explicitly.
However, in order to initialize the Gibbs sampler, we may want to be able to
sample from the prior.
Let us do this using the prior over weights:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt; &lt;span class="c1"&gt;# prior precision&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;S_inv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# initialize `beta`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;S_inv&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;or more simply:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt; &lt;span class="c1"&gt;# prior precision&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# initialize `beta`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="conditional-likelihood"&gt;Conditional likelihood&lt;/h3&gt;
&lt;p&gt;The conditional likelihood is defined like before, except we instead
condition on weights $\boldsymbol{\beta}$ and substitute occurrences
of $\mathbf{f}$ with $\boldsymbol{\Phi} \boldsymbol{\beta}$,
&lt;/p&gt;
$$
p(\mathbf{y} | \boldsymbol{\beta}, \boldsymbol{\omega}) \propto \mathcal{N}\left (\boldsymbol{\Omega}^{-1} \boldsymbol{\kappa} | \boldsymbol{\Phi} \boldsymbol{\beta}, \boldsymbol{\Omega}^{-1} \right ).
$$&lt;h3 id="inference-and-prediction-1"&gt;Inference and Prediction&lt;/h3&gt;
&lt;h4 id="posterior-over-latent-function-values-1"&gt;Posterior over latent function values&lt;/h4&gt;
&lt;p&gt;The posterior over the latent weights $\boldsymbol{\beta}$ conditioned on the
auxiliary latent variables $\boldsymbol{\omega}$ is
&lt;/p&gt;
$$
p(\boldsymbol{\beta} | \mathbf{y}, \boldsymbol{\omega}) = \mathcal{N}(\boldsymbol{\beta} | \boldsymbol{\Sigma} \boldsymbol{\Phi}^{\top} \boldsymbol{\kappa}, \boldsymbol{\Sigma}),
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\boldsymbol{\Sigma} = \left (\boldsymbol{\Phi}^{\top} \boldsymbol{\Omega} \boldsymbol{\Phi} + \alpha \mathbf{I} \right )^{-1}.
$$&lt;p&gt;Let us implement the function that computes the mean and covariance
of $p(\boldsymbol{\beta} | \mathbf{y}, \boldsymbol{\omega})$:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Sigma_inv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma_inv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma_inv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;and a function to return samples from the multivariate Gaussian parameterized
by this mean and covariance:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;random_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;check_random_state&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h4 id="posterior-over-auxiliary-variables-1"&gt;Posterior over auxiliary variables&lt;/h4&gt;
&lt;p&gt;The conditional posterior over the local auxiliary variable $\omega_n$ is
defined as before, except we instead condition on weights $\boldsymbol{\beta}$
and substitute occurrences of $f_n$ with $\boldsymbol{\beta}^{\top} \phi(\mathbf{x}_n)$,
&lt;/p&gt;
$$
p(\omega_n | \boldsymbol{\beta}) \propto
\mathrm{PG}(\omega_n | 1, \boldsymbol{\beta}^{\top} \phi(\mathbf{x}_n)).
$$&lt;p&gt;Let us implement a function to compute the parameters of the posterior
Polya-gamma distribution:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;and accordingly a function to return samples from this distribution:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;PyPolyaGamma&lt;/span&gt;&lt;span class="p"&gt;()):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;shape mismatch&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;empty_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pgdrawv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where we have imported the &lt;code&gt;PyPolyaGamma&lt;/code&gt; object from
the
package:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;from pypolyagamma import PyPolyaGamma
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The &lt;code&gt;pypolyagamma&lt;/code&gt; package can be installed via &lt;code&gt;pip&lt;/code&gt; as usual:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;$ pip install pypolyagamma
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To provide some context, this package is a
port, created by S. Linderman, of the original
R package
authored by J. Windle
that implements the method described in their paper on the efficient sampling
of Pólya-gamma variables&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h4 id="gibbs-sampling"&gt;Gibbs sampling&lt;/h4&gt;
&lt;p&gt;With these functions defined, we can define the Gibbs sampling procedure by the
simple for-loop below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# preprocessing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;kappa&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# initialize `beta`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_iterations&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We now visualize the samples $(\boldsymbol{\beta}^{(t)}, \boldsymbol{\omega}^{(t)})$
produced by this procedure.
In the figures that follow, we set the hues to be proportional to the step
counter $t$ along a perceptually uniform colormap.&lt;/p&gt;
&lt;p&gt;First, we show the sampled weight vector $\boldsymbol{\beta}^{(t)} \in \mathbb{R}^K$
where we have set $K = 3$.
We plot the $i$th entry $\beta_i^{(t)}$ against the $j$th entry $\beta_j^{(t)}$
for all $i &lt; j$ and $0 &lt; j &lt; K$.
&lt;figure&gt;&lt;img src="https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/figures/beta_paper_600x600.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Parameter $\boldsymbol{\beta}^{(t)}$ samples as Gibbs sampling iteration $t$ increases.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
We find a strong correlation between $\beta_1$ and $\beta_2$, the
coefficients associated with the linear and quadratic terms of our augmented
feature representation, respectively.
Furthermore, we find $\beta_1$ to consistently have a relatively large
magnitude.&lt;/p&gt;
&lt;p&gt;Second, we show the sampled auxiliary latent variables $\boldsymbol{\omega}^{(t)}$ by
plotting the pairs $(x_n, \omega_n^{(t)})$.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/figures/omega_paper_1500x927.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Auxiliary variable $\omega_n^{(t)}$ samples as Gibbs sampling iteration $t$ increases. For visualization purposes, each $\omega_n^{(t)}$ is placed at its corresponding input location $x_n$ along the horizontal axis.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;As expected, we find longer-tailed distributions in the variables $\omega_n$
that are associated with negative examples.&lt;/p&gt;
&lt;p&gt;Finally, we plot the sampled class-membership probability predictions
&lt;/p&gt;
$$
\pi^{(t)}(\mathbf{x}) = \sigma(f^{(t)}(\mathbf{x})),
\quad
\text{where}
\quad
f^{(t)}(\mathbf{x}) = {\boldsymbol{\beta}^{(t)}}^{\top} \phi(\mathbf{x}),
$$&lt;p&gt;
in the figure below:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/polya-gamma-bayesian-logistic-regression/figures/class_prob_pred_paper_1500x927.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Predicted class-membership probability $\pi^{(t)}(\mathbf{x})$ as Gibbs sampling iteration $t$ increases.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;At least qualitatively, we find that the sampling procedure produces
predictions that fit the true class-membership probability reasonably well.&lt;/p&gt;
&lt;h3 id="code"&gt;Code&lt;/h3&gt;
&lt;p&gt;The full code is reproduced below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;scipy.stats&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;pypolyagamma&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PyPolyaGamma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;draw_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;make_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# constants&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;num_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;num_iterations&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;degree&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt; &lt;span class="c1"&gt;# prior precision&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;seed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8888&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;random_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;RandomState&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;pg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PyPolyaGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# generate dataset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;X_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;draw_samples&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;make_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# preprocessing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;kappa&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# initialize `beta`&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_iterations&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;where the module &lt;code&gt;utils.py&lt;/code&gt; contains:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;sklearn.utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;check_random_state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;pypolyagamma&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PyPolyaGamma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;draw_samples&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_top&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_samples&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;rate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_bot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_samples&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;num_top&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X_top&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rvs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_top&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X_bot&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rvs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_bot&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;X_top&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_bot&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;make_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expand_dims&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_pos&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_neg&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;basis_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;power&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;degree&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;PyPolyaGamma&lt;/span&gt;&lt;span class="p"&gt;()):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;shape mismatch&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;empty_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pgdrawv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;random_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;check_random_state&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;eye&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Sigma_inv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;eye&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma_inv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma_inv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="bonus-gibbs-sampling-with-mutual-recursion-and-generator-delegation"&gt;Bonus: Gibbs sampling with mutual recursion and generator delegation&lt;/h3&gt;
&lt;p&gt;The Gibbs sampling procedure naturally lends itself to implementations based
on
.
Combining this with the &lt;code&gt;yield from&lt;/code&gt; expression for
,
we can succinctly replace the for-loop with the following mutually recursive
functions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gibbs_sampler&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_auxiliary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;polya_gamma_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;yield from&lt;/span&gt; &lt;span class="n"&gt;gibbs_sampler_helper&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gibbs_sampler_helper&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;conditional_posterior_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gassian_sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;yield&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;yield from&lt;/span&gt; &lt;span class="n"&gt;gibbs_sampler&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now you can use &lt;code&gt;gibbs_sampler&lt;/code&gt; as a
,
for example, to explicitly iterate over it in a for-loop:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omega&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;gibbs_sampler&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;stop_predicate&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# do something&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;pass&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;or by making use of
and other
primitives:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;itertools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;islice&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# example: collect beta and omega samples into respective lists&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;betas&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;omegas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;islice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gibbs_sampler&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kappa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;num_iterations&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;There are a few obvious drawbacks to this implementation.
First, while it may be a lot fun to write, it will probably not be as fun to
read when you revisit it later on down the line.
Second, you may occasionally find yourself hitting the maximum recursion depth
before you have reached a sufficient number of iterations for the warm-up
or &amp;ldquo;burn-in&amp;rdquo; phase to have been completed.
It goes without saying, the latter can make this implementation a non-starter.&lt;/p&gt;
&lt;h2 id="links-and-further-readings"&gt;Links and Further Readings&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;Papers:
&lt;ul&gt;
&lt;li&gt;Original paper (Polson et al., 2013)&lt;sup id="fnref1:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;Extended to GP classification (Wenzel et al., 2019)&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;Few-shot classification with GPs and the one-vs-each likelihood (Snell et al., 2020)&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Blog posts:
&lt;ul&gt;
&lt;li&gt;
by G. Gundersen&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Code:
&lt;ul&gt;
&lt;li&gt;
: A Python package by S. Linderman&lt;/li&gt;
&lt;li&gt;
: An R package by J. Windle&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2021polyagamma,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{A} {P}rimer on {P}ólya-gamma {R}andom {V}ariables - {P}art II: {B}ayesian {L}ogistic {R}egression&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2021&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/polya-gamma-bayesian-logistic-regression/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="appendix"&gt;Appendix&lt;/h2&gt;
&lt;h3 id="i"&gt;I&lt;/h3&gt;
&lt;p&gt;First, note that the logistic function can be written as
&lt;/p&gt;
$$
\sigma(u) = \frac{1}{1+e^{-u}} = \frac{e^u}{1+e^u}
$$&lt;p&gt;
Therefore, we have
&lt;/p&gt;
$$
\begin{align*}
p(y_n | f_n) &amp;=
\left ( \frac{e^{f_n}}{1+e^{f_n}} \right )^{y_n}
\left ( \frac{\left (1+e^{f_n} \right ) - e^{f_n}}{1+e^{f_n}} \right )^{1-y_n} \newline &amp;=
\left ( \frac{e^{f_n}}{1+e^{f_n}} \right )^{y_n}
\left ( \frac{1}{1+e^{f_n}} \right )^{1-y_n} \newline &amp;=
\left (e^{f_n} \right )^{y_n} \left ( \frac{1}{1+e^{f_n}} \right )^{y_n}
\left ( \frac{1}{1+e^{f_n}} \right )^{1-y_n} \newline &amp;=
\frac{e^{y_n f_n}}{1 + e^{f_n}}
\end{align*}
$$&lt;h3 id="ii"&gt;II&lt;/h3&gt;
&lt;p&gt;The conditional likelihood factorizes as
&lt;/p&gt;
$$
\begin{align*}
p(\mathbf{y} | \mathbf{f}, \boldsymbol{\omega}) &amp;=
\prod_{i=1}^n p(y_n | f_n, \omega_n) \newline &amp;\propto
\prod_{i=1}^n \exp{\left ( - \frac{\omega_n}{2} \left (f_n - z_n \right )^2 \right )} \newline &amp;=
\exp{\left ( - \frac{1}{2} \sum_{i=1}^n \omega_n \left (f_n - z_n \right )^2 \right )} \newline &amp;=
\exp{\left \{ - \frac{1}{2} (\mathbf{f} - \mathbf{z})^{\top} \boldsymbol{\Omega} (\mathbf{f} - \mathbf{z}) \right \}} \newline &amp;\propto
\mathcal{N}\left (\boldsymbol{\Omega}^{-1} \boldsymbol{\kappa} | \mathbf{f}, \boldsymbol{\Omega}^{-1} \right )
\end{align*}
$$&lt;h3 id="iii"&gt;III&lt;/h3&gt;
&lt;p&gt;We have omitted the normalizing constant $\int p(f_n, \omega_n) d\omega_n$
from our discussion for the sake of brevity since it is not required to carry
out inference using Gibbs sampling.
However, this is easy to compute, simply by referring to the Laplace transform
of the $\mathrm{PG}(1, 0)$ distribution:&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="laplace-transform-of-the--distribution-polson-et-al-2013"&gt;Laplace transform of the $\mathrm{PG}(1, 0)$ distribution (Polson et al. 2013)&lt;/h4&gt;
&lt;p&gt;The
of the $\mathrm{PG}(1, 0)$ distribution is
&lt;/p&gt;
$$
\mathbb{E}_{\omega \sim \mathrm{PG}(1, 0)}[\exp(-\omega t)] =
\frac{1}{\cosh{\left(\sqrt{\frac{t}{2}}\right)}}.
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Hence, by making the substitution $t = \frac{f_n^2}{2}$, we obtain
&lt;/p&gt;
$$
\int p(f_n, \omega_n) d\omega_n =
\int \exp{\left (-\frac{f_n^2}{2}\omega_n \right )} \mathrm{PG}(\omega_n | 1, 0) d\omega_n =
\frac{1}{\cosh{\left(\frac{f_n}{2}\right)}}.
$$&lt;p&gt;
Therefore, we have
&lt;/p&gt;
$$
\begin{align*}
p(\omega_n | f_n) &amp; =
\frac{p(f_n, \omega_n)}{\int p(f_n, \omega_n) d\omega_n} \newline &amp; =
\cosh{\left(\frac{f_n}{2}\right)} \exp{\left (-\frac{f_n^2}{2}\omega_n \right )}
\mathrm{PG}(\omega_n | 1, 0) \newline &amp; = \mathrm{PG}(\omega_n | 1, f_n).
\end{align*}
$$&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;MacKay, D. J. (1992).
. Neural Computation, 4(5), 720-736.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Jaakkola, T. S., &amp;amp; Jordan, M. I. (2000).
. Statistics and Computing, 10(1), 25-37.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Minka, T. P. (2001, August).
. In Proceedings of the Seventeenth Conference on Uncertainty in Artificial Intelligence (pp. 362-369).&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Polson, N. G., Scott, J. G., &amp;amp; Windle, J. (2013).
. Journal of the American Statistical Association, 108(504), 1339-1349.&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Geman, S., &amp;amp; Geman, D. (1984).
. IEEE Transactions on Pattern Analysis and Machine Intelligence, (6), 721-741.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;Windle, J., Polson, N. G., &amp;amp; Scott, J. G. (2014).
. arXiv preprint arXiv:1405.0506.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;Wenzel, F., Galy-Fajou, T., Donner, C., Kloft, M., &amp;amp; Opper, M. (2019, July).
. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 33, No. 01, pp. 5417-5424).&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Snell, J., &amp;amp; Zemel, R. (2020).
. arXiv preprint arXiv:2007.10417.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>An Illustrated Guide to the Knowledge Gradient Acquisition Function</title><link>https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/</link><pubDate>Thu, 18 Feb 2021 19:13:23 +0100</pubDate><guid>https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/</guid><description>
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;Draft &amp;ndash; work in progress.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;We provide a short guide to the knowledge-gradient (KG) acquisition
function (Frazier et al., 2009)&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt; for Bayesian
optimization (BO).
Rather than being a self-contained tutorial, this posts is intended to serve as
an illustrated compendium to the paper of Frazier et al., 2009&lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;
and the subsequent tutorial by Frazier, 2018&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, authored
nearly a decade later.&lt;/p&gt;
&lt;p&gt;This post assumes a basic level of familiarity with BO and Gaussian processes (GPs),
to the extent provided by the literature survey of Shahriari et al.,
2015&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;, and the acclaimed textbook of Rasmussen and Williams, 2006,
respectively.&lt;/p&gt;
&lt;h2 id="knowledge-gradient"&gt;Knowledge-gradient&lt;/h2&gt;
&lt;p&gt;First, we set-up the notation and terminology.
Let $f: \mathcal{X} \to \mathbb{R}$ be the blackbox function we wish to
minimize.
We denote the GP posterior predictive distribution, or &lt;em&gt;predictive&lt;/em&gt; for short,
by $p(y | \mathbf{x}, \mathcal{D})$.
The mean of the predictive, or the &lt;em&gt;predictive mean&lt;/em&gt; for short, is denoted by
&lt;/p&gt;
$$
\mu(\mathbf{x}; \mathcal{D}) = \mathbb{E}[y | \mathbf{x}, \mathcal{D}]
$$&lt;p&gt;
Let $\mathcal{D}_n$ be the set of $n$ input-output
observations $\mathcal{D}_n = \{ (\mathbf{x}_i, y_i) \}_{i=1}^n$, where
output $y_i = f(\mathbf{x}_i) + \epsilon$ is assumed to be observed with noise
$\epsilon \sim \mathcal{N}(0, \sigma^2)$.
We make the following abbreviation
&lt;/p&gt;
$$
\mu_n(\mathbf{x}) = \mu(\mathbf{x}; \mathcal{D}_n)
$$&lt;p&gt;
Next, we define the minimum of the predictive mean, or &lt;em&gt;predictive minimum&lt;/em&gt; for short,
as
&lt;/p&gt;
$$
\tau(\mathcal{D}) = \min_{\mathbf{x}' \in \mathcal{X}} \mu(\mathbf{x}'; \mathcal{D})
$$&lt;p&gt;
If we view $\mu(\mathbf{x}; \mathcal{D})$ as our fit to the underlying
function $f(\mathbf{x})$ from which the observations $\mathcal{D}$ were
generated, then $\tau(\mathcal{D})$ is our estimate of the minimum of $f(\mathbf{x})$,
given observations $\mathcal{D}$.&lt;/p&gt;
&lt;p&gt;Further, we make the following abbreviations
&lt;/p&gt;
$$
\tau_n = \tau(\mathcal{D}_n),
\qquad
\text{and}
\qquad
\tau_{n+1} = \tau(\mathcal{D}_{n+1}),
$$&lt;p&gt;
where $\mathcal{D}_{n+1} = \mathcal{D}_n \cup \{ (\mathbf{x}, y) \}$ is the
set of existing observations, augmented by some input-output pair $(\mathbf{x}, y)$.
Then, the knowledge-gradient is defined as
&lt;/p&gt;
$$
\alpha(\mathbf{x}; \mathcal{D}_n) =
\mathbb{E}_{p(y | \mathbf{x}, \mathcal{D}_n)} [ \tau_n - \tau_{n+1} ]
$$&lt;p&gt;
Crucially, note that $\tau_{n+1}$ is implicitly a function of $(\mathbf{x}, y)$,
and that this expression integrates over all possible input-output observation
pairs $(\mathbf{x}, y)$ for the given $\mathbf{x}$ under the
predictive $p(y | \mathbf{x}, \mathcal{D}_n)$.&lt;/p&gt;
&lt;h3 id="monte-carlo-estimation"&gt;Monte Carlo estimation&lt;/h3&gt;
&lt;p&gt;Not surprisingly, the knowledge-gradient function is analytically intractable.
Therefore, in practice, we compute it using Monte Carlo estimation,
&lt;/p&gt;
$$
\alpha(\mathbf{x}; \mathcal{D}_n) \approx
\frac{1}{M} \left ( \sum_{m=1}^M \tau_n - \tau_{n+1}^{(m)} \right ),
\qquad
y^{(m)} \sim p(y | \mathbf{x}, \mathcal{D}_n),
$$&lt;p&gt;
where $\tau_{n+1}^{(m)} = \tau(\mathcal{D}_{n+1}^{(m)})$
and $\mathcal{D}_{n+1}^{(m)} = \mathcal{D}_n \cup \{ (\mathbf{x}, y^{(m)}) \}$.&lt;/p&gt;
&lt;p&gt;We refer to $y^{(m)}$ as the $m$th simulated outcome, or the $m$th &lt;em&gt;simulation&lt;/em&gt;
for short.
Then, $\mathcal{D}_{n+1}^{(m)}$ is the $m$th simulation-augmented dataset and,
accordingly, $\tau_{n+1}^{(m)}$ is the $m$th simulation-augmented predictive minimum.&lt;/p&gt;
&lt;p&gt;We see that this approximation to the knowledge-gradient is simply the average
difference between the predictive minimum values &lt;em&gt;based on simulation-augmented
data&lt;/em&gt; $\tau_{n+1}^{(m)}$, and that &lt;em&gt;based on observed data&lt;/em&gt; $\tau_n$,
across $M$ simulations.&lt;/p&gt;
&lt;p&gt;This might take a moment to digest, as there are quite a number of moving parts
to keep track of. To help visualize these parts, we provide an illustration of
each of the steps required to compute KG on a simple one-dimensional synthetic
problem.&lt;/p&gt;
&lt;h2 id="one-dimensional-example"&gt;One-dimensional example&lt;/h2&gt;
&lt;p&gt;As the running example throughout this post, we use a synthetic function
defined as
&lt;/p&gt;
$$
f(x) = \sin(3x) + x^2 - 0.7 x.
$$&lt;p&gt;
We generate $n=10$ observations at locations sampled uniformly at random.
The true function, and the set of noisy observations $\mathcal{D}_n$ are
visualized in the figure below:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/observations_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Latent blackbox function and $n=10$ observations.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Using the observations $\mathcal{D}_n$ we have collected so far, we wish to
use KG to score a candidate location $x_c$ at which to evaluate next.&lt;/p&gt;
&lt;h2 id="posterior-predictive-distribution"&gt;Posterior predictive distribution&lt;/h2&gt;
&lt;p&gt;The posterior predictive $p(y | \mathbf{x}, \mathcal{D}_n)$ is visualized in
the figure below. In particular, the predictive mean $\mu_n(\mathbf{x})$ is
represented by the solid orange curve.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_mean_before_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Posterior predictive distribution (*before* hyperparameter estimation).&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Clearly, this is a poor fit to the data and a uncalibrated estimation of the
predictive uncertainly.&lt;/p&gt;
&lt;h3 id="step-1-hyperparameter-estimation"&gt;Step 1: Hyperparameter estimation&lt;/h3&gt;
&lt;p&gt;Therefore, first step is to optimize the hyperparameters of the GP regression
model, i.e. the kernel lengthscale, amplitude, and the observation noise variance.
We do this using type-II maximum likelihood estimation (MLE), or &lt;em&gt;empirical Bayes&lt;/em&gt;.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_mean_after_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Posterior predictive distribution (*after* hyperparameter estimation).&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-2-determine-the-predictive-minimum"&gt;Step 2: Determine the predictive minimum&lt;/h3&gt;
&lt;p&gt;Next, we compute the predictive minimum $\tau_n = \min_{\mathbf{x}' \in \mathcal{X}} \mu_n(\mathbf{x}')$.
Since $\mu_n$ is end-to-end differentiable wrt to input $\mathbf{x}$, we can
simply use a multi-started quasi-Newton hill-climber such as L-BFGS.
We visualize this in the figure below, where the value of the predictive
minimum is represented by the orange horizontal dashed line, and its location is
denoted by the orange star and triangle.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/predictive_minimum_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Predictive minimum $\tau_n$.&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-3-compute-simulation-augmented-predictive-means"&gt;Step 3: Compute simulation-augmented predictive means&lt;/h3&gt;
&lt;p&gt;Suppose we are scoring the candidate location $x_c = 0.1$.
For illustrative purposes, let us draw just $M=1$ sample $y_c^{(1)} \sim p(y | x_c, \mathcal{D}_n)$.
In the figure below, the candidate location $x_c$ is represented by the
vertical solid gray line, and the single simulated outcome $y_c^{(1)}$ is
represented by the filled blue dot.&lt;/p&gt;
&lt;p&gt;In general, we denote the simulation-augmented predictive mean as
&lt;/p&gt;
$$
\mu_{n+1}^{(m)}(\mathbf{x}) = \mu(\mathbf{x}; \mathcal{D}_{n+1}^{(m)}),
$$&lt;p&gt;
where
$\mathcal{D}_{n+1}^{(m)} = \mathcal{D}_n \cup \{ (\mathbf{x}, y^{(m)}) \}$
as defined earlier.&lt;/p&gt;
&lt;p&gt;Here, the simulation-augmented dataset $\mathcal{D}_{n+1}^{(1)}$ is the set
of existing observations $\mathcal{D}_n$, augmented by the simulated
input-output pair $(x_c, y_c^{(1)})$,
&lt;/p&gt;
$$
\mathcal{D}_{n+1}^{(1)} = \mathcal{D}_n \cup \{ (x_c, y_c^{(1)}) \},
$$&lt;p&gt;
and the corresponding simulation-augmented predictive mean $\mu_{n+1}^{(1)}(x)$
is represented in the figure below by the solid blue curve.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/simulated_predictive_mean_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive mean $\mu_{n&amp;#43;1}^{(1)}(x)$ at location $x_c = 0.1$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;h3 id="step-4-compute-simulation-augmented-predictive-minimums"&gt;Step 4: Compute simulation-augmented predictive minimums&lt;/h3&gt;
&lt;p&gt;Next, we compute the simulation-augmented predictive minimum
&lt;/p&gt;
$$
\tau_{n+1}^{(1)} = \min_{\mathbf{x}' \in \mathcal{X}} \mu_{n+1}^{(1)}(\mathbf{x}')
$$&lt;p&gt;
It may not be immediately obvious, but $\mu_{n+1}^{(1)}$ is in fact also
end-to-end differentiable wrt to input $\mathbf{x}$. Therefore, we can again
appeal to an method such as L-BFGS.
We visualize this in the figure below, where the value of the simulation-augmented
predictive minimum is represented by the blue horizontal dashed line, and its
location is denoted by the blue star and triangle.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/simulated_predictive_minimum_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive minimum $\tau_{n&amp;#43;1}^{(1)}$ at location $x_c = 0.1$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Taking the difference between the orange and blue horizontal dashed line will
give us an unbiased estimate of the knowledge-gradient.
However, this is likely to be a crude one, since it is based on just a single
MC sample.
To obtain a more accurate estimate, one needs to increase $M$, the number of
MC samples.&lt;/p&gt;
&lt;h4 id="samples"&gt;Samples $M &gt; 1$&lt;/h4&gt;
&lt;p&gt;Let us now consider $M=5$ samples. We draw $y_c^{(m)} \sim p(y | x_c, \mathcal{D}_n)$,
for $m = 1, \dotsc, 5$.
As before, the input location $x_c$ is represented by the vertical solid
gray line, and the corresponding simulated outcomes are represented by the
filled dots below, with varying hues from a perceptually uniform color palette
to distinguish between samples.&lt;/p&gt;
&lt;p&gt;Accordingly, the simulation-augmented predictive means
$\mu_{n+1}^{(m)}(x)$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$ are
represented by the colored curves, with hues set to that of the simulated
outcome on which the predictive distribution is based.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/bar_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive mean $\mu_{n&amp;#43;1}^{(m)}(x)$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Next we compute the simulation-augmented predictive
minimum $\tau_{n+1}^{(m)}$, which requires minimizing
$\mu_{n+1}^{(m)}(x)$ for $m = 1, \dotsc, 5$.
These values are represented below by the horizontal dashed lines, and their
location is denoted by the stars and triangles.&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/figures/baz_paper_1800x1112.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;Simulation-augmented predictive minimum $\tau_{n&amp;#43;1}^{(1)}$ at location $x_c = 0.1$, for $m = 1, \dotsc, 5$&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Finally, taking the average difference between the orange dashed line and every
other dashed line gives us the estimate of the knowledge gradient at
input $x_c$.&lt;/p&gt;
&lt;h2 id="links-and-further-readings"&gt;Links and Further Readings&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;In this post, we only showed a (naïve) approach to calculating the KG at a
given location.
Suffice it to say, there is still quite a gap between this and being able to
efficiently minimize KG within a sequential decision-making algorithm.
For a guide on incorporating KG in a modular and fully-fledged framework for
BO (namely
) see
&lt;/li&gt;
&lt;li&gt;Another introduction to KG:
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2021knowledge,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{A}n {I}llustrated {G}uide to the {K}nowledge {G}radient {A}cquisition {F}unction&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2021&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/an-illustrated-guide-to-the-knowledge-gradient-acquisition-function/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Frazier, P., Powell, W., &amp;amp; Dayanik, S. (2009).
. INFORMS Journal on Computing, 21(4), 599-613.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Frazier, P. I. (2018).
. arXiv preprint arXiv:1807.02811.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., &amp;amp; De Freitas, N. (2015).
. Proceedings of the IEEE, 104(1), 148-175.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Contributed Talk: BORE — Bayesian Optimization by Density-Ratio Estimation</title><link>https://tiao.io/events/neurips2020-meta-learning/</link><pubDate>Fri, 11 Dec 2020 15:00:00 +0000</pubDate><guid>https://tiao.io/events/neurips2020-meta-learning/</guid><description/></item><item><title>Bayesian Optimization by Density Ratio Estimation</title><link>https://tiao.io/publications/bore-1/</link><pubDate>Tue, 01 Dec 2020 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/bore-1/</guid><description/></item><item><title>📄 One paper accepted to NeurIPS 2020</title><link>https://tiao.io/posts/one-paper-accepted-to-neurips2020/</link><pubDate>Fri, 25 Sep 2020 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/one-paper-accepted-to-neurips2020/</guid><description>&lt;p&gt;Our paper
was accepted to NeurIPS 2020 as a
&lt;strong&gt;Spotlight Presentation&lt;/strong&gt; (awarded to the top 3% of submissions). This is
joint work with Pantelis Elinas and Edwin Bonilla.&lt;/p&gt;</description></item><item><title>Variational Inference for Graph Convolutional Networks in the Absence of Graph Data and Adversarial Settings</title><link>https://tiao.io/publications/vi-gcn-2/</link><pubDate>Mon, 01 Jun 2020 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/vi-gcn-2/</guid><description>&lt;p&gt;This paper is a follow-up to our
, previously
presented at the NeurIPS2019 Graph Representation Learning Workshop, now with
significantly expanded experimental analyses.&lt;/p&gt;
&lt;div id="presentation-embed-38937946"&gt;&lt;/div&gt;
&lt;script src='https://slideslive.com/embed_presentation.js'&gt;&lt;/script&gt;
&lt;script&gt;
embed = new SlidesLiveEmbed('presentation-embed-38937946', {
presentationId: '38937946',
autoPlay: false, // change to true to autoplay the embedded presentation
verticalEnabled: true
});
&lt;/script&gt;</description></item><item><title>Model-based Asynchronous Hyperparameter and Neural Architecture Search</title><link>https://tiao.io/publications/async-multi-fidelity-hpo/</link><pubDate>Sun, 01 Mar 2020 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/async-multi-fidelity-hpo/</guid><description/></item><item><title>Variational Graph Convolutional Networks</title><link>https://tiao.io/publications/vi-gcn-1/</link><pubDate>Sun, 01 Dec 2019 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/vi-gcn-1/</guid><description/></item><item><title>A Handbook for Sparse Variational Gaussian Processes</title><link>https://tiao.io/posts/sparse-variational-gaussian-processes/</link><pubDate>Fri, 13 Sep 2019 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/sparse-variational-gaussian-processes/</guid><description>
&lt;details class="print:hidden xl:hidden" &gt;
&lt;summary&gt;Table of Contents&lt;/summary&gt;
&lt;div class="text-sm"&gt;
&lt;nav id="TableOfContents"&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#prior"&gt;Prior&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#marginal-prior-over-inducing-variables"&gt;Marginal prior over inducing variables&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#conditional-prior"&gt;Conditional prior&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#variational-distribution"&gt;Variational Distribution&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#whitened-parameterization"&gt;Whitened parameterization&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#inference"&gt;Inference&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#preliminaries"&gt;Preliminaries&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#gaussian-likelihoods--sparse-gaussian-process-regression-sgpr"&gt;Gaussian Likelihoods &amp;ndash; Sparse Gaussian Process Regression (SGPR)&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#non-gaussian-likelihoods"&gt;Non-Gaussian Likelihoods&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#large-scale-data-with-stochastic-optimization"&gt;Large-Scale Data with Stochastic Optimization&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;a href="#links-and-further-readings"&gt;Links and Further Readings&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#appendix"&gt;Appendix&lt;/a&gt;
&lt;ul&gt;
&lt;li&gt;&lt;a href="#i"&gt;I&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#ii"&gt;II&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#iii"&gt;III&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#iv"&gt;IV&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#v"&gt;V&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#vi"&gt;VI&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="#vii"&gt;VII&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/nav&gt;
&lt;/div&gt;
&lt;/details&gt;
&lt;p&gt;In the sparse variational Gaussian process (SVGP) framework (Titsias, 2009)&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;,
one augments the joint distribution $p(\mathbf{y}, \mathbf{f})$ with auxiliary
variables $\mathbf{u}$ so that the joint becomes
&lt;/p&gt;
$$
p(\mathbf{y}, \mathbf{f}, \mathbf{u}) = p(\mathbf{y} | \mathbf{f}) p(\mathbf{f}, \mathbf{u}).
$$&lt;p&gt;
The vector $\mathbf{u} = \begin{bmatrix} u(\mathbf{z}_1) \cdots u(\mathbf{z}_M)\end{bmatrix}^{\top} \in \mathbb{R}^M$
consists of &lt;em&gt;inducing variables&lt;/em&gt;, the latent function values corresponding
to the &lt;em&gt;inducing input&lt;/em&gt; locations contained in the matrix
$\mathbf{Z} = \begin{bmatrix} \mathbf{z}_1 \cdots \mathbf{z}_M \end{bmatrix}^{\top} \in \mathbb{R}^{M \times D}$.&lt;/p&gt;
&lt;h2 id="prior"&gt;Prior&lt;/h2&gt;
&lt;p&gt;The joint distribution of the latent function values $\mathbf{f}$, and the
inducing variables $\mathbf{u}$ according to the prior is
&lt;/p&gt;
$$
p(\mathbf{f}, \mathbf{u}) =
\mathcal{N} \left (
\begin{bmatrix}
\mathbf{f} \newline
\mathbf{u}
\end{bmatrix}
;
\begin{bmatrix}
\mathbf{0} \newline
\mathbf{0}
\end{bmatrix},
\begin{bmatrix}
\mathbf{K}_\mathbf{ff} &amp; \mathbf{K}_\mathbf{uf}^\top \newline
\mathbf{K}_\mathbf{uf} &amp; \mathbf{K}_\mathbf{uu}
\end{bmatrix}
\right ).
$$&lt;p&gt;
If we let the joint prior factorize as
&lt;/p&gt;
$$
p(\mathbf{f}, \mathbf{u}) = p(\mathbf{f} | \mathbf{u}) p(\mathbf{u}),
$$&lt;p&gt;
we can apply the rules of Gaussian conditioning to derive the marginal prior
$p(\mathbf{u})$ and conditional prior $p(\mathbf{f} | \mathbf{u})$.&lt;/p&gt;
&lt;h3 id="marginal-prior-over-inducing-variables"&gt;Marginal prior over inducing variables&lt;/h3&gt;
&lt;p&gt;The marginal prior over inducing variables is simply given by
&lt;/p&gt;
$$
p(\mathbf{u}) = \mathcal{N}(\mathbf{u} | \mathbf{0}, \mathbf{K}_\mathbf{uu}).
$$
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="gaussian-process-notation"&gt;Gaussian process notation&lt;/h4&gt;
&lt;p&gt;We can express the prior over the inducing variable $u(\mathbf{z})$ at
inducing input $\mathbf{z}$ as
&lt;/p&gt;
$$
p(u(\mathbf{z})) = \mathcal{GP}(0, k_{\theta}(\mathbf{z}, \mathbf{z}')).
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h3 id="conditional-prior"&gt;Conditional prior&lt;/h3&gt;
&lt;p&gt;First, let us define the vector-valued function $\boldsymbol{\psi}_\mathbf{u}: \mathbb{R}^{D} \to \mathbb{R}^{M}$ as
&lt;/p&gt;
$$
\boldsymbol{\psi}_\mathbf{u}(\mathbf{x}) \triangleq \mathbf{K}_\mathbf{uu}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x}),
$$&lt;p&gt;
where $\mathbf{k}_\mathbf{u}(\mathbf{x}) = k_{\theta}(\mathbf{Z}, \mathbf{x})$ denotes the
vector of covariances between $\mathbf{x}$ and the inducing inputs $\mathbf{Z}$.
Further, let $\boldsymbol{\Psi} \in \mathbb{R}^{M \times N}$ be the matrix
containing values of function $\psi$ applied row-wise to the matrix of inputs
$\mathbf{X} = \begin{bmatrix} \mathbf{x}_1 \cdots \mathbf{x}_N \end{bmatrix}^{\top} \in \mathbb{R}^{N \times D}$,
&lt;/p&gt;
$$
\boldsymbol{\Psi} \triangleq
\begin{bmatrix}
\psi(\mathbf{x}_1)
\cdots
\psi(\mathbf{x}_N)
\end{bmatrix} = \mathbf{K}_\mathbf{uu}^{-1} \mathbf{K}_\mathbf{uf}.
$$&lt;p&gt;
Then, we can condition the joint prior distribution on the inducing
variables to give
&lt;/p&gt;
$$
p(\mathbf{f} | \mathbf{u}) = \mathcal{N}(\mathbf{f} | \mathbf{m}, \mathbf{S}),
$$&lt;p&gt;
where the mean vector and covariance matrix are
&lt;/p&gt;
$$
\mathbf{m} = \boldsymbol{\Psi}^{\top} \mathbf{u},
\quad
\text{and}
\quad
\mathbf{S} = \mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^{\top} \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi}.
$$
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="gaussian-process-notation"&gt;Gaussian process notation&lt;/h4&gt;
&lt;p&gt;We can express the distribution over the function value $f(\mathbf{x})$ at
input $\mathbf{x}$, given $\mathbf{u}$, that is, the conditional
$p(f(\mathbf{x}) | \mathbf{u})$, as a Gaussian process:
&lt;/p&gt;
$$
p(f(\mathbf{x}) | \mathbf{u}) = \mathcal{GP}(m(\mathbf{x}), s(\mathbf{x}, \mathbf{x}')),
$$&lt;p&gt;
with mean and covariance functions,
&lt;/p&gt;
$$
m(\mathbf{x}) = \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{u},
\quad
\text{and}
\quad
s(\mathbf{x}, \mathbf{x}') = k_{\theta}(\mathbf{x}, \mathbf{x}') - \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu} \boldsymbol{\psi}_\mathbf{u}(\mathbf{x}').
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Before moving on, we briefly highlight the important
quantity,
&lt;/p&gt;
$$
\mathbf{Q}_\mathbf{ff} \triangleq \boldsymbol{\Psi}^{\top} \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi},
$$&lt;p&gt;
which is sometimes referred to as the &lt;em&gt;Nyström approximation&lt;/em&gt; of $\mathbf{K}_\mathbf{ff}$.
It can be written as
&lt;/p&gt;
$$
\mathbf{Q}_\mathbf{ff} = \mathbf{K}_\mathbf{fu} \mathbf{K}_\textbf{uu}^{-1} \mathbf{K}_\mathbf{uf}.
$$&lt;h2 id="variational-distribution"&gt;Variational Distribution&lt;/h2&gt;
&lt;p&gt;We specify a joint variational distribution $q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})$
which factorizes as
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}}(\mathbf{f}, \mathbf{u}) \triangleq p(\mathbf{f} | \mathbf{u}) q_{\boldsymbol{\phi}}(\mathbf{u}).
$$&lt;p&gt;
For convenience, let us specify a variational distribution that is also Gaussian,
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}}(\mathbf{u}) \triangleq \mathcal{N}(\mathbf{u} | \mathbf{b}, \mathbf{W}\mathbf{W}^{\top}),
$$&lt;p&gt;
with variational parameters $\boldsymbol{\phi} = \{ \mathbf{W}, \mathbf{b} \}$.
To obtain the corresponding marginal variational distribution over $\mathbf{f}$,
we marginalize out the inducing variables $\mathbf{u}$, leading to
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}}(\mathbf{f}) =
\int q_{\boldsymbol{\phi}}(\mathbf{f}, \mathbf{u}) \, \mathrm{d}\mathbf{u} =
\mathcal{N}(\mathbf{f} | \boldsymbol{\mu}, \mathbf{\Sigma}),
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\boldsymbol{\mu} = \boldsymbol{\Psi}^\top \mathbf{b},
\quad
\text{and}
\quad
\mathbf{\Sigma} = \mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^\top (\mathbf{K}_\mathbf{uu} - \mathbf{W}\mathbf{W}^{\top}) \boldsymbol{\Psi}.
$$
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="gaussian-process-notation"&gt;Gaussian process notation&lt;/h4&gt;
&lt;p&gt;We can express the variational distribution over the function value $f(\mathbf{x})$ at
input $\mathbf{x}$, that is, the marginal $q_{\boldsymbol{\phi}}(f(\mathbf{x}))$,
as a Gaussian process:
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}}(f(\mathbf{x})) = \mathcal{GP}(\mu(\mathbf{x}), \sigma(\mathbf{x}, \mathbf{x}')),
$$&lt;p&gt;
with mean and covariance functions,
&lt;/p&gt;
$$
\begin{aligned}
\mu(\mathbf{x}) &amp;= \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{b}, \\
\sigma(\mathbf{x}, \mathbf{x}') &amp;= \kappa_{\theta}(\mathbf{x}, \mathbf{x}') - \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) (\mathbf{K}_\mathbf{uu} - \mathbf{W}\mathbf{W}^{\top}) \boldsymbol{\psi}_\mathbf{u}(\mathbf{x}').
\end{aligned}
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h3 id="whitened-parameterization"&gt;Whitened parameterization&lt;/h3&gt;
&lt;p&gt;Whitening is a powerful trick for stabilizing the learning of variational
parameters that works by reducing correlations in the variational distribution (Murray &amp;amp; Adams, 2010; Hensman et al, 2015)&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;.
Let $\mathbf{L}$ be the Cholesky factor of $\mathbf{K}_\mathbf{uu}$, i.e. the
lower triangular matrix such that $\mathbf{L} \mathbf{L}^{\top} = \mathbf{K}_\mathbf{uu}$.
Then, the whitened variational parameters are given by
&lt;/p&gt;
$$
\mathbf{W} \triangleq \mathbf{L} \mathbf{W}',
\quad
\text{and}
\quad
\mathbf{b} \triangleq \mathbf{L} \mathbf{b}',
$$&lt;p&gt;
with free parameters $\{ \mathbf{W}', \mathbf{b}' \}$.
This leads to mean and covariance
&lt;/p&gt;
$$
\boldsymbol{\mu} = \boldsymbol{\Lambda}^\top \mathbf{b}',
\quad
\text{and}
\quad
\mathbf{\Sigma} = \mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^\top (\mathbf{I}_M - {\mathbf{W}'} {\mathbf{W}'}^{\top}) \boldsymbol{\Lambda},
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\boldsymbol{\Lambda} \triangleq \mathbf{L}^\top \boldsymbol{\Psi} = \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf}.
$$&lt;p&gt;
Refer to
for derivations.&lt;/p&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-blue-100 dark:bg-blue-900 border-blue-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-blue-600 dark:text-blue-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m16.862 4.487l1.687-1.688a1.875 1.875 0 1 1 2.652 2.652L6.832 19.82a4.5 4.5 0 0 1-1.897 1.13l-2.685.8l.8-2.685a4.5 4.5 0 0 1 1.13-1.897zm0 0L19.5 7.125"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Note&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;h4 id="gaussian-process-notation"&gt;Gaussian process notation&lt;/h4&gt;
&lt;p&gt;The mean and covariance functions are now
&lt;/p&gt;
$$
\begin{aligned}
\mu(\mathbf{x}) &amp;= \boldsymbol{\lambda}^\top(\mathbf{x}) \mathbf{b}', \\
\sigma(\mathbf{x}, \mathbf{x}') &amp;= k_{\theta}(\mathbf{x}, \mathbf{x}') - \boldsymbol{\lambda}^\top(\mathbf{x}) (\mathbf{I}_M - \mathbf{W}' {\mathbf{W}'}^{\top}) \boldsymbol{\lambda}(\mathbf{x}'),
\end{aligned}
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\begin{aligned}
\boldsymbol{\lambda}(\mathbf{x}) &amp;\triangleq \mathbf{L}^{\top} \boldsymbol{\psi}_\mathbf{u}(\mathbf{x}) \\
&amp;= \mathbf{L}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x}).
\end{aligned}
$$&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;For an efficient and numerically stable way to compute and evaluate the
variational distribution $q_{\boldsymbol{\phi}}(\mathbf{f})$ at an arbitrary
set of inputs, see
.&lt;/p&gt;
&lt;h2 id="inference"&gt;Inference&lt;/h2&gt;
&lt;h3 id="preliminaries"&gt;Preliminaries&lt;/h3&gt;
&lt;p&gt;We seek to approximate the exact posterior $p(\mathbf{f},\mathbf{u} \mid \mathbf{y})$
by an variational distribution $q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})$.
To this end, we minimize the Kullback-Leibler (KL) divergence
between $q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})$
and $p(\mathbf{f},\mathbf{u} \mid \mathbf{y})$, which is given by
&lt;/p&gt;
$$
\begin{align*}
\mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) \mid\mid p(\mathbf{f},\mathbf{u} \mid \mathbf{y})] &amp; =
\mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}\left[\log{\frac{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}{p(\mathbf{f},\mathbf{u} \mid \mathbf{y})}}\right] \newline &amp; =
\log{p(\mathbf{y})} + \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}\left[\log{\frac{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}{p(\mathbf{f},\mathbf{u}, \mathbf{y})}}\right] \newline &amp; =
\log{p(\mathbf{y})} - \mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}),
\end{align*}
$$&lt;p&gt;
where we&amp;rsquo;ve defined the &lt;em&gt;evidence lower bound (ELBO)&lt;/em&gt; as
&lt;/p&gt;
$$
\mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}) \triangleq \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}\left[\log{\frac{p(\mathbf{f},\mathbf{u}, \mathbf{y})}{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}}\right].
$$&lt;p&gt;
Notice that minimizing the KL divergence above is equivalent to maximizing the ELBO.
Furthermore, the ELBO is a lower bound on the log marginal likelihood, since
&lt;/p&gt;
$$
\log{p(\mathbf{y})} = \mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}) + \mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) \mid\mid p(\mathbf{f},\mathbf{u} \mid \mathbf{y})],
$$&lt;p&gt;
and the KL divergence is nonnegative.
Therefore, we have $\log{p(\mathbf{y})} \geq \mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z})$
with equality at $\mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) \mid\mid p(\mathbf{f},\mathbf{u} \mid \mathbf{y})] = 0 \Leftrightarrow q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) = p(\mathbf{f},\mathbf{u} \mid \mathbf{y})$.&lt;/p&gt;
&lt;p&gt;Let us now focus our attention on the ELBO, which can be written as
&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}) &amp; = \iint \log{\frac{p(\mathbf{f},\mathbf{u}, \mathbf{y})}{q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u})}} q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) \,\mathrm{d}\mathbf{f} \mathrm{d}\mathbf{u} \newline &amp; =
\iint \log{\frac{p(\mathbf{y} | \mathbf{f}) \bcancel{p(\mathbf{f} | \mathbf{u})} p(\mathbf{u})}{\bcancel{p(\mathbf{f} | \mathbf{u})} q_{\boldsymbol{\phi}}(\mathbf{u})}} q_{\boldsymbol{\phi}}(\mathbf{f},\mathbf{u}) \,\mathrm{d}\mathbf{f} \mathrm{d}\mathbf{u} \newline &amp; =
\int \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u},
\end{align*}
$$&lt;p&gt;
where we have made use of the previous
definition $q_{\boldsymbol{\phi}}(\mathbf{f}, \mathbf{u}) = p(\mathbf{f} | \mathbf{u}) q_{\boldsymbol{\phi}}(\mathbf{u})$
and also introduced the definition
&lt;/p&gt;
$$
\Phi(\mathbf{y}, \mathbf{u}) \triangleq \exp{ \left ( \int \log{p(\mathbf{y} | \mathbf{f})} p(\mathbf{f} | \mathbf{u}) \,\mathrm{d}\mathbf{f} \right ) }.
$$&lt;p&gt;
It is straightforward to verify that the optimal variational distribution, that
is, the distribution $q_{\boldsymbol{\phi}^{\star}}(\mathbf{u})$ at which the
ELBO is maximized, satisfies
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) \propto \Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u}).
$$&lt;p&gt;
Refer to
for details.
Specifically, after normalization, we have
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) = \frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{\mathcal{Z}},
$$&lt;p&gt;
where $\mathcal{Z} \triangleq \int \Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u}) \,\mathrm{d}\mathbf{u}$.
Plugging this back into the ELBO, we get
&lt;/p&gt;
$$
\begin{aligned}
\mathrm{ELBO}(\boldsymbol{\phi}^{\star}, \mathbf{Z})
&amp;= \int \log{\left (\bcancel{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})} \frac{\mathcal{Z}}{\bcancel{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}} \right )} q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\
&amp;= \log{\mathcal{Z}}.
\end{aligned}
$$&lt;h3 id="gaussian-likelihoods--sparse-gaussian-process-regression-sgpr"&gt;Gaussian Likelihoods &amp;ndash; Sparse Gaussian Process Regression (SGPR)&lt;/h3&gt;
&lt;p&gt;Let us assume we have a Gaussian likelihood of the form
&lt;/p&gt;
$$
p(\mathbf{y} | \mathbf{f}) = \mathcal{N}(\mathbf{y} | \mathbf{f}, \beta^{-1} \mathbf{I}).
$$&lt;p&gt;
Then it is straightforward to show that
&lt;/p&gt;
$$
\log{\Phi(\mathbf{y}, \mathbf{u})} =
\log{\mathcal{N}(\mathbf{y} | \mathbf{m}, \beta^{-1} \mathbf{I} )} - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}),
$$&lt;p&gt;
where $\mathbf{m}$ and $\mathbf{S}$ are defined as before, i.e. $\mathbf{m} = \boldsymbol{\Psi}^{\top} \mathbf{u}$ and
$\mathbf{S} = \mathbf{K}_\textbf{ff} - \boldsymbol{\Psi}^{\top} \mathbf{K}_\textbf{uu} \boldsymbol{\Psi}$.
Refer to
for derivations.&lt;/p&gt;
&lt;p&gt;Now, there are a few key objects of interest.
First, the
optimal variational distribution $q_{\boldsymbol{\phi}^{\star}}(\mathbf{u})$,
which is required to compute the predictive distribution $q_{\boldsymbol{\phi}^{\star}}(\mathbf{f}) = \int p(\mathbf{f}|\mathbf{u}) q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) \, \mathrm{d}\mathbf{u}$,
but which may also be of independent interest.
Second, the ELBO, the objective with respect to which the inducing input
locations $\mathbf{Z}$ are optimized.&lt;/p&gt;
&lt;p&gt;The optimal variational distribution is given by
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) =
\mathcal{N}(\mathbf{u} \mid \beta \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uf} \mathbf{y}, \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu}),
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\mathbf{M} \triangleq \mathbf{K}_\mathbf{uu} + \beta \mathbf{K}_\mathbf{uf} \mathbf{K}_\mathbf{fu}.
$$&lt;p&gt;
This can be verified by reducing the product of two exponential-quadratic
functions in $\Phi(\mathbf{y}, \mathbf{u})$ and $p(\mathbf{u})$ into a single
exponential-quadratic function up to a constant factor,
an operation also known as &amp;ldquo;completing the square&amp;rdquo;.
Refer to
for complete derivations.&lt;/p&gt;
&lt;p&gt;This leads to the predictive distribution
&lt;/p&gt;
$$
\begin{aligned}
q_{\boldsymbol{\phi}^{\star}}(\mathbf{f})
&amp;= \mathcal{N}\bigl(\beta \boldsymbol{\Psi}^\top \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi} \mathbf{y}, \\
&amp;\qquad\qquad \mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^\top (\mathbf{K}_\mathbf{uu} - \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu} ) \boldsymbol{\Psi} \bigr) \\
&amp;= \mathcal{N}\bigl(\beta \mathbf{K}_\mathbf{fu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uf} \mathbf{y}, \\
&amp;\qquad\qquad \mathbf{K}_\mathbf{ff} - \mathbf{K}_\mathbf{fu} (\mathbf{K}_\mathbf{uu}^{-1} - \mathbf{M}^{-1}) \mathbf{K}_\mathbf{uf} \bigr).
\end{aligned}
$$&lt;p&gt;The ELBO is given by
&lt;/p&gt;
$$
\mathrm{ELBO}(\boldsymbol{\phi}^{\star}, \mathbf{Z}) =
\log \mathcal{Z} =
\log \mathcal{N}(\mathbf{0}, \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I}) - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}).
$$&lt;p&gt;
This can be verified by applying simple rules for marginalizing Gaussians.
Again, refer to
for complete derivations.
Refer to
for a numerically efficient and
robust method for computing these quantities.&lt;/p&gt;
&lt;h3 id="non-gaussian-likelihoods"&gt;Non-Gaussian Likelihoods&lt;/h3&gt;
&lt;p&gt;Recall from earlier that the ELBO is written as
&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}) &amp; =
\int \log{\left(\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}\right)} q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\\\ &amp; =
\int \left(\log{\Phi(\mathbf{y}, \mathbf{u})} + \log{\frac{p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}}\ \right) q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\\\ &amp; =
\mathrm{ELL}(\boldsymbol{\phi}, \mathbf{Z}) - \mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{u})|p(\mathbf{u})],
\end{align*}
$$&lt;p&gt;
where we define $\mathrm{ELL}(\boldsymbol{\phi}, \mathbf{Z})$, the &lt;em&gt;expected log-likelihood (ELL)&lt;/em&gt;, as
&lt;/p&gt;
$$
\mathrm{ELL}(\boldsymbol{\phi}, \mathbf{Z}) \triangleq \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{u})}\left[\log{\Phi(\mathbf{y}, \mathbf{u})}\right].
$$&lt;p&gt;
This constitutes the first term in the ELBO, and can be written as
&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELL}(\boldsymbol{\phi}, \mathbf{Z}) &amp; =
\int \log{\Phi(\mathbf{y}, \mathbf{u})} q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\\\ &amp; =
\int \left(\int \log{p(\mathbf{y} | \mathbf{f})} p(\mathbf{f} | \mathbf{u}) \,\mathrm{d}\mathbf{f}\right) q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\\\ &amp; =
\int \log{p(\mathbf{y} | \mathbf{f})} \left(\int p(\mathbf{f} | \mathbf{u}) q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \right) \,\mathrm{d}\mathbf{f} \\\\ &amp; =
\int \log{p(\mathbf{y} | \mathbf{f})} q(\mathbf{f}) \,\mathrm{d}\mathbf{f} \\\\ &amp; =
\mathbb{E}_{q(\mathbf{f})}[\log{p(\mathbf{y} | \mathbf{f})}].
\end{align*}
$$&lt;p&gt;
While this integral is analytically intractable in general, we can nonetheless
approximate it efficiently using numerical integration techniques such as
Monte Carlo (MC) estimation or quadrature rules.
In particular, because $q(\mathbf{f})$ is Gaussian, we can utilize simple yet
effective rules such as
.&lt;/p&gt;
&lt;p&gt;Now, the second term in the ELBO is the KL divergence between $q_{\boldsymbol{\phi}}(\mathbf{u})$ and $p(\mathbf{u})$, which are both multivariate Gaussians,
&lt;/p&gt;
$$
\mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{u})|p(\mathbf{u})] =
\mathrm{KL}[\mathcal{N}(\mathbf{b}, \mathbf{W} {\mathbf{W}}^\top) || \mathcal{N}(\mathbf{0}, \mathbf{K}_\mathbf{uu})],
$$&lt;p&gt;
and has a
.
In the case of the whitened parameterization, it can be simplified as
&lt;/p&gt;
$$
\begin{align*}
\mathrm{KL}[q_{\boldsymbol{\phi}}(\mathbf{u})|p(\mathbf{u})] &amp; =
\mathrm{KL}[\mathcal{N}(\mathbf{b}', \mathbf{W}' {\mathbf{W}'}^\top) || \mathcal{N}(\mathbf{0}, \mathbf{K}_\mathbf{uu})] \\\\ &amp; =
\mathrm{KL}[\mathcal{N}(\mathbf{b}, \mathbf{W} {\mathbf{W}}^\top) || \mathcal{N}(\mathbf{0}, \mathbf{I})].
\end{align*}
$$&lt;p&gt;
This comes from the fact that
&lt;/p&gt;
$$
\begin{aligned}
&amp;\mathrm{KL}\left[\mathcal{N}(\mathbf{A} \boldsymbol{\mu}_0, \mathbf{A} \boldsymbol{\Sigma}_0 \mathbf{A}^\top) \,\|\, \mathcal{N}(\mathbf{A} \boldsymbol{\mu}_1, \mathbf{A} \boldsymbol{\Sigma}_1 \mathbf{A}^\top) \right] \\
&amp;\qquad = \mathrm{KL}\left[\mathcal{N}(\boldsymbol{\mu}_0, \boldsymbol{\Sigma}_0) \,\|\, \mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1) \right]
\end{aligned}
$$&lt;p&gt;
where we set $\boldsymbol{\mu}_0 = \mathbf{b}, \boldsymbol{\Sigma}_0 = \mathbf{W} \mathbf{W}^\top, \boldsymbol{\mu}_1 = \mathbf{0}, \boldsymbol{\Sigma}_1 = \mathbf{I}$ and $\mathbf{A} = \mathbf{L}$ where $\mathbf{L}$ is the Cholesky factor of $\mathbf{K}_\mathbf{uu}$, i.e. the lower triangular matrix such that $\mathbf{L}\mathbf{L}^\top = \mathbf{K}_\mathbf{uu}$.&lt;/p&gt;
&lt;h3 id="large-scale-data-with-stochastic-optimization"&gt;Large-Scale Data with Stochastic Optimization&lt;/h3&gt;
&lt;div class="callout flex px-4 py-3 mb-6 rounded-md border-l-4 bg-orange-100 dark:bg-orange-900 border-orange-500"
data-callout="warning"
data-callout-metadata=""&gt;
&lt;span class="callout-icon pr-3 pt-1 text-orange-600 dark:text-orange-300"&gt;
&lt;svg height="24" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0zM12 15.75h.007v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content dark:text-neutral-300"&gt;
&lt;div class="callout-title font-semibold mb-1"&gt;Warning&lt;/div&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;Coming soon.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;!-- \int \left ( \int \log{p(\mathbf{y} \| \mathbf{f})} p(\mathbf{f} \| \mathbf{u}) \\,\mathrm{d}\mathbf{f} + \log{\frac{p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} \right ) q_{\boldsymbol{\phi}}(\mathbf{u}) \\,\mathrm{d}\mathbf{u} \newline &amp; = --&gt;
&lt;!-- \int \left ( \log{\Phi(\mathbf{y}, \mathbf{u})} + \log{\frac{p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} \right ) q_{\boldsymbol{\phi}}(\mathbf{u}) \\,\mathrm{d}\mathbf{u} \newline &amp; = --&gt;
&lt;!-- Therefore,
$$
q(\mathbf{u}) = \mathcal{N}(\mathbf{u} \mid \beta \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uf} \mathbf{y}, \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu})
$$
since $\mathbf{K}_\mathbf{uu} \boldsymbol{\Psi} = \mathbf{K}_\mathbf{uf}$.
$$
\exp \left ( - \frac{1}{2} \left ( \mathbf{u}^\top (\beta \boldsymbol{\Psi}\boldsymbol{\Psi}^\top) \mathbf{u} - 2 \beta \mathbf{y}^\top \boldsymbol{\Psi}^\top \mathbf{u} + \mathbf{u}^\top \mathbf{K}_\mathbf{uu}^{-1} \mathbf{u} \right ) \right )
$$
$$
\exp \left ( - \frac{1}{2} \left ( \mathbf{u}^\top ( \mathbf{K}_\mathbf{uu}^{-1} + \beta \boldsymbol{\Psi}\boldsymbol{\Psi}^\top) \mathbf{u} - 2 \beta (\boldsymbol{\Psi} \mathbf{y})^\top \mathbf{u} \right ) \right )
$$ --&gt;
&lt;h2 id="links-and-further-readings"&gt;Links and Further Readings&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;Papers:
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Forerunners:&lt;/strong&gt; Deterministic Training Conditional (DTC; Csató &amp;amp; Opper, 2002&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;; Seeger, 2003&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;); Fully Independent Training Conditional (FITC; Snelson &amp;amp; Ghahramani, 2005&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;; Quinonero-Candela &amp;amp; Rasmussen, 2005&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;)&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Inter-domain Gaussian processes:&lt;/strong&gt; Lázaro-Gredilla &amp;amp; Figueiras-Vidal, 2009&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Deep Gaussian processes:&lt;/strong&gt; Damianou &amp;amp; Lawrence, 2013&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;, Salimbeni et al, 2017&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Non-Gaussian likelihoods:&lt;/strong&gt; Hensman et al, 2013&lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;; Dezfouli &amp;amp; Bonilla, 2015&lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Unifying inducing-/pseudo-point approximations:&lt;/strong&gt; Bui et al, 2017&lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Orthogonal decompositions:&lt;/strong&gt; Salimbeni et al, 2018&lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;; Shi et al, 2020&lt;sup id="fnref:15"&gt;&lt;a href="#fn:15" class="footnote-ref" role="doc-noteref"&gt;15&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Convergence analysis:&lt;/strong&gt; Burt et al, 2019&lt;sup id="fnref:16"&gt;&lt;a href="#fn:16" class="footnote-ref" role="doc-noteref"&gt;16&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Efficient sampling:&lt;/strong&gt; Wilson et al, 2020&lt;sup id="fnref:17"&gt;&lt;a href="#fn:17" class="footnote-ref" role="doc-noteref"&gt;17&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Technical Reports:
&lt;ul&gt;
&lt;li&gt;
by M. Titsias&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Notes:
&lt;ul&gt;
&lt;li&gt;
by T. Bui and R. Turner&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Blog posts:
&lt;ul&gt;
&lt;li&gt;
by J. Hensman&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-gdscript3" data-lang="gdscript3"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="err"&gt;@&lt;/span&gt;&lt;span class="n"&gt;article&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tiao2020svgp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;title&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;{A} {H}andbook for {S}parse {V}ariational {G}aussian {P}rocesses&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;author&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;Tiao, Louis C&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;journal&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;tiao.io&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;year&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;2020&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;url&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;https://tiao.io/post/sparse-variational-gaussian-processes/&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="appendix"&gt;Appendix&lt;/h2&gt;
&lt;h3 id="i"&gt;I&lt;/h3&gt;
&lt;h4 id="whitened-parameterization-1"&gt;Whitened parameterization&lt;/h4&gt;
&lt;p&gt;Recall the definition $\boldsymbol{\Lambda} \triangleq \mathbf{L}^\top \boldsymbol{\Psi}$.
Then, the mean simplifies to
&lt;/p&gt;
$$
\boldsymbol{\mu} = \boldsymbol{\Psi}^\top \mathbf{b} = \boldsymbol{\Psi}^\top (\mathbf{L} \mathbf{b}') = (\mathbf{L}^\top \boldsymbol{\Psi})^\top \mathbf{b}' = \boldsymbol{\Lambda}^\top \mathbf{b}'.
$$&lt;p&gt;
Similarly, the covariance simplifies to
&lt;/p&gt;
$$
\begin{align*}
\mathbf{\Sigma} &amp; = \mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^{\top} (\mathbf{K}_\mathbf{uu} - \mathbf{W} \mathbf{W}^{\top}) \boldsymbol{\Psi} \newline &amp; =
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^{\top} (\mathbf{L} \mathbf{L}^{\top} - \mathbf{L} ({\mathbf{W}'}{\mathbf{W}'}^{\top}) \mathbf{L}^{\top}) \boldsymbol{\Psi} \newline &amp; =
\mathbf{K}_\mathbf{ff} - (\mathbf{L}^{\top} \boldsymbol{\Psi})^{\top} ( \mathbf{I}_M - {\mathbf{W}'}{\mathbf{W}'}^{\top}) (\mathbf{L}^{\top} \boldsymbol{\Psi}) \newline &amp; =
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^{\top} ( \mathbf{I}_M - {\mathbf{W}'}{\mathbf{W}'}^{\top}) \boldsymbol{\Lambda}.
\end{align*}
$$&lt;h3 id="ii"&gt;II&lt;/h3&gt;
&lt;h4 id="svgp-implementation-details"&gt;SVGP Implementation Details&lt;/h4&gt;
&lt;p&gt;&lt;em&gt;Single input index point&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;Here is an efficient and numerically stable way to compute $q_{\boldsymbol{\phi}}(f(\mathbf{x}))$
for an input $\mathbf{x}$.
We take the following steps:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Cholesky decomposition: $\mathbf{L} \triangleq \mathrm{cholesky}(\mathbf{K}_\textbf{uu})$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^3)$ complexity.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Solve system of linear equations: $\boldsymbol{\lambda}(\mathbf{x}) \triangleq \mathbf{L} \backslash \mathbf{k}_\mathbf{u}(\mathbf{x})$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^2)$ complexity since $\mathbf{L}$ is lower triangular; $\boldsymbol{\beta} = \mathbf{A} \backslash \mathbf{x}$ denotes the vector $\boldsymbol{\beta}$ such that $\mathbf{A} \boldsymbol{\beta} = \mathbf{x} \Leftrightarrow \boldsymbol{\beta} = \mathbf{A}^{-1} \mathbf{x}$.
Hence, $\boldsymbol{\lambda}(\mathbf{x}) = \mathbf{L}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x})$.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$s(\mathbf{x}, \mathbf{x}) \triangleq k_{\theta}(\mathbf{x}, \mathbf{x}) - \boldsymbol{\lambda}^\top(\mathbf{x}) \boldsymbol{\lambda}(\mathbf{x})$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt;
&lt;/p&gt;
$$
\begin{aligned}
\boldsymbol{\lambda}^\top(\mathbf{x}) \boldsymbol{\lambda}(\mathbf{x})
&amp;= \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{L}^{-\top} \mathbf{L}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x}) \\
&amp;= \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x}) \\
&amp;= \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu} \boldsymbol{\psi}_\mathbf{u}(\mathbf{x}).
\end{aligned}
$$&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;For &lt;strong&gt;whitened parameterization&lt;/strong&gt;:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;$\mu \triangleq \boldsymbol{\lambda}^\top(\mathbf{x}) \mathbf{b}'$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{v}^\top(\mathbf{x}) \triangleq \boldsymbol{\lambda}^\top(\mathbf{x}) {\mathbf{W}'}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathbf{v}^\top(\mathbf{x}) \mathbf{v}(\mathbf{x}) = \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{L}^{-\top} ({\mathbf{W}'} {\mathbf{W}'}^{\top}) \mathbf{L}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x})$&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;otherwise:&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Solve system of linear equations: $\boldsymbol{\psi}_\mathbf{u}(\mathbf{x}) \triangleq \mathbf{L}^\top \backslash \boldsymbol{\lambda}(\mathbf{x})$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^2)$ complexity since $\mathbf{L}^{\top}$ is upper triangular. Further,
&lt;/p&gt;
$$
\boldsymbol{\psi}_\mathbf{u}(\mathbf{x}) = \mathbf{L}^{-\top} \boldsymbol{\lambda}(\mathbf{x}) = \mathbf{L}^{-\top} \mathbf{L}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x}) = \mathbf{K}_\mathbf{uu}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x})
$$&lt;p&gt;
and
&lt;/p&gt;
$$
\boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) = \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu}^{-\top} = \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu}^{-1}
$$&lt;p&gt;
since $\mathbf{K}_\mathbf{uu}$ is symmetric and nonsingular.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mu(\mathbf{x}) \triangleq \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{b}$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{v}^\top(\mathbf{x}) \triangleq \boldsymbol{\psi}_\mathbf{u}^\top(\mathbf{x}) \mathbf{W}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathbf{v}^\top(\mathbf{x}) \mathbf{v}(\mathbf{x}) = \mathbf{k}_\mathbf{u}^\top(\mathbf{x}) \mathbf{K}_\mathbf{uu}^{-1} (\mathbf{W} \mathbf{W}^{\top}) \mathbf{K}_\mathbf{uu}^{-1} \mathbf{k}_\mathbf{u}(\mathbf{x})$&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\sigma^2(\mathbf{x}) \triangleq s(\mathbf{x}, \mathbf{x}) + \mathbf{v}^\top(\mathbf{x}) \mathbf{v}(\mathbf{x})$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Return $\mathcal{N}(f(\mathbf{x}) ; \mu(\mathbf{x}), \sigma^2(\mathbf{x}))$&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;em&gt;Multiple input index points&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;It is simple to extend this to compute $q_{\boldsymbol{\phi}}(\mathbf{f})$ for an
arbitary number of index points $\mathbf{X}$:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Cholesky decomposition: $\mathbf{L} = \mathrm{cholesky}(\mathbf{K}_\textbf{uu})$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^3)$ complexity.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Solve system of linear equations: $\boldsymbol{\Lambda} = \mathbf{L} \backslash \mathbf{K}_\mathbf{uf}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^2)$ complexity since $\mathbf{L}$ is lower triangular; $\mathbf{B} = \mathbf{A} \backslash \mathbf{X}$ denotes the matrix $\mathbf{B}$ such that $\mathbf{A} \mathbf{B} = \mathbf{X} \Leftrightarrow \mathbf{B} = \mathbf{A}^{-1} \mathbf{X}$.
Hence, $\boldsymbol{\Lambda} = \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf}$.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{S} \triangleq \mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^{\top} \boldsymbol{\Lambda}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt;
&lt;/p&gt;
$$
\begin{aligned}
\boldsymbol{\Lambda}^{\top} \boldsymbol{\Lambda}
&amp;= \mathbf{K}_\mathbf{fu} \mathbf{L}^{-\top} \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf} \\
&amp;= \mathbf{K}_\mathbf{fu} \mathbf{K}_\textbf{uu}^{-1} \mathbf{K}_\mathbf{uf} \\
&amp;= \mathbf{K}_\mathbf{fu} \mathbf{K}_\textbf{uu}^{-1} (\mathbf{K}_\textbf{uu}) \mathbf{K}_\textbf{uu}^{-1} \mathbf{K}_\mathbf{uf} \\
&amp;= \boldsymbol{\Psi}^\top \mathbf{K}_\textbf{uu} \boldsymbol{\Psi}.
\end{aligned}
$$&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;For &lt;strong&gt;whitened parameterization&lt;/strong&gt;:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;$\boldsymbol{\mu} \triangleq \boldsymbol{\Lambda}^\top \mathbf{b}'$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{V}^\top \triangleq \boldsymbol{\Lambda}^\top {\mathbf{W}'}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathbf{V}^\top \mathbf{V} = \mathbf{K}_\mathbf{fu} \mathbf{L}^{-\top} ({\mathbf{W}'} {\mathbf{W}'}^{\top}) \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf}.$&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;otherwise:&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Solve system of linear equations: $\boldsymbol{\Psi} = \mathbf{L}^{\top} \backslash \boldsymbol{\Lambda}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathcal{O}(M^2)$ complexity since $\mathbf{L}^{\top}$ is upper triangular. Further,&lt;/p&gt;
$$
\boldsymbol{\Psi} = \mathbf{L}^{-\top} \boldsymbol{\Lambda} = \mathbf{L}^{-\top} \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf} = (\mathbf{L}\mathbf{L}^\top)^{-1} \mathbf{K}_\mathbf{uf} = \mathbf{K}_\mathbf{uu}^{-1} \mathbf{K}_\mathbf{uf},
$$&lt;p&gt;
and
&lt;/p&gt;
$$
\boldsymbol{\Psi}^\top = \mathbf{K}_\mathbf{fu} \mathbf{K}_\mathbf{uu}^{-\top} = \mathbf{K}_\mathbf{fu} \mathbf{K}_\mathbf{uu}^{-1},
$$&lt;p&gt;
since $\mathbf{K}_\mathbf{uu}$ is symmetric and nonsingular.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\boldsymbol{\mu} \triangleq \boldsymbol{\Psi}^\top \mathbf{b}$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{V}^\top \triangleq \boldsymbol{\Psi}^\top \mathbf{W}$&lt;/p&gt;
&lt;p&gt;&lt;em&gt;Note:&lt;/em&gt; $\mathbf{V}^\top \mathbf{V} = \mathbf{K}_\mathbf{fu} \mathbf{K}_\mathbf{uu}^{-1} (\mathbf{W} \mathbf{W}^{\top}) \mathbf{K}_\mathbf{uu}^{-1} \mathbf{K}_\mathbf{uf}$.&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;$\mathbf{\Sigma} \triangleq \mathbf{S} + \mathbf{V}^\top \mathbf{V}$&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Return $\mathcal{N}(\mathbf{f} ; \boldsymbol{\mu}, \mathbf{\Sigma})$&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;In TensorFlow, this looks something like:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;variational_predictive&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Knn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Kmm&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Kmn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;whiten&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jitter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;L&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Kmm&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;jitter&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# L L^T = Kmm + jitter I_m&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;triangular_solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Kmn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Lambda = L^{-1} Kmn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;S&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Knn&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adjoint_a&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Knn - Lambda^T Lambda&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Phi = L^{-T} L^{-1} Kmn = Kmm^{-1} Kmn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;whiten&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;triangular_solve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adjoint&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;U&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adjoint_a&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# U = V^T = Phi^T W&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Phi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adjoint_a&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Phi^T b&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;S&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;U&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;U&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adjoint_b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# S + UU^T = S + V^T V&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="iii"&gt;III&lt;/h3&gt;
&lt;h4 id="optimal-variational-distribution-in-general"&gt;Optimal variational distribution (in general)&lt;/h4&gt;
&lt;p&gt;Taking the functional derivative of the ELBO wrt to $q_{\boldsymbol{\phi}}(\mathbf{u})$, we get
&lt;/p&gt;
$$
\begin{align*}
\frac{\partial}{\partial q_{\boldsymbol{\phi}}(\mathbf{u})} \mathrm{ELBO}(\boldsymbol{\phi}, \mathbf{Z}) &amp; =
\frac{\partial}{\partial q_{\boldsymbol{\phi}}(\mathbf{u})} \left ( \int \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} q_{\boldsymbol{\phi}}(\mathbf{u}) \,\mathrm{d}\mathbf{u} \right ) \newline &amp; =
\int \frac{\partial}{\partial q_{\boldsymbol{\phi}}(\mathbf{u})} \left ( \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} q_{\boldsymbol{\phi}}(\mathbf{u}) \right ) \,\mathrm{d}\mathbf{u} \newline &amp; =
\begin{split}
&amp; \int \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} \left ( \frac{\partial}{\partial q_{\boldsymbol{\phi}}(\mathbf{u})} q_{\boldsymbol{\phi}}(\mathbf{u}) \right ) + \newline
&amp; \qquad q_{\boldsymbol{\phi}}(\mathbf{u}) \left ( \frac{\partial}{\partial q_{\boldsymbol{\phi}}(\mathbf{u})} \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} \right ) \,\mathrm{d}\mathbf{u}
\end{split}
\newline &amp; =
\int \log{\frac{\Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u})}{q_{\boldsymbol{\phi}}(\mathbf{u})}} +
q_{\boldsymbol{\phi}}(\mathbf{u}) \left ( -\frac{1}{q_{\boldsymbol{\phi}}(\mathbf{u})} \right ) \,\mathrm{d}\mathbf{u}
\newline &amp; =
\int \log{\Phi(\mathbf{y}, \mathbf{u})} + \log{p(\mathbf{u})} - \log{q_{\boldsymbol{\phi}}(\mathbf{u})} - 1 \,\mathrm{d}\mathbf{u}.
\end{align*}
$$&lt;p&gt;
Setting this expression to zero, we have
&lt;/p&gt;
$$
\begin{align*}
\log{q_{\boldsymbol{\phi}^{\star}}(\mathbf{u})} &amp; = \log{\Phi(\mathbf{y}, \mathbf{u})} + \log{p(\mathbf{u})} - 1 \\\\
\Rightarrow \qquad
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) &amp; \propto \Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u}).
\end{align*}
$$&lt;h3 id="iv"&gt;IV&lt;/h3&gt;
&lt;h4 id="variational-lower-bound-partial-for-gaussian-likelihoods"&gt;Variational lower bound (partial) for Gaussian likelihoods&lt;/h4&gt;
&lt;p&gt;To carry out this derivation, we will need to recall the following two simple
identities. First, we can write the inner product between two vectors as the
trace of their outer product,
&lt;/p&gt;
$$
\mathbf{a}^\top \mathbf{b} = \mathrm{tr}(\mathbf{a} \mathbf{b}^\top).
$$&lt;p&gt;
Second, the relationship between the auto-correlation matrix $\mathbb{E}[\mathbf{a}\mathbf{a}^{\top}]$
and the covariance matrix,
&lt;/p&gt;
$$
\begin{align*}
\mathrm{Cov}[\mathbf{a}] &amp; = \mathbb{E}[\mathbf{a}\mathbf{a}^{\top}] - \mathbb{E}[\mathbf{a}] \, \mathbb{E}[\mathbf{a}]^\top \\\\
\Leftrightarrow \quad
\mathbb{E}[\mathbf{a}\mathbf{a}^{\top}] &amp; = \mathrm{Cov}[\mathbf{a}] + \mathbb{E}[\mathbf{a}] \, \mathbb{E}[\mathbf{a}]^\top
\end{align*}
$$&lt;p&gt;
These allow us to write
&lt;/p&gt;
$$
\begin{align*}
\log{\Phi(\mathbf{y}, \mathbf{u})} &amp; =
\int \log{\mathcal{N}(\mathbf{y} | \mathbf{f}, \beta^{-1} \mathbf{I})} \mathcal{N}(\mathbf{f} | \mathbf{m}, \mathbf{S}) \,\mathrm{d}\mathbf{f}
\newline &amp; = - \frac{1}{2\sigma^2} \int (\mathbf{y} - \mathbf{f})^{\top} (\mathbf{y} - \mathbf{f}) \mathcal{N}(\mathbf{f} | \mathbf{m}, \mathbf{S}) \,\mathrm{d}\mathbf{f}
\newline &amp; \quad - \frac{N}{2}\log{(2\pi\sigma^2)}
\newline &amp; = - \frac{1}{2\sigma^2} \int \mathrm{tr} \left (\mathbf{y}\mathbf{y}^{\top} - 2 \mathbf{y}\mathbf{f}^{\top} + \mathbf{f}\mathbf{f}^{\top} \right) \mathcal{N}(\mathbf{f} | \mathbf{m}, \mathbf{S}) \,\mathrm{d}\mathbf{f}
\newline &amp; \quad - \frac{N}{2}\log{(2\pi\sigma^2)}
\newline &amp; = - \frac{1}{2\sigma^2} \mathrm{tr} \left (\mathbf{y}\mathbf{y}^{\top} - 2 \mathbf{y}\mathbf{m}^{\top} + \mathbf{S} + \mathbf{m} \mathbf{m}^{\top} \right)
\newline &amp; \quad - \frac{N}{2}\log{(2\pi\sigma^2)}
\newline &amp; = - \frac{1}{2\sigma^2} (\mathbf{y} - \mathbf{m})^{\top} (\mathbf{y} - \mathbf{m}) - \frac{N}{2}\log{(2\pi\sigma^2)}
\newline &amp; \quad - \frac{1}{2\sigma^2} \mathrm{tr}(\mathbf{S})
\newline &amp; = \log{\mathcal{N}(\mathbf{y} | \mathbf{m}, \beta^{-1} \mathbf{I} )} - \frac{1}{2\sigma^2} \mathrm{tr}(\mathbf{S}).
\end{align*}
$$&lt;h3 id="v"&gt;V&lt;/h3&gt;
&lt;h4 id="optimal-variational-distribution-for-gaussian-likelihoods"&gt;Optimal variational distribution for Gaussian likelihoods&lt;/h4&gt;
&lt;p&gt;Firstly, the optimal variational distribution can be found in closed-form as
&lt;/p&gt;
$$
\begin{align*}
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) &amp; \propto \Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u}) \\\\
&amp; \propto \mathcal{N}(\mathbf{y} \mid \boldsymbol{\Psi}^\top \mathbf{u}, \beta^{-1} \mathbf{I}) \mathcal{N}(\mathbf{u} \mid \mathbf{0}, \mathbf{K}_\mathbf{uu}) \\\\ &amp; \propto
\exp \left ( - \frac{\beta}{2} (\mathbf{y} - \boldsymbol{\Psi}^\top \mathbf{u})^\top
(\mathbf{y} - \boldsymbol{\Psi}^\top \mathbf{u}) - \frac{1}{2} \mathbf{u}^\top \mathbf{K}_\mathbf{uu}^{-1} \mathbf{u} \right ) \\\\ &amp; \propto
\exp \left ( - \frac{1}{2} \left ( \mathbf{u}^\top \mathbf{C} \mathbf{u} - 2 \beta (\boldsymbol{\Psi} \mathbf{y})^\top \mathbf{u} \right ) \right ),
\end{align*}
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\mathbf{C} \triangleq \mathbf{K}_\mathbf{uu}^{-1} + \beta \boldsymbol{\Psi} \boldsymbol{\Psi}^\top =
\mathbf{K}_\mathbf{uu}^{-1} (\mathbf{K}_\mathbf{uu} + \beta \mathbf{K}_\mathbf{uf} \mathbf{K}_\mathbf{fu} ) \mathbf{K}_\mathbf{uu}^{-1}.
$$&lt;p&gt;
By
, we get
&lt;/p&gt;
$$
\begin{align*}
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) &amp; \propto
\exp \left ( - \frac{1}{2} (\mathbf{u} - \beta \mathbf{C}^{-1} \boldsymbol{\Psi} \mathbf{y})^\top \mathbf{C} (\mathbf{u} - \beta \mathbf{C}^{-1} \boldsymbol{\Psi} \mathbf{y}) \right ) \\\\ &amp; \propto
\mathcal{N}(\mathbf{u} \mid \beta \mathbf{C}^{-1} \boldsymbol{\Psi} \mathbf{y}, \mathbf{C}^{-1}).
\end{align*}
$$&lt;p&gt;
We define
&lt;/p&gt;
$$
\mathbf{M} \triangleq \mathbf{K}_\mathbf{uu} + \beta \mathbf{K}_\mathbf{uf} \mathbf{K}_\mathbf{fu}
$$&lt;p&gt;
so that
&lt;/p&gt;
$$
\mathbf{C} = \mathbf{K}_\mathbf{uu}^{-1} \mathbf{M} \mathbf{K}_\mathbf{uu}^{-1},
$$&lt;p&gt;
which allows us to write
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) =
\mathcal{N}(\mathbf{u} \mid \beta \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uf} \mathbf{y}, \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu}).
$$&lt;h3 id="vi"&gt;VI&lt;/h3&gt;
&lt;h4 id="variational-lower-bound-complete-for-gaussian-likelihoods"&gt;Variational lower bound (complete) for Gaussian likelihoods&lt;/h4&gt;
&lt;p&gt;We have
&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELBO}(\boldsymbol{\phi}^{\star}, \mathbf{Z}) &amp; =
\log \mathcal{Z} \\\\ &amp; =
\log \int \Phi(\mathbf{y}, \mathbf{u}) p(\mathbf{u}) \,\mathrm{d}\mathbf{u} \\\\ &amp; =
\log \biggl[ \exp{\left(-\frac{\beta}{2} \mathrm{tr}(\mathbf{S})\right)}
\newline &amp; \qquad \cdot \int \mathcal{N}(\mathbf{y} | \boldsymbol{\Psi}^{\top} \mathbf{u}, \beta^{-1} \mathbf{I}) p(\mathbf{u}) \,\mathrm{d}\mathbf{u} \biggr] \\\\ &amp; =
\log \int \mathcal{N}(\mathbf{y} \mid \boldsymbol{\Psi}^{\top} \mathbf{u}, \beta^{-1} \mathbf{I}) \mathcal{N}(\mathbf{u} \mid \mathbf{0}, \mathbf{K}_\mathbf{uu}) \,\mathrm{d}\mathbf{u} - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}) \\\\ &amp; =
\log \mathcal{N}(\mathbf{y} \mid \mathbf{0}, \beta^{-1} \mathbf{I} + \boldsymbol{\Psi}^{\top} \mathbf{K}_\textbf{uu} \boldsymbol{\Psi}) - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}) \\\\ &amp; =
\log \mathcal{N}(\mathbf{y} \mid \mathbf{0}, \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I}) - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}).
\end{align*}
$$&lt;h3 id="vii"&gt;VII&lt;/h3&gt;
&lt;h4 id="sgpr-implementation-details"&gt;SGPR Implementation Details&lt;/h4&gt;
&lt;p&gt;Here we provide implementation details that simultaneously minimizes the
computational demands while avoiding numerically unstable calculations.&lt;/p&gt;
&lt;p&gt;The difficulty in calculating the ELBO stem from terms involving
the &lt;em&gt;inverse&lt;/em&gt; and the &lt;em&gt;determinant&lt;/em&gt; of $\mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I}$.
More specifically, we have
&lt;/p&gt;
$$
\begin{split}
\mathrm{ELBO}(\boldsymbol{\phi}^{\star}, \mathbf{Z}) &amp; = - \frac{1}{2} \Bigl( \log \det \left ( \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I} \right ) \\\\
&amp; \qquad + \mathbf{y}^\top \left ( \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I} \right )^{-1} \mathbf{y} + N \log {2\pi} \Bigr) \\\\
&amp; \qquad - \frac{\beta}{2} \mathrm{tr}(\mathbf{S}).
\end{split}
$$&lt;p&gt;
It turns out that many of the required terms can be expressed in terms of the
symmetric positive definite matrix
&lt;/p&gt;
$$
\mathbf{B} \triangleq \mathbf{U} \mathbf{U}^\top + \mathbf{I},
$$&lt;p&gt;
where $\mathbf{U} \triangleq \beta^{\frac{1}{2}} \boldsymbol{\Lambda}$.&lt;/p&gt;
&lt;p&gt;First, let&amp;rsquo;s tackle the inverse term.
Using the Woodbury identity, we can write it as
&lt;/p&gt;
$$
\begin{align*}
\left(\mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I}\right)^{-1}
&amp; = \left(\beta^{-1} \mathbf{I} + \boldsymbol{\Psi}^\top \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi}\right)^{-1} \\\\
&amp; = \beta \mathbf{I} - \beta^2 \boldsymbol{\Psi}^\top \left(\mathbf{K}_\mathbf{uu}^{-1} + \beta \boldsymbol{\Psi} \boldsymbol{\Psi}^\top \right)^{-1} \boldsymbol{\Psi} \\\\
&amp; = \beta \left(\mathbf{I} - \beta \boldsymbol{\Psi}^\top \mathbf{C}^{-1} \boldsymbol{\Psi}\right).
\end{align*}
$$&lt;p&gt;Recall that $\mathbf{C}^{-1} = \mathbf{K}_\mathbf{uu} \mathbf{M}^{-1} \mathbf{K}_\mathbf{uu}$.
We can expand $\mathbf{M}$ as
&lt;/p&gt;
$$
\begin{align*}
\mathbf{M} &amp; \triangleq \mathbf{K}_\mathbf{uu} + \beta \mathbf{K}_\mathbf{uf} \mathbf{K}_\mathbf{fu} \\\\
&amp; = \mathbf{L} \mathbf{L}^\top + \beta \mathbf{L} \mathbf{L}^{-1} \mathbf{K}_\mathbf{uf} \mathbf{K}_\mathbf{fu} \mathbf{L}^{-\top} \mathbf{L}^\top \\\\
&amp; = \mathbf{L} \left( \mathbf{I} + \beta \boldsymbol{\Lambda} \boldsymbol{\Lambda}^\top \right) \mathbf{L}^\top \\\\
&amp; = \mathbf{L} \mathbf{B} \mathbf{L}^{\top},
\end{align*}
$$&lt;p&gt;
so its inverse is simply
&lt;/p&gt;
$$
\mathbf{M}^{-1} = \mathbf{L}^{-\top} \mathbf{B}^{-1} \mathbf{L}^{-1}.
$$&lt;p&gt;
Therefore, we have
&lt;/p&gt;
$$
\begin{align*}
\mathbf{C}^{-1}
&amp; = \mathbf{K}_\mathbf{uu} \mathbf{L}^{-\top} \mathbf{B}^{-1} \mathbf{L}^{-1} \mathbf{K}_\mathbf{uu} \\\\
&amp; = \mathbf{L} \mathbf{B}^{-1} \mathbf{L}^\top \\\\
&amp; = \mathbf{W} \mathbf{W}^\top
\end{align*}
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\mathbf{W} \triangleq \mathbf{L} \mathbf{L}_\mathbf{B}^{-\top}
$$&lt;p&gt;
and $\mathbf{L}_\mathbf{B}$ is the Cholesky factor of $\mathbf{B}$,
i.e. the lower triangular matrix such
that $\mathbf{L}_\mathbf{B}\mathbf{L}_\mathbf{B}^\top = \mathbf{B}$.
All in all, we now have
&lt;/p&gt;
$$
\begin{align*}
\left(\mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I}\right)^{-1}
&amp; = \beta \left(\mathbf{I} - \beta \boldsymbol{\Psi}^\top \mathbf{W} \mathbf{W}^\top \boldsymbol{\Psi}\right),
\end{align*}
$$&lt;p&gt;
so we can compute the quadratic term in $\mathbf{y}$ as
&lt;/p&gt;
$$
\begin{align*}
\mathbf{y}^\top \left ( \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I} \right )^{-1} \mathbf{y}
&amp; = \beta \left( \mathbf{y}^\top \mathbf{y} - \beta \mathbf{y}^\top \boldsymbol{\Psi}^\top \mathbf{W} \mathbf{W}^\top \boldsymbol{\Psi} \mathbf{y} \right) \\\\
&amp; = \beta \mathbf{y}^\top \mathbf{y} - \mathbf{c}^\top \mathbf{c},
\end{align*}
$$&lt;p&gt;
where
&lt;/p&gt;
$$
\mathbf{c} \triangleq \beta \mathbf{W}^\top \boldsymbol{\Psi} \mathbf{y} = \beta \mathbf{L}_\mathbf{B}^{-1} \boldsymbol{\Lambda} \mathbf{y} = \beta^{\frac{1}{2}} \mathbf{L}_\mathbf{B}^{-1} \mathbf{U} \mathbf{y}.
$$&lt;p&gt;Next, let&amp;rsquo;s address the determinant term.
To this end, first note that the determinant of $\mathbf{M}$ is
&lt;/p&gt;
$$
\begin{align*}
\det \left( \mathbf{M} \right) &amp; = \det \left( \mathbf{L} \mathbf{B} \mathbf{L}^{\top} \right) \\\\ &amp; =
\det \left( \mathbf{L} \right) \det \left( \mathbf{B} \right) \det \left( \mathbf{L}^{\top} \right) \\\\ &amp; =
\det \left( \mathbf{K}_\mathbf{uu} \right) \det \left( \mathbf{B} \right).
\end{align*}
$$&lt;p&gt;
Hence, the determinant of $\mathbf{C}$ is
&lt;/p&gt;
$$
\begin{align*}
\det \left( \mathbf{C} \right) &amp; =
\det \left( \mathbf{K}_\mathbf{uu}^{-1} \mathbf{M} \mathbf{K}_\mathbf{uu}^{-1} \right) \\\\ &amp; =
\frac{\det \left( \mathbf{M} \right)}{\det \left( \mathbf{K}_\mathbf{uu} \right )^2} \\\\ &amp; =
\frac{\det \left( \mathbf{B} \right)}{\det \left( \mathbf{K}_\mathbf{uu} \right )}.
\end{align*}
$$&lt;p&gt;
Therefore, by the
, we have
&lt;/p&gt;
$$
\begin{align*}
\det \left( \mathbf{Q}_\mathbf{ff} + \beta^{-1} \mathbf{I} \right) &amp; =
\det \left( \beta^{-1} \mathbf{I} + \boldsymbol{\Psi}^\top \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi} \right) \\\\ &amp; =
\det \left( \mathbf{K}_\mathbf{uu}^{-1} + \beta \boldsymbol{\Psi} \boldsymbol{\Psi}^\top \right)
\det \left( \mathbf{K}_\mathbf{uu} \right)
\det \left( \beta^{-1} \mathbf{I} \right) \\\\ &amp; =
\det \left( \mathbf{C} \right)
\det \left( \mathbf{K}_\mathbf{uu} \right)
\det \left( \beta^{-1} \mathbf{I} \right) \\\\ &amp; =
\det \left( \mathbf{B} \right) \det \left( \beta^{-1} \mathbf{I} \right).
\end{align*}
$$&lt;p&gt;
We can re-use $\mathbf{L}_\mathbf{B}$ to calculate $\det \left( \mathbf{B} \right)$
in linear time.&lt;/p&gt;
&lt;p&gt;The last non-trivial component of the ELBO is the trace term, which can be
calculated as
&lt;/p&gt;
$$
\frac{\beta}{2} \mathrm{tr}(\mathbf{S}) = \frac{\beta}{2} \mathrm{tr}\left(\mathbf{K}_\mathbf{ff}\right) - \frac{1}{2} \mathrm{tr}\left(\mathbf{U} \mathbf{U}^\top \right),
$$&lt;p&gt;
since
&lt;/p&gt;
$$
\begin{align*}
\mathrm{tr}\left(\mathbf{U} \mathbf{U}^\top\right) &amp; =
\mathrm{tr}\left(\mathbf{U}^\top \mathbf{U}\right) \\\\ &amp; =
\beta \cdot \mathrm{tr}\left(\boldsymbol{\Lambda} \boldsymbol{\Lambda}^\top\right) \\\\ &amp; =
\beta \cdot \mathrm{tr}\left( \boldsymbol{\Psi}^{\top} \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi} \right).
\end{align*}
$$&lt;p&gt;
Again, we can re-use $\mathbf{U} \mathbf{U}^\top$ computed earlier.&lt;/p&gt;
&lt;p&gt;Finally, let us address the posterior predictive.
Recall that
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) = \mathcal{N}(\mathbf{u} \mid \beta \mathbf{C}^{-1} \boldsymbol{\Psi} \mathbf{y}, \mathbf{C}^{-1}).
$$&lt;p&gt;
Re-writing this in terms of $\mathbf{W}$, we get
&lt;/p&gt;
$$
\begin{align*}
q_{\boldsymbol{\phi}^{\star}}(\mathbf{u})
&amp; = \mathcal{N}\left(\mathbf{u} \mid \beta \mathbf{W} \mathbf{W}^\top \boldsymbol{\Psi} \mathbf{y}, \mathbf{W} \mathbf{W}^\top \right) \\\\
&amp; = \mathcal{N}\left(\mathbf{u} \mid \beta \mathbf{L} \mathbf{L}_\mathbf{B}^{-\top} \mathbf{W}^\top \boldsymbol{\Psi} \mathbf{y}, \mathbf{L} \mathbf{L}_\mathbf{B}^{-\top} \mathbf{L}_\mathbf{B}^{-1} \mathbf{L}^\top\right) \\\\
&amp; = \mathcal{N}\left(\mathbf{u} \mid \mathbf{L} \left(\mathbf{L}_\mathbf{B}^{-\top} \mathbf{c}\right), \mathbf{L} \mathbf{B}^{-1} \mathbf{L}^\top\right).
\end{align*}
$$&lt;p&gt;
Hence, we see that the optimal variational distribution is itself a
whitened parameterization with $\mathbf{b}' = \mathbf{L}_\mathbf{B}^{-\top} \mathbf{c}$
and $\mathbf{W}' = \mathbf{L}_\mathbf{B}^{-\top}$ (such that ${\mathbf{W}'} {\mathbf{W}'}^\top = \mathbf{B}^{-1}$).
Combined with results from a
,
we can directly write the predictive $q_{\boldsymbol{\phi}^{\star}}(\mathbf{f}) = \int p(\mathbf{f}|\mathbf{u}) q_{\boldsymbol{\phi}^{\star}}(\mathbf{u}) \, \mathrm{d}\mathbf{u}$ as
&lt;/p&gt;
$$
q_{\boldsymbol{\phi}^{\star}}(\mathbf{f}) =
\mathcal{N}\left(\boldsymbol{\Lambda}^\top \mathbf{L}_\mathbf{B}^{-\top} \mathbf{c},
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^\top \left( \mathbf{I} - \mathbf{B}^{-1} \right) \boldsymbol{\Lambda} \right).
$$&lt;p&gt;
Alternatively, we can derive this by noting the following simple identity,
&lt;/p&gt;
$$
\boldsymbol{\Psi}^\top \mathbf{C}^{-1} \boldsymbol{\Psi} = \boldsymbol{\Psi}^\top \mathbf{L} \mathbf{B}^{-1} \mathbf{L}^\top \boldsymbol{\Psi} = \boldsymbol{\Lambda}^\top \mathbf{B}^{-1} \boldsymbol{\Lambda},
$$&lt;p&gt;
and applying the rules for marginalizing Gaussians to obtain
&lt;/p&gt;
$$
\begin{align*}
q_{\boldsymbol{\phi}^{\star}}(\mathbf{f})
&amp; = \mathcal{N}\left(\beta \boldsymbol{\Psi}^\top \mathbf{C}^{-1} \boldsymbol{\Psi} \mathbf{y},
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Psi}^\top \mathbf{K}_\mathbf{uu} \boldsymbol{\Psi} + \boldsymbol{\Psi}^\top \mathbf{C}^{-1} \boldsymbol{\Psi} \right) \\\\
&amp; = \mathcal{N}\left(\beta \boldsymbol{\Lambda}^\top \mathbf{B}^{-1} \boldsymbol{\Lambda} \mathbf{y},
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^\top \boldsymbol{\Lambda} + \boldsymbol{\Lambda}^\top \mathbf{B}^{-1} \boldsymbol{\Lambda} \right) \\\\
&amp; = \mathcal{N}\left(\boldsymbol{\Lambda}^\top \mathbf{L}_\mathbf{B}^{-\top} \mathbf{c},
\mathbf{K}_\mathbf{ff} - \boldsymbol{\Lambda}^\top \left( \mathbf{I} - \mathbf{B}^{-1} \right) \boldsymbol{\Lambda} \right).
\end{align*}
$$&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Titsias, M. (2009, April). Variational Learning of Inducing Variables in Sparse Gaussian Processes. In Artificial Intelligence and Statistics (pp. 567-574).&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Murray, I., &amp;amp; Adams, R. P. (2010). Slice Sampling Covariance Hyperparameters of Latent Gaussian Models. In Advances in Neural Information Processing Systems (pp. 1732-1740).&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Hensman, J., Matthews, A. G., Filippone, M., &amp;amp; Ghahramani, Z. (2015). MCMC for Variationally Sparse Gaussian Processes. In Advances in Neural Information Processing Systems (pp. 1648-1656).&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Csató, L., &amp;amp; Opper, M. (2002). Sparse On-line Gaussian Processes. Neural Computation, 14(3), 641-668.&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Seeger, M. (2003). Bayesian Gaussian Process Models: PAC-Bayesian Generalisation Error Bounds and Sparse Approximations (PhD Thesis). University of Edinburgh.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;Snelson, E., &amp;amp; Ghahramani, Z. (2005). Sparse Gaussian Processes using Pseudo-inputs. Advances in Neural Information Processing Systems, 18, 1257-1264.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;Quinonero-Candela, J., &amp;amp; Rasmussen, C. E. (2005). A Unifying View of Sparse Approximate Gaussian Process Regression. The Journal of Machine Learning Research, 6, 1939-1959.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Lázaro-Gredilla, M., &amp;amp; Figueiras-Vidal, A. R. (2009, December). Inter-domain Gaussian Processes for Sparse Inference using Inducing Features. In Advances in Neural Information Processing Systems.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;Damianou, A., &amp;amp; Lawrence, N. D. (2013, April). Deep Gaussian Processes. In Artificial Intelligence and Statistics (pp. 207-215). PMLR.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;Salimbeni, H., &amp;amp; Deisenroth, M. (2017). Doubly Stochastic Variational Inference for Deep Gaussian Processes. Advances in Neural Information Processing Systems, 30.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;Hensman, J., Fusi, N., &amp;amp; Lawrence, N. D. (2013, August). Gaussian Processes for Big Data. In Proceedings of the Twenty-Ninth Conference on Uncertainty in Artificial Intelligence (pp. 282-290).&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;Dezfouli, A., &amp;amp; Bonilla, E. V. (2015). Scalable Inference for Gaussian Process Models with Black-box Likelihoods. In Advances in Neural Information Processing Systems (pp. 1414-1422).&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Bui, T. D., Yan, J., &amp;amp; Turner, R. E. (2017). A Unifying Framework for Gaussian Process Pseudo-point Approximations using Power Expectation Propagation. The Journal of Machine Learning Research, 18(1), 3649-3720.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;Salimbeni, H., Cheng, C. A., Boots, B., &amp;amp; Deisenroth, M. (2018). Orthogonally Decoupled Variational Gaussian Processes. In Advances in Neural Information Processing Systems (pp. 8711-8720).&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:15"&gt;
&lt;p&gt;Shi, J., Titsias, M., &amp;amp; Mnih, A. (2020, June). Sparse Orthogonal Variational Inference for Gaussian Processes. In International Conference on Artificial Intelligence and Statistics (pp. 1932-1942). PMLR.&amp;#160;&lt;a href="#fnref:15" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:16"&gt;
&lt;p&gt;Burt, D., Rasmussen, C. E., &amp;amp; Van Der Wilk, M. (2019, May). Rates of Convergence for Sparse Variational Gaussian Process Regression. In International Conference on Machine Learning (pp. 862-871). PMLR.&amp;#160;&lt;a href="#fnref:16" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:17"&gt;
&lt;p&gt;Wilson, J., Borovitskiy, V., Terenin, A., Mostowsky, P., &amp;amp; Deisenroth, M. (2020, November). Efficiently Sampling Functions from Gaussian Process Posteriors. In International Conference on Machine Learning (pp. 10292-10302). PMLR.&amp;#160;&lt;a href="#fnref:17" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Tech Talk: Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/events/amazon-ml-tech-talk-2019/</link><pubDate>Sat, 15 Jun 2019 13:00:00 +0000</pubDate><guid>https://tiao.io/events/amazon-ml-tech-talk-2019/</guid><description/></item><item><title>Density Ratio Estimation for KL Divergence Minimization between Implicit Distributions</title><link>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</link><pubDate>Mon, 27 Aug 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/</guid><description>&lt;!-- TODO: Clarify that optimal classifier refers to the classifier that minimizes the Bayes risk --&gt;
&lt;p&gt;The Kullback-Leibler (KL) divergence between distributions $p$ and $q$ is
defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;It can be expressed more succinctly as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] = \mathbb{E}_{p(x)} [ \log r^{*}(x) ],
$$&lt;p&gt;where $r^{*}(x)$ is defined to be the ratio of between the densities $p(x)$ and
$q(x)$,&lt;/p&gt;
$$
r^{*}(x) := \frac{p(x)}{q(x)}.
$$&lt;p&gt;This density ratio is crucial for computing not only the KL divergence but for
all $f$-divergences, defined as&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)] :=
\mathbb{E}_{q(x)} \left [ f \left ( \frac{p(x)}{q(x)} \right ) \right ].
$$&lt;p&gt;Rarely can this expectation (i.e. integral) can be calculated analytically&amp;mdash;in
most cases, we must resort to Monte Carlo approximation methods, which
explicitly requires the density ratio.
In the more severe case where this density ratio is unavailable, because either
or both $p(x)$ and $q(x)$ are not calculable, we must resort to methods for
&lt;em&gt;density ratio estimation&lt;/em&gt;.
In this post, we illustrate how to perform density ratio estimation by
exploiting its tight correspondence to &lt;em&gt;probabilistic classification&lt;/em&gt;.&lt;/p&gt;
&lt;h3 id="example-univariate-gaussians"&gt;Example: Univariate Gaussians&lt;/h3&gt;
&lt;p&gt;Let us consider the following univariate Gaussian distributions as the running
example for this post,&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid 1, 1^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid 0, 2^2).
$$&lt;p&gt;We will be using &lt;em&gt;TensorFlow&lt;/em&gt;, &lt;em&gt;TensorFlow Probability&lt;/em&gt;, and &lt;em&gt;Keras&lt;/em&gt; in the
code snippets throughout this post.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow_probability&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfp&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We first instantiate the distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Their densities are shown below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Univariate Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_densities.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;For any pair of distributions, we can implement their density ratio function $r$
as follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Let&amp;rsquo;s create the density ratio function for the Gaussian distributions we just
instantiated:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This density ratio function is plotted as the orange dotted line below,
alongside the individual densities shown in the previous plot:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Ratio of Gaussian densities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/gaussian_1d_density_ratios.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h2 id="analytical-form"&gt;Analytical Form&lt;/h2&gt;
&lt;p&gt;For our running example, we picked $p(x)$ and $q(x)$ to be Gaussians so that
it is possible to integrate out $x$ and compute the KL divergence &lt;em&gt;analytically&lt;/em&gt;.
When we introduce the approximate methods later, this will provide us a &amp;ldquo;gold
standard&amp;rdquo; to benchmark against.&lt;/p&gt;
&lt;p&gt;In general, for Gaussian distributions&lt;/p&gt;
$$
p(x) = \mathcal{N}(x \mid \mu_p, \sigma_p^2),
\qquad
\text{and}
\qquad
q(x) = \mathcal{N}(x \mid \mu_q, \sigma_q^2),
$$&lt;p&gt;
it is easy to verify that
&lt;/p&gt;
$$
\mathrm{KL}[ p(x) || q(x) ]
= \log \sigma_q - \log \sigma_p - \frac{1}{2}
\left [
1 - \left ( \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{\sigma_q^2} \right )
\right ].
$$&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can use this to compute the KL divergence between $p(x)$ and $q(x)$
&lt;em&gt;exactly&lt;/em&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;_kl_divergence_gaussians&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Equivalently, we could also use &lt;code&gt;kl_divergence&lt;/code&gt; from &lt;em&gt;TensorFlow
Probability&amp;ndash;Distributions&lt;/em&gt; (&lt;code&gt;tfp.distributions&lt;/code&gt;), which implements the
analytical closed-form expression of the KL divergence between distributions
when such exists.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44314718&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="monte-carlo-estimation--prescribed-distributions"&gt;Monte Carlo Estimation &amp;mdash; prescribed distributions&lt;/h2&gt;
&lt;p&gt;For distributions where their KL divergence is not analytically tractable, we
may appeal to Monte Carlo (MC) estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;p&gt;Clearly, this requires the density ratio $r^{*}(x)$ and, in turn, the densities
$p(x)$ and $q(x)$ to be analytically tractable. Distributions for which the
density function can be readily evaluated are sometimes referred to as
&lt;strong&gt;prescribed distributions&lt;/strong&gt;. As before, we &lt;em&gt;prescribed&lt;/em&gt; Gaussians distributions
in our running example so the Monte Carlo estimate can be later compared against.
We approximate their KL divergence using $M = 5000$ Monte Carlo samples as
follows:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;true_log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_density_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.44670376&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Or equivalently, using the &lt;code&gt;expectation&lt;/code&gt; function from &lt;em&gt;TensorFlow
Probability&amp;ndash;Monte Carlo&lt;/em&gt; (&lt;code&gt;tfp.monte_carlo&lt;/code&gt;):&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;true_log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4581419&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;More generally, we can approximate any $f$-divergence with MC estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathbb{E}_{q(x)} [ f(r^{*}(x)) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} f(r^{*}(x_q^{(i)})),
\quad x_q^{(i)} \sim q(x).
\end{align*}
$$&lt;p&gt;This can be done using the &lt;code&gt;monte_carlo_csiszar_f_divergence&lt;/code&gt; function from
&lt;em&gt;TensorFlow Probability&amp;ndash;Variational Inference&lt;/em&gt; (&lt;code&gt;tfp.vi&lt;/code&gt;).
One simply needs to specify the appropriate convex function $f$.
The convex function that instantiates the (forward) KL divergence is provided
in &lt;code&gt;tfp.vi&lt;/code&gt; as &lt;code&gt;kl_forward&lt;/code&gt;, alongside many other common $f$-divergences.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kl_forward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4430853&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="density-ratio-estimation--implicit-distributions"&gt;Density Ratio Estimation &amp;mdash; implicit distributions&lt;/h2&gt;
&lt;p&gt;When either density $p(x)$ or $q(x)$ is unavailable, things become more tricky.
Which brings us to the topic of this post. Suppose we only have samples from
$p(x)$ and $q(x)$&amp;mdash;these could be natural images, outputs from a neural
network with stochastic inputs, or in the case of our running example, i.i.d.
samples drawn from Gaussians, etc.
Distributions for which we are only able to observe their samples are known as
&lt;strong&gt;implicit distributions&lt;/strong&gt;, since their samples &lt;em&gt;imply&lt;/em&gt; some underlying true
density which we may not have direct access to.&lt;/p&gt;
&lt;p&gt;Density ratio estimation is concerned with estimating the ratio of densities
$r^{*}(x) = p(x) / q(x)$ given access only to samples from $p(x)$ and $q(x)$.
Moreover, density ratio estimation usually encompass methods that achieve this
without resorting to direct &lt;em&gt;density estimation&lt;/em&gt; of the individual densities
$p(x)$ or $q(x)$, since any error in the estimation of the denominator $q(x)$
is magnified exponentially.&lt;/p&gt;
&lt;p&gt;Of the many density ratio estimation methods that now
flourish&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, the classical approach of &lt;em&gt;probabilistic
classification&lt;/em&gt; remains dominant, due in no small part to its simplicity.&lt;/p&gt;
&lt;h3 id="reducing-density-ratio-estimation-to-probabilistic-classification"&gt;Reducing Density Ratio Estimation to Probabilistic Classification&lt;/h3&gt;
&lt;p&gt;We now demonstrate that density ratio estimation can be reduced to probabilistic
classification. We shall do this by highlighting the one-to-one correspondence
between the density ratio of $p(x)$ and $q(x)$ and the optimal probabilistic
classifier that discriminates between their samples.
Specifically, suppose we have a collection of samples from both $p(x)$ and $q(x)$,
where each sample is assigned a class label indicating which distribution it was
drawn from. Then, from an estimator of the class-membership probabilities, it is
straightforward to recover an estimator of the density ratio.&lt;/p&gt;
&lt;p&gt;Suppose we have $N_p$ and $N_q$ samples drawn from $p(x)$ and $q(x)$,
respectively,&lt;/p&gt;
$$
x_p^{(1)}, \dotsc, x_p^{(N_p)} \sim p(x),
\qquad \text{and} \qquad
x_q^{(1)}, \dotsc, x_q^{(N_q)} \sim q(x).
$$&lt;p&gt;Then, we form the dataset $\{ (x_n, y_n) \}_{n=1}^N$, where $N = N_p + N_q$
and&lt;/p&gt;
$$
\begin{align*}
(x_1, \dotsc, x_N) &amp; = (x_p^{(1)}, \dotsc, x_p^{(N_p)},
x_q^{(1)}, \dotsc, x_q^{(N_q)}), \newline
(y_1, \dotsc, y_N) &amp; = (\underbrace{1, \dotsc, 1}_{N_p},
\underbrace{0, \dotsc, 0}_{N_q}).
\end{align*}
$$&lt;p&gt;In other words, we label samples drawn from $p(x)$ as 1 and those drawn from
$q(x)$ as 0. In code, this looks like:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sample_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This dataset is visualized below. The blue squares in the top row are samples
$x_p^{(i)} \sim p(x)$ with label 1; red squares in the bottom row are samples
$x_q^{(j)} \sim q(x)$ with label 0.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Classification dataset"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/dataset.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Now, by construction, we have&lt;/p&gt;
$$
p(x) = \mathcal{P}(x \mid y = 1),
\qquad
\text{and}
\qquad
q(x) = \mathcal{P}(x \mid y = 0).
$$&lt;p&gt;Using Bayes&amp;rsquo; rule, we can write&lt;/p&gt;
$$
\mathcal{P}(x \mid y) =
\frac{\mathcal{P}(y \mid x) \mathcal{P}(x)}
{\mathcal{P}(y)}.
$$&lt;p&gt;Hence, we can express the density ratio $r^{*}(x)$ as&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) &amp; = \frac{p(x)}{q(x)}
= \frac{\mathcal{P}(x \mid y = 1)}
{\mathcal{P}(x \mid y = 0)} \newline
&amp; = \left ( \frac{\mathcal{P}(y = 1 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 1)} \right )
\left ( \frac{\mathcal{P}(y = 0 \mid x) \mathcal{P}(x)}
{\mathcal{P}(y = 0)} \right ) ^ {-1} \newline
&amp; = \frac{\mathcal{P}(y = 0)}{\mathcal{P}(y = 1)}
\frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;p&gt;Let us approximate the ratio of marginal densities by the ratio of sample sizes,&lt;/p&gt;
$$
\frac{\mathcal{P}(y = 0)}
{\mathcal{P}(y = 1)}
\approx
\frac{N_q}{N_p + N_q}
\left ( \frac{N_p}{N_p + N_q} \right )^{-1}
= \frac{N_q}{N_p}.
$$&lt;p&gt;To avoid notational clutter, let us assume from now on that $N_q = N_p$.
We can then write $r^{*}(x)$ in terms of class-posterior probabilities,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}.
\end{align*}
$$&lt;h4 id="recovering-the-density-ratio-from-the-class-probability"&gt;Recovering the Density Ratio from the Class Probability&lt;/h4&gt;
&lt;p&gt;This yields a one-to-one correspondence between the density ratio $r^{*}(x)$
and the class-posterior probability $\mathcal{P}(y = 1 \mid x)$.
Namely,&lt;/p&gt;
$$
\begin{align*}
r^{*}(x) = \frac{\mathcal{P}(y = 1 \mid x)}
{\mathcal{P}(y = 0 \mid x)}
&amp; = \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \newline
&amp; = \exp
\left [
\log \frac{\mathcal{P}(y = 1 \mid x)}
{1 - \mathcal{P}(y = 1 \mid x)} \right ] \newline
&amp; = \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ],
\end{align*}
$$&lt;p&gt;where $\sigma^{-1}$ is the &lt;em&gt;logit&lt;/em&gt; function, or inverse sigmoid function, given
by $\sigma^{-1}(\rho) = \log \left ( \frac{\rho}{1-\rho} \right )$&lt;/p&gt;
&lt;h4 id="recovering-the-class-probability-from-the-density-ratio"&gt;Recovering the Class Probability from the Density Ratio&lt;/h4&gt;
&lt;p&gt;By simultaneously manipulating both sides of this equation, we can also recover
the exact class-posterior probability as a function of the density ratio,&lt;/p&gt;
$$
\mathcal{P}(y=1 \mid x) = \sigma(\log r^{*}(x)) = \frac{p(x)}{p(x) + q(x)}.
$$
&lt;p&gt;This is implemented below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;optimal_classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truediv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In the figure below, The class-posterior probability $\mathcal{P}(y=1 \mid x)$
is plotted against the dataset visualized earlier.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Optimal classifier&amp;mdash;class-posterior probabilities"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/optimal_classifier.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="probabilistic-classification-with-logistic-regression"&gt;Probabilistic Classification with Logistic Regression&lt;/h3&gt;
&lt;p&gt;The class-posterior probability $\mathcal{P}(y = 1 \mid x)$ can be approximated
using a parameterized function $D_{\theta}(x)$ with parameters $\theta$. This
functions takes as input samples from $p(x)$ and $q(x)$ and outputs a &lt;em&gt;score&lt;/em&gt;,
or probability, in the range $[0, 1]$ that it was drawn from $p(x)$.
Hence, we refer to $D_{\theta}(x)$ as the probabilistic classifier.&lt;/p&gt;
&lt;p&gt;From before, it is clear to see how an estimator of the density ratio
$r_{\theta}(x)$ might be constructed as a function of probabilistic classifier
$D_{\theta}(x)$. Namely,&lt;/p&gt;
$$
\begin{align*}
r_{\theta}(x) &amp; = \exp[ \sigma^{-1}(D_{\theta}(x)) ] \newline
&amp; \approx \exp[ \sigma^{-1}(\mathcal{P}(y = 1 \mid x)) ] = r^{*}(x),
\end{align*}
$$&lt;p&gt;
and &lt;em&gt;vice versa&lt;/em&gt;,
&lt;/p&gt;
$$
\begin{align*}
D_{\theta}(x) &amp; = \sigma(\log r_{\theta}(x)) \newline
&amp; \approx \sigma(\log r^{*}(x)) = \mathcal{P}(y = 1 \mid x).
\end{align*}
$$&lt;p&gt;Instead of $D_{\theta}(x)$, we usually specify the parameterized function
$\log r_{\theta}(x)$. This is also referred to as the &lt;em&gt;log-odds&lt;/em&gt;, or &lt;em&gt;logits&lt;/em&gt;,
since it is equivalent to the unnormalized output of the classifier before being
fed through the logistic sigmoid function.&lt;/p&gt;
&lt;p&gt;We define a small fully-connected neural network with two hidden layers and ReLU
activations:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This simple architecture is visualized in the diagram below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Architecture"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_ratio_architecture.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We learn the optimal class probability estimator by optimizing it with respect
to a &lt;em&gt;proper scoring rule&lt;/em&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; that yields well-calibrated probabilistic predictions, such as the &lt;em&gt;binary cross-entropy loss&lt;/em&gt;,&lt;/p&gt;
$$
\begin{align*}
\mathcal{L}(\theta) &amp; :=
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
-\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ]
-\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ].
\end{align*}
$$&lt;p&gt;An implementation optimized for numerical stability is given below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sigmoid_cross_entropy_with_logits&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reduce_mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss_p&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;loss_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now we can build a
, where the
&amp;mdash;samples from
$p(x)$ and $q(x)$, respectively.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The model can now be compiled and finalized. Since we&amp;rsquo;re using a custom loss
that take the two sets of log-ratios as input, we specify &lt;code&gt;loss=None&lt;/code&gt; and
define it instead through the &lt;code&gt;add_loss&lt;/code&gt; method.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_q&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_ratio_p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio_q&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As a sanity-check, the loss evaluated on a random batch can be obtained like so:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;1.3765026330947876&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can now fit our estimator, recording the loss at the end of each epoch:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;hist&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;steps_per_epoch&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The following animation shows how the predictions for the probabilistic
classifier, density ratio, log density ratio, evolve after every epoch:&lt;/p&gt;
&lt;p&gt;&lt;video controls autoplay src="https://giant.gfycat.com/FrighteningThunderousFlicker.webm"&gt;&lt;/video&gt;&lt;/p&gt;
&lt;p&gt;It is overlaid on top of their exact, analytical counterparts, which are only
available since we prescribed them to be Gaussian distribution.
For implicit distributions, these won&amp;rsquo;t be accessible at all.&lt;/p&gt;
&lt;p&gt;Below is the final plot of how the binary cross-entropy loss converges:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary Cross-entropy Loss"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the probabilistic classifier $D_{\theta}(x)$ (&lt;em&gt;dotted green&lt;/em&gt;),
plotted against the optimal classifier, which is the class-posterior probability
$\mathcal{P}(y=1 \mid x) = \frac{p(x)}{p(x) + q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Class Probability Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/class_probability_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Below is a plot of the density ratio estimator $r_{\theta}(x)$
(&lt;em&gt;dotted green&lt;/em&gt;), plotted against the exact density ratio function
$r^{*}(x) = \frac{p(x)}{q(x)}$ (&lt;em&gt;solid blue&lt;/em&gt;):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;And finally, the previous plot in logarithmic scale:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Log Density Ratio Estimator"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/log_density_ratio_estimation.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;While it may appear that we are simply performing regression on the latent
function $r^{*}(x)$ (which is not wrong&amp;mdash;we are), it is important to emphasize that
we do this without ever having observed values of $r^{*}(x)$.
Instead, we only ever observed samples from $p(x)$ and $q(x)$
This has profound implications and potential for a great number of applications
that we shall explore later on.&lt;/p&gt;
&lt;h3 id="back-to-monte-carlo-estimation"&gt;Back to Monte Carlo estimation&lt;/h3&gt;
&lt;p&gt;Having an obtained an estimate of the log density ratio, it is now feasible to
perform Monte Carlo estimation:&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)]
&amp; = \mathbb{E}_{p(x)} [ \log r^{*}(x) ] \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r^{*}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x) \newline
&amp; \approx \frac{1}{M} \sum_{i=1}^{M} \log r_{\theta}(x_p^{(i)}),
\quad x_p^{(i)} \sim p(x).
\end{align*}
$$&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;expectation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_samples&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="mf"&gt;0.4570999&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;In other words, we draw MC samples from $p(x)$ as before. But instead of taking
the mean of the function $\log r^{*}(x)$ evaluated on these samples (which is
unavailable for implicit distributions), we do so on a proxy function
$\log r_{\theta}(x)$ that is estimated through probabilistic classification as
described above.&lt;/p&gt;
&lt;h2 id="learning-in-implicit-generative-models"&gt;Learning in Implicit Generative Models&lt;/h2&gt;
&lt;p&gt;Now let&amp;rsquo;s take a look at where these ideas are being used in practice.
Consider a collection of natural images, such as the MNIST handwritten
digits shown below, which are assumed to be samples drawn from some implicit
distribution $q(\mathbf{x})$:&lt;/p&gt;
&lt;figure&gt;&lt;img src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/MnistExamples.png"&gt;&lt;figcaption&gt;
&lt;h4&gt;MNIST hand-written digits&lt;/h4&gt;
&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;Directly estimating the density of $q(\mathbf{x})$ may not always be feasible&amp;mdash;in
some cases, it may not even exist.
Instead, consider defining a parametric function $G_{\phi}: \mathbf{z} \mapsto
\mathbf{x}$ with parameters $\phi$, that takes as input $\mathbf{z}$ drawn from
some fixed distribution $p(\mathbf{z})$.
The outputs $\mathbf{x}$ of this generative process are assumed to be samples
following some implicit distribution $p_{\phi}(\mathbf{x})$. In other words,
we can write&lt;/p&gt;
$$
\mathbf{x} \sim p_{\phi}(\mathbf{x}) \quad
\Leftrightarrow \quad
\mathbf{x} = G_{\phi}(\mathbf{z}),
\quad \mathbf{z} \sim p(\mathbf{z}).
$$&lt;p&gt;By optimizing parameters $\phi$, we can make $p_{\phi}(\mathbf{x})$ close to
the real data distribution $q(\mathbf{x})$. This is a compelling alternative to
density estimation since there are many situations where being able to generate
samples is more important than being able to calculate the numerical value of
the density. Some examples of these include &lt;em&gt;image super-resolution&lt;/em&gt; and
&lt;em&gt;semantic segmentation&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;One approach might be to introduce a classifier $D_{\theta}$ that discriminates
between real and synthetic samples.
Then we optimize $G_{\phi}$ to synthesize samples that are indistinguishable,
to classifier $D_{\theta}$, from the real samples. This can be achieved by
simultaneously optimizing the binary cross-entropy loss, resulting in the
saddle-point objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p_{\phi}(\mathbf{x})} [ \log(1-D_{\theta} (\mathbf{x})) ] \newline =
&amp; \min_{\phi} \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ].
\end{align*}
$$&lt;p&gt;This is, of course, none other than the groundbreaking &lt;em&gt;generative adversarial
network (GAN)&lt;/em&gt;&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.
You can read more about the density ratio estimation perspective of GANs in
the paper by Uehara et al. 2016&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;. For an even more general and complete treatment of learning in implicit models, I recommend the paper
from Mohamed and Lakshminarayanan, 2016&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;, which partially inspired this post.&lt;/p&gt;
&lt;p&gt;For the remainder of this section, I want to highlight a variant of this
approach that specifically aims to minimize the KL divergence w.r.t. parameters
$\phi$,&lt;/p&gt;
$$
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})].
$$&lt;p&gt;To overcome the fact that the densities of both $p_{\phi}(\mathbf{x})$ and
$q(\mathbf{x})$ are unknown, we can readily adopt the density ratio estimation
approach outlined in this post.
Namely, by maximizing the following objective,&lt;/p&gt;
$$
\begin{align*}
&amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log D_{\theta} (\mathbf{x}) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1-D_{\theta} (G_{\phi}(\mathbf{z}))) ] \newline
= &amp; \max_{\theta}
\mathbb{E}_{q(\mathbf{x})} [ \log \sigma ( \log r_{\theta} (\mathbf{x}) ) ] +
\mathbb{E}_{p(\mathbf{z})} [ \log(1 - \sigma ( \log r_{\theta} (G_{\phi}(\mathbf{z})) )) ],
\end{align*}
$$&lt;p&gt;which attains its maximum at&lt;/p&gt;
$$
r_{\theta}(\mathbf{x}) = \frac{q(\mathbf{x})}{p_{\phi}(\mathbf{x})}.
$$&lt;p&gt;Concurrently, we also minimize the current best estimate of the KL divergence,&lt;/p&gt;
$$
\begin{align*}
\min_{\phi} \mathcal{D}_{\mathrm{KL}}[p_{\phi}(\mathbf{x}) || q(\mathbf{x})]
&amp; =
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} \left [ \log \frac{p_{\phi}(\mathbf{x})}{q(\mathbf{x})} \right ] \newline
&amp; \approx
\min_{\phi} \mathbb{E}_{p_{\phi}(\mathbf{x})} [ - \log r_{\theta}(\mathbf{x}) ] \newline
&amp; =
\min_{\phi} \mathbb{E}_{p(\mathbf{z})} [ - \log r_{\theta}(G_{\phi}(\mathbf{z})) ].
\end{align*}
$$&lt;p&gt;In addition to being more stable than the vanilla GAN approach (alleviates
saturating gradients), this is especially important in contexts where there is
a specific need to minimize the KL divergence, such as in &lt;em&gt;variational inference
(VI)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;This was first used in &lt;em&gt;AffGAN&lt;/em&gt; by Sønderby et al. 2016&lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;,
and has since been incorporated in many papers that deal with implicit
distributions in variational inference, such as
(Mescheder et al. 2017&lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;,
Huszar 2017&lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;,
Tran et al. 2017&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;,
Pu et al. 2017&lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;,
Chen et al. 2018&lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;,
Tiao et al. 2018&lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt;), and many others.&lt;/p&gt;
&lt;h2 id="bound-on-the-jensen-shannon-divergence"&gt;Bound on the Jensen-Shannon Divergence&lt;/h2&gt;
&lt;p&gt;Before we wrap things up, let us take another look at the plot of the
binary-cross entropy loss recorded at the end of each epoch.
We see that it converges quickly to some value.
It is natural to wonder: what is the significance, if any, of this value?&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Binary cross-entropy loss converges to Jensen Shannon divergence (up to constants)"
src="https://tiao.io/posts/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/binary_crossentropy_vs_jensen_shannon.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;It is in fact the (negative) Jensen-Shannon (JS) divergence, up to constants,&lt;/p&gt;
$$
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Recall the Jensen-Shannon divergence is defined as&lt;/p&gt;
$$
\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]
= \frac{1}{2} \mathcal{D}_{\mathrm{KL}}[p(x) || m(x)] +
\frac{1}{2} \mathcal{D}_{\mathrm{KL}}[q(x) || m(x)],
$$&lt;p&gt;where $m$ is the mixture density&lt;/p&gt;
$$
m(x) = \frac{p(x) + q(x)}{2}.
$$&lt;p&gt;With our running example, this cannot be evaluated exactly since the KL
divergence between a Gaussian and a mixture of Gaussians is analytically
intractable.
However, like the KL, we can still estimate their JS divergence with Monte
Carlo estimation&lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;js&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;monte_carlo_csiszar_f_divergence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tfp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vi&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jensen_shannon&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;p_log_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This value is shown in the horizontal black line in the plot above. Along the
right margin, we also plot the a histogram of the binary cross-entropy loss
values over epochs. We can see that this value indeed coincides with the mode of
this histogram.&lt;/p&gt;
&lt;p&gt;It is straightforward to show that we have the upper bound&lt;/p&gt;
$$
\inf_{\theta} \mathcal{L}(\theta) \geq - 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4.
$$&lt;p&gt;Firstly, we have&lt;/p&gt;
$$
\begin{align*}
\sup_{\theta} &amp;
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
&amp; =
\mathbb{E}_{p(x)} [ \log \mathcal{P}(y=1 \mid x) ] +
\mathbb{E}_{q(x)} [ \log \mathcal{P}(y=0 \mid x) ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{p(x) + q(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{p(x) + q(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{1}{2} \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{1}{2} \frac{q(x)}{m(x)} \right ] \newline
&amp; =
\mathbb{E}_{p(x)} \left [ \log \frac{p(x)}{m(x)} \right ] +
\mathbb{E}_{q(x)} \left [ \log \frac{q(x)}{m(x)} \right ] - 2 \log 2 \newline
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4.
\end{align*}
$$&lt;p&gt;Therefore,&lt;/p&gt;
$$
2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
$$&lt;p&gt;Negating both sides, we get&lt;/p&gt;
$$
\begin{align*}
-2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] + \log 4
\leq &amp;
-\sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta}
-\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ]
-\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ] \newline
= &amp; \inf_{\theta} \mathcal{L}(\theta),
\end{align*}
$$&lt;p&gt;as required.&lt;/p&gt;
&lt;p&gt;In short, this tells us that the binary cross-entropy loss is &lt;em&gt;itself&lt;/em&gt; an
approximation (up to constants) to the Jensen-Shannon divergence.
This begs the question: is it possible to construct a more general loss that bounds any given $f$-divergence?&lt;/p&gt;
&lt;h2 id="teaser-lower-bound-on-any--divergence"&gt;Teaser: Lower Bound on any $f$-divergence&lt;/h2&gt;
&lt;p&gt;Using convex analysis, one can actually show that for any $f$-divergence, we
have the lower bound&lt;sup id="fnref:15"&gt;&lt;a href="#fn:15" class="footnote-ref" role="doc-noteref"&gt;15&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;
$$
\mathcal{D}_f[p(x) || q(x)]
\geq
\sup_{\theta}
\mathbb{E}_{p(x)} [ f'(r_{\theta}(x)) ] -
\mathbb{E}_{q(x)} [ f^{\star}(f'(r_{\theta}(x))) ],
$$&lt;p&gt;with equality exactly when $r_{\theta}(x) = r^{*}(x)$.
Importantly, this lower bound can be computed without requiring the densities of
$p(x)$ or $q(x)$&amp;mdash;only their samples are needed.&lt;/p&gt;
&lt;p&gt;In the special case of $f(u) = u \log u - (u + 1) \log (u + 1)$, we recover the
binary cross-entropy loss and the previous result, as expected,&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = 2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4 \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log \sigma ( \log r_{\theta} (x) ) ] +
\mathbb{E}_{q(x)} [ \log(1 - \sigma ( \log r_{\theta} (x) )) ] \newline
&amp; = \sup_{\theta}
\mathbb{E}_{p(x)} [ \log D_{\theta} (x) ] +
\mathbb{E}_{q(x)} [ \log(1-D_{\theta} (x)) ].
\end{align*}
$$&lt;p&gt;Alternately, in the special case of $f(u) = u \log u$, we get&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_f[p(x) || q(x)]
&amp; = \mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] \newline
&amp; \geq \sup_{\theta}
\mathbb{E}_{p(x)} [ \log r_{\theta} (x) ] -
\mathbb{E}_{q(x)} [ r_{\theta} (x) - 1 ].
\end{align*}
$$&lt;p&gt;This gives us &lt;em&gt;yet&lt;/em&gt; another way to estimate the KL divergence between
implicit distributions, in the form of a direct lower bound on the KL divergence
itself.
As it turns out, this lower bound is closely-related to the objective of the
&lt;em&gt;KL Importance Estimation Procedure (KLIEP)&lt;/em&gt;&lt;sup id="fnref:16"&gt;&lt;a href="#fn:16" class="footnote-ref" role="doc-noteref"&gt;16&lt;/a&gt;&lt;/sup&gt;, and will be
the topic of our next post in this series.&lt;/p&gt;
&lt;h1 id="summary"&gt;Summary&lt;/h1&gt;
&lt;p&gt;This post covered how to evaluate the KL divergence, or any $f$-divergence,
between implicit distributions&amp;mdash;distributions which we can only sample from.
First, we underscored the crucial role of the density ratio in the estimation of
$f$-divergences.
Next, we showed the correspondence between the density ratio and the optimal
classifier.
By exploiting this link, we demonstrated how one can use a trained probabilistic classifier to construct a proxy for the exact density ratio, and use this to
enable estimation of any $f$-divergence.
Finally, we provided some context on where this method is used, touching upon
some recent advances in implicit generative models and variational inference.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2018dre,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{D}ensity {R}atio {E}stimation for {KL} {D}ivergence {M}inimization between {I}mplicit {D}istributions&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2018&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/density-ratio-estimation-for-kl-divergence-minimization-between-implicit-distributions/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h1 id="acknowledgements"&gt;Acknowledgements&lt;/h1&gt;
&lt;p&gt;I am grateful to
for providing
extensive feedback and insightful discussions. I would also like to thank
Alistair Reid and
for their comments and suggestions.&lt;/p&gt;
&lt;h1 id="links-and-resources"&gt;Links and Resources&lt;/h1&gt;
&lt;ul&gt;
&lt;li&gt;The
used to generate the figures in this post, which you can
.&lt;/li&gt;
&lt;li&gt;The very readable textbook on
&lt;sup id="fnref1:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;, which I highly recommend. (Note: the Gaussian distributions example was borrowed from this book.) &lt;a target="_blank" href="https://www.amazon.com/gp/product/0521190177/ref=as_li_tl?ie=UTF8&amp;camp=1789&amp;creative=9325&amp;creativeASIN=0521190177&amp;linkCode=as2&amp;tag=tiao03-20&amp;linkId=0907c42c1a834ffa68ca2f27c2bdb92f"&gt;&lt;img border="0" src="//ws-na.amazon-adsystem.com/widgets/q?_encoding=UTF8&amp;MarketPlace=US&amp;ASIN=0521190177&amp;ServiceVersion=20070822&amp;ID=AsinImage&amp;WS=1&amp;Format=_SL250_&amp;tag=tiao03-20" &gt;&lt;/a&gt;&lt;img src="//ir-na.amazon-adsystem.com/e/ir?t=tiao03-20&amp;l=am2&amp;o=1&amp;a=0521190177" width="1" height="1" border="0" alt="" style="border:none !important; margin:0px !important;" /&gt;&lt;/li&gt;
&lt;li&gt;Shakir Mohamed&amp;rsquo;s blog post
.&lt;/li&gt;
&lt;li&gt;The paper by Menon and Ong, 2016&lt;sup id="fnref:17"&gt;&lt;a href="#fn:17" class="footnote-ref" role="doc-noteref"&gt;17&lt;/a&gt;&lt;/sup&gt;, which gives a generalized treatment of the theoretical link between density ratio estimation and probabilistic classification.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;The (forward) KL divergence can be recovered with
&lt;/p&gt;
$$
f_{\mathrm{KL}}(u) := u \log u.
$$&lt;p&gt;
This is easy to verify,
&lt;/p&gt;
$$
\begin{align*}
\mathcal{D}_{\mathrm{KL}}[p(x) || q(x)] &amp; :=
\mathbb{E}_{p(x)} \left [ \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ \frac{p(x)}{q(x)} \log \left ( \frac{p(x)}{q(x)} \right ) \right ] \newline
&amp; = \mathbb{E}_{q(x)} \left [ f_{\mathrm{KL}} \left ( \frac{p(x)}{q(x)} \right ) \right ].
\end{align*}
$$&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Sugiyama, M., Suzuki, T., &amp;amp; Kanamori, T. (2012). &lt;em&gt;Density Ratio Estimation in Machine Learning&lt;/em&gt;. Cambridge University Press.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Gneiting, T., &amp;amp; Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. &lt;em&gt;Journal of the American Statistical Association&lt;/em&gt;, 102(477), (pp. 359-378).&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., &amp;hellip; &amp;amp; Bengio, Y. (2014). Generative Adversarial Nets. In Advances in &lt;em&gt;Neural Information Processing Systems&lt;/em&gt; (pp. 2672-2680).&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Uehara, M., Sato, I., Suzuki, M., Nakayama, K., &amp;amp; Matsuo, Y. (2016). Generative Adversarial Nets from a Density Ratio Estimation Perspective. &lt;em&gt;arXiv preprint arXiv:1610.02920&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;Mohamed, S., &amp;amp; Lakshminarayanan, B. (2016). Learning in Implicit Generative Models. &lt;em&gt;arXiv preprint arXiv:1610.03483&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;Sønderby, C. K., Caballero, J., Theis, L., Shi, W., &amp;amp; Huszár, F. (2016). Amortised map inference for image super-resolution. &lt;em&gt;arXiv preprint arXiv:1610.04490&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Mescheder, L., Nowozin, S., &amp;amp; Geiger, A. (2017). Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks. In &lt;em&gt;International Conference on Machine learning (ICML)&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;Huszár, F. (2017). Variational inference using implicit distributions. &lt;em&gt;arXiv preprint arXiv:1702.08235&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;Tran, D., Ranganath, R., &amp;amp; Blei, D. (2017). Hierarchical implicit models and likelihood-free variational inference. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 5523-5533).&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;Pu, Y., Wang, W., Henao, R., Chen, L., Gan, Z., Li, C., &amp;amp; Carin, L. (2017). Adversarial symmetric variational autoencoder. In &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt; (pp. 4330-4339).&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;Chen, L., Dai, S., Pu, Y., Zhou, E., Li, C., Su, Q., &amp;hellip; &amp;amp; Carin, L. (2018, March). Symmetric variational autoencoder and connections to adversarial learning. In &lt;em&gt;International Conference on Artificial Intelligence and Statistics&lt;/em&gt; (pp. 661-669).&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Tiao, L. C., Bonilla, E. V., &amp;amp; Ramos, F. (2018). Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference. &lt;em&gt;arXiv preprint arXiv:1806.01771&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;Note that &lt;code&gt;jensen_shannon&lt;/code&gt; with &lt;code&gt;self_normalized=False&lt;/code&gt; (default), corresponds to $2 \cdot \mathcal{D}_{\mathrm{JS}}[p(x) || q(x)] - \log 4$, while &lt;code&gt;self_normalized=True&lt;/code&gt; corresponds to $\mathcal{D}_{\mathrm{JS}}[p(x) || q(x)]$.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:15"&gt;
&lt;p&gt;Nguyen, X., Wainwright, M. J., &amp;amp; Jordan, M. I. (2010). Estimating divergence functionals and the likelihood ratio by convex risk minimization. &lt;em&gt;IEEE Transactions on Information Theory&lt;/em&gt;, 56(11), 5847-5861.&amp;#160;&lt;a href="#fnref:15" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:16"&gt;
&lt;p&gt;Sugiyama, M., Nakajima, S., Kashima, H., Buenau, P. V., &amp;amp; Kawanabe, M. (2008). Direct importance estimation with model selection and its application to covariate shift adaptation. In Advances in neural information processing systems (pp. 1433-1440).&amp;#160;&lt;a href="#fnref:16" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:17"&gt;
&lt;p&gt;Menon, A., &amp;amp; Ong, C. S. (2016, June). Linking Losses for Density Ratio and Class-Probability Estimation. In &lt;em&gt;International Conference on Machine Learning&lt;/em&gt; (pp. 304-313).&amp;#160;&lt;a href="#fnref:17" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Building Probability Distributions with the TensorFlow Probability Bijector API</title><link>https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/</link><pubDate>Mon, 30 Jul 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/</guid><description>&lt;p&gt;TensorFlow Distributions, now under the broader umbrella of
, is a fantastic TensorFlow library for efficient and
composable manipulation of probability distributions&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Among the many features it has to offer, one of the most powerful in my opinion
is the &lt;code&gt;Bijector&lt;/code&gt; API, which provide the modular building blocks necessary to
construct a broad class of probability distributions.
Instead of describing it any further in the abstract, let&amp;rsquo;s dive right in with
a simple example.&lt;/p&gt;
&lt;h2 id="example-banana-shaped-distribution"&gt;Example: Banana-shaped distribution&lt;/h2&gt;
&lt;p&gt;Consider the &lt;em&gt;banana-shaped distribution&lt;/em&gt;, a commonly-used testbed for adaptive
MCMC methods&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;.
Denote the density of this distribution as $p_{Y}(\mathbf{y})$.
To illustrate, 1k samples randomly drawn from this distribution are shown below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana distribution samples"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_samples.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;The underlying process that generates samples
$\tilde{\mathbf{y}} \sim p_{Y}(\mathbf{y})$ is simple to describe,
and is of the general form,&lt;/p&gt;
$$
\tilde{\mathbf{y}} \sim p_{Y}(\mathbf{y}) \quad
\Leftrightarrow \quad
\tilde{\mathbf{y}} = G(\tilde{\mathbf{x}}),
\quad \tilde{\mathbf{x}} \sim p_{X}(\mathbf{x}).
$$&lt;p&gt;In other words, a sample $\tilde{\mathbf{y}}$ is the output of a transformation
$G$, given a sample $\tilde{\mathbf{x}}$ drawn from some underlying
base distribution $p_{X}(\mathbf{x})$.&lt;/p&gt;
&lt;p&gt;However, it is not as straightforward to compute an analytical expression for
density $p_{Y}(\mathbf{y})$.
In fact, this is only possible if $G$ is a &lt;em&gt;differentiable&lt;/em&gt; and &lt;em&gt;invertible&lt;/em&gt;
transformation (a &lt;em&gt;diffeomorphism&lt;/em&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;), and if there is an analytical
expression for $p_{X}(\mathbf{x})$.&lt;/p&gt;
&lt;p&gt;Transformations that fail to satisfy these conditions (which includes something
as simple as a multi-layer perceptron with non-linear activations) give rise to
&lt;em&gt;implicit distributions&lt;/em&gt;, and will be the subject of many posts to come.
But for now, we will restrict our attention to diffeomorphisms.&lt;/p&gt;
&lt;h3 id="base-distribution"&gt;Base distribution&lt;/h3&gt;
&lt;p&gt;Following on with our example, the base distribution $p_{X}(\mathbf{x})$ is
given by a two-dimensional Gaussian with unit variances and covariance
$\rho = 0.95$:&lt;/p&gt;
$$
p_{X}(\mathbf{x}) = \mathcal{N}(\mathbf{x} | \mathbf{0}, \mathbf{\Sigma}),
\qquad
\mathbf{\Sigma} =
\begin{bmatrix}
1 &amp; 0.95 \newline
0.95 &amp; 1
\end{bmatrix}
$$&lt;p&gt;This can be encapsulated by an instance of
,
which is parameterized by a lower-triangular matrix.
First let&amp;rsquo;s import TensorFlow Distributions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;tensorflow.contrib.distributions&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;tfd&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Then we create the lower-triangular matrix and the instantiate the distribution:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)[::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;Sigma&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([[&lt;/span&gt;&lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.&lt;/span&gt; &lt;span class="p"&gt;]],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MultivariateNormalTriL&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scale_tril&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cholesky&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Sigma&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As with all subclasses of &lt;code&gt;tfd.Distribution&lt;/code&gt;, we can evaluated the probability
density function of this distribution by calling the &lt;code&gt;p_x.prob&lt;/code&gt; method.
Evaluating this on an uniformly-spaced grid yields the equiprobability contour
plot below:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Base density"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_base_density.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="forward-transformation"&gt;Forward Transformation&lt;/h3&gt;
&lt;p&gt;The required transformation $G$ is defined as:&lt;/p&gt;
$$
G(\mathbf{x}) =
\begin{bmatrix}
x_1 \newline
x_2 - x_1^2 - 1 \newline
\end{bmatrix}
$$&lt;p&gt;We implement this in the &lt;code&gt;_forward&lt;/code&gt; function below&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We can now use this to generate samples from $p_{Y}(\mathbf{y})$.
To do this we first sample from the base distribution $p_{X}(\mathbf{x})$ by
calling &lt;code&gt;p_x.sample&lt;/code&gt;. For this illustration, we generate 1k samples, which is
specified through the &lt;code&gt;sample_shape&lt;/code&gt; argument. We then transform these samples
through $G$ by calling &lt;code&gt;_forward&lt;/code&gt; on them.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;x_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The figure below contains scatterplots of the 1k samples &lt;code&gt;x_samples&lt;/code&gt; (left)
and the transformed &lt;code&gt;y_samples&lt;/code&gt; (right):&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana and base samples"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_base_samples.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h3 id="instantiating-a-transformeddistribution-with-a-bijector"&gt;Instantiating a &lt;code&gt;TransformedDistribution&lt;/code&gt; with a &lt;code&gt;Bijector&lt;/code&gt;&lt;/h3&gt;
&lt;p&gt;Having specified the forward transformation and the underlying distribution, we
have now fully described the sample generation process, which is the bare
minimum necessary to define a probability distribution.&lt;/p&gt;
&lt;p&gt;The forward transformation is also the &lt;em&gt;first&lt;/em&gt; of &lt;strong&gt;three&lt;/strong&gt; operations needed to
fully specify a &lt;code&gt;Bijector&lt;/code&gt;, which can be used to instantiate a
&lt;code&gt;TransformedDistribution&lt;/code&gt; that encapsulates the banana-shaped distribution.&lt;/p&gt;
&lt;h4 id="creating-a-bijector"&gt;Creating a &lt;code&gt;Bijector&lt;/code&gt;&lt;/h4&gt;
&lt;p&gt;First, let&amp;rsquo;s subclass &lt;code&gt;Bijector&lt;/code&gt; to define the &lt;code&gt;Banana&lt;/code&gt; bijector and implement
the forward transformation as an instance method:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bijector&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;banana&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Note that we need to specify either &lt;code&gt;forward_min_event_ndims&lt;/code&gt; or
&lt;code&gt;inverse_min_event_ndims&lt;/code&gt;, the number of dimensions the forward or inverse
transformation operate on (which can sometimes differ).
In our example, both the inverse and forward transformation operate on vectors
(rank 1 tensors), so we set &lt;code&gt;inverse_min_event_ndims=1&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;With an instance of the &lt;code&gt;Banana&lt;/code&gt; bijector, we can call the &lt;code&gt;forward&lt;/code&gt; method on
&lt;code&gt;x_samples&lt;/code&gt; to produce &lt;code&gt;y_samples&lt;/code&gt; as before:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_samples&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h4 id="instantiating-a-transformeddistribution"&gt;Instantiating a &lt;code&gt;TransformedDistribution&lt;/code&gt;&lt;/h4&gt;
&lt;p&gt;More importantly, we can now create a &lt;code&gt;TransformedDistribution&lt;/code&gt; with the base
distribution &lt;code&gt;p_x&lt;/code&gt; and an instance of the &lt;code&gt;Banana&lt;/code&gt; bijector:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;p_y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TransformedDistribution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;distribution&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bijector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This now allows us to directly sample from &lt;code&gt;p_y&lt;/code&gt; just as we could with &lt;code&gt;p_x&lt;/code&gt;,
and any other TensorFlow Probability &lt;code&gt;Distribution&lt;/code&gt;:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;y_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p_y&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Neat!&lt;/p&gt;
&lt;h3 id="probability-density-function"&gt;Probability Density Function&lt;/h3&gt;
&lt;p&gt;Although we can now sample from this distribution, we have yet to define the
operations necessary to evaluate its probability density function&amp;mdash;the
remaining &lt;em&gt;two&lt;/em&gt; of &lt;strong&gt;three&lt;/strong&gt; operations needed to fully specify a &lt;code&gt;Bijector&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;Indeed, calling &lt;code&gt;p_y.prob&lt;/code&gt; at this stage would simply raise a
&lt;code&gt;NotImplementedError&lt;/code&gt; exception. So what else do we need to define?&lt;/p&gt;
&lt;p&gt;Recall the probability density of $p_{Y}(\mathbf{y})$ is given by:&lt;/p&gt;
$$
p_{Y}(\mathbf{y}) = p_{X}(G^{-1}(\mathbf{y})) \mathrm{det}
\left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
$$&lt;p&gt;Hence we need to specify the inverse transformation $G^{-1}(\mathbf{y})$ and its
Jacobian determinant
$\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )$.&lt;/p&gt;
&lt;p&gt;For numerical stability, the &lt;code&gt;Bijector&lt;/code&gt; API requires that this be defined in
log-space. Hence, it is useful to recall that the forward and inverse log
determinant Jacobians differ only in their signs&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;,&lt;/p&gt;
$$
\begin{align}
\log \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = - \log \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right ),
\end{align}
$$&lt;p&gt;which gives us the option of implementing either (or both).
However, do note the following from the official
API docs:&lt;/p&gt;
&lt;blockquote class="border-l-4 border-neutral-300 dark:border-neutral-600 pl-4 italic text-neutral-600 dark:text-neutral-400 my-6"&gt;
&lt;p&gt;Generally its preferable to directly implement the inverse Jacobian
determinant. This should have superior numerical stability and will often share
subgraphs with the &lt;code&gt;_inverse&lt;/code&gt; implementation.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="inverse-transformation"&gt;Inverse Transformation&lt;/h3&gt;
&lt;p&gt;So let&amp;rsquo;s implement the inverse transform $G^{-1}$, which is given by:&lt;/p&gt;
$$
G^{-1}(\mathbf{y}) =
\begin{bmatrix}
y_1 \newline
y_2 + y_1^2 + 1 \newline
\end{bmatrix}
$$&lt;p&gt;We define this in the &lt;code&gt;_inverse&lt;/code&gt; function below:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="jacobian-determinant"&gt;Jacobian determinant&lt;/h3&gt;
&lt;p&gt;Now we compute the log determinant of the Jacobian of the &lt;em&gt;inverse&lt;/em&gt;
transformation.
In this simple example, the transformation is &lt;em&gt;volume-preserving&lt;/em&gt;, meaning its
Jacobian determinant is equal to 1.&lt;/p&gt;
&lt;p&gt;This is easy to verify:&lt;/p&gt;
$$
\begin{align}
\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = \mathrm{det}
\begin{pmatrix}
\frac{\partial}{\partial y_1} y_1 &amp; \frac{\partial}{\partial y_2} y_1 \newline
\frac{\partial}{\partial y_1} y_2 + y_1^2 + 1 &amp; \frac{\partial}{\partial y_2} y_2 + y_1^2 + 1 \newline
\end{pmatrix} \newline
&amp; = \mathrm{det}
\begin{pmatrix}
1 &amp; 0 \newline
2 y_1 &amp; 1 \newline
\end{pmatrix}
= 1
\end{align}
$$&lt;p&gt;Hence, the log determinant Jacobian is given by zeros shaped like input &lt;code&gt;y&lt;/code&gt;, up
to the last &lt;code&gt;inverse_min_event_ndims=1&lt;/code&gt; dimensions:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Since the log determinant Jacobian is constant, i.e. independent of the input,
we can just specify it for one input by setting the flag &lt;code&gt;is_constant_jacobian=True&lt;/code&gt;&lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;,
and the &lt;code&gt;Bijector&lt;/code&gt; class will handle the necessary shape inference for us.&lt;/p&gt;
&lt;p&gt;Putting it all together in the &lt;code&gt;Banana&lt;/code&gt; bijector subclass, we have:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bijector&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;banana&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Banana&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;is_constant_jacobian&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;y_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;x_0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_tail&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Finally, we can instantiate distribution &lt;code&gt;p_y&lt;/code&gt; by calling
&lt;code&gt;tfd.TransformedDistribution&lt;/code&gt; as we did before &lt;em&gt;et voilà&lt;/em&gt;,
we can now simply call &lt;code&gt;p_y.prob&lt;/code&gt; to evaluate the probability density function.&lt;/p&gt;
&lt;p&gt;Evaluating this on the same uniformly-spaced grid as before yields the following
equiprobability contour plot:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Banana density"
src="https://tiao.io/posts/building-probability-distributions-with-tensorflow-probability-bijector-api/banana_density.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h4 id="inline-bijector"&gt;Inline Bijector&lt;/h4&gt;
&lt;p&gt;Before we conclude, we note that instead of creating a subclass, one can also
opt for a more lightweight and functional approach by creating an
bijector:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;banana&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bijectors&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Inline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;forward_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_forward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_inverse&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_log_det_jacobian_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;_inverse_log_det_jacobian&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inverse_min_event_ndims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;is_constant_jacobian&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TransformedDistribution&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;distribution&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bijector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;banana&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;!-- ### Swiss roll distribution
$$
\begin{align}
y_1 &amp; = r \cos x_1 \newline
y_2 &amp; = r \sin x_1
\end{align}
$$
where
$$
r = a x_1 + b x_2
$$
for $a = \frac{2}{5}$ and $b = 1$
for $x_1$ in range 5 to 10 and $x_2 = 0$
### Pinwheel distribution --&gt;
&lt;h1 id="summary"&gt;Summary&lt;/h1&gt;
&lt;p&gt;In this post, we showed that using diffeomorphisms&amp;mdash;mappings that are
differentiable and invertible, it is possible transform standard distributions
into interesting and complicated distributions, while still being able to
compute their densities analytically.&lt;/p&gt;
&lt;p&gt;The &lt;code&gt;Bijector&lt;/code&gt; API provides an interface that encapsulates the basic properties
of a diffeomorphism needed to transform a distribution. These are: the
forward transform itself, its inverse and the determinant of their Jacobians.&lt;/p&gt;
&lt;p&gt;Using this, &lt;code&gt;TransformedDistribution&lt;/code&gt; &lt;em&gt;automatically&lt;/em&gt; implements perhaps the two
most important methods of a probability distribution: sampling (&lt;code&gt;sample&lt;/code&gt;), and
density evaluation (&lt;code&gt;prob&lt;/code&gt;).&lt;/p&gt;
&lt;p&gt;Needless to say, this is a very powerful combination.
Through the &lt;code&gt;Bijector&lt;/code&gt; API, the number of possible distributions that can be
implemented and used directly with other functionalities in the TensorFlow
Probability ecosystem effectively becomes &lt;em&gt;endless&lt;/em&gt;.&lt;/p&gt;
&lt;!-- And I haven't even mentioned the fact that you can easily *parameterize* and
*compose* `Bijector`s to implement *normalizing flows* such as the
*autoregressive flows*!
--&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;@article{tiao2018bijector,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; title = &amp;#34;{B}uilding {P}robability {D}istributions with the {T}ensor{F}low {P}robability {B}ijector {API}&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; author = &amp;#34;Tiao, Louis C&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; journal = &amp;#34;tiao.io&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; year = &amp;#34;2018&amp;#34;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; url = &amp;#34;https://tiao.io/post/building-probability-distributions-with-tensorflow-probability-bijector-api/&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="links--resources"&gt;Links &amp;amp; Resources&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;Try this out yourself in a
.&lt;/li&gt;
&lt;li&gt;Paper: see footnote&lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
&lt;li&gt;Blog Post:
&lt;/li&gt;
&lt;li&gt;API Documentation:
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;Dillon, J.V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., Patton, B., Alemi, A., Hoffman, M. and Saurous, R.A., 2017. &lt;em&gt;TensorFlow Distributions.&lt;/em&gt;
.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;Haario, H., Saksman, E., &amp;amp; Tamminen, J. (1999).
. &lt;em&gt;Computational Statistics&lt;/em&gt;, 14(3), 375-396.&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;for the transformation to be a diffeomorphism, it also needs to be &lt;em&gt;smooth&lt;/em&gt;.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;we implement this for the general case of $K \geq 2$ dimensional inputs since this actually turns out to be easier and cleaner (a phenomenon known as
).&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;this is a straightforward consequence of the
which says the matrix inverse of the Jacobian of $G$ is the Jacobian of
its inverse $G^{-1}$,
&lt;/p&gt;
$$
\frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) =
\left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1}
$$&lt;p&gt;
Taking the determinant of both sides, we get:
&lt;/p&gt;
$$
\begin{align}
\mathrm{det} \left ( \frac{\partial}{\partial\mathbf{y}} G^{-1}(\mathbf{y}) \right )
&amp; = \mathrm{det} \left ( \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1} \right ) \newline
&amp; = \mathrm{det} \left ( \frac{\partial}{\partial\mathbf{x}} G(\mathbf{x}) \right )^{-1}
\end{align}
$$&lt;p&gt;
as required.&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;See description of
argument for further details.&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>Contributed Talk: Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/events/icml2018-tagdm/</link><pubDate>Sat, 14 Jul 2018 15:20:00 +0000</pubDate><guid>https://tiao.io/events/icml2018-tagdm/</guid><description/></item><item><title>Cycle-Consistent Adversarial Learning as Approximate Bayesian Inference</title><link>https://tiao.io/publications/cycle-bayes/</link><pubDate>Sun, 01 Jul 2018 00:00:00 +0000</pubDate><guid>https://tiao.io/publications/cycle-bayes/</guid><description/></item><item><title>A Tutorial on Variational Autoencoders with a Concise Keras Implementation</title><link>https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/</link><pubDate>Wed, 20 Apr 2016 00:00:00 +0000</pubDate><guid>https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/</guid><description>&lt;p&gt;
is awesome. It is a very well-designed library that clearly abides by
its
of modularity and extensibility, enabling us to
easily assemble powerful, complex models from primitive building blocks.
This has been demonstrated in numerous blog posts and tutorials, in particular,
the excellent tutorial on
.
As the name suggests, that tutorial provides examples of how to implement
various kinds of autoencoders in Keras, including the variational autoencoder
(VAE)&lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Like all autoencoders, the variational autoencoder is primarily used for
unsupervised learning of hidden representations.
However, they are fundamentally different to your usual neural network-based
autoencoder in that they approach the problem from a probabilistic perspective.
They specify a joint distribution over the observed and latent variables, and
approximate the intractable posterior conditional density over latent
variables with variational inference, using an &lt;em&gt;inference network&lt;/em&gt;
&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; (or more classically, a &lt;em&gt;recognition model&lt;/em&gt;
&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;) to amortize the cost of inference.&lt;/p&gt;
&lt;p&gt;While the examples in the aforementioned tutorial do well to showcase the
versatility of Keras on a wide range of autoencoder model architectures,
doesn&amp;rsquo;t properly take
advantage of Keras&amp;rsquo; modular design, making it difficult to generalize and
extend in important ways. As we will see, it relies on implementing custom
layers and constructs that are restricted to a specific instance of
variational autoencoders. This is a shame because when combined, Keras&amp;rsquo;
building blocks are powerful enough to encapsulate most variants of the
variational autoencoder and more generally, recognition-generative model
combinations for which the generative model belongs to a large family of
&lt;em&gt;deep latent Gaussian models&lt;/em&gt; (DLGMs)&lt;sup id="fnref:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;The goal of this post is to propose a clean and elegant alternative
implementation that takes better advantage of Keras&amp;rsquo; modular design.
It is not intended as tutorial on variational autoencoders &lt;sup id="fnref:6"&gt;&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref"&gt;6&lt;/a&gt;&lt;/sup&gt;.
Rather, we study variational autoencoders as a special case of variational
inference in deep latent Gaussian models using inference networks, and
demonstrate how we can use Keras to implement them in a modular fashion such
that they can be easily adapted to approximate inference in tasks beyond
unsupervised learning, and with complicated (non-Gaussian) likelihoods.&lt;/p&gt;
&lt;p&gt;This first post will lay the groundwork for a series of future posts that
explore ways to extend this basic modular framework to implement the
cutting-edge methods proposed in the latest research, such as the normalizing
flows for building richer posterior approximations &lt;sup id="fnref:7"&gt;&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref"&gt;7&lt;/a&gt;&lt;/sup&gt;, importance
weighted autoencoders &lt;sup id="fnref:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;, the Gumbel-softmax trick for inference in
discrete latent variables &lt;sup id="fnref:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;, and even the most recent GAN-based
density-ratio estimation techniques for likelihood-free inference
&lt;sup id="fnref:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;h1 id="model-specification"&gt;Model specification&lt;/h1&gt;
&lt;p&gt;First, it is important to understand that the variational autoencoder
.
Rather, the generative model is a component of the variational autoencoder and
is, in general, a deep latent Gaussian model.
In particular, let $\mathbf{x}$ be a local observed variable and
$\mathbf{z}$ its corresponding local latent variable, with joint
distribution&lt;/p&gt;
$$
p_{\theta}(\mathbf{x}, \mathbf{z})
= p_{\theta}(\mathbf{x} | \mathbf{z}) p(\mathbf{z}).
$$&lt;p&gt;In Bayesian modelling, we assume the distribution of observed variables to be
governed by the latent variables. Latent variables are drawn from a prior
density $p(\mathbf{z})$ and related to the observations through the
likelihood $p_{\theta}(\mathbf{x} | \mathbf{z})$.
Deep latent Gaussian models (DLGMs) are a general class of models where the
observed variable is governed by a &lt;em&gt;hierarchy&lt;/em&gt; of latent variables, and the
latent variables at each level of the hierarchy are Gaussian &lt;em&gt;a priori&lt;/em&gt;
&lt;sup id="fnref1:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;In a typical instance of the variational autoencoder, we have only a single
layer of latent variables with a Normal prior distribution,&lt;/p&gt;
$$
p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I}).
$$&lt;p&gt;Now, each local latent variable is related to its corresponding observation
through the likelihood $p_{\theta}(\mathbf{x} | \mathbf{z})$, which can
be viewed as a &lt;em&gt;probabilistic&lt;/em&gt; decoder. Given a hidden lower-dimensional
representation (or &amp;ldquo;code&amp;rdquo;) $\mathbf{z}$, it &amp;ldquo;decodes&amp;rdquo; it into a
&lt;em&gt;distribution&lt;/em&gt; over the observation $\mathbf{x}$.&lt;/p&gt;
&lt;h2 id="decoder"&gt;Decoder&lt;/h2&gt;
&lt;p&gt;In this example, we define $p_{\theta}(\mathbf{x} | \mathbf{z})$ to be a
multivariate Bernoulli whose probabilities are computed from $\mathbf{z}$ using
a fully-connected neural network with a single hidden layer,&lt;/p&gt;
$$
\begin{align*}
p_{\theta}(\mathbf{x} | \mathbf{z})
&amp; = \mathrm{Bern}( \sigma( \mathbf{W}_2 \mathbf{h} + \mathbf{b}_2 ) ), \newline
\mathbf{h}
&amp; = h(\mathbf{W}_1 \mathbf{z} + \mathbf{b}_1),
\end{align*}
$$&lt;p&gt;where $\sigma$ is the logistic sigmoid function, $h$ is some non-linearity, and
the model parameters
$\theta = \{ \mathbf{W}_1, \mathbf{W}_2, \mathbf{b}_1, \mathbf{b}_2 \}$
consist of the weights and biases of this neural network.&lt;/p&gt;
&lt;p&gt;It is straightforward to implement this in Keras with the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;You can view a summary of the model parameters $\theta$ by calling
&lt;code&gt;decoder.summary()&lt;/code&gt;. Additionally, you can produce a high-level diagram of
the network architecture, and optionally the input and output shapes of each
layer using
from the
&lt;code&gt;keras.utils.vis_utils&lt;/code&gt; module. Although our architecture is about as
simple as it gets, it is included in the figure below as an example of what
the diagrams look like.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Decoder architecture"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/decoder.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Note that by fixing $\mathbf{W}_1$, $\mathbf{b}_1$ and $h$ to be the identity
matrix, the zero vector, and the identity function, respectively (or
equivalently dropping the first &lt;code&gt;Dense&lt;/code&gt; layer in the snippet above
altogether), we recover &lt;em&gt;logistic factor analysis&lt;/em&gt;.
With similarly minor modifications, we can recover other members from the
family of DLGMs, which include &lt;em&gt;non-linear factor analysis&lt;/em&gt;,
&lt;em&gt;non-linear Gaussian belief networks&lt;/em&gt;, &lt;em&gt;sigmoid belief networks&lt;/em&gt;, and many
others &lt;sup id="fnref2:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Having specified how the probabilities are computed, we can now define the
negative log likelihood of a Bernoulli $- \log p_{\theta}(\mathbf{x}|\mathbf{z})$, which is in fact equivalent to the
:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keras.losses.binary_crossentropy gives the mean&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# over the last axis. we require the sum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;As we discuss later, this will not be the loss we ultimately minimize, but will
constitute the data-fitting term of our final loss.&lt;/p&gt;
&lt;p&gt;Note this is a valid definition of a
,
which is required to compile and optimize a model. It is a symbolic function
that returns a scalar for each data-point in &lt;code&gt;y_true&lt;/code&gt; and &lt;code&gt;y_pred&lt;/code&gt;.
In our example, &lt;code&gt;y_pred&lt;/code&gt; will be the output of our &lt;code&gt;decoder&lt;/code&gt; network, which
are the predicted probabilities, and &lt;code&gt;y_true&lt;/code&gt; will be the true probabilities.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-using-tensorflow-distributions-in-loss"&gt;Side note: Using TensorFlow Distributions in loss&lt;/h4&gt;
&lt;p&gt;If you are using the TensorFlow backend, you can directly use the (negative)
log probability of &lt;code&gt;Bernoulli&lt;/code&gt; from TensorFlow Distributions as a Keras
loss, as I demonstrate in my post on
.&lt;/p&gt;
&lt;p&gt;Specifically we can define the loss as,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Bernoulli&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lh&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This is exactly equivalent to the previous definition, but does not call
&lt;code&gt;K.binary_crossentropy&lt;/code&gt; directly.&lt;/p&gt;
&lt;hr&gt;
&lt;h1 id="inference"&gt;Inference&lt;/h1&gt;
&lt;p&gt;Having specified the generative process, we would now like to perform inference
on the latent variables and model parameters $\mathbf{z}$ and $\theta$,
respectively.
In particular, our goal is to compute the posterior
$p_{\theta}(\mathbf{z} | \mathbf{x})$, the conditional density of the latent
variable $\mathbf{z}$ given observed variable $\mathbf{x}$.
Additionally, we wish to optimize the model parameters $\theta$ with respect to
the marginal likelihood $p_{\theta}(\mathbf{x})$.
Both depend on the marginal likelihood, whose calculation requires marginalizing
out the latent variables $\mathbf{z}$. In general, this is computational
intractable, requiring exponential time to compute, or it is analytically
intractable and cannot be evaluated in closed-form. In our case, we suffer from
the latter intractability, since our prior is Gaussian non-conjugate to the
Bernoulli likelihood.&lt;/p&gt;
&lt;p&gt;To circumvent this intractability we turn to &lt;em&gt;variational inference&lt;/em&gt;, which
formulates inference as an optimization problem. It seeks an approximate
posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$ closest in Kullback-Leibler
(KL) divergence to the true posterior. More precisely, the approximate posterior
is parameterized by &lt;em&gt;variational parameters&lt;/em&gt; $\phi$, and we seek a setting
of these parameters that minimizes the aforementioned KL divergence,&lt;/p&gt;
$$
\phi^* = \mathrm{argmin}_{\phi}
\mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z} | \mathbf{x}) ]
$$&lt;p&gt;With the luck we&amp;rsquo;ve had so far, it shouldn&amp;rsquo;t come as a surprise anymore that
&lt;em&gt;this too&lt;/em&gt; is intractable. It also depends on the log marginal likelihood,
whose intractability is the reason we appealed to approximate inference in the
first place. Instead, we &lt;em&gt;maximize&lt;/em&gt; an alternative objective function, the
&lt;em&gt;evidence lower bound&lt;/em&gt; (ELBO), which is expressed as&lt;/p&gt;
$$
\begin{align*}
\mathrm{ELBO}(q)
&amp; =
\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [
\log p_{\theta}(\mathbf{x} | \mathbf{z}) +
\log p(\mathbf{z}) -
\log q_{\phi}(\mathbf{z} | \mathbf{x})
] \newline
&amp; =
\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}
[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) ]
-\mathrm{KL} [ q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ].
\end{align*}
$$&lt;p&gt;Importantly, the ELBO is a lower bound to the log marginal likelihood.
Therefore, maximizing it with respect to the model parameters $\theta$
approximately maximizes the log marginal likelihood.
Additionally, maximizing it with respect to variational parameters $\phi$ can
be shown to minimize
$\mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z} | \mathbf{x}) ]$.
Also, it turns out that the KL divergence determines the tightness of the lower
bound, where we have equality iff the KL divergence is zero, which happens iff
$q_{\phi}(\mathbf{z} | \mathbf{x}) = p_{\theta}(\mathbf{z} | \mathbf{x})$.
Hence, simultaneously maximizing it with respect to $\theta$ and $\phi$ gets us
two birds with one stone.&lt;/p&gt;
&lt;p&gt;Next we discuss the form of the approximate posterior
$q_{\phi}(\mathbf{z} | \mathbf{x})$, which can be viewed as a
&lt;em&gt;probabilistic&lt;/em&gt; encoder. Its role is opposite to that of the decoder.
Given an observation $\mathbf{x}$, it &amp;ldquo;encodes&amp;rdquo; it into a &lt;em&gt;distribution&lt;/em&gt;
over its hidden lower-dimensional representations.&lt;/p&gt;
&lt;h2 id="encoder"&gt;Encoder&lt;/h2&gt;
&lt;p&gt;For each local observed variable $\mathbf{x}_n$, we wish to approximate
the true posterior distribution $p(\mathbf{z}_n|\mathbf{x}_n)$ over its
corresponding local latent variables $\mathbf{z}_n$. A common approach is to
approximate it using a &lt;em&gt;variational distribution&lt;/em&gt;
$q_{\lambda_n}(\mathbf{z}_n)$, specified as a diagonal
Gaussian, where the &lt;em&gt;local&lt;/em&gt; variational parameters
$\lambda_n = \{ \boldsymbol{\mu}_n, \boldsymbol{\sigma}_n \}$ are the mean and
standard deviation of this approximating distribution,
&lt;/p&gt;
$$
q_{\lambda_n}(\mathbf{z}_n) =
\mathcal{N}(
\mathbf{z}_n |
\boldsymbol{\mu}_n,
\mathrm{diag}(\boldsymbol{\sigma}_n^2)
).
$$&lt;p&gt;
This approach has a number of shortcomings. First, the number of local
variational parameters we need to optimize grows with the size of the dataset.
Second, a new set of local variational parameters need to be optimized for new
unseen test points. This is not to mention the strong factorization assumption
we make by specifying diagonal Gaussian distributions as the family of
approximations. The last is still an active area of research, and the first
two can be addressed by introducing a further approximation using an inference
network.&lt;/p&gt;
&lt;h3 id="inference-network"&gt;Inference network&lt;/h3&gt;
&lt;h1 id="q_phimathbfz_n--mathbfx_n"&gt;We &lt;em&gt;amortize&lt;/em&gt; the cost of inference by introducing an &lt;em&gt;inference network&lt;/em&gt; which
approximates the local variational parameters $\lambda_n$ for a given local
observed variable $\textbf{x}_n$.
For our approximating distribution in particular, given $\textbf{x}_n$ the
inference network yields two vector-valued outputs $\boldsymbol{\mu}_{\phi}(\textbf{x}_n)$ and
$\boldsymbol{\sigma}_{\phi}(\textbf{x}_n)$, which we use to approximate its local
variational parameters $\boldsymbol{\mu}_n$ and $\boldsymbol{\sigma}_n$, respectively.
Our approximate posterior distribution now becomes
$$
q_{\phi}(\mathbf{z}_n | \mathbf{x}_n)&lt;/h1&gt;
&lt;p&gt;\mathcal{N}(\mathbf{z}&lt;em&gt;n
| \boldsymbol{\mu}&lt;/em&gt;{\phi}(\mathbf{x}&lt;em&gt;n),
\mathrm{diag}(\boldsymbol{\sigma}&lt;/em&gt;{\phi}^2(\mathbf{x}_n))
).
$$
Instead of learning &lt;em&gt;local&lt;/em&gt; variational parameters $\lambda_n$ for each data-point,
we now learn a fixed number of &lt;em&gt;global&lt;/em&gt; variational parameters $\phi$ which
constitute the parameters (i.e. weights) of the inference network.
Moreover, this approximation allows statistical strength to be shared across
observed data-points and also generalize to unseen test points.&lt;/p&gt;
&lt;p&gt;We specify the mean $\boldsymbol{\mu}_{\phi}(\mathbf{x})$ and log variance
$\log \boldsymbol{\sigma}_{\phi}^2(\mathbf{x})$ of this distribution as the output of
an inference network. For this post, we keep the architecture of the network
simple, with only a single hidden layer and two fully-connected output layers.
Again, this is simple to define in Keras:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# input layer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# hidden layer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# output layer for mean and log variance&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Since this network has multiple outputs, we couldn&amp;rsquo;t use the Sequential model
API as we did for the decoder. Instead, we will resort to the more powerful
,
which allows us to implement complex models with shared layers, multiple
inputs, multiple outputs, and so on.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Inference network"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/inference_network.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Note that we output the log variance instead of the standard deviation because
this is not only more convenient to work with, but also helps with numerical
stability. However, we still require the standard deviation later. To recover
it, we simply implement the appropriate transformation and encapsulate it in a
.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# normalize log variance to std dev&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Before moving on, we give a few words on nomenclature and context.
In the prelude and title of this section, we characterized the approximate
posterior distribution with an inference network as a probabilistic encoder
(analogously to its counterpart, the probabilistic decoder).
Although this is an accurate interpretation, it is a limited one.
Classically, inference networks are known as &lt;em&gt;recognition models&lt;/em&gt;, and have now
been used for decades in a wide variety of probabilistic methods.
When composed end-to-end, the recognition-generative model combination can be
seen as having an autoencoder structure. Indeed, this structure contains the
variational autoencoder as a special case, and also the now less fashionable
&lt;em&gt;Helmholtz machine&lt;/em&gt; &lt;sup id="fnref1:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;.
Even more generally, this recognition-generative model combination constitutes
a widely-applicable approach currently known as &lt;em&gt;amortized variational inference&lt;/em&gt;,
which can be used to perform approximate inference in models that lie beyond
even the large class of deep latent Gaussian models.&lt;/p&gt;
&lt;p&gt;Having specified all the ingredients necessary to carry out variational
inference (namely, the prior, likelihood and approximate posterior), we next
focus on finalizing the definition of the (negative) ELBO as our loss function
in Keras. As written earlier, the ELBO can be decomposed into two terms,
$\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [ \log p_{\theta}(\mathbf{x} | \mathbf{z}) ]$
the expected log likelihood (ELL) over $q_{\phi}(\mathbf{z} | \mathbf{x})$,
and $- \mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ]$
the negative KL divergence between prior $p(\mathbf{z})$ and approximate
posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$. We first turn our attention
to the KL divergence term.&lt;/p&gt;
&lt;h3 id="kl-divergence"&gt;KL Divergence&lt;/h3&gt;
&lt;p&gt;Intuitively, maximizing the negative KL divergence term encourages approximate
posterior densities that place its mass on configurations of the latent
variables which are closest to the prior. Effectively, this regularizes the
complexity of latent space. Now, since both the prior $p(\mathbf{z})$ and
approximate posterior $q_{\phi}(\mathbf{z} | \mathbf{x})$ are Gaussian,
the KL divergence can actually be calculated with the closed-form expression,&lt;/p&gt;
$$
\mathrm{KL} [ q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}) ]
= - \frac{1}{2} \sum_{k=1}^K \{ 1 + \log \sigma_k^2 - \mu_k^2 - \sigma_k^2 \}
$$&lt;p&gt;where $\mu_k$ and $\sigma_k$ are the $k$-th components of output vectors
$\mu_{\phi}(\mathbf{x})$ and $\sigma_{\phi}(\mathbf{x})$, respectively.
This is not too difficult to derive, and I would recommend verifying this as an
exercise. You can also find a derivation in the appendix of Kingma and Welling&amp;rsquo;s
(2014) paper &lt;sup id="fnref1:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;p&gt;Recall that earlier, we defined the expected log likelihood term of the ELBO as
a Keras loss. We were able to do this since the log likelihood is a function of
the network&amp;rsquo;s final output (the predicted probabilities), so it maps nicely to a
Keras loss. Unfortunately, the same does not apply for the KL divergence term,
which is a function of the network&amp;rsquo;s intermediate layer outputs, the mean &lt;code&gt;mu&lt;/code&gt;
and log variance &lt;code&gt;log_var&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;We define an auxiliary
which takes &lt;code&gt;mu&lt;/code&gt; and &lt;code&gt;log_var&lt;/code&gt; as input and simply returns them as output
without modification. We do however explicitly introduce the
of
calculating the KL divergence and adding it to a collection of losses, by
calling the method &lt;code&gt;add_loss&lt;/code&gt; &lt;sup id="fnref:12"&gt;&lt;a href="#fn:12" class="footnote-ref" role="doc-noteref"&gt;12&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Identity transform layer that adds KL divergence
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; to the final model loss.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_placeholder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kl_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_var&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kl_batch&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Next we feed &lt;code&gt;z_mu&lt;/code&gt; and &lt;code&gt;z_log_var&lt;/code&gt; through this layer (this needs to take
place before feeding &lt;code&gt;z_log_var&lt;/code&gt; through the Lambda layer to recover &lt;code&gt;z_sigma&lt;/code&gt;).&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Now when the Keras model is finally compiled, the collection of losses will be
aggregated and added to the specified Keras loss function to form the loss we
ultimately minimize. If we specify the loss as the negative log-likelihood we
defined earlier (&lt;code&gt;nll&lt;/code&gt;), we recover the negative ELBO as the final loss we
minimize, as intended.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-alternative-divergences"&gt;Side note: Alternative divergences&lt;/h4&gt;
&lt;p&gt;A key benefit of encapsulating the divergence in an auxiliary layer is that we
can easily implement and swap in other divergences, such as the
$\chi$-divergence or the $\alpha$-divergence.
Using alternative divergences for variational inference is an active research
topic &lt;sup id="fnref:13"&gt;&lt;a href="#fn:13" class="footnote-ref" role="doc-noteref"&gt;13&lt;/a&gt;&lt;/sup&gt; &lt;sup id="fnref:14"&gt;&lt;a href="#fn:14" class="footnote-ref" role="doc-noteref"&gt;14&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-implicit-models-and-adversarial-learning"&gt;Side note: Implicit models and adversarial learning&lt;/h4&gt;
&lt;p&gt;Additionally, we could also extend the divergence layer to use an auxiliary
density ratio estimator function, instead of evaluating the KL divergence in
the analytical form above.
This relaxes the requirement on approximate posterior
$q_{\phi}(\mathbf{z}|\mathbf{x})$ (and incidentally also prior $p(\mathbf{z})$)
to yield tractable densities, at the cost of maximizing a cruder estimate of the
ELBO.
This is known as Adversarial Variational Bayes&lt;sup id="fnref1:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;, and is an
important line of recent research that, when taken to its logcal conclusion,
can extend the applicability of variational inference to arbitrarily expressive
implicit probabilistic models with intractable likelihoods&lt;sup id="fnref1:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;h3 id="reparameterization-using-merge-layers"&gt;Reparameterization using Merge Layers&lt;/h3&gt;
&lt;p&gt;To perform gradient-based optimization of ELBO with respect to model parameters
$\theta$ and variational parameters $\phi$, we require its gradients with
respect to these parameters, which is generally intractable.
Currently, the dominant approach for circumventing this is by Monte Carlo (MC)
estimation of the gradients. The basic idea is to write the gradient of the
ELBO as an expectation of the gradient, approximate it with MC estimates, then
perform stochastic gradient descent with the repeated MC gradient estimates.&lt;/p&gt;
&lt;p&gt;There exist a number of estimators based on different variance reduction
techniques. However, MC gradient estimates based on the reparameterization trick,
known as the &lt;em&gt;reparameterization gradients&lt;/em&gt;, have be shown to have the lowest
variance among competing estimators for continuous latent variables&lt;sup id="fnref3:5"&gt;&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref"&gt;5&lt;/a&gt;&lt;/sup&gt;.
The reparameterization trick is a straightforward change of variables that
expresses the random variable $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$
as a deterministic transformation $g_{\phi}$ of another random variable
$\boldsymbol{\epsilon}$ and input $\mathbf{x}$, with parameters $\phi$,&lt;/p&gt;
$$
z = g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon}), \quad
\boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon}).
$$&lt;p&gt;Note that $p(\boldsymbol{\epsilon})$ is simpler base distribution which is
parameter-free and independent of $\mathbf{x}$ or $\phi$.
To prevent clutter, we write the ELBO as an expectation of the function
$f(\mathbf{x}, \mathbf{z}) = \log p_{\theta}(\mathbf{x} , \mathbf{z}) -
\log q_{\phi}(\mathbf{z} | \mathbf{x})$ over distribution
$q_{\phi}(\mathbf{z} | \mathbf{x})$.
Now, for any function $f(\mathbf{x}, \mathbf{z})$, taking the gradient of the
expectation with respect to $\phi$, and substituting all occurrences of
$\mathbf{z}$ with $g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})$, we have&lt;/p&gt;
$$
\begin{align*}
\nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}
[ f(\mathbf{x}, \mathbf{z}) ]
&amp; = \nabla_{\phi} \mathbb{E}_{p(\boldsymbol{\epsilon})}
[ f(\mathbf{x}, g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})) ] \newline
&amp; = \mathbb{E}_{p(\mathbf{\epsilon})}
[ \nabla_{\phi} f(\mathbf{x}, g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})) ].
\end{align*}
$$&lt;p&gt;In other words, this simple reparameterization allows the gradient and the
expectation to commute, thereby allowing us to compute unbiased stochastic
estimates of the ELBO gradients by drawing noise samples $\boldsymbol{\epsilon}$
from $p(\boldsymbol{\epsilon})$.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;To recover the diagonal Gaussian approximation we specified earlier
$q_{\phi}(\mathbf{z}_n | \mathbf{x}_n) = \mathcal{N}(\mathbf{z}_n |
\boldsymbol{\mu}_{\phi}(\mathbf{x}_n), \mathrm{diag}(\boldsymbol{\sigma}_{\phi}^2(\mathbf{x}_n)))$,
we draw noise from the Normal base distribution, and specify a simple
location-scale transformation&lt;/p&gt;
$$
\mathbf{z}
= g_{\phi}(\mathbf{x}, \boldsymbol{\epsilon})
= \mu_{\phi}(\mathbf{x}) +
\sigma_{\phi}(\mathbf{x}) \odot
\boldsymbol{\epsilon}, \quad
\boldsymbol{\epsilon}
\sim \mathcal{N}(\mathbf{0}, \mathbf{I}),
$$&lt;p&gt;where $\mu_{\phi}(\mathbf{x})$ and $\sigma_{\phi}(\mathbf{x})$ are the outputs
of the inference network defined earlier with parameters $\phi$, and $\odot$
denotes the elementwise product. In Keras, we explicitly make the noise vector
an input to the model by defining an Input layer for it. We then implement the
above location-scale transformation using
, namely &lt;code&gt;Add&lt;/code&gt; and &lt;code&gt;Multiply&lt;/code&gt;.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Reparameterization with simple location-scale transformation using Keras merge layers.
"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/reparameterization.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-note-monte-carlo-sample-size"&gt;Side note: Monte Carlo sample size&lt;/h4&gt;
&lt;p&gt;Note both the inputs for observed variables and noise (&lt;code&gt;x&lt;/code&gt; and &lt;code&gt;eps&lt;/code&gt;) need to be
specified explicitly as inputs to our final model.
Furthermore, the size of their first dimension (i.e. batch size) are required
to be the same.
This corresponds to using a exactly one Monte Carlo sample to approximate the
expected log likelihood, drawing a single sample $\mathbf{z}_n$ from
$q_{\phi}(\mathbf{z}_n | \mathbf{x}_n)$ for each data-point $\mathbf{x}_n$ in
the batch. Although you might find an MC sample size of 1 surprisingly small,
it is actually adequate for a sufficiently large batch size (~100) &lt;sup id="fnref2:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;.
In a
,
I demonstrate how to extend our approach to support larger MC sample sizes using
just a few minor tweaks. This extension is crucial for implementing the
&lt;em&gt;importance weighted autoencoder&lt;/em&gt; &lt;sup id="fnref1:8"&gt;&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref"&gt;8&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Now, since the noise input is drawn from the Normal distribution, we can save
from having to feed in values for this input from outside the computation graph
by binding a tensor to this Input layer. Specifically, we bind a tensor created
using &lt;code&gt;K.random_normal&lt;/code&gt; with the required shape,&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;While &lt;code&gt;eps&lt;/code&gt; still needs to be explicitly specified as an input to compile the
model, values for this input will no longer be expected by methods such as
&lt;code&gt;fit&lt;/code&gt;, &lt;code&gt;predict&lt;/code&gt;. Instead, samples from this distribution will be lazily
generated inside the computation graph when required. See my notes on
for more
details.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Encoder architecture."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/encoder.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;In the
, all of this logic is encapsulated in a single
&lt;code&gt;Lambda&lt;/code&gt; layer, which simultaneously draws samples from a hard-coded base
distribution and also performs the location-scale transformation.
In contrast, this approach achieves a good level of
and
.
By decoupling the random noise vector from the layer&amp;rsquo;s internal logic and
explicitly making it a model input, we emphasize the fact that all sources of
stochasticity emanate from this input. It thereby becomes clear that a random
sample drawn from a particular approximating distribution is obtained by feeding
this source of stochasticity through a number of successive deterministic
transformations.&lt;/p&gt;
&lt;hr&gt;
&lt;h4 id="side-notes-gumbel-softmax-trick-for-discrete-latent-variables"&gt;Side notes: Gumbel-softmax trick for discrete latent variables&lt;/h4&gt;
&lt;p&gt;As an example, we could provide samples drawn from the Uniform distribution
as noise input. By applying a number of deterministic transformations that
constitute the &lt;em&gt;Gumbel-softmax reparameterization trick&lt;/em&gt; &lt;sup id="fnref1:9"&gt;&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref"&gt;9&lt;/a&gt;&lt;/sup&gt;, we
are able to obtain samples from the Categorical distribution. This allows us
to perform approximate inference on &lt;em&gt;discrete&lt;/em&gt; latent variables, and can be
implemented in this framework by adding a dozen or so lines of code!&lt;/p&gt;
&lt;h1 id="putting-it-all-together"&gt;Putting it all together&lt;/h1&gt;
&lt;p&gt;So far, we&amp;rsquo;ve dissected the variational autoencoder into modular components and
discussed the role and implementation of each one at some length.
Now let&amp;rsquo;s compose these components together end-to-end to form the final
autoencoder architecture.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;It&amp;rsquo;s surprisingly concise, taking up around 20 lines of code.
The diagram of the full model architecture is visualized below.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Variational autoencoder architecture."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/vae_full.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Finally, we specify and compile the model, using the negative log likelihood
&lt;code&gt;nll&lt;/code&gt; defined earlier as the loss.&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h1 id="model-fitting"&gt;Model fitting&lt;/h1&gt;
&lt;h2 id="dataset-mnist-digits"&gt;Dataset: MNIST digits&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_data&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="Variational autoencoder architecture for the MNIST digits dataset."
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/vae_full_shapes.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;validation_data&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="loss-nelbo-convergence"&gt;Loss (NELBO) convergence&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt=""
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/nelbo.svg"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h1 id="model-evaluation"&gt;Model evaluation&lt;/h1&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D plot of the digit classes in the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;viridis&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt=""
srcset="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_8bb4eb676623e380.webp 320w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_64a98c5233df2a9d.webp 480w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_f97f8af14c434d9d.webp 600w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_latent_space_hu_8bb4eb676623e380.webp"
width="600"
height="500"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D manifold of the digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt; &lt;span class="c1"&gt;# figure with 15x15 digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;digit_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# linearly spaced coordinates on the unit square were transformed&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# through the inverse CDF (ppf) of the Gaussian to produce values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# of the latent variables z, since the prior of the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# is Gaussian&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; \
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_pred_grid&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt=""
srcset="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e23f379b58eda1c7.webp 320w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e0d7ef0dff27fb2e.webp 480w, https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_8a0316a94df89cca.webp 500w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://tiao.io/posts/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/result_manifold_hu_e23f379b58eda1c7.webp"
width="500"
height="500"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h1 id="recap"&gt;Recap&lt;/h1&gt;
&lt;p&gt;In this post, we covered the basics of amortized variational inference, looking
at variational autoencoders as a specific example. In particular, we&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Implemented the decoder and encoder using the
and
respectively.&lt;/li&gt;
&lt;li&gt;Augmented the final loss with the KL divergence term by writing an auxiliary
.&lt;/li&gt;
&lt;li&gt;Worked with the log variance for numerical stability, and used a
to transform it to the
standard deviation when necessary.&lt;/li&gt;
&lt;li&gt;Explicitly made the noise an Input layer, and implemented the
reparameterization trick using
.&lt;/li&gt;
&lt;li&gt;
,
so random samples are generated &lt;em&gt;within&lt;/em&gt; the computation graph.&lt;/li&gt;
&lt;/ul&gt;
&lt;h1 id="whats-next"&gt;What&amp;rsquo;s next&lt;/h1&gt;
&lt;p&gt;Next, we will extend the divergence layer to use an auxiliary density ratio
estimator function, instead of evaluating the KL divergence in the analytical
form above.
This relaxes the requirement on approximate posterior
$q_{\phi}(\mathbf{z}|\mathbf{x})$ (and incidentally also prior $p(\mathbf{z})$)
to yield tractable densities, at the cost of maximizing a cruder estimate of the
ELBO.
This is known as Adversarial Variational Bayes&lt;sup id="fnref2:10"&gt;&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref"&gt;10&lt;/a&gt;&lt;/sup&gt;, and is an
important line of recent research that, when taken to its logcal conclusion,
can extend the applicability of variational inference to arbitrarily expressive
implicit probabilistic models with intractable likelihoods&lt;sup id="fnref2:11"&gt;&lt;a href="#fn:11" class="footnote-ref" role="doc-noteref"&gt;11&lt;/a&gt;&lt;/sup&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;Cite as:&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-gdscript3" data-lang="gdscript3"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="err"&gt;@&lt;/span&gt;&lt;span class="n"&gt;article&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tiao2017vae&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;title&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;{A} {T}utorial on {V}ariational {A}utoencoders with a {C}oncise {K}eras {I}mplementation&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;author&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;Tiao, Louis C&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;journal&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;tiao.io&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;year&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;2017&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;url&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;https://tiao.io/post/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;To receive updates on more posts like this, follow me on
and
!&lt;/p&gt;
&lt;h2 id="links--resources"&gt;Links &amp;amp; Resources&lt;/h2&gt;
&lt;p&gt;Below, you can find:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;The
used to generate the diagrams and plots in this post.&lt;/li&gt;
&lt;li&gt;The above snippets combined in a single executable Python file:&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;scipy.stats&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;backend&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.layers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.models&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;keras.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;784&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;256&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;epsilon_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Negative log likelihood (Bernoulli). &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keras.losses.binary_crossentropy gives the mean&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# over the last axis. we require the sum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;binary_crossentropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Layer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34; Identity transform layer that adds KL divergence
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; to the final model loss.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_placeholder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;kl_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;log_var&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_var&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kl_batch&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;sigmoid&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;intermediate_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;relu&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;KLDivergenceLayer&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Lambda&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;.5&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;))(&lt;/span&gt;&lt;span class="n"&gt;z_log_var&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Input&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stddev&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epsilon_std&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;latent_dim&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Multiply&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Add&lt;/span&gt;&lt;span class="p"&gt;()([&lt;/span&gt;&lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_eps&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x_pred&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;rmsprop&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# train the VAE on MNIST digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mnist&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_data&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;original_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mf"&gt;255.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vae&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;validation_data&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;z_mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D plot of the digit classes in the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;z_test&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;viridis&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;colorbar&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# display a 2D manifold of the digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt; &lt;span class="c1"&gt;# figure with 15x15 digits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;digit_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# linearly spaced coordinates on the unit square were transformed&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# through the inverse CDF (ppf) of the Gaussian to produce values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# of the latent variables z, since the prior of the latent space&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# is Gaussian&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;u_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;meshgrid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ppf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u_grid&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_decoded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;z_grid&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;x_decoded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x_decoded&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;digit_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x_decoded&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cmap&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;D. P. Kingma and M. Welling, &amp;ldquo;Auto-Encoding Variational Bayes,&amp;rdquo; in Proceedings of the 2nd International Conference on Learning Representations (ICLR), 2014.&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;
&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;Section &amp;ldquo;Recognition models and amortised inference&amp;rdquo; in
&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;Dayan, P., Hinton, G. E., Neal, R. M., &amp;amp; Zemel, R. S. (1995). The Helmholtz machine. Neural Computation, 7(5), 889–904.
&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:5"&gt;
&lt;p&gt;Rezende, D. J., Mohamed, S., &amp;amp; Wierstra, D. (2014). &amp;ldquo;Stochastic backpropagation and approximate inference in deep generative models,&amp;rdquo; in Proceedings of The 31st International Conference on Machine Learning, 2014, (Vol. 32, pp. 1278–1286). Bejing, China: PMLR.
&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref3:5" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:6"&gt;
&lt;p&gt;For a complete treatment of variational autoencoders, and variational
inference in general, I highly recommend:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Jaan Altosaar&amp;rsquo;s blog post,
&lt;/li&gt;
&lt;li&gt;Diederik P. Kingma&amp;rsquo;s PhD Thesis,
.&lt;/li&gt;
&lt;/ul&gt;
&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/li&gt;
&lt;li id="fn:7"&gt;
&lt;p&gt;D. Rezende and S. Mohamed, &amp;ldquo;Variational Inference with Normalizing Flows,&amp;rdquo; in Proceedings of the 32nd International Conference on Machine Learning, 2015, vol. 37, pp. 1530–1538.&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:8"&gt;
&lt;p&gt;Y. Burda, R. Grosse, and R. Salakhutdinov, &amp;ldquo;Importance Weighted Autoencoders,&amp;rdquo; in Proceedings of the 3rd International Conference on Learning Representations (ICLR), 2015.&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:8" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:9"&gt;
&lt;p&gt;E. Jang, S. Gu, and B. Poole, &amp;ldquo;Categorical Reparameterization with Gumbel-Softmax,&amp;rdquo; Nov. 2016. in Proceedings of the 5th International Conference on Learning Representations (ICLR), 2017.&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:9" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:10"&gt;
&lt;p&gt;L. Mescheder, S. Nowozin, and A. Geiger, &amp;ldquo;Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks,&amp;rdquo; in Proceedings of the 34th International Conference on Machine Learning, 2017, vol. 70, pp. 2391–2400.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:10" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:11"&gt;
&lt;p&gt;D. Tran, R. Ranganath, and D. Blei, &amp;ldquo;Hierarchical Implicit Models and Likelihood-Free Variational Inference,&amp;rdquo; in Advances in Neural Information Processing Systems 30, 2017.&amp;#160;&lt;a href="#fnref:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref1:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&amp;#160;&lt;a href="#fnref2:11" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:12"&gt;
&lt;p&gt;To support sample weighting (fined-tuning how much each data-point
contributes to the loss), Keras losses are expected returns a scalar for each
data-point in the batch. In contrast, losses appended with the &lt;code&gt;add_loss&lt;/code&gt;
method don&amp;rsquo;t support this, and are expected to be a single scalar.
Hence, we calculate the KL divergence for all data-points in the batch and
take the mean before passing it to &lt;code&gt;add_loss&lt;/code&gt;.&amp;#160;&lt;a href="#fnref:12" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:13"&gt;
&lt;p&gt;Y. Li and R. E. Turner, &amp;ldquo;Rényi Divergence Variational Inference,&amp;rdquo; in Advances in Neural Information Processing Systems 29, 2016.&amp;#160;&lt;a href="#fnref:13" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:14"&gt;
&lt;p&gt;A. B. Dieng, D. Tran, R. Ranganath, J. Paisley, and D. Blei, &amp;ldquo;Variational Inference via chi Upper Bound Minimization,&amp;rdquo; in Advances in Neural Information Processing Systems 30, 2017.&amp;#160;&lt;a href="#fnref:14" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>