在Python类中支持等价(“相等”)的优雅方法

Elegant ways to support equivalence (“equality”) in Python classes

在编写自定义类时,通常重要的是允许通过==!=操作符实现等价。在python中,这可以通过分别实现__eq____ne__特殊方法来实现。我发现最简单的方法是以下方法:

1
2
3
4
5
6
7
8
9
10
11
12
class Foo:
    def __init__(self, item):
        self.item = item

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

    def __ne__(self, other):
        return not self.__eq__(other)

你知道更优雅的方法吗?你知道用上述方法比较__dict__s有什么特别的缺点吗?

注:有点澄清——当__eq____ne__未定义时,您会发现这种行为:

1
2
3
4
5
6
>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
False

也就是说,a == bFalse进行评估,因为它真正运行a is b,一个身份测试(即"ab是同一对象吗?").

当定义__eq____ne__时,您会发现这种行为(这是我们所追求的行为):

1
2
3
4
5
6
>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
True


考虑这个简单的问题:

1
2
3
4
5
6
7
8
9
10
class Number:

    def __init__(self, number):
        self.number = number


n1 = Number(1)
n2 = Number(1)

n1 == n2 # False -- oops

因此,默认情况下,python使用对象标识符进行比较操作:

1
2
id(n1) # 140400634555856
id(n2) # 140400634555920

覆盖__eq__函数似乎可以解决以下问题:

1
2
3
4
5
6
7
8
9
def __eq__(self, other):
   """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return False


n1 == n2 # True
n1 != n2 # True in Python 2 -- oops, False in Python 3

在python2中,始终记住重写__ne__函数,以及文档状态:

There are no implied relationships among the comparison operators. The
truth of x==y does not imply that x!=y is false. Accordingly, when
defining __eq__(), one should also define __ne__() so that the
operators will behave as expected.

1
2
3
4
5
6
7
def __ne__(self, other):
   """Overrides the default implementation (unnecessary in Python 3)"""
    return not self.__eq__(other)


n1 == n2 # True
n1 != n2 # False

在Python3中,这不再是必需的,因为文档说明:

By default, __ne__() delegates to __eq__() and inverts the result
unless it is NotImplemented. There are no other implied
relationships among the comparison operators, for example, the truth
of (x does not imply x<=y.

但这并不能解决我们所有的问题。让我们添加一个子类:

1
2
3
4
5
6
7
8
9
10
class SubNumber(Number):
    pass


n3 = SubNumber(1)

n1 == n3 # False for classic-style classes -- oops, True for new-style classes
n3 == n1 # True
n1 != n3 # True for classic-style classes -- oops, False for new-style classes
n3 != n1 # False

注意:python2有两种类:

  • 不继承于object的古典风格(或旧风格)类,称为class A:class A():class A(B):,其中B是古典风格类;

  • 继承自object的新型类,声明为class A(object)class A(B):,其中B是一种新型类。python3只有声明为class A:class A(object):class A(B):的新型类。

对于经典样式类,比较操作始终调用第一个操作数的方法,而对于新样式类,无论操作数的顺序如何,它始终调用子类操作数的方法。

因此,如果Number是一个经典的风格类:

  • n1 == n3呼叫n1.__eq__
  • n3 == n1呼叫n3.__eq__
  • n1 != n3呼叫n1.__ne__
  • n3 != n1呼叫n3.__ne__

如果Number是一个新的阶级:

  • n1 == n3n3 == n1都叫n3.__eq__
  • n1 != n3n3 != n1都叫n3.__ne__

为了解决python 2经典样式类的==!=运算符的非交换性问题,当不支持操作数类型时,__eq____ne__方法应返回NotImplemented值。文件将NotImplemented值定义为:

Numeric methods and rich comparison methods may return this value if
they do not implement the operation for the operands provided. (The
interpreter will then try the reflected operation, or some other
fallback, depending on the operator.) Its truth value is true.

在这种情况下,运算符将比较操作委托给另一个操作数的反射方法。文件将反射方法定义为:

There are no swapped-argument versions of these methods (to be used
when the left argument does not support the operation but the right
argument does); rather, __lt__() and __gt__() are each other’s
reflection, __le__() and __ge__() are each other’s reflection, and
__eq__() and __ne__() are their own reflection.

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
def __eq__(self, other):
   """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return NotImplemented

def __ne__(self, other):
   """Overrides the default implementation (unnecessary in Python 3)"""
    x = self.__eq__(other)
    if x is not NotImplemented:
        return not x
    return NotImplemented

如果操作数是不相关类型(无继承)时需要==!=运算符的交换性,则返回NotImplemented值而不是False值对于新类型的类也是正确的做法。

我们到了吗?不完全是这样。我们有多少个唯一的号码?

1
len(set([n1, n2, n3])) # 3 -- oops

集合使用对象的散列,默认情况下,python返回对象标识符的散列。让我们尝试重写它:

1
2
3
4
5
def __hash__(self):
   """Overrides the default implementation"""
    return hash(tuple(sorted(self.__dict__.items())))

len(set([n1, n2, n3])) # 1

最终结果如下(我在末尾添加了一些断言进行验证):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Number:

    def __init__(self, number):
        self.number = number

    def __eq__(self, other):
       """Overrides the default implementation"""
        if isinstance(other, Number):
            return self.number == other.number
        return NotImplemented

    def __ne__(self, other):
       """Overrides the default implementation (unnecessary in Python 3)"""
        x = self.__eq__(other)
        if x is not NotImplemented:
            return not x
        return NotImplemented

    def __hash__(self):
       """Overrides the default implementation"""
        return hash(tuple(sorted(self.__dict__.items())))


class SubNumber(Number):
    pass


n1 = Number(1)
n2 = Number(1)
n3 = SubNumber(1)
n4 = SubNumber(4)

assert n1 == n2
assert n2 == n1
assert not n1 != n2
assert not n2 != n1

assert n1 == n3
assert n3 == n1
assert not n1 != n3
assert not n3 != n1

assert not n1 == n4
assert not n4 == n1
assert n1 != n4
assert n4 != n1

assert len(set([n1, n2, n3, ])) == 1
assert len(set([n1, n2, n3, n4])) == 2


你需要小心继承:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> class Foo:
    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

>>> class Bar(Foo):pass

>>> b = Bar()
>>> f = Foo()
>>> f == b
True
>>> b == f
False

更严格地检查类型,如下所示:

1
2
3
4
def __eq__(self, other):
    if type(other) is type(self):
        return self.__dict__ == other.__dict__
    return False

除此之外,你的方法会很好地工作,这就是有特殊方法的目的。


你描述的方式就是我一直这样做的。因为它是完全通用的,所以您可以将该功能分解为一个mixin类,并在需要该功能的类中继承它。

1
2
3
4
5
6
7
8
9
10
11
12
13
class CommonEqualityMixin(object):

    def __eq__(self, other):
        return (isinstance(other, self.__class__)
            and self.__dict__ == other.__dict__)

    def __ne__(self, other):
        return not self.__eq__(other)

class Foo(CommonEqualityMixin):

    def __init__(self, item):
        self.item = item


这不是一个直接的答案,但似乎有足够的相关性,因为有时它会节省一些冗长乏味的内容。直接从文档中剪切…

functools.total_排序(cls)

给定一个定义一个或多个丰富的比较排序方法的类,这个类修饰器提供其余的方法。这简化了指定所有可能的富比较操作所涉及的工作:

类必须定义lt()、le()、gt()或ge()中的一个。此外,类应该提供一个eq()方法。

2.7版新增功能

1
2
3
4
5
6
7
8
@total_ordering
class Student:
    def __eq__(self, other):
        return ((self.lastname.lower(), self.firstname.lower()) ==
                (other.lastname.lower(), other.firstname.lower()))
    def __lt__(self, other):
        return ((self.lastname.lower(), self.firstname.lower()) <
                (other.lastname.lower(), other.firstname.lower()))


您不必同时覆盖__eq____ne__,您只能覆盖__cmp__,但这将对结果=,!==,<,>等等。

is测试对象身份。这意味着当a和b都持有同一对象的引用时,isb将是True。在Python中,您总是持有对变量中对象的引用,而不是实际对象,因此本质上,对于a is b为true,其中的对象应该位于相同的内存位置。最重要的是,你为什么要凌驾于此?

编辑:我不知道从python 3中删除了__cmp__,所以要避免它。


从这个答案:https://stackoverflow.com/a/30676267/541136我已经证明了,虽然用__eq__来定义__ne__是正确的,而不是

1
2
def __ne__(self, other):
    return not self.__eq__(other)

你应该使用:

1
2
def __ne__(self, other):
    return not self == other


我认为你要找的两个术语是平等和身份。例如:

1
2
3
4
5
6
>>> a = [1,2,3]
>>> b = [1,2,3]
>>> a == b
True       <-- a and b have values which are equal
>>> a is b
False      <-- a and b are not the same list object


"is"测试将使用内置的"id()"函数测试标识,该函数实质上返回对象的内存地址,因此不可重载。

但是,在测试类的相等性的情况下,您可能希望对测试更严格一点,并且只比较类中的数据属性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import types

class ComparesNicely(object):

    def __eq__(self, other):
        for key, value in self.__dict__.iteritems():
            if (isinstance(value, types.FunctionType) or
                    key.startswith("__")):
                continue

            if key not in other.__dict__:
                return False

            if other.__dict__[key] != value:
                return False

         return True

此代码将只比较类中的非函数数据成员,并跳过通常需要的私有数据。对于普通的旧python对象,我有一个基类,它实现了uuinit_uuuuuuu、uuu str_uuuuuu、uuu repr_uuuu和uuu eq_uuuuuu,因此我的popo对象不承担所有额外(在大多数情况下是相同的)逻辑的负担。


我不使用子类化/混合,而是使用通用类修饰器

1
2
3
4
5
6
7
8
9
10
11
12
def comparable(cls):
   """ Class decorator providing generic comparison functionality"""

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

    def __ne__(self, other):
        return not self.__eq__(other)

    cls.__eq__ = __eq__
    cls.__ne__ = __ne__
    return cls

用途:

1
2
3
4
5
6
7
8
@comparable
class Number(object):
    def __init__(self, x):
        self.x = x

a = Number(1)
b = Number(1)
assert a == b