Last updated: 2017-01-15

Code version: 086c3157516dcce3907fd3fd6857a0aec1f35bfa

Overview

The file fast_ash.cpp contains some rcpp functions to compare.

The function wsum_direct simply sums the responsibilities (posterior class probabilities) directly. This is close to the approach currently used in ash.

In contrast the wsum method is based on a multi-resolution method that bifurcates the data and merges rows that yield the same results within some specified tolerance.

library(ashr)
set.seed(100)
nsamp=100000
z = rnorm(nsamp,0,2)
#now sort z so that they are in order
z = z[order(abs(z))]

res <- ash(z,1,mixcompdist="normal",outputlevel=4)
lik = res$fit_details$matrix_lik
fitted_g = get_fitted_g(res)

# set up the initial value of pi as uniform
pi = rep(1, ncomp(fitted_g))

# This is closer to the usual initial value we use in ash
# normalize=function(x){x/sum(x)}
#pi = rep(1/nsamp, ncomp(fitted_g))
#pi[1]=1
#pi = normalize(pi)

Here we just check the approximate methods produce similar answers to the direct method.

Rcpp::sourceCpp('fast_ash.cpp')
wsum_direct(pi,lik,0,nsamp-1)
 [1] 6212.279 6222.939 6233.550 6254.627 6296.192 6376.958 6528.978
 [8] 6795.816 7196.982 7621.205 7742.674 7240.996 6165.978 4867.429
[15] 3656.324 2668.879 1918.194
wsum(pi,lik,0,nsamp-1,0,0,tol=1e-3)
 [1] 6211.553 6222.237 6232.873 6253.997 6295.650 6376.569 6528.829
 [8] 6795.970 7197.414 7621.788 7743.275 7241.547 6166.447 4867.805
[15] 3656.610 2669.089 1918.345
wsum(pi,lik,0,nsamp-1,0,0,tol=1e-5)
 [1] 6212.278 6222.938 6233.549 6254.626 6296.192 6376.957 6528.978
 [8] 6795.816 7196.982 7621.205 7742.675 7240.997 6165.978 4867.430
[15] 3656.325 2668.879 1918.194
ws = rep(0,length(pi))
lprobsum = c(0)
add_to_wsum_direct(lprobsum, ws,pi,lik,0,nsamp-1)
ws = rep(0,length(pi))
lprobsum = c(0)
add_to_wsum(lprobsum, ws,pi,lik,0,nsamp-1,0,0,tol=1e-5)
ws
 [1] 6212.278 6222.938 6233.549 6254.626 6296.192 6376.957 6528.978
 [8] 6795.816 7196.982 7621.205 7742.675 7240.997 6165.978 4867.430
[15] 3656.325 2668.879 1918.194
wsum_direct(pi,lik,11,2000) # just check the functions are working ok for subsets that are not the whole data 
 [1] 198.20069 197.23456 196.28244 194.41882 190.84550 184.25274 172.88777
 [8] 155.31315 131.92739 105.73009  80.72572  59.60723  43.13531  30.86898
[15]  21.96125  15.57683  11.03153
wsum(pi,lik,11,2000,0,0,tol=1e-3)
 [1] 198.19021 197.22451 196.27279 194.40997 190.83810 184.24784 172.88664
 [8] 155.31655 131.93464 105.73915  80.73449  59.61458  43.14099  30.87319
[15]  21.96429  15.57901  11.03308
wsum(pi,lik,11,2000,0,0,tol=1e-5)
 [1] 198.19907 197.23302 196.28095 194.41746 190.84436 184.25198 172.88760
 [8] 155.31367 131.92851 105.73149  80.72707  59.60836  43.13618  30.86963
[15]  21.96172  15.57717  11.03177
ws = rep(0,length(pi))
lprobsum = 0
add_to_wsum(lprobsum, ws,pi,lik,11,2000,0,0,tol=1e-5)
ws
 [1] 198.19907 197.23302 196.28095 194.41746 190.84436 184.25198 172.88760
 [8] 155.31367 131.92851 105.73149  80.72707  59.60836  43.13618  30.86963
[15]  21.96172  15.57717  11.03177

Here we compare the compute times. If the tolerance is too tight then you get slower than simply directly summing. However, a weaker tolerance gives a substantial gain.

tlik = t(lik)
system.time(wsum_direct(pi,lik,0,nsamp-1))
   user  system elapsed 
  0.012   0.002   0.014 
system.time(wsum(pi,lik,0,nsamp-1,0,0,tol=1e-3))
   user  system elapsed 
  0.001   0.000   0.002 
system.time(wsum(pi,lik,0,nsamp-1,0,0,tol=1e-5))
   user  system elapsed 
  0.049   0.002   0.051 
wsum=rep(0,length(pi))
lprobsum = 0
system.time(add_to_wsum(lprobsum,wsum,pi,lik,0,nsamp-1,0,0,tol=1e-5))
   user  system elapsed 
  0.057   0.001   0.058 

Session information

R version 3.3.1 (2016-06-21)
Platform: x86_64-apple-darwin13.4.0 (64-bit)
Running under: OS X 10.11.5 (El Capitan)

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] microbenchmark_1.4-2.1 ashr_2.1               workflowr_0.2.0       
[4] rmarkdown_1.3         

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.8       rstudioapi_0.6    knitr_1.15.1     
 [4] magrittr_1.5      REBayes_0.73      MASS_7.3-45      
 [7] munsell_0.4.3     doParallel_1.0.10 pscl_1.4.9       
[10] colorspace_1.2-7  SQUAREM_2016.8-2  lattice_0.20-34  
[13] foreach_1.4.3     plyr_1.8.4        stringr_1.1.0    
[16] tools_3.3.1       parallel_3.3.1    grid_3.3.1       
[19] gtable_0.2.0      git2r_0.18.0      htmltools_0.3.5  
[22] iterators_1.0.8   assertthat_0.1    yaml_2.1.14      
[25] rprojroot_1.1     digest_0.6.10     Matrix_1.2-7.1   
[28] ggplot2_2.1.0     codetools_0.2-15  evaluate_0.10    
[31] stringi_1.1.2     scales_0.4.0      Rmosek_7.1.2     
[34] backports_1.0.4   truncnorm_1.0-7  

This site was created with R Markdown