Iteration through all 1 dimensional subarrays of a multi-dimensional array
What is the fastest way to iterate through all one dimensional sub-arrays of an n dimensional array in python.
For example consider the 3-D array:
import numpy as np
a = np.arange(24)
a = a.reshape(2,3,4)
T开发者_Go百科he desired sequence of yields from the iterator is :
a[:,0,0]
a[:,0,1]
..
a[:,2,3]
a[0,:,0]
..
a[1,:,3]
a[0,0,:]
..
a[1,2,:]
Here is a compact implementation of such an iterator:
def iter1d(a):
return itertools.chain.from_iterable(
numpy.rollaxis(a, axis, a.ndim).reshape(-1, dim)
for axis, dim in enumerate(a.shape))
This will yield the subarrays in the order you gave in your post:
for x in iter1d(a):
print x
prints
[ 0 12]
[ 1 13]
[ 2 14]
[ 3 15]
[ 4 16]
[ 5 17]
[ 6 18]
[ 7 19]
[ 8 20]
[ 9 21]
[10 22]
[11 23]
[0 4 8]
[1 5 9]
[ 2 6 10]
[ 3 7 11]
[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]
[0 1 2 3]
[4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]
[16 17 18 19]
[20 21 22 23]
The trick here is to iterate over all axes, and for each axis reshape the array to a two-dimensional array the rows of which are the desired one-dimensional subarrays.
There may be a more efficient way, but this should work...
import itertools
import numpy as np
a = np.arange(24)
a = a.reshape(2,3,4)
colon = slice(None)
dimensions = [range(dim) + [colon] for dim in a.shape]
for dim in itertools.product(*dimensions):
if dim.count(colon) == 1:
print a[dim]
This yields (I'm leaving out a trivial bit of code to print the left hand side of this...):
a[0,0,:] --> [0 1 2 3]
a[0,1,:] --> [4 5 6 7]
a[0,2,:] --> [ 8 9 10 11]
a[0,:,0] --> [0 4 8]
a[0,:,1] --> [1 5 9]
a[0,:,2] --> [ 2 6 10]
a[0,:,3] --> [ 3 7 11]
a[1,0,:] --> [12 13 14 15]
a[1,1,:] --> [16 17 18 19]
a[1,2,:] --> [20 21 22 23]
a[1,:,0] --> [12 16 20]
a[1,:,1] --> [13 17 21]
a[1,:,2] --> [14 18 22]
a[1,:,3] --> [15 19 23]
a[:,0,0] --> [ 0 12]
a[:,0,1] --> [ 1 13]
a[:,0,2] --> [ 2 14]
a[:,0,3] --> [ 3 15]
a[:,1,0] --> [ 4 16]
a[:,1,1] --> [ 5 17]
a[:,1,2] --> [ 6 18]
a[:,1,3] --> [ 7 19]
a[:,2,0] --> [ 8 20]
a[:,2,1] --> [ 9 21]
a[:,2,2] --> [10 22]
a[:,2,3] --> [11 23]
The key here is that indexing a
with (for example) a[0,0,:]
is equivalent to indexing a with a[(0,0,slice(None))]
. (This is just generic python slicing, nothing numpy-specific. To prove it to yourself, you can write a dummy class with just a __getitem__
and print what's passed in when you index an instance of your dummy class.).
So, what we want is every possible combination of 0 to nx, 0 to ny, 0 to nz, etc and a None
for each axis.
However, we want 1D arrays, so we need to filter out anything with more or less than one None
(i.e. we don't want a[:,:,:]
, a[0,:,:]
, a[0,0,0]
etc).
Hopefully that makes some sense, anyway...
Edit: I'm assuming that the exact order doesn't matter... If you need the exact ordering you list in your question, you'll need to modify this...
Your friends are the slice()
objects, numpy's ndarray.__getitem__()
method, and possibly the itertools.chain.from_iterable
.
精彩评论