关于python:检查NumPy数组是否包含另一个数组

Checking if a NumPy array contains another array

本问题已经有最佳答案,请猛点这里访问。

在Python 2.7中使用NumPy,我想创建一个n-by-2数组y。然后,我要检查这个数组是否包含一个特定的1×2阵列z在其任何行。

到目前为止,这是我尝试过的,在这种情况下,n = 1:

1
2
3
4
5
x = np.array([1, 2]) # Create a 1-by-2 array
y = [x] # Create an n-by-2 array (n = 1), and assign the first row to x
z = np.array([1, 2]) # Create another 1-by-2 array
if z in y: # Check if y contains the row z
    print 'yes it is'

但是,这给了我以下错误:

1
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我究竟做错了什么?


您可以执行(y == z).all(1).any()

为了更详细一点,numpy将使用称为"广播"的东西自动在更高维度上进行逐元素比较。 因此,如果y是您的n-by-2数组,而z是您的1-by-2数组,则y == z会将y的每一行与z的每个元素进行比较。 然后,您可以使用all(axis=1)获取所有元素匹配的行,并使用any()找出是否匹配。

所以这里是实践中:

1
2
3
4
5
6
7
8
>>> y1 = np.array([[1, 2], [1, 3], [1, 2], [2, 2]])
>>> y2 = np.array([[100, 200], [100,300], [100, 200], [200, 200]])
>>> z = np.array([1, 2])
>>>
>>> (y1 == z).all(1).any()
True
>>> (y2 == z).all(1).any()
False

这比执行基于循环或基于生成器的方法要快得多,因为它可以矢量化操作。


您可以简单地使用any((z == x).all() for x in y)。 不过,我不知道它是否最快。