Optimal method of comparing a vector of numbers to values in another vector
Suppose I have two vectors of values开发者_JS百科:
a <- c(1,3,4,5,6,7,3)
b <- c(3,5,1,3,2)
And I want to apply some function, FUN
, to each of the inputs of a
against the whole of b
, what's the most efficient way to do it.
More specifically, in this case for each of the elements in a
I want to know for each value of 'a', how many of the elements in b
are greater than or equal to that value. The naïve approach is to do the following:
sum(a < b)
Of course, this doesn't work as it attempts to iterate over each of the vectors in parallel and gives me the warning:
longer object length is not a multiple of shorter object length
The output, btw, of that command is 3
.
However, in my situation, what I'd like to see is an output that is:
0 2 4 4 5 5 2
Of course, I realize I can do it using a for loop as such:
out <- c()
for (i in a) {
for (i in a) { out[length(out) + 1] = sum(b<i)}
}
Likewise, I could use sapply
as such:
sapply(a, function(x)sum(b<x))
However, I'm trying to be a good R programmer and stay away from for loops and sapply
seems to be very slow. Are there other alternatives?
For what it's worth, I'm doing this a couple of million times where length(b)
is always less than length(a)
and length(a)
ranges from 1 to 30.
Try this:
findInterval(a - 0.5, sort(b))
Speed improvement from a) avoiding sort
, and b) avoiding overhead in findInterval
and order
by using simpler .Internal
wrappers:
order2 = function(x) .Internal(order(T, F, x))
findInterval2 = function(x, vec, rightmost.closed=F, all.inside=F) {
nx <- length(x)
index <- integer(nx)
.C('find_interv_vec', xt=as.double(vec), n=length(vec),
x=as.double(x), nx=nx, as.logical(rightmost.closed),
as.logical(all.inside), index, DUP = FALSE, NAOK=T,
PACKAGE='base')
index
}
> system.time(for (i in 1:10000) findInterval(a - 0.5, sort(b)))
user system elapsed
1.22 0.00 1.22
> system.time(for (i in 1:10000) sapply(a, function(x)sum(b<x)))
user system elapsed
0.79 0.00 0.78
> system.time(for (i in 1:10000) rowSums(outer(a, b, ">")))
user system elapsed
0.72 0.00 0.72
> system.time(for (i in 1:10000) findInterval(a - 0.5, b[order(b)]))
user system elapsed
0.42 0.00 0.42
> system.time(for (i in 1:10000) findInterval2(a - 0.5, b[order2(b)]))
user system elapsed
0.16 0.00 0.15
The complexity of defining findInterval2
and order2
is probably only warranted if you have heaps of iterations with fairly small N.
Also timings for larger N:
> a = rep(a, 100)
> b = rep(b, 100)
> system.time(for (i in 1:100) findInterval(a - 0.5, sort(b)))
user system elapsed
0.01 0.00 0.02
> system.time(for (i in 1:100) sapply(a, function(x)sum(b<x)))
user system elapsed
0.67 0.00 0.68
> system.time(for (i in 1:100) rowSums(outer(a, b, ">")))
user system elapsed
3.67 0.26 3.94
> system.time(for (i in 1:100) findInterval(a - 0.5, b[order(b)]))
user system elapsed
0 0 0
> system.time(for (i in 1:100) findInterval2(a - 0.5, b[order2(b)]))
user system elapsed
0 0 0
One option is to use outer()
to apply the binary operator function >
to a
and b
:
> outer(a, b, ">")
[,1] [,2] [,3] [,4] [,5]
[1,] FALSE FALSE FALSE FALSE FALSE
[2,] FALSE FALSE TRUE FALSE TRUE
[3,] TRUE FALSE TRUE TRUE TRUE
[4,] TRUE FALSE TRUE TRUE TRUE
[5,] TRUE TRUE TRUE TRUE TRUE
[6,] TRUE TRUE TRUE TRUE TRUE
[7,] FALSE FALSE TRUE FALSE TRUE
The answer to the Q is then given by the row sums of the result above:
> rowSums(outer(a, b, ">"))
[1] 0 2 4 4 5 5 2
For this example data set, this solution is slightly faster that findIntervals()
but not by much:
> system.time(replicate(1000, findInterval(a - 0.5, sort(b))))
user system elapsed
0.131 0.000 0.132
> system.time(replicate(1000, rowSums(outer(a, b, ">"))))
user system elapsed
0.078 0.000 0.079
It is also slightly faster than the sapply()
version, but marginally:
> system.time(replicate(1000, sapply(a, function(x)sum(b<x))))
user system elapsed
0.082 0.000 0.082
@Charles notes that most of the time in the findInterval()
example is used by sort()
, which can be circumvented via order()
. When this is done, the findInterval()
solution is faster than the outer()
solution:
> system.time(replicate(1000, findInterval(a - 0.5, b[order(b)])))
user system elapsed
0.049 0.000 0.049
I'd be very wary of using the internals of R in production code. The internals can easily change between releases.
sort.int is faster than sort - and it's just plain weird that b[order(b)] is faster than sort.int(b). R could definitely improve its sorting...
And unless you use the internals of R, it seems like using vapply is actually faster:
> system.time(for (i in 1:10000) findInterval(a - 0.5, sort(b)))
user system elapsed
0.99 0.00 0.98
> system.time(for (i in 1:10000) findInterval(a - 0.5, sort.int(b)))
user system elapsed
0.8 0.0 0.8
> system.time(for (i in 1:10000) findInterval(a - 0.5, b[order(b)]))
user system elapsed
0.32 0.00 0.32
> system.time(for (i in 1:10000) sapply(a, function(x)sum(b<x)))
user system elapsed
0.61 0.00 0.59
> system.time(for (i in 1:10000) vapply(a, function(x)sum(b<x), 0L))
user system elapsed
0.18 0.00 0.19
Just a add-on note: if you know the range of the values for each vector, then it might be quicker to calculate the max and mins first, e.g.
order2 = function(x) .Internal(order(T, F, x))
findInterval2 = function(x, vec, rightmost.closed=F, all.inside=F) {
nx <- length(x)
index <- integer(nx)
.C('find_interv_vec', xt=as.double(vec), n=length(vec),
x=as.double(x), nx=nx, as.logical(rightmost.closed),
as.logical(all.inside), index, DUP = FALSE, NAOK=T,
PACKAGE='base')
index
}
f <- function(a, b) {
# set up vars
a.length <- length(a)
b.length <- length(b)
b.sorted <- b[order2(b)]
b.min <- b.sorted[1]
b.max <- b.sorted[b.length]
results <- integer(a.length)
# pre-process minimums
v.min <- which(a <= b.min)
# pre-process maximums
v.max <- which(a > b.max)
results[v.max] <- b.max
# compare the rest
ind <- c(v.min, v.max)
results[-ind] <- findInterval2(a[-ind] - 0.5, b.sorted)
results
}
Which gives the following timeings
> N <- 10
> n <- 1e5
> b <- runif(n, 0, 100)
> a <- runif(n, 40, 60) # NB smaller range of values than b
> summary( replicate(N, system.time(findInterval2(a - 0.5, b[order2(b)]))[3]) )
Min. 1st Qu. Median Mean 3rd Qu. Max.
0.0300 0.0300 0.0400 0.0390 0.0475 0.0500
> summary( replicate(N, system.time(f(a, b))[3]) )
Min. 1st Qu. Median Mean 3rd Qu. Max.
0.010 0.030 0.030 0.027 0.030 0.040
However, if you don't know the ranges ahead of time, or can't make an educated guess about them, then this would probably be slower.
精彩评论