开发者

Determine Index of Highest Value in Python's NumPy

I want to generate an array with the index of the highest max value of each row.

a = np.array([ [1,2,3], [6,5,4], [0,1,0] ])
maxIndexArray = getMaxIndexOnEachRow(a)
print maxIndexArray 

[[2], [0], [1]]

There'开发者_如何学Pythons a np.argmax function but it doesn't appear to do what I want...


The argmax() function does do what you want:

print a.argmax(axis=1)
array([2, 0, 1])
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜