关于python:测试Numpy数组是否包含给定行

testing whether a Numpy array contains a given row

是否有Pythonic和有效的方法来检查Numpy数组是否包含给定行的至少一个实例?"有效"是指它在找到第一个匹配行时终止,而不是遍历整个数组,即使已经找到了结果。

使用Python数组,这可以用if row in array:非常干净地完成,但是这不像我对Numpy数组所期望的那样工作,如下所示。

使用Python数组:

1
2
3
4
5
>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False

但是Numpy数组给出了不同的,而且看起来很奇怪的结果。 (ndarray__contains__方法似乎没有记录。)

1
2
3
4
5
6
7
8
9
>>> a = np.array([[1,2],[10,20],[100,200]])
>>> np.array([1,2]) in a
True
>>> np.array([1,20]) in a
True
>>> np.array([1,42]) in a
True
>>> np.array([42,1]) in a
False


你可以使用.tolist()

1
2
3
4
5
6
7
8
9
10
11
>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False

或使用视图:

1
2
3
4
>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False

或者生成numpy列表(可能非常慢):

1
any(([1,2] == x).all() for x in a)     # stops on first occurrence

或者使用numpy逻辑函数:

1
any(np.equal(a,[1,2]).all(1))

如果你计算时间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import time

n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()

t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]

tests=[ ('early hit',[t1, t1+1, t1+2]),
        ('middle hit',[t2,t2+1,t2+2]),
        ('late hit', [t3,t3+1,t3+2]),
        ('miss',[0,2,0])]

fmt='\t{:20}{:.5f} seconds and is {}'    

for test, tgt in tests:
    print('
{}: {} in {:,} elements:'
.format(test,tgt,n))

    name='view'
    t1=time.time()
    result=(a[...]==tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='python list'
    t1=time.time()
    result = True if tgt in b else False
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='gen over numpy'
    t1=time.time()
    result=any((tgt == x).all() for x in a)
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='logic equal'
    t1=time.time()
    np.equal(a,tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

你可以看到命中或错过,numpy例程与搜索数组的速度相同。对于早期命中,Python in运算符可能要快得多,如果必须一直遍历数组,则生成器只是坏消息。

以下是300,000 x 3元素数组的结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
early hit: [9000, 9001, 9002] in 300,000 elements:
    view                0.01002 seconds and is True
    python list         0.00305 seconds and is True
    gen over numpy      0.06470 seconds and is True
    logic equal         0.00909 seconds and is True

middle hit: [450000, 450001, 450002] in 300,000 elements:
    view                0.00915 seconds and is True
    python list         0.15458 seconds and is True
    gen over numpy      3.24386 seconds and is True
    logic equal         0.00937 seconds and is True

late hit: [899970, 899971, 899972] in 300,000 elements:
    view                0.00936 seconds and is True
    python list         0.30604 seconds and is True
    gen over numpy      6.47660 seconds and is True
    logic equal         0.00965 seconds and is True

miss: [0, 2, 0] in 300,000 elements:
    view                0.00936 seconds and is False
    python list         0.01287 seconds and is False
    gen over numpy      6.49190 seconds and is False
    logic equal         0.00965 seconds and is False

对于3,000,000 x 3阵列:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
early hit: [90000, 90001, 90002] in 3,000,000 elements:
    view                0.10128 seconds and is True
    python list         0.02982 seconds and is True
    gen over numpy      0.66057 seconds and is True
    logic equal         0.09128 seconds and is True

middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
    view                0.09331 seconds and is True
    python list         1.48180 seconds and is True
    gen over numpy      32.69874 seconds and is True
    logic equal         0.09438 seconds and is True

late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
    view                0.09868 seconds and is True
    python list         3.01236 seconds and is True
    gen over numpy      65.15087 seconds and is True
    logic equal         0.09591 seconds and is True

miss: [0, 2, 0] in 3,000,000 elements:
    view                0.09588 seconds and is False
    python list         0.12904 seconds and is False
    gen over numpy      64.46789 seconds and is False
    logic equal         0.09671 seconds and is False

这似乎表明np.equal是最快的纯粹numpy方式来做到这一点......


在编写本文时,Numpys __contains__(a == b).any(),如果b是标量(它有点毛茸茸,但我相信 - 只在1.7或更高版本中这样做 - 这可能是唯一正确的 - 这个将是正确的通用方法(a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any(),这对ab维度的所有组合都有意义...

编辑:为了清楚起见,这不一定是涉及广播时的预期结果。也有人可能认为它应该像np.in1d一样单独处理a中的项目。我不确定它应该有一个明确的方法。

现在你想要numpy在找到第一个匹配项时停止。此AFAIK目前不存在。这很难,因为numpy主要基于ufuncs,它在整个数组中做同样的事情。
Numpy确实优化了这种减少,但实际上只有在减少的数组已经是布尔数组(即np.ones(10, dtype=bool).any())时才有效。

否则它将需要一个不存在的__contains__的特殊功能。这可能看起来很奇怪,但你必须记住numpy支持许多数据类型,并且有更大的机制来选择正确的数据并选择正确的函数来处理它。换句话说,ufunc机器无法做到这一点,并且由于数据类型的原因,实现__contains__或其他特别实际上并不是那么简单。

您当然可以在python中编写它,或者因为您可能知道您的数据类型,所以在Cython / C中自己编写它非常简单。

那就是说。对于这些事情,使用基于排序的方法通常会好得多。这有点单调乏味,并且lexsort没有searchsorted这样的东西,但它有效(如果你愿意的话,你也可以滥用scipy.spatial.cKDTree)。这假设您只想沿最后一个轴进行比较:

1
2
3
4
5
6
7
8
9
10
11
12
# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()

# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.

sorted.sort()

b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)

result = sorted[ind] == b_comp

这也适用于数组b,如果你保持排序的数组,如果你一次在b中为单个值(行)执行它,当a保持不变时(例如)否则我会在将其视为重新排列之后np.in1d。重要提示:您必须执行np.ascontiguousarray以确保安全。它通常什么都不做,但如果确实如此,那将是一个很大的潜在错误。


我认为

1
2
equal([1,2], a).all(axis=1)   # also,  ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)

将列出匹配的行。正如Jamie所指出的,要知道是否存在至少一个这样的行,请使用any

1
2
equal([1,2], a).all(axis=1).any()
# True

旁白:我怀疑in(和__contains__)如上所述,但使用any而不是all


如果你真的想在第一次出现时停下来,你可以写一个循环,如:

1
2
3
4
5
6
7
8
9
10
import numpy as np

needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
    if np.all(row == needle):
        found = True
        break
print("Found:", found)

但是,我强烈怀疑,它将比使用numpy例程为整个数组执行它的其他建议慢得多。