Efficient Cholesky decomposition of low-rank updates
A short and practical guide to efficiently computing the Cholesky decomposition of matrices perturbed by low-rank updates
Suppose we’re given a positive semidefinite (PSD)
matrix
What is the best way to calculate the Cholesky decomposition of
Given no additional information the obvious way is to calculate it directly,
which incurs a cost of
Rank-1 Updates
First, let’s consider the simpler case involving just rank-1 updates
In TFP, this is implemented in the function named tfp.math.cholesky_update. For example,
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
update_factor_vector # Tensor; shape [..., N]
a # Tensor; shape [..., N, N]
update = tf.linalg.matmul(
update_factor_vector[..., tf.newaxis],
update_factor_vector[..., tf.newaxis],
transpose_b=True
)
b = a + update # Tensor; shape [..., N, N]
a_factor = tf.linalg.cholesky(a) # O(N^3); suppose this is pre-computed and stored
b_factor = tf.linalg.cholesky(b) # O(N^3), ignores `a_factor`
b_factor_1 = tfp.math.cholesky_update(a_factor, update_factor_vector) # O(N^2), uses `a_factor`
np.testing.assert_array_almost_equal(b_factor, b_factor_1)
Here cholesky_update
takes as arguments chol
with shape [B1, ..., Bn, N, N]
and u
with shape [B1, ..., Bn, N]
, and returns a lower triangular Cholesky
factor of the rank-1 updated matrix chol @ chol.T + u @ u.T
in
Low-Rank Updates
Now let’s return to rank-
Now we can write the rank-
update_factor_matrix # Tensor; shape [..., N, M]
# [..., N, 1, M] [..., 1, N, M] -> [..., N, N, M] -> [..., N, N]
update1 = tf.reduce_sum(update_factor_matrix[..., tf.newaxis, :] *
update_factor_matrix[..., tf.newaxis, :, :], axis=-1)
# [..., N, M] [..., M, N] -> [..., N, N]
update2 = tf.linalg.matmul(update_factor_matrix,
update_factor_matrix, transpose_b=True)
# not exactly equal due to finite precision, but still equal up to high precision
np.testing.assert_array_almost_equal(update1, update2, decimal=14)
Thus seen, a low-rank update is nothing more than a repeated application of
rank-1 updates,
Therefore, we can simply leverage the
Hence, we have:
# [..., N, M] [..., M, N] -> [..., N, N]
update = tf.linalg.matmul(update_factor_matrix,
update_factor_matrix, transpose_b=True)
b = a + update # Tensor; shape [..., N, N]
b_factor = tf.linalg.cholesky(b) # O(N^3), ignores `a_factor`
b_factor_1 = cholesky_update_iterated(a_factor, update_factor_matrix) # O(N^2M), uses `a_factor`
np.testing.assert_array_almost_equal(b_factor_1, b_factor)
where function cholesky_update_iterated
is implemented as follows:
def cholesky_update_iterated(chol, update_factor_matrix):
# base case
if update_factor_matrix.shape[-1] == 0:
return chol
prev = cholesky_update_iterated(chol, update_factor_matrix[..., :-1])
return tfp.math.cholesky_update(prev, update_factor_matrix[..., -1])
We can also implement this iteratively.
First we’d use tf.unstack
to turn the update factor matrix
>>> update_factor_vectors = tf.unstack(update_factor_matrix, axis=-1)
>>> assert isinstance(update_factor_vectors, list) # `update_factor_vectors` is a list
>>> assert len(update_factor_vectors) == M # ... the list contains M vectors
>>> assert update_factor_vectors[0].shape == (*Bs, N) # ... and each vector has shape [B1, ..., Bn, N]
Then, we have:
def cholesky_update_iterated(chol, update_factor_matrix):
new_chol = chol
for update_factor_vector in tf.unstack(update_factor_matrix, axis=-1):
new_chol = tfp.math.cholesky_update(new_chol, update_factor_vector)
return new_chol
The astute reader will recognize that this is simply an special case of
the itertools.accumulate
or functools.reduce
patterns, where
the binary operator is tfp.math.cholesky_update
,
the iterable is tf.unstack(update_factor, axis=-1)
and
the initial value is chol
.
Therefore, we can also implement it neatly using the one-liner:
from functools import reduce
def cholesky_update_iterated(chol, update_factor_matrix):
return reduce(tfp.math.cholesky_update, tf.unstack(update_factor_matrix, axis=-1), chol)
Summary
In summary, we showed that to efficiently calculate the Cholesky decomposition of a matrix perturbed by a low-rank update, one just needs to iteratively calculate that of the same matrix perturbed by a series of rank-1 updates. Better yet, all of this can be done with a simple one-liner!
To receive updates on more posts like this, follow me on Twitter and GitHub!
Seeger, M. (2004). Low rank updates for the Cholesky decomposition. ↩︎