如何以有效的方式找到2d numpy数组的最大N个数?

我有一个二维的numpy数组'a':

a = np.array(range(0,25)).reshape(5,5)
---
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

我想找到每行的最大R值并将其替换为100。我这样做很慢:

N = 2
idx = a.argsort()
for i in range(a.shape[0]):
    a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[  0   1   2 100 100]
 [  5   6   7 100 100]
 [ 10  11  12 100 100]
 [ 15  16  17 100 100]
 [ 20  21  22 100 100]]

实际上我矩阵的形状是6000 * 6000。如何更好地做到这一点?喜欢申请吗?

评论
  • 浮云、暖
    浮云、暖 回复

    You can use argpartition here:

    N=2
    ix = a.argpartition(-N)[:,-N:]
    a[np.arange(a.shape[0])[:,None], ix] = 100
    
    print(a)
    array([[  0,   1,   2, 100, 100],
           [  5,   6,   7, 100, 100],
           [ 10,  11,  12, 100, 100],
           [ 15,  16,  17, 100, 100],
           [ 20,  21,  22, 100, 100]])