MATLAB-style find() function in Python
In MATLAB it is easy to find the indices of values that meet a particular condition:
>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2) % find the indecies where this condition is true
[3, 6, 9] % (MATLAB uses 1-based indexing)
>> a(find(a > 2)) % get the values at those locations
[3, 3, 3]
What would be the best way to do this in Python?
So far, I have come up with the following. To just get the values:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]
But if I want the index of each of those values it's a bit more complicated:
开发者_运维技巧>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]
Is there a better way to do this in Python, especially for arbitrary conditions (not just 'val > 2')?
I found functions equivalent to MATLAB 'find' in NumPy but I currently do not have access to those libraries.
in numpy you have where
:
>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13, 1, 15, 8, 0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
a = [1, 2, 3, 1, 2, 3, 1, 2, 3]
inds = indices(a, lambda x: x > 2)
>>> inds
[2, 5, 8]
It's a little closer to your Matlab example, without having to load up all of numpy.
Or use numpy's nonzero function:
import numpy as np
a = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds]
array([3, 4, 5])
Why not just use this:
[i for i in range(len(a)) if a[i] > 2]
or for arbitrary conditions, define a function f
for your condition and do:
[i for i in range(len(a)) if f(a[i])]
The numpy
routine more commonly used for this application is numpy.where()
; though, I believe it works the same as numpy.nonzero()
.
import numpy
a = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)
To get the values, you can either store the indices and slice withe them:
a[inds]
or you can pass the array as an optional parameter:
numpy.where(a>2, a)
or multiple arrays:
b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
To get values with arbitrary conditions, you could use filter()
with a lambda function:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]
One possible way to get the indices would be to use enumerate()
to build a tuple with both indices and values, and then filter that:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
I've been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):
a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]
It turns out that this is much faster than:
inds = [i for (i, val) in enumerate(a) if val > 2]
It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.
Edit:
I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:
inds = np.arange( len(a) )[ a < 2 ]
I think I may have found one quick and simple substitute. BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.
import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a
>> [[ 1.36406736 1.45217257 -0.06896245 0.98429727 -0.59281957]]
idx = mlab.find(a<0)
print idx
type(idx)
>> [2 4]
>> np.ndarray
Best, Da
Matlab's find code has two arguments. John's code accounts for the first argument but not the second. For instance, if you want to know where in the index the condition is satisfied: Mtlab's function would be:
find(x>2,1)
Using John's code, all you have to do is add a [x] at the end of the indices function, where x is the index number you're looking for.
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
a = [1, 2, 3, 1, 2, 3, 1, 2, 3]
inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument
which returns >>> 2, the first index to exceed 2.
精彩评论