关于python:重载__eq__以返回自定义对象

Overloading __eq__ to return custom objects

我正在用Python编写DSL,我想重载运算符,以便能够轻松编写我的DSL表达式。例如,我想写Var("a") + Var("b")并获得Add(Var("a"), Var("b"))的等效表示。为此,我重载了__add__方法,它适用于这个方法。

不过,我尝试重载__eq__方法来实现类似的东西:我想编写Var("a") == Var("b")并获得Eq(Var("a"), Var("b"))的等效表示。通过重载__eq__方法,返回Eq的实例,我实现了我的目标。但是当重载__eq__方法时,它显然会干扰标准Python的行为,例如Var("b") in [Var("a")]返回True

有没有办法实现我的目标,即能够编写Var("a") == Var("b")并获取Eq(Var("a"), Var("b")),但仍然能够编写if Var("a") == Var("b"): blablabla或将表达式放入内置容器等等?

编辑

我试图实现Eq类的__bool__方法,它似乎工作(参见下面的代码)。有什么我缺少的东西或它是一个可行的解决方案?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class Expr:
    def __add__(self, other):
        return Add(self, other)

    def __eq__(self, other):
        return Eq(self, other)

    def __repr__(self):
        return str(self)

    def __add__(self, other):
        return Add(self, other)

    def __ne__(self, other):
        return Neq(self, other)

class Var(Expr):
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return"Var(" + str(self.name) +")"

    def equals(self, other):
        if type(self) is type(other):
            return self.name == other.name
        else:
            return False

    def __hash__(self):
        return 17 + 23 * hash(self.name)

class Add(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return"Add(" + str(self.left) +"," + str(self.right) +")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( self.left.equals(other.left) and
                       self.right.equals(other.right) ) or
                     ( self.left.equals(other.right) and
                       self.right.equals(other.left) ) )
        else:
            return False

    def __hash__(self):
        return (17 + 23 * hash("+") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))

class Eq(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return"Eq(" + str(self.left) +"," + str(self.right) +")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( self.left.equals(other.left) and
                       self.right.equals(other.right) ) or
                     ( self.left.equals(other.right) and
                       self.right.equals(other.left) ) )
        else:
            return False

    def __bool__(self):
        return self.left.equals(self.right)

    def __hash__(self):
        return (17 + 23 * hash("==") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))

class Neq(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __str__(self):
        return"Neq(" + str(self.left) +"," + str(self.right) +")"

    def equals(self, other):
        if type(self) is type(other):
            return ( ( not self.left.equals(other.left) or
                       not self.right.equals(other.right) ) and
                     ( not self.left.equals(other.right) or
                       not self.right.equals(other.left) ) )
        else:
            return False

    def __bool__(self):
        return not self.left.equals(self.right)

    def __hash__(self):
        return (17 + 23 * hash("!=") +
                23 * 23 * hash(self.left) + 23 * 23 * hash(self.right))


a = Var("a")
aa = Var("a")
b = Var("b")
c = Var("c")


print("a + b","=>", a + b)   # a + b => Add(Var(a), Var(b))
print("a == b","=>", a == b) # a == b => Eq(Var(a), Var(b))
print("a != b","=>", a != b) # a != b => Neq(Var(a), Var(b))

print("a if a == b else b","=>", a if a == b else b)
# a if a == b else b => Var(b)
print("a if a == aa else b","=>", a if a == aa else b)
# a if a == aa else b => Var(a)


l = [a, a+b]
print("l","=>", l)               # l => [Var(a), Add(Var(a), Var(b))]
print("b in l","=>", b in l)     # b in l => False
print("a in l","=>", a in l)     # a in l => True
print("aa in l","=>", aa in l)   # aa in l => True
print("a+b in l","=>", a+b in l) # a+b in l => True
print("b+a in l","=>", b+a in l) # b+a in l => True
print("a+c in l","=>", a+c in l) # a+c in l => False


if a == b:
    print("a == b is True")
else:
    print("a == b is False")        # a == b is False
if a == aa:
    print("a == aa is True")        # a == aa is True
else:
    print("a == aa is False")

if a != b:
    print("a != b is True")         # a != b is True
else:
    print("a != b is False")
if a != aa:
    print("a != aa is True")
else:
    print("a != aa is False")       # a != aa is False


if a == b or a == aa:
    print("a == b or a == aa is True")   # a == b or a == aa is True
else:
    print("a == b or a == aa is False")
if a == aa and a == b:
    print("a == aa and a == b is True")
else:
    print("a == aa and a == b is False") # a == aa and a == b is False
if not a == aa:
    print("not a == aa is True")
else:
    print("not a == aa is False")        # not a == aa is False
if not a == b:
    print("not a == b is True")          # not a == b is True
else:
    print("not a == b is False")


if a == 3:
    print("a == 3 is True")
else:
    print("a == 3 is False")             # a == 3 is False
if a != 3:
    print("a != 3 is True")              # a != 3 is True
else:
    print("a != 3 is False")
if 3 == a:
    print("3 == a is True")
else:
    print("3 == a is False")             # 3 == a is False
if 3 != a:
    print("3 != a is True")              # 3 != a is True
else:
    print("3 != a is False")


if a == 'a':
    print("a == 'a' is True")
else:
    print("a == 'a' is False")           # a == 'a' is False
if a != 'a':
    print("a != 'a' is True")            # a != 'a' is True
else:
    print("a != 'a' is False")
if 'a' == a:
    print("'a' == a is True")
else:
    print("'a' == a is False")           # 'a' == a is False
if 'a' != a:
    print("'a' != a is True")            # 'a' != a is True
else:
    print("'a' != a is False")


s = {a}
print("s","=>", s)             # s => {Var(a)}
print("a in s","=>", a in s)   # a in s => True
print("b in s","=>", b in s)   # b in s => False
print("aa in s","=>", aa in s) # aa in s => True

d = {a: 1, b: 2}
print("d","=>", d)             # d => {Var(b): 2, Var(a): 1}
print("d[a]","=>", d[a])       # d[a] => 1
print("d[b]","=>", d[b])       # d[b] => 2
print("c in d","=>", c in d)   # c in d => False
print("aa in d","=>", aa in d) # aa in d => True
print("d[aa]","=>", d[aa])     # d[aa] => 1


你不能。 你必须选择一种行为或另一种行为。 使用.__eq__()方法的上下文不是(可靠地)可检测的。

如果您需要两者,那么您将需要使用不同的运算符或方法来表示DSL行为。