A Distributed Method for Fitting Laplacian Regularized Stratified Models
Jonathan Tuck Shane Barratt Stephen Boyd
International Conference on Statistical Optimization and Learning
Beijing Jiatong University, June 19 2019
Outline
Stratified models
Data models
Regularization graphs Solution method Examples Conclusions
Stratified models 2
Data records
I Data records have form (z, x , y ) ∈ Z × X × Y
I z ∈ Z = {1, . . . , K } are categorical features (we’ll stratify over)
I x ∈ X are the other features
I y ∈ Y is the label/dependent variable
Stratified models 3
Stratified model
I Fit a model to data (z, x , y )
I Stratified model: fit different model for each value of z
I θ k is model parameter for z = k
I θ = (θ 1 , . . . , θ K ) ∈ Θ ⊆ R Kn parameterizes the stratified model
I Old idea [Kernan 99]
I Example: stratified regression model ˆ y = x T θ z
Stratified models 4
Example
0.0 0.2 0.4 0.6 0.8 1.0
10 5 0 5
10 z=1
z=-1
ˆ
y = θ 1 + θ 2 x + θ 3 z, z ∈ {−1, 1}
Stratified models 5
Example
0.0 0.2 0.4 0.6 0.8 1.0
10 5 0 5 10
z=1 z=-1
ˆ y =
( θ −1,1 + θ −1,2 x z = −1 θ 1,1 + θ 1,2 x z = 1
Stratified models 6
Stratified model
I Stratified model is simple function of x (e.g., linear), arbitrary function of z
I Examples: separate models for – z ∈ {Male, Female}
– z ∈ {Monday, . . . , Sunday}
I If K is large, might not have enough training data to fit θ k
I Extreme case: no training data for some values of z
I We’ll add regularization so nearby θ k ’s are close
Stratified models 7
Stratified model with Laplacian regularization
I Choose θ = (θ 1 , . . . , θ K ) to minimize
N
X
i =1
l (θ z
i, x i , y i ) +
K
X
k=1
r (θ k ) +
K
X
u,v =1
W uv kθ u − θ v k 2
I l is loss function, r is (local) regularization
I Last term is Laplacian regularization
I W uv ≥ 0 are edge weights of graph with node set Z
I Convex problem when l , r convex in θ
Stratified models 8
Stratified model with Laplacian regularization
I Graph encodes prior that nearby values of z should have similar models
I Can be used to capture periodicities, other structure
I Examples:
– θ male and θ female should be close – θ jan should be close to θ feb and θ dec
I Model for each value of z ‘borrows strength’ from its neighbors
I Works even when there’s no data for some values of z
I As W uv → 0, get traditional (unregularized) stratified model
I As W uv → ∞, get common model (θ 1 = · · · = θ K )
Stratified models 9
Related work
I Network lasso [Hallac 15] )
I Pliable lasso [Tibshirani 17]
I Varying-coefficient models [Hastie 93, Fan 08]
I Multi-task learning [Caruana 97]
Stratified models 10
Outline
Stratified models Data models
Regularization graphs Solution method Examples Conclusions
Data models 11
Point estimate: Predict y
I Regression: X = R n , Y = R
– l (θ, x , y ) = p(x T θ − y ), p is penalty function – y = x ˆ T θ z
I Classification: X = R n , Y = {−1, 1}
– l (θ, x , y ) = p(yx T θ) – y = sign(x ˆ T θ z )
I M-class classification: X ∈ R n , Y = {1, . . . , M}
– l (θ, x , y ) = p y (x T θ), θ ∈ R n×M , p y is penalty function for class y – y = argmax(x ˆ T θ z )
Data models 12
Conditional distribution estimate: Predict prob(y | x, z)
I Multinomial logistic regression: X = R n , Y = {1, . . . , M}
I Conditional distribution:
prob(y | x , z) = exp(x T θ z ) y P M
j =1 exp(x T θ z ) j
, y = 1, . . . , M
I Loss function (convex in θ):
l (θ, x , y ) = log
M
X
j =1
exp(x T θ) j
− (x T θ) y
Data models 13
Distribution estimate: Predict p(y | z)
I Gaussian distribution: Y = R m
I Density:
p(y | z) = (2π) −m/2 det(Σ) −1/2 exp(−1/2(y − µ) T Σ −1 (y − µ))
I Use natural parameter θ = (S , ν) = (Σ −1 , Σ −1 µ) (so Σ = S −1 , µ = S −1 ν)
I Loss function (convex in θ):
l (θ, y ) = − log det S + y T Sy − 2y T ν + ν T S −1 ν
Data models 14
Outline
Stratified models Data models
Regularization graphs
Solution method Examples Conclusions
Regularization graphs 15
Path graph
0 1 2 3 4 5
I Models time, distance, . . .
I Yields time-varying, distance-varying models
Regularization graphs 16
Cycle graph
M W Tu
Th
F
Sa Su
I Yields diurnal, weekly, seasonally-varying models
Regularization graphs 17
Tree graph
CEO
CTO
SVP1
COO
SVP2
VP1 VP2
SVP3
VP3
I Yields hierarchical models
Regularization graphs 18
Grid graph
0,0 0,1 0,2 0,3
1,0 1,1 1,2 1,3
2,0 2,1 2,2 2,3
3,0 3,1 3,2 3,3
I Yields (2D) space-varying models
Regularization graphs 19
Products of graphs
M,1
F,1
M,2
F,2
· · ·
· · ·
M,99
F,99
M,100
F,100
I Z = {Male, Female} × {1, 2, . . . , 99, 100} (sex × age)
I Yields sex, age-varying models
Regularization graphs 20
Outline
Stratified models Data models
Regularization graphs Solution method
Examples Conclusions
Solution method 21
Fitting problem
I To fit stratified model, minimize `(θ) + r (θ) + L(θ)
I `(θ) = P K
k=1 ` k (θ k ) is loss, ` k (θ k ) = P
i :z
i=k l (θ k , x i , y i )
I r (θ) = P K
k=1 r (θ k ) is (local) regularization
I L(θ) = P K
u,v =1 W uv kθ u − θ v k 2 is Laplacian regularization
I `, r are separable in θ k
I L is quadratic, separable in components of θ k
I We’ll use operator splitting method (ADMM)
Solution method 22
Reformulation
I Replicate variables:
minimize `(θ) + r (˜ θ) + L(ˆ θ) subject to θ = ˜ θ = ˆ θ
I Augmented Lagrangian
L(θ, ˜ θ, ˆ θ, u, ˜ u) = `(θ) + r (˜ θ) + L(ˆ θ) + 1
2t kθ − ˆ θ + uk 2 2 + 1
2t k˜ θ − ˆ θ + ˜ uk 2 2
I u, ˜ u dual variables for the two constraints, t > 0 is algorithm parameter
Solution method 23
ADMM
I Augmented Lagrangian
L(θ, ˜ θ, ˆ θ, u, ˜ u) = `(θ) + r (˜ θ) + L(ˆ θ) + 1
2t kθ − ˆ θ + uk 2 2 + 1
2t k˜ θ − ˆ θ + ˜ uk 2 2
I ADMM: for i = 1, 2, . . .
θ i +1 , ˜ θ i +1 := argmin
θ,˜ θ
L(θ, ˜ θ, ˆ θ i , u i , ˜ u i ) θ ˆ i +1 := argmin
θ ˆ
L(θ i , ˜ θ i , ˆ θ, u i , ˜ u i ) u i +1 := u i + θ i +1 − ˆ θ i +1
˜
u i +1 := u ˜ i + ˜ θ i +1 − ˆ θ i +1
Solution method 24
ADMM
I First step can be expressed as
θ k i +1 = prox t`
k(ˆ θ i k − u k i ), θ ˜ i +1 k = prox tr (ˆ θ i k − ˜ u k i ), k = 1, . . . , K
I proximal operator of tg is
prox tg (ν) = argmin
θ
tg (θ) + (1/2)kθ − νk 2 2
I Can evaluate these 2K proximal operators in parallel
Solution method 25
Loss proximal operators
I Often has closed form expression or efficient implementation
I Square loss: ` k (θ k ) = kX k θ k − y k k 2 2 – prox t`
k
(ν) = (I + tX k T X k ) −1 (ν − tX k T y k ) – Cache factorization or warm-start CG
I Logistic loss: ` k (θ k ) = P
i log(1 + exp(−y ki θ k T x ki )) – Use L-BFGS to evaluate prox t`
k
(ν) – Warm-start
I Many others . . .
Solution method 26
Regularizer proximal operators
I Often has closed form expression or efficient implementation Regularizer r (θ k ) prox tr (ν) Sum of squares (` 2 ) (λ/2)kθ k k 2 2 ν/(tλ + 1)
` 1 norm λkθ k k 1 (ν − tλ) + − (−ν − tλ) + Nonnegative I + (θ k ) (θ k ) +
I λ > 0 local regularization parameter
Solution method 27
Laplacian proximal operator
I Separable across each component of θ k
I To find each component (θ k ) i , need to solve a Laplacian system
I Many efficient ways to solve, e.g., diagonally preconditioned CG
I These n systems can be solved in parallel
Solution method 28
Software implementation
I Available at www.github.com/cvxgrp/strat_models
I numpy, scipy for matrix operations
I networkx for handling graphs and graph operations
I torch for L-BFGS and GPU computation
I multiprocessing for parallelism
I model.fit(X,Y,Z,G) (writes θ k on graph nodes)
Solution method 29
Outline
Stratified models Data models
Regularization graphs Solution method Examples Conclusions
Examples 30
House price prediction
I N ≈ 22000 (z, x , y ) from King County, WA
I Split into train/test ≈ 16200/5400
I z = (latitude bin, longitude bin)
I x ∈ R 10 = features of house, y = log of house price
I Graph is 50 × 50 grid with all edge weights same; K = 2500
I Stratified ridge regression model with two hyperparameters (one for local regularization, one for Laplacian regularization)
Examples 31
House price prediction: Results
I Compare stratifed, common, and random forest with 50 trees
Model Parameters RMS test error
Stratified 25000 0.181
Common 10 0.316
Random forest 985888 0.184
Examples 32
House price prediction: Parameters
bedrooms bathrooms sqft living sqft lot floors
waterfront condition grade yr built intercept