Last updated: 2017-01-15
Code version: 086c3157516dcce3907fd3fd6857a0aec1f35bfa
The file fast_ash.cpp
contains some rcpp functions to compare.
The function wsum_direct
simply sums the responsibilities (posterior class probabilities) directly. This is close to the approach currently used in ash.
In contrast the wsum
method is based on a multi-resolution method that bifurcates the data and merges rows that yield the same results within some specified tolerance.
library(ashr)
set.seed(100)
nsamp=100000
z = rnorm(nsamp,0,2)
#now sort z so that they are in order
z = z[order(abs(z))]
res <- ash(z,1,mixcompdist="normal",outputlevel=4)
lik = res$fit_details$matrix_lik
fitted_g = get_fitted_g(res)
# set up the initial value of pi as uniform
pi = rep(1, ncomp(fitted_g))
# This is closer to the usual initial value we use in ash
# normalize=function(x){x/sum(x)}
#pi = rep(1/nsamp, ncomp(fitted_g))
#pi[1]=1
#pi = normalize(pi)
Here we just check the approximate methods produce similar answers to the direct method.
Rcpp::sourceCpp('fast_ash.cpp')
wsum_direct(pi,lik,0,nsamp-1)
[1] 6212.279 6222.939 6233.550 6254.627 6296.192 6376.958 6528.978
[8] 6795.816 7196.982 7621.205 7742.674 7240.996 6165.978 4867.429
[15] 3656.324 2668.879 1918.194
wsum(pi,lik,0,nsamp-1,0,0,tol=1e-3)
[1] 6211.553 6222.237 6232.873 6253.997 6295.650 6376.569 6528.829
[8] 6795.970 7197.414 7621.788 7743.275 7241.547 6166.447 4867.805
[15] 3656.610 2669.089 1918.345
wsum(pi,lik,0,nsamp-1,0,0,tol=1e-5)
[1] 6212.278 6222.938 6233.549 6254.626 6296.192 6376.957 6528.978
[8] 6795.816 7196.982 7621.205 7742.675 7240.997 6165.978 4867.430
[15] 3656.325 2668.879 1918.194
ws = rep(0,length(pi))
lprobsum = c(0)
add_to_wsum_direct(lprobsum, ws,pi,lik,0,nsamp-1)
ws = rep(0,length(pi))
lprobsum = c(0)
add_to_wsum(lprobsum, ws,pi,lik,0,nsamp-1,0,0,tol=1e-5)
ws
[1] 6212.278 6222.938 6233.549 6254.626 6296.192 6376.957 6528.978
[8] 6795.816 7196.982 7621.205 7742.675 7240.997 6165.978 4867.430
[15] 3656.325 2668.879 1918.194
wsum_direct(pi,lik,11,2000) # just check the functions are working ok for subsets that are not the whole data
[1] 198.20069 197.23456 196.28244 194.41882 190.84550 184.25274 172.88777
[8] 155.31315 131.92739 105.73009 80.72572 59.60723 43.13531 30.86898
[15] 21.96125 15.57683 11.03153
wsum(pi,lik,11,2000,0,0,tol=1e-3)
[1] 198.19021 197.22451 196.27279 194.40997 190.83810 184.24784 172.88664
[8] 155.31655 131.93464 105.73915 80.73449 59.61458 43.14099 30.87319
[15] 21.96429 15.57901 11.03308
wsum(pi,lik,11,2000,0,0,tol=1e-5)
[1] 198.19907 197.23302 196.28095 194.41746 190.84436 184.25198 172.88760
[8] 155.31367 131.92851 105.73149 80.72707 59.60836 43.13618 30.86963
[15] 21.96172 15.57717 11.03177
ws = rep(0,length(pi))
lprobsum = 0
add_to_wsum(lprobsum, ws,pi,lik,11,2000,0,0,tol=1e-5)
ws
[1] 198.19907 197.23302 196.28095 194.41746 190.84436 184.25198 172.88760
[8] 155.31367 131.92851 105.73149 80.72707 59.60836 43.13618 30.86963
[15] 21.96172 15.57717 11.03177
Here we compare the compute times. If the tolerance is too tight then you get slower than simply directly summing. However, a weaker tolerance gives a substantial gain.
tlik = t(lik)
system.time(wsum_direct(pi,lik,0,nsamp-1))
user system elapsed
0.012 0.002 0.014
system.time(wsum(pi,lik,0,nsamp-1,0,0,tol=1e-3))
user system elapsed
0.001 0.000 0.002
system.time(wsum(pi,lik,0,nsamp-1,0,0,tol=1e-5))
user system elapsed
0.049 0.002 0.051
wsum=rep(0,length(pi))
lprobsum = 0
system.time(add_to_wsum(lprobsum,wsum,pi,lik,0,nsamp-1,0,0,tol=1e-5))
user system elapsed
0.057 0.001 0.058
R version 3.3.1 (2016-06-21)
Platform: x86_64-apple-darwin13.4.0 (64-bit)
Running under: OS X 10.11.5 (El Capitan)
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] microbenchmark_1.4-2.1 ashr_2.1 workflowr_0.2.0
[4] rmarkdown_1.3
loaded via a namespace (and not attached):
[1] Rcpp_0.12.8 rstudioapi_0.6 knitr_1.15.1
[4] magrittr_1.5 REBayes_0.73 MASS_7.3-45
[7] munsell_0.4.3 doParallel_1.0.10 pscl_1.4.9
[10] colorspace_1.2-7 SQUAREM_2016.8-2 lattice_0.20-34
[13] foreach_1.4.3 plyr_1.8.4 stringr_1.1.0
[16] tools_3.3.1 parallel_3.3.1 grid_3.3.1
[19] gtable_0.2.0 git2r_0.18.0 htmltools_0.3.5
[22] iterators_1.0.8 assertthat_0.1 yaml_2.1.14
[25] rprojroot_1.1 digest_0.6.10 Matrix_1.2-7.1
[28] ggplot2_2.1.0 codetools_0.2-15 evaluate_0.10
[31] stringi_1.1.2 scales_0.4.0 Rmosek_7.1.2
[34] backports_1.0.4 truncnorm_1.0-7
This site was created with R Markdown