性能注意事项data.table中的get()

我一直在循环中使用get()通过i引用其他多个列来操纵j列。

我想知道是否有更快/更有效的方法?有性能方面的考虑吗?

这是我想到的操作类型的最小示例:

require(data.table) # version 1.12.8
dt = data.table(v1=c(1,2,NA),v2=c(0,0,1),v3=c(0,0,0))
for (i in 1:2){
     dt[ is.na(get(paste0('v',i))), (paste0('v',i)):= get(paste0('v',i+1))+2 ][]
}

我这样做的实际表要大得多(〜5 mio行,〜300列)。

我将不胜感激。

评论
  • eipsum
    eipsum 回复

    We can use set which would assign in place

    library(data.table)
    for(j in 1:2) {
        i1 <- which(is.na(dt[[j]]))
        set(dt, i = i1, j = j, value = dt[[j+1]][i1]+ 2)
     }
    
    dt
    #   v1 v2 v3
    #1:  1  0  0
    #2:  2  0  0
    #3:  3  1  0
    
  • oalias
    oalias 回复

    Yes your for loop slows you down considerably. Even a simple lapply (and there's probably more elegant ways to do), brings you significant performance gains:

    library(data.table)
    dt <- data.table(v1 = rnorm(100), v2 = sample(c(NA,1:5)), v3 = sample(c(NA,1:5)), v4 = sample(c(NA,1:5)))
    
    microbenchmark::microbenchmark(
      for (i in 1:2){
        dt[ is.na(get(paste0('v',i))), (paste0('v',i)):= get(paste0('v',i+1))+2 ]
        },
      for (i in 1:2){ 
        dt[ is.na(get(paste0('v',i))), (paste0('v',i)):= get(paste0('v',i+1))+2 ][]
      },
      lapply(1:2, function(i) dt[ is.na(get(paste0('v',i))), (paste0('v',i)):= get(paste0('v',i+1))+2 ])
    )
    
    Unit: milliseconds
                                                                                                                      expr      min       lq      mean   median       uq
       for (i in 1:2) {     dt[is.na(get(paste0("v", i))), `:=`((paste0("v", i)), get(paste0("v",          i + 1)) + 2)] } 8.364342 8.519589  9.809164 8.672182 9.133614
     for (i in 1:2) {     dt[is.na(get(paste0("v", i))), `:=`((paste0("v", i)), get(paste0("v",          i + 1)) + 2)][] } 8.779920 8.968042 10.608441 9.071600 9.613822
         lapply(1:2, function(i) dt[is.na(get(paste0("v", i))), `:=`((paste0("v",      i)), get(paste0("v", i + 1)) + 2)]) 1.036560 1.136559  1.401750 1.162624 1.226263
          max neval
     29.47810   100
     29.96673   100
     15.63642   100
    

    Since you update by reference your data.table, you will get the same output in both cases. To extract the result, you can take the last element of your output list

    output <-   lapply(1:2, function(i) dt[ is.na(get(paste0('v',i))), (paste0('v',i)):= get(paste0('v',i+1))+2 ])
    output <- output[[length(output)]]
    

    可能有一种更优雅的方法,但是对于只花几秒钟编程的东西,将其平均计算时间除以10就是一个很好的结果;)