Pythonic way to test if a row is in an array
This seems like a simple question开发者_JAVA百科, but I haven't been able to find a good answer.
I'm looking for a pythonic way to test whether a 2d numpy array contains a given row. For example:
myarray = numpy.array([[0,1],
[2,3],
[4,5]])
myrow1 = numpy.array([2,3])
myrow2 = numpy.array([2,5])
myrow3 = numpy.array([0,3])
myrow4 = numpy.array([6,7])
Given myarray, I want to write a function that returns True if I test myrow1, and False if I test myrow2, myrow3 and myrow4.
I tried the "in" keyword, and it didn't give me the results I expected:
>>> myrow1 in myarray
True
>>> myrow2 in myarray
True
>>> myrow3 in myarray
True
>>> myrow4 in myarray
False
It seems to only check if one or more of the elements are the same, not if all elements are the same. Can someone explain why that's happening?
I can do this test element by element, something like this:
def test_for_row(array,row):
numpy.any(numpy.logical_and(array[:,0]==row[0],array[:,1]==row[1]))
But that's not very pythonic, and becomes problematic if the rows have many elements. There must be a more elegant solution. Any help is appreciated!
The SO question below should help you out, but basically you can use:
any((myrow1 == x).all() for x in myarray)
Numpy.Array in Python list?
You can just simply subtract your test row from the array. Then find out the zero elements, and sum over column wise. Then those are matches where the sum equals the number of columns.
For example:
In []: A= arange(12).reshape(4, 3)
In []: A
Out[]:
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
In []: 3== (0== (A- [3, 4, 5])).sum(1)
Out[]: array([False, True, False, False], dtype=bool)
Update: based on comments and other answers:
Paul
's suggestion seems indeed to be able to streamline code:
In []: ~np.all(A- [3, 4, 5], 1)
Out[]: array([False, True, False, False], dtype=bool)
JoshAdel
's answer emphasis more generally the problem related to determine 100% reliable manner the equality. So, obviously my answer is valid only in the situations where equality can be determined unambiguous manner.
Update 2: But as Emma
figured it out, there exists corner cases where Paul
's solution will not produce correct results.
This is a generalization of @maz's solution that handles floats more elegantly, where strict equality is going to fail:
import numpy as np
def test_for_row(myarray,row):
return any(np.allclose(row,x) for x in myarray)
See http://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html for details. Also as a side note, be careful that you haven't done something like from numpy import *
since np.any
and python's built-in any
will result in different answers, the former being incorrect.
I ran into the same problem, and the following approach works for me
def is_row_in_matrix(row, matrix):
return sum(np.prod(matrix == row, axis = 1))
Basically, test if each element of the row is in the corresponding column of the matrix, then multiply along the column (axis = 1)
, and sum the result.
How about:
def row_in_array(myarray, myrow):
return (myarray == myrow).all(-1).any()
This is what it looks like for your test cases:
myarray = numpy.array([[0,1],
[2,3],
[4,5]])
row_in_array(myarray, [2, 3])
# True
row_in_array(myarray, [2, 5])
# False
row_in_array(myarray, [0, 3])
# False
row_in_array(myarray, [6, 7])
# False
精彩评论