testing whether a Numpy array contains a given row
是否有Pythonic和有效的方法来检查Numpy数组是否包含给定行的至少一个实例?"有效"是指它在找到第一个匹配行时终止,而不是遍历整个数组,即使已经找到了结果。
使用Python数组,这可以用
使用Python数组:
1 2 3 4 5 | >>> a = [[1,2],[10,20],[100,200]] >>> [1,2] in a True >>> [1,20] in a False |
但是Numpy数组给出了不同的,而且看起来很奇怪的结果。 (
1 2 3 4 5 6 7 8 9 | >>> a = np.array([[1,2],[10,20],[100,200]]) >>> np.array([1,2]) in a True >>> np.array([1,20]) in a True >>> np.array([1,42]) in a True >>> np.array([42,1]) in a False |
你可以使用.tolist()
1 2 3 4 5 6 7 8 9 10 11 | >>> a = np.array([[1,2],[10,20],[100,200]]) >>> [1,2] in a.tolist() True >>> [1,20] in a.tolist() False >>> [1,20] in a.tolist() False >>> [1,42] in a.tolist() False >>> [42,1] in a.tolist() False |
或使用视图:
1 2 3 4 | >>> any((a[:]==[1,2]).all(1)) True >>> any((a[:]==[1,20]).all(1)) False |
或者生成numpy列表(可能非常慢):
1 | any(([1,2] == x).all() for x in a) # stops on first occurrence |
或者使用numpy逻辑函数:
1 | any(np.equal(a,[1,2]).all(1)) |
如果你计算时间:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import numpy as np import time n=300000 a=np.arange(n*3).reshape(n,3) b=a.tolist() t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0] tests=[ ('early hit',[t1, t1+1, t1+2]), ('middle hit',[t2,t2+1,t2+2]), ('late hit', [t3,t3+1,t3+2]), ('miss',[0,2,0])] fmt='\t{:20}{:.5f} seconds and is {}' for test, tgt in tests: print(' {}: {} in {:,} elements:'.format(test,tgt,n)) name='view' t1=time.time() result=(a[...]==tgt).all(1).any() t2=time.time() print(fmt.format(name,t2-t1,result)) name='python list' t1=time.time() result = True if tgt in b else False t2=time.time() print(fmt.format(name,t2-t1,result)) name='gen over numpy' t1=time.time() result=any((tgt == x).all() for x in a) t2=time.time() print(fmt.format(name,t2-t1,result)) name='logic equal' t1=time.time() np.equal(a,tgt).all(1).any() t2=time.time() print(fmt.format(name,t2-t1,result)) |
你可以看到命中或错过,numpy例程与搜索数组的速度相同。对于早期命中,Python
以下是300,000 x 3元素数组的结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | early hit: [9000, 9001, 9002] in 300,000 elements: view 0.01002 seconds and is True python list 0.00305 seconds and is True gen over numpy 0.06470 seconds and is True logic equal 0.00909 seconds and is True middle hit: [450000, 450001, 450002] in 300,000 elements: view 0.00915 seconds and is True python list 0.15458 seconds and is True gen over numpy 3.24386 seconds and is True logic equal 0.00937 seconds and is True late hit: [899970, 899971, 899972] in 300,000 elements: view 0.00936 seconds and is True python list 0.30604 seconds and is True gen over numpy 6.47660 seconds and is True logic equal 0.00965 seconds and is True miss: [0, 2, 0] in 300,000 elements: view 0.00936 seconds and is False python list 0.01287 seconds and is False gen over numpy 6.49190 seconds and is False logic equal 0.00965 seconds and is False |
对于3,000,000 x 3阵列:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | early hit: [90000, 90001, 90002] in 3,000,000 elements: view 0.10128 seconds and is True python list 0.02982 seconds and is True gen over numpy 0.66057 seconds and is True logic equal 0.09128 seconds and is True middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements: view 0.09331 seconds and is True python list 1.48180 seconds and is True gen over numpy 32.69874 seconds and is True logic equal 0.09438 seconds and is True late hit: [8999970, 8999971, 8999972] in 3,000,000 elements: view 0.09868 seconds and is True python list 3.01236 seconds and is True gen over numpy 65.15087 seconds and is True logic equal 0.09591 seconds and is True miss: [0, 2, 0] in 3,000,000 elements: view 0.09588 seconds and is False python list 0.12904 seconds and is False gen over numpy 64.46789 seconds and is False logic equal 0.09671 seconds and is False |
这似乎表明
在编写本文时,Numpys
编辑:为了清楚起见,这不一定是涉及广播时的预期结果。也有人可能认为它应该像
现在你想要numpy在找到第一个匹配项时停止。此AFAIK目前不存在。这很难,因为numpy主要基于ufuncs,它在整个数组中做同样的事情。
Numpy确实优化了这种减少,但实际上只有在减少的数组已经是布尔数组(即
否则它将需要一个不存在的
您当然可以在python中编写它,或者因为您可能知道您的数据类型,所以在Cython / C中自己编写它非常简单。
那就是说。对于这些事情,使用基于排序的方法通常会好得多。这有点单调乏味,并且
1 2 3 4 5 6 7 8 9 10 11 12 | # Unfortunatly you need to use structured arrays: sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel() # Actually at this point, you can also use np.in1d, if you already have many b # then that is even better. sorted.sort() b_comp = np.ascontiguousarray(b).view(sorted.dtype) ind = sorted.searchsorted(b_comp) result = sorted[ind] == b_comp |
这也适用于数组
我认为
1 2 | equal([1,2], a).all(axis=1) # also, ([1,2]==a).all(axis=1) # array([ True, False, False], dtype=bool) |
将列出匹配的行。正如Jamie所指出的,要知道是否存在至少一个这样的行,请使用
1 2 | equal([1,2], a).all(axis=1).any() # True |
旁白:我怀疑
如果你真的想在第一次出现时停下来,你可以写一个循环,如:
1 2 3 4 5 6 7 8 9 10 | import numpy as np needle = np.array([10, 20]) haystack = np.array([[1,2],[10,20],[100,200]]) found = False for row in haystack: if np.all(row == needle): found = True break print("Found:", found) |
但是,我强烈怀疑,它将比使用numpy例程为整个数组执行它的其他建议慢得多。