Constructing a Python set from a Numpy matrix
I'm trying to execute the开发者_如何学Go following
>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> y = set(x)
TypeError: unhashable type: 'numpy.ndarray'
How can I easily and efficiently create a set with all the elements from the Numpy array?
If you want a set of the elements, here is another, probably faster way:
y = set(x.flatten())
PS: after performing comparisons between x.flat
, x.flatten()
, and x.ravel()
on a 10x100 array, I found out that they all perform at about the same speed. For a 3x3 array, the fastest version is the iterator version:
y = set(x.flat)
which I would recommend because it is the less memory expensive version (it scales up well with the size of the array).
PPS: There is also a NumPy function that does something similar:
y = numpy.unique(x)
This does produce a NumPy array with the same element as set(x.flat)
, but as a NumPy array. This is very fast (almost 10 times faster), but if you need a set
, then doing set(numpy.unique(x))
is a bit slower than the other procedures (building a set comes with a large overhead).
The immutable counterpart to an array is the tuple, hence, try convert the array of arrays into an array of tuples:
>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> x_hashable = map(tuple, x)
>> y = set(x_hashable)
set([(3, 2, 3), (4, 4, 4)])
The above answers work if you want to create a set out of the elements contained in an ndarray
, but if you want to create a set of ndarray
objects – or use ndarray
objects as keys in a dictionary – then you'll have to provide a hashable wrapper for them. See the code below for a simple example:
from hashlib import sha1
from numpy import all, array, uint8
class hashable(object):
r'''Hashable wrapper for ndarray objects.
Instances of ndarray are not hashable, meaning they cannot be added to
sets, nor used as keys in dictionaries. This is by design - ndarray
objects are mutable, and therefore cannot reliably implement the
__hash__() method.
The hashable class allows a way around this limitation. It implements
the required methods for hashable objects in terms of an encapsulated
ndarray object. This can be either a copied instance (which is safer)
or the original object (which requires the user to be careful enough
not to modify it).
'''
def __init__(self, wrapped, tight=False):
r'''Creates a new hashable object encapsulating an ndarray.
wrapped
The wrapped ndarray.
tight
Optional. If True, a copy of the input ndaray is created.
Defaults to False.
'''
self.__tight = tight
self.__wrapped = array(wrapped) if tight else wrapped
self.__hash = int(sha1(wrapped.view(uint8)).hexdigest(), 16)
def __eq__(self, other):
return all(self.__wrapped == other.__wrapped)
def __hash__(self):
return self.__hash
def unwrap(self):
r'''Returns the encapsulated ndarray.
If the wrapper is "tight", a copy of the encapsulated ndarray is
returned. Otherwise, the encapsulated ndarray itself is returned.
'''
if self.__tight:
return array(self.__wrapped)
return self.__wrapped
Using the wrapper class is simple enough:
>>> from numpy import arange
>>> a = arange(0, 1024)
>>> d = {}
>>> d[a] = 'foo'
Traceback (most recent call last):
File "<input>", line 1, in <module>
TypeError: unhashable type: 'numpy.ndarray'
>>> b = hashable(a)
>>> d[b] = 'bar'
>>> d[b]
'bar'
If you want a set of the elements:
>> y = set(e for r in x
for e in r)
set([2, 3, 4])
For a set of the rows:
>> y = set(tuple(r) for r in x)
set([(3, 2, 3), (4, 4, 4)])
I liked xperroni's idea. But I think implementation can be simplified using direct inheritance from ndarray instead of wrapping it.
from hashlib import sha1
from numpy import ndarray, array
class HashableNdarray(ndarray):
@classmethod
def create(cls, array):
return HashableNdarray(shape=array.shape, dtype=array.dtype, buffer=array.copy())
def __hash__(self):
if not hasattr(self, '_HashableNdarray__hash'):
self.__hash = int(sha1(self.view()).hexdigest(), 16)
return self.__hash
def __eq__(self, other):
if not isinstance(other, HashableNdarray):
return super().__eq__(other)
return super().__eq__(super(HashableNdarray, other)).all()
NumPy ndarray
can be viewed as derived class and used as hashable object. view(ndarray)
can be used for back transformation, but it is not even needed in most cases.
>>> a = array([1,2,3])
>>> b = array([2,3,4])
>>> c = array([1,2,3])
>>> s = set()
>>> s.add(a.view(HashableNdarray))
>>> s.add(b.view(HashableNdarray))
>>> s.add(c.view(HashableNdarray))
>>> print(s)
{HashableNdarray([2, 3, 4]), HashableNdarray([1, 2, 3])}
>>> d = next(iter(s))
>>> print(d == a)
[False False False]
>>> import ctypes
>>> print(d.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
<__main__.LP_c_double object at 0x7f99f4dbe488>
Adding to @Eric Lebigot and his great post.
The following did the trick for building a tensor lookup table:
a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
np.unique(a, axis=0)
output:
array([[1, 0, 0], [2, 3, 4]])
np.unique documentation
精彩评论