对 Python 特殊方法的拓展应用

对 Python 特殊方法的拓展应用

疑问

曾经我在学习SQLAlchemy库时对这段代码百思不得其解:

with get_session() as session:
    # <class 'sqlalchemy.orm.query.Query'>
    query = (
        session.query(User)
               .filter(User.username == "asd")  # line 5
               .filter_by(username="asd")
               .join(Addreess)  # 使用ForeignKey
               .join(Addreess,Addreess.user_id==User.id)  # 使用显式声明
               .limit(10)
               .offset(0)
    )

在第5行中,使用了会话对象的filter方法来查询数据库中匹配条件的数据,而令我疑惑的地方正是这里。

在我看来,filter方法的参数User.username == "asd"应该是返回一个bool值,要么是True要么是False

那么问题来了,filter方法怎么可能通过一个bool值参数来完成查询操作呢?这不科学!

这个问题一直困扰了我一年多。

灵感

前不久在V2EX上看到了这一篇帖子:六行代码实现 Python 管道

在这篇帖子中,作者大胆地使用了Python的按位或运算符(|)实现了类似Shell的管道功能(即|左边的表达式的输出作为|右边的表达式的输入)。

作者编写的代码如下:

from functools import partial

class F(partial):
    def __ror__(self, other):
        if isinstance(other, tuple):
            return self(*other)
        return self(other)

# 使用样例
range(10) | F(filter, lambda x: x % 2) | F(sum)

首先要知道按位或运算符|的常规用法:对于两个值,只要相应位有一个为1时,结果位就为1。即:

0b01 | 0b00 == 0b01  # True
0b01 | 0b10 == 0b11  # True
0b01 | 0b11 == 0b11  # True

另外一个作用就是取两个集合的并集:

{1, 2, 3} | {2, 3, 4} == {1, 2, 3, 4}  # True

然而,Python允许我们自己编写特殊方法来改变该运算符的行为,摆脱运算符的“限制”,由此便可实现管道功能。

答疑

回到我们一开始提出的问题,换个角度思考:

我们可以假设User.username__eq__方法被重写了,现在User.username == "asd"返回的不再是TrueFalse,而是一个三元元组:(User.username, "==", "asd")

这样一来,filter方法接收到了这个三元元组参数,自然就可以进行数据库查询了。妙啊!

由于sqlalchemy的源码太过复杂,我没能验证自己的猜测,不过我猜自己的猜测是对的,嘿嘿

举一反三

简单地写了一段代码进行实践:

class KeyValue:  # 假设这就是sql数据库中的每条数据的字段

    def __init__(self, key, value=None):
        self.key = key
        self.value = value

    def __gt__(self, obj):  # 重写">"运算符方法
        return self, ">", obj

class Model:  # 假设这就是sql数据库中的每条数据

    name = KeyValue("name")
    age = KeyValue("age")

    def __init__(self, name, age):
        self.name = KeyValue("name", name)
        self.age = KeyValue("age", age)

class ModelManager:

    def __init__(self):
        self.__objs = []

    def add(self, obj):
        self.__objs.append(obj)

    def filter(self, expression):
        val_left, operator, val_right = expression
        if operator == ">":
            return [obj for obj in self.__objs if getattr(obj, val_left.key).value > val_right]
        else:
            raise ValueError("Unparseable expression: %s" % expression)

model_a = Model("A", 9)
model_b = Model("B", 18)
model_c = Model("C", 27)

model_manager = ModelManager()

for model in [model_a, model_b, model_c]:
    model_manager.add(model)

assert model_manager.filter(Model.age > 17) == [model_b, model_c, ]
assert model_manager.filter(Model.age > 18) == [model_c, ]

参考文献

  1. SQLAlchemy入门和进阶
  2. Python3 运算符