Sampling via Moment Sharing:
A New Framework for Distributed Bayesian Inference for Big Data
Yee Whye Teh (Oxford)
in collaboration with:
Minjie Xu, Jun Zhu, Bo Zhang (Tsinghua) Balaji Lakshminarayanan (Gatsby)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Inference
! Parameter vector X.
! Data items Y = y1, y2,... yN.
! Model:
! Aim:
p(X, Y ) = p(X)
YN
i=1
p(yi|X)
X
y1 y2 y3 y4 ... yN
p(X|Y ) = p(X)p(Y |X) p(Y )
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Why Bayes for Machine Learning?
! An important framework to frame learning.
! Quantification of uncertainty.
! Flexible and intuitive construction of complex models.
! Straightforward derivation of learning algorithms.
! Mitigation of overfitting.
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Big Data and Bayesian Inference?
! Large scale datasets are fast becoming the norm.
! Analysing and extracting understanding from these data is a driver of progress in many sectors of society.
! Current successes in scalable learning are optimization- based and non-Bayesian.
! What is the role of Bayesian learning in world of Big Data?
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Generic (Machine) Learning on Big Data
! Stochastic optimisation using mini-batches.
! Stochastic gradient descent.
> Stochastic Gradient Langevin Dynamics (Welling & Teh, Teh et al)
! Distributed/parallel computations on cores/clusters/GPUs.
! MapReduce, parameter server.
! Bringing the computations to the data, not the reverse.
! High communication costs.
> Distributed Bayesian Posterior Sampling via Moment Sharing (Xu et al)
! High synchronisation costs.
> Asynchronous Anytime Sequential Monte Carlo (Paige et al)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Generic (Bayesian) Learning on Big Data
! Stochastic optimisation using mini-batches.
! Stochastic gradient descent.
! > Stochastic Gradient Langevin Dynamics [Welling & Teh 2011, Patterson & Teh 2013, Teh et al (forthcoming)]
! Distributed/parallel computations on cores/clusters/GPUs.
! MapReduce, parameter server.
! Bringing the computations to the data, not the reverse.
! High communication costs.
! > Distributed Bayesian Posterior Sampling via Moment Sharing [Xu et al 2014]
! High synchronisation costs.
! > Asynchronous Anytime Sequential Monte Carlo [Paige et al 2014]
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Machine Learning on Distributed Systems
y1i y2i y3i y4i
! Distributed storage
! Distributed computation
! Network
communication costs
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Embarassingly Parallel MCMC Sampling
y1i y2i y3i y4i
Treat as independent inference problems.
Collect samples.
“Combine” samples together.
! Only communication at the combination stage.
{Xji}j=1...m,i=1...n
{Xi}i=1...n
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Local and Global Posteriors
! Each worker machine j has access only to its data subset.
! where pj(X) is a local prior and pj(X | yj) is local posterior.
! The (target) global posterior is
! If prior p(X) = ∏j pj(X), then
! Given collection of samples { Xji }i=1…n from pj(.|y), how do we get { Xi }i=1…n samples from p(.|y)?
pj(X | yj) = pj(X)
YI i=1
p(yji | X)
p(X | y) / p(X)
Ym j=1
p(yj | X) / p(X)
Ym j=1
pj(X | yj) pj(X)
p(X | y) /
Ym j=1
pj(X | yj)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Consensus Monte Carlo
! Each worker machine j collects N samples {Xmn} from:
! Master machine combines samples by weighted average:
[Scott et al 2013]
pj(X | yj) = p(X)1/m
YI i=1
p(yji|X)
Xi =
0
@
Xm j=1
Wj 1 A
1 Xm j=1
WjXji
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Consensus Monte Carlo
! Combination is correct if local posteriors are Gaussian.
! Weights are local posterior precisions.
! If not Gaussian, makes strong assumptions and unclear what local priors and weights for it to work.
[Scott et al 2013]
Xi =
0
@
Xm j=1
Wj 1 A
1 Xm j=1
WjXji
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Approximating Local Posterior Densities
! [Neiswanger et al 2013] proposed methods to combine estimates of local posterior densities instead of samples:
! Parametric: Gaussian approximation.
! Nonparametric: kernel density estimation based on samples.
! Semiparametric: Product of a parametric Gaussian
approximation with a nonparametric KDE correction term.
! Combination: Product of (approximate) densities.
! Sampling: Resort to Metropolis-within-Gibbs.
! [Wang & Dunson 2013]’s Weierstrass sampler is similar, using rejection sampling instead.
[Neiswanger et al 2013, Wang & Dunson 2013]
p(X | y) /
Ym j=1
pj(X | yj) ⇡
Ym j=1
1 n
Xn i=1
Khj (X; Xji)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Approximating Local Posterior Densities
! Parametric approximation can be quite bad unless Bernstein-von Mises Theorem kicks in.
! Complex and expensive combination step in non- and semi-parametric
estimates.
! KDE suffers from curse of dimensionality.
! Performs poorly if local posteriors differ significantly.
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Intuition and Desiderata
! Distributed system with independent MCMC sampling.
! Identify regions of high (global) posterior probability mass.
! Each local sampler is based on
local data, but “concentrate on high probability regions”.
! High probability regions found using samples, by allowing for some small amount of communication.
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
! Allow some amount of communication to align worker MCMC
samplers.
! “High probability region”
defined by low order moments.
! Align using Expectation Propagation (EP).
! Asynchronous and infrequent updates.
y1i y2i y3i y4i
(Not Quite) Embarrassingly Parallel MCMC
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Expectation Propagation
! If N is large, the worker j likelihood term p(yj | X) should be well approximated by Gaussian
! Parameters fit iteratively using a variational approach to minimize KL divergence:
p(yj | X) ⇡ qj(X) = N (X; µj, ⌃j)
[Minka 2001]
p(X | y) ⇡ pj(X | y) / p(yj | X) p(X) Y
k6=j
qk(X)
| {z }
pj(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Expectation Propagation
! Update performed as follows:
! Compute (or estimate) first two moments µ*, Σ* of pj( X | y).
! Compute µj, Σj so that N(.; µj, Σj) pj( X )/Z has moments µ*, Σ*.
! Computations done on natural parameters.
! Generalizes to other exponential families.
p(X | y) ⇡ pj(X | y) / p(yj | X) p(X) Y
k6=j
qk(X)
| {z }
pj(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Expectation Propagation
! Variational parameters fit
iteratively until convergence.
! EP tends to converge very quickly (when it does).
! Damping updates can help convergence.
! At convergence, all local
posteriors agree on their first two moments.
! Generalizes to hierarchical and graphical models [infer.NET,
Gelman et al 2014].
y1i y2i y3i y4i
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
pj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
pj(·)
{Xji}
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
pj(·)
{Xji} ) (µ⇤, ⌃⇤)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
pj(·)
{Xji} ) (µ⇤, ⌃⇤) ) (µj, ⌃j)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Sampling via Moment Sharing (SMS)
y1i y2i y3i y4i
! KL minimized by matching moments of pj(X | y).
! Moments computed by drawing MCMC samples.
! All samples from all machines can be treated as approximate samples from full posterior
given all data.
! Communicate only moments, synchronous or asynchronous.
p(X)
p(y1|X)
≈ q1(X)
p(y2|X)
≈ q2(X)
p(y3|X)
≈ q3(X)
p(y4|X)
≈ q4(X)
qjnew(·) = arg min
N (·;µ,⌃) KL pj(· | y) k N (·; µ, ⌃)pj(·)
pj(·)
{Xji} ) (µ⇤, ⌃⇤) ) (µj, ⌃j) qj(·)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Logistic Regression
! Simulated dataset.
! d=20, # data items N=1000.
! NUTS base sampler.
! # workers m = 4,10,50.
! # MCMC iters T = 1000,1000,10000.
! # EP iters k given as vertical lines.
200 400 600 800 1000 1200 1400
−2.5
−2
−1.5
−1
−0.5 0 0.5 1
k × T × N/m × 103
100 200 300 400 500 600
−2.5
−2
−1.5
−1
−0.5 0 0.5 1
k × T × N/m × 103 250 500 750 1000 1250 1500
−2.5
−2
−1.5
−1
−0.5 0 0.5 1
k × T × N/m × 103
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Logistic Regression
! MSE of posterior mean, as function of total # iterations.
3.2 6.4 9.6 12.8 16 19.2
x 105 10−6
10−4 10−2 100
k × T × m SMS(s)
SMS(a) SCOT NEIS(p) NEIS(n) WANG
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Logistic Regression
! Approximate KL, MSE of predictive probabilities, as function of total # iterations.
3.2 6.4 9.6 12.8 16 19.2
x 105 10−7
10−6 10−5 10−4 10−3 10−2 10−1
k × T × m SMS(s)
SMS(a) SCOT NEIS(n) WANG
3.2 6.4 9.6 12.8 16 19.2
x 105 10−1
100 101 102
k × T × m SMS(s)
SMS(a) SCOT WANG
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Logistic Regression
! Approximate KL as function of # nodes.
m=8 m=16 m=32 m=48 m=64
0 0.5 1 1.5 2 2.5
SMS(s,s) SMS(s,e) SMS(a,s) SMS(a,e) SCOT XING(p)
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bayesian Logistic Regression
! Approximate KL, as function of # iterations per node and
# likelihood evaluations.
0 0.5 1 1.5 2 2.5
x 108 10−2
10−1 100 101 102
k × T × N/m
SMS(s) SMS(a) m = 8 m = 16 m = 32 m = 48 m = 64
0 1 2 3 4 5 6 7
x 104 10−2
10−1 100 101 102
k × T
SMS(s) SMS(a) m = 8 m = 16 m = 32 m = 48 m = 64
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Spike-and-Slab Sparse Regression
0 500 1000 1500 2000
−0.4
−0.2 0 0.2 0.4
k × T × N/m × 103
0 1000 2000 3000 4000
−0.4
−0.2 0 0.2 0.4
k × T × N/m × 103
! Posterior mean coefficients.
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Some Remarks
! Scalable distributed MCMC sampling.
! A bit of communication goes a long way.
! Issue with stochasticity of moment estimates:
! EP theory does not cover stochastic updates.
! Not clear what is the best stochastic update to use.
! Nor how can we characterise convergence and quality of approximation.
! Matlab source: https://github.com/chokkyvista/smssample
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Other Approaches to Scalable Bayes
! Median posterior [Stanislav et al 2014]:
! Embeds local posteriors into an RKHS, and computes the geometric median.
! Improves robustness to outliers in data.
! Stochastic gradient MCMC approaches:
! Reduce cost of each MCMC step by using data subset.
! A distributed version have been developed.
! [Welling & Teh 2011, Ahn et al 2012, 2014, Teh, Thiery &
Vollmer (forthcoming), Bardenet et al 2014]
! Variational approaches:
! Faster convergence, with possibly significant bias.
! Recent works successfully extend these to large scale datasets using stochastic approximation techniques [Hoffman et al 2010, 2013, etc] and to flexible parameterized variational distributions [Mnih & Gregor 2014, Rezende et al 2014, Kingma & Welling 2014].
SMS: Distributed Bayesian Inference for Big Data Yee Whye Teh
Bigger Picture
! The probabilistic modelling/Bayesian inference approach offers a principled and powerful data analysis framework.
! Standard methodologies do not extend easily to Big Data.
! Important to develop generic methodologies allowing these approaches to be applicable on Big Data.
! Bias/variance trade-offs becoming more important.
! Low bias “exact” methods do not scale as well to Big Data.
Thank you!
Thanks for funding:
Yee Whye Teh SMS: Distributed Bayesian Inference for Big Data