If you want to run the rmd files in this repo, start by running R CMD BATCH make_example_big_data.R
in the code
directory. This will produce a file used for testing later.
The Adaptive Shrinkage method in (http://biorxiv.org/content/early/2016/06/08/038216) involves solving a convex optimization problem. Currently in my R package I use either an accelerated EM algorithm or an Interior Point (IP) method to solve this problem. (The interior point method uses the REBayes
package function KWdual
which interfaces with the mosek
library.) Both EM and IP are quite quick enough for single data sets, even quite large ones (e.g. \(10^6\) observations) but in some applications (e.g. matrix factorization) we want to run this algorithm hundreds or thousands of times. So we seek to speed it up.
The R package software actually involves three steps.
compute an \(n \times k\) likelihood matrix \(L\). Specifically \(L\) is based on a normal likelihood so \[L_{jk} = (s_j^2 + \sigma_k^2)^{-0.5} \exp[-0.5 x_j^2/(s_j^2 + \sigma_k^2)].\] where \(s_j, x_j\) are known data, and \(\sigma_1,\dots,\sigma_K\) are increasing from small to large.
estimate mixture proportions \(\pi_1,\dots,\pi_k\) by maximizing \[f(\pi) = \sum_j \log \sum_k \pi_k L_{jk}.\] This is the convex optimization problem.
compute posterior quantities given these mixture proportions.
All 3 steps take some time, but my expectation is that convex optimization step is taking most of the time. In this vignette I made a quick assessment of how these three stages break down to make sure this expectation is correct. In fact thet last step (compute posterior quantities) also takes non-trivial time for large data sets. However, I believe this is due to some inefficiencies in the current implementation that would not be too hard to fix. I therefore focus on speeding up the optimization.
The IP method is faster than the accelerated EM. However, there are some advantages of the EM:
REBayes::KWDual
.For this reason I start by working with EM. Hopefully things we learn there might also translate to the IP method. We can return to this in the future.
For future reference: I did do some quick assessments of how IP scales with \(n\) and \(k\) here. This suggests that scaling exceeds linear cost, particularly for \(n\), presumably because the number of iterations needed increases with \(n\).
Peter Carbonetto suggested we use kd trees to speed up EM calculations, and pointed me to this paper. Based on this I came up with the following strategy.
First, lets simplify the case to where the \(s_j\) are all equal (wlog, to 1 say). So \[L_{jk} = (1 + \sigma_k^2)^{-0.5} \exp[-0.5 x_j^2/(1 + \sigma_k^2)].\] And let us assume that \(x_j\) are sorted to be increasing in absolute value. In this case \(x_j, x_{j+1}\) will often be similar to one another, and so \(L_{j \cdot}\) and \(L_{j+1 \cdot}\) will often be similar. We can exploit this in the EM algorithm as follows.
First, let \[w_{j k}(\pi): = \pi_k L_{jk} / \sum_k (\pi_k L_{jk})\] These quantities are sometimes called the “responsibilies” in machine learning literature. Each EM iteration involves summing the responsibilities for the current value of \(\pi\) to compute a new value for \(\pi\) (\(\hat{\pi}\)):
\[\hat{\pi}_k = (1/n) \sum_{j=1}^n w_{jk}(\pi)\]
The basic idea is: 1. if many rows of \(w\) are similar we can avoid summing them all. 2. if rows \(a\) and \(b\) of \(w\) are similar then so will be all rows in between (because of the sorting step).
An assessment of how similar the nearby rows tend to be is given here.
This suggests the following algorithm (in pseudocode) to sum from row \(a\) to \(b\) of \(w\):
function wsum(w,a,b):
if(w[a,] similar to w[b,])
return (b-a+1) * (w[a,]+w[b,])/2
else
return wsum(w,a,c) + wsum(w,c+1,b)
This is what I implemented in fast_ash.cpp. (The function wsum
).
A similar idea can be applied to also compute the objective function at the same time as performing the EM updates. The function add_to_wsum
does this. (Since we want to return 2 things now it passes these quantitees by reference to help this.)
In developing this, testing, and trying to get this to run fast I
This site was created with R Markdown