开发者

Python bisect.bisect() counterpart in R?

I want to draw from discrete distribution.

I have a matrix, pi, which consists of vectors of probabilities (with the same number of columns, and sum of each row is 1).

In Python, I ca开发者_StackOverflow中文版n do the following

cumsumpi = cumsum(pi, axis = 1)
[bisect.bisect(k, random.rand()) for k in cumsumpi]

to get the vector of draws by the probability given by pi.

Now I want to reproduce this with R. I know there is "sample" function in R, but it seems it uses some different algorithm then bisect so I get different draws, even though I use the same set.seed() in both cases.

I used rpy2 to get the exactly same random draws in Python as in R. For example,

instead of random.rand(), I used [bisect.bisect(k, asarray(robjects.r('runif(1)'))) for k in cumsumpi]

Please let me know if there is other function than sample in R which do the same thing.

-Joon

edited: I managed to reproduce the exactly same draws with the following, but it was slow.

    cumsumpi = t(apply(pi, 1, cumsum))

    getfirstindx = function(cumprobs) {
        return(which(cumprobs > runif(1))[1])
    }

    apply(cumsumpi, 1, getfirstindx)


here is an alternate approach that avoids using apply and instead vectorizes the operation. initial checks indicate that it is twice as fast, but one needs to explore more in detail.

cumsumpi = t(apply(pi, 1, cumsum));
u = runif(nrow(cumsumpi));

max.col((cumsumpi > u) * 1, "first")

to speed it up further, one could think of vectorizing the operation of calculating the cumulative column sums for each row. let me know if that step was the bottleneck, by running a profiler on your R code.


I can't reconcile your question's title with the question body--in any event, here's an R function identical to python's bisect:

The package gtool*s has a binary search function, **binsearch*, that is nearly identical to python's bisect, e.g.,

# search for 25 in the range 0 through 100
> binseaerch(fun = function(x) x - 25, range=c(0, 100))

$call
binsearch(fun = function(x) x - 25, range = c(0, 100))

$numiter
[1] 2

$flag
[1] "Found"

$where
[1] 25

$value
[1] 0


What I was looking for was findInterval - Find Interval Numbers or Indices. :)


I did not post it, but what I ended up using was pretty similar:

cumsumpi = t(apply(pi, 1, cumsum))

1 + rowSums(cumsumpi > runif(nrow(pi)))

The speed was pretty much same as your code. If I were aware of max.col, I would have used that.

And following your suggestion, I vectorized the cumsum thing and it gave me nontrivial speed increase. Thank you.

-Joon

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜