开发者

How does python numpy.where() work?

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(arra开发者_运维知识库y([2, 2, 2]), array([0, 1, 2]))

How do they achieve internally that you are able to pass something like x > 5 into a method? I guess it has something to do with __gt__ but I am looking for a detailed explanation.


How do they achieve internally that you are able to pass something like x > 5 into a method?

The short answer is that they don't.

Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__, __lt__, etc all return boolean arrays where the given condition is true).

E.g.

x = np.arange(9).reshape(3,3)
print x > 5

yields:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

This is the same reason why something like if x > 5: raises a ValueError if x is a numpy array. It's an array of True/False values, not a single value.

Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5] yields [6 7 8], in this case.

Honestly, it's fairly rare that you actually need numpy.where but it just returns the indicies where a boolean array is True. Usually you can do what you need with simple boolean indexing.


Old Answer it is kind of confusing. It gives you the LOCATIONS (all of them) of where your statment is true.

so:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

I use it as an alternative to list.index(), but it has many other uses as well. I have never used it with 2D arrays.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

New Answer It seems that the person was asking something more fundamental.

The question was how could YOU implement something that allows a function (such as where) to know what was requested.

First note that calling any of the comparison operators do an interesting thing.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

This is done by overloading the "__gt__" method. For instance:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

As you can see, "a > 4" was valid code.

You can get a full list and documentation of all overloaded functions here: http://docs.python.org/reference/datamodel.html

Something that is incredible is how simple it is to do this. ALL operations in python are done in such a way. Saying a > b is equivalent to a.gt(b)!


np.where returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)

For example:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y is a tuple of length 2 because x.ndim is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8 Note that the ndarray is traversed along first dimension(row-wise).

Similarly,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


will return a tuple of length 3 because x has 3 dimensions.

But wait, there's more to np.where!

when two additional arguments are added to np.where; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])


I had a tough time understanding the output, that I received for an input.

import numpy as np
pp =  np.array([[True,False,True,True],
               [False,True,False,True]])
np.where(pp)

The output was:

(array([0, 0, 0, 1, 1]), array([0, 2, 3, 1, 3]))

The best way to understand this is to read out the output tuple pair by pair, i.e. (0,0);(0,2);(0,3);(1,1);(1,3) and voila, these are the coordinates where the condition was True.

So on so forth for higher dimensions.

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜