开发者

Numpy, problem with long arrays

I have two arrays (a and b) with n integer elements in the range (0,N).

typo: arrays with 2^n inte开发者_开发技巧gers where the largest integer takes the value N = 3^n

I want to calculate the sum of every combination of elements in a and b (sum_ij_ = a_i_ + b_j_ for all i,j). Then take modulus N (sum_ij_ = sum_ij_ % N), and finally calculate the frequency of the different sums.

In order to do this fast with numpy, without any loops, I tried to use the meshgrid and the bincount function.

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

Now, the problem is that my input arrays are long. And meshgrid gives me MemoryError when I use inputs with 2^13 elements. I would like to calculate this for arrays with 2^15-2^20 elements.

that is n in the range 15 to 20

Is there any clever tricks to do this with numpy?

Any help will be highly appreciated.

-- jon


try chunking it. your meshgrid is an NxN matrix, block that up to 10x10 N/10xN/10 and just compute 100 bins, add them up at the end. this only uses ~1% as much memory as doing the whole thing.


Edit in response to jonalm's comment:

jonalm: N~3^n not n~3^N. N is max element in a and n is number of elements in a.

n is ~ 2^20. If N is ~ 3^n then N is ~ 3^(2^20) > 10^(500207). Scientists estimate (http://www.stormloader.com/ajy/reallife.html) that there are only around 10^87 particles in the universe. So there is no (naive) way a computer can handle an int of size 10^(500207).

jonalm: I am however a bit curios about the pv() function you define. (I do not manage to run it as text.find() is not defined (guess its in another module)). How does this function work and what is its advantage?

pv is a little helper function I wrote to debug the value of variables. It works like print() except when you say pv(x) it prints both the literal variable name (or expression string), a colon, and then the variable's value.

If you put

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

in a script you should get

x: 1

The modest advantage of using pv over print is that it saves you typing. Instead of having to write

print('x: %s'%x)

you can just slap down

pv(x)

When there are multiple variables to track, it's helpful to label the variables. I just got tired of writing it all out.

The pv function works by using the traceback module to peek at the line of code used to call the pv function itself. (See http://docs.python.org/library/traceback.html#module-traceback) That line of code is stored as a string in the variable text. text.find() is a call to the usual string method find(). For instance, if

text='pv(x)'

then

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

I'm assuming n ~ 3^N, and n~2**20

The idea is to work module N. This cuts down on the size of the arrays. The second idea (important when n is huge) is to use numpy ndarrays of 'object' type because if you use an integer dtype you run the risk of overflowing the size of the maximum integer allowed.

#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

You can change n to be 2**20, but below I show what happens with small n so the output is easier to read.

n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa holds the number of 0s, 1s, 2s, 3s in a wb holds the number of 0s, 1s, 2s, 3s in b

wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

Think of a 0 as a token or chip. Similarly for 1,2,3.

Think of wa=[24 28 28 20] as meaning there is a bag with 24 0-chips, 28 1-chips, 28 2-chips, 20 3-chips.

You have a wa-bag and a wb-bag. When you draw a chip from each bag, you "add" them together and form a new chip. You "mod" the answer (modulo N).

Imagine taking a 1-chip from the wb-bag and adding it with each chip in the wa-bag.

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

Since there are 34 1-chips in the wb bag, when you add them against all the chips in the wa=[24 28 28 20] bag, you get

34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

This is just the partial count due to the 34 1-chips. You also have to handle the other types of chips in the wb-bag, but this shows you the method used below:

for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

This is the final result: 2444 0s, 2504 1s, 2520 2s, 2532 3s.

# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]


Check your math, that's a lot of space you're asking for:

2^20*2^20 = 2^40 = 1 099 511 627 776

If each of your elements was just one byte, that's already one terabyte of memory.

Add a loop or two. This problem is not suited to maxing out your memory and minimizing your computation.

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜