开发者

利用Python实现一个类似MybatisPlus的简易SQL注解

目录
  • 前言
  • 实现思路
  • 定义一个类
  • 然后开始手撸这个微型框架
  • 根据字符串获取到所定义的DTO类
  • 构建返回结果
  • 装饰器
    • 解析字符串,获得变量
    • SQL字符串拼接
  • 使用装饰器

    前言

    在实际开发中,根据业务拼接SQL所需要考虑的内容太多了。于是,有没有一种办法,可以像MyBATisPlus一样通过配置注解实现SQL注入呢?

    就像是:

    @mybatis.select("select * from user where id = #{id}")
    def get_user(id): ...
    

    那可就降低了好多工作量。

    P.S.:本文并不希望完全复现MyBatisPlus的所有功能,能够基本配置SQL注解就基本能够完成大部分工作了。

    实现思路

    那我们这么考虑:

    1. 首先,我们需要定义一个类,类中给一个或者多个装饰器;
    2. 我们先在类内定义一个字符串,这个字符串能够配置到指定的DTO类,用于存储结果;
    3. 我们针对装饰器中的SQL字符串进行解析,解析到其中的变量个数与名称;
    4. 我们针对被装饰的函数进行解析,与SQL变量进行匹配;
    5. 替换变量;
    6. 执行SQL

    听起来并不难。我们一步步来。

    定义一个类

    首先定义:

    # dto/student.py
    class Student:
        def __init__(self, name, age):
            self.name = name
            self.age = age
    

    为了简化操作,这个类就不放在任意位置了,直接放在dto文件夹下,后续导入这个类也就直接从dto文件夹中引入,就不考虑做这个包名定位的接口了。

    当然,为了更方便后续的操作,我们需要在dto文件夹中定义一个__init__.py文件,用于对外暴露这个类:

    # dto/__init__.py
    from dto.student import Student
    __all__ = ["Student"]
    

    最后呢,我们为了方便这个类的序列化,让他能够变成dict类型,加一些魔法函数:

    # dto/student.py
    class Student:
        def __init__(self, name, age):
            self.name = name
            self.age = age
        def __iter__(self):
            for key, value in self.__dict__.items():
                yield key, value
        def __getitem__(self, key):
            return getattr(self, key)
        def keys(self):
            return self.__dict__.keys()
    

    当然,一个项目里面肯定不止这一个返回结果,所以各位也可以这么操作:

    # dto/common.py
    class CommonResult:
        def __init__(self):www.devze.com ...
        def __iter__(self):
            for key, value in self.__dict__.items():
                yield key, value
        def __getitem__(self, key):
            return getattr(self, key)
        def keys(self):
            return self.__dict__.keys()
    # dto/student.py
    from dto.common import CommonResult
    class Student(CommonResult):
        def __init__(self, name, age):
            self.name = name
            self.age = age
    

    至于实际业务中还有很多复杂的联立等操作需要新的类,受限于篇幅,就不展开了。如果能够把本篇看懂的话,相信各位也没什么其他的困难了。

    然后开始手撸这个微型框架

    # db/common.py
    from pydantic import BaseModel, Field
    
    class DBManager(BaseModel):
      base_type: str = Field(..., description="数据库表名")
      link: str = Field(..., description="数据库连接地址")
      local_generator: Any = Field(..., description="实体类实例化解析生成器")
      def search(query_template): ...
    

    在这里呢,我们定义了一个DBManager作为父类,要求后面的子类必须有:

    • str类型的base_type,表示返回结果类的名称;
    • str类型的link,表示数据库连接地址;
    • Any类型的local_generator,表示实体类实例化解析生成器,- 任意返回值的query方法,用于执行SQL

    为什么一定要用BaseModel定义?直接定义self.xxx不好吗?

    因为这样会看起来代码量很大(逃)

    看着差不多。

    根据字符串获取到所定义的DTO类

    考虑到实际上我们所有的方法都需要特定到具体的位置,所以这个方法还是直接写到DBManager类中,这样子类就不需要再重写了。

    # db/common.py
    from pydantic import BaseModel, Field
    
    class DBManager(BaseModel):
        base_type: str = Field(..., description="数据库表名")
        link: str = Field(..., description="数据库连接地址")
        local_generator: Any = Field(..., description="实体类实例化解析生成器")
    
        def search(query_template): ...
    
        def import_class_from_package(self, package_name, class_name):
            # 根据包名获得`DTO`包
            _package = importlib.import_module(package_name)
            # 检测是不是有这么个类
            if class_name not in _package.__all__:
                raise ImportError(f"{class_name} not found in {package_name}")
            # 有就拿着
            cls = getattr(_package, class_name)
            # 返回这个类
            if cls is not None:
                return cls
            else:
                raise ImportError(f"{class_name} not found in {package_name}")
    

    这样子类就可以调用这个方法获得所需的类了。

    构建返回结果

    既然都已经能够动态导入类了,那我把返回结果导入到Student中,没问题吧?

    其中需要注意的是,我这边采用的数据库驱动是sqlalchemy,所以构造返回结果所需要的参数是sqlalchemyRow类型。

    同样的,为了减少子类重写的代码量,直接在父类给出来:

    # db/common.py
    from pydantic import BaseModel, Field
    from sqlalchemy.engine.row import Row
    
    class DBManager(BaseModel):
        base_type: str = Field(..., description="数据库表名")
        link: str = Field(..., description="数据库连接地址")
        local_generator: Any = Field(..., description="实体类实例化解析生成器")
    
        def search(query_template): ...
        # 为了方便看,省略掉细节
        def import_class_from_package(self, package_name, class_name): ...
    
        def build_obj(self, row: Row):
            return self.local_generator(**row._asdict()) if self.local_generator else None
    

    装饰器

    那么接下来就是重头戏了,怎么定义这个装饰器。

    我们先构建一个子类:

    # db/student.py
    class StudentDBManager(DBManager):
        base_type: ClassVar[str] = "Student"
        link: ClassVar[str] = 'SQLite:///school.db'
        local_generator: ClassVar[Any] = None
    
        """
        自定义PyMyBatis
        """
        def __init__(self):
            StudentDBManager.local_generator = self.import_class_from_package("dto", self.base_type)
    

    在这里,首先需要注意的是,需要用ClassVar修饰,将变量名定义为类内成员变量,否则无法使用self.xxx访问。

    其次,我们利用base_type指定返回值对应的DTO类、link指定数据库连接地址,local_generator指定实体类实例化解析生成器。

    在这个类实例化的过程中,我们还需要进一步构建local_generator,也就是动态执行from xxx import xxx

    然后定义一个装饰器:

    def query(query_template: str):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper
        return decorator
    

    这可以算得上是比较基础的模板了。至于之后怎么改,管他呢,先套公式。

    在这里,我们首先定义的装饰器是decorator,没有参数;其次再用query装饰器包装,从而给无参的装饰器给一个参数,从而接收一个SQL字符串参数。

    好的,我们再进一步。

    解析字符串,获得变量

    首先当然是解析SQL字符串,获得变量。如何做呢?为了简便,这里直接采用正则匹配的方式:

    def query(self, query_template):
        def decorator(func):
            # 解析 SQL 中的 #{变量} 语法
            param_pattern = re.compile(r"#{(\w+)}")
            required_params = set(param_pattern.findall(query_template))
            @wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper
        return decorator
    

    没啥问题。

    接下来,调用的时候,我们需要检测是否完整给出了SQL字符串所需的参数。

    我们考虑到,如果但凡SQL中的参数有变化,方法就会有变化,因此每个SQL都有一个方法也太麻烦了。主要是这么多相似的方法起方法名太烦了

    所以,直接上反射,获取 调用 的时侯传入的参数。

    值得注意的是,这里说的是 调用 的时候。因为python定义 方法的时候可以使用**kargs传入多个参数,但是如果反射直接获取到 定义 的参数,将会只有一个kargs,这显然不是我们所希望的。

    所以,再加一些:

    def query(self, query_template):
        def decorator(func):
            # 解析 SQL 中的 #{变量} 语法
            param_pattern = re.compile(r"#{(\w+)}")
            required_params = set(param_pattern.findall(query_template))
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 获取函数的参数签名
                sig = inspect.signature(func)
                bound_args = sig.bind_partial(*args, **kwargs)
                bound_args.apply_defaults()
                # 提取传递的参数,包括 **kwargs 中的参数
                provided_params = set(bound_args.arguments.keys()) | set(kwargs.keys())
                # 检查缺失的参数
                missing_params = required_params - provided_params
                if missing_params:
                    raise ValueError(f"Missing required parameters: {', '.join(missing_params)}")
                return func(*args, **kwargs)
            return wrapper
        return decorator
    

    这下应该就能够适配到所有的SQL情况了。

    SQL字符串拼接

    接下来就是直接替换值了。但是,拼接真的就是对的吗?我们不光是需要考虑不同的变量有着不同的植入格式,同时也需要考虑到植入过程中可能的SQL注入问题。

    所以,我们就直接采用sqlalchemytext函数,对SQL进行拼接与赋值。

    def query(self, query_template):
        def decorator(func):
            # 解析 SQL 中的 #{变量} 语法
            param_pattern = re.compile(r"#{(\w+)}")
            required_params = set(param_pattern.findall(query_template))
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 获取函数的参数签名
                sig = inspect.signature(func)
                bound_args = sig.bind_partjsial(*args, **kwa编程客栈rgs)
                bound_args.apply_defaults()
                # 提取传递的参数,包括 **kwargs 中的参数
                provided_params = set(bound_args.arguments.keys()) | set(kwargs.keys())
                # 检查缺失的参数
                missing_params = required_params - provided_params
                if missing_params:
                    raise ValueError(f"Missing required parameters: {', '.join(missing_params)}")
                # 构建 SQL 语句,并考虑不同类型的数据格式
                sql_query = text(query_template.replace("#{", ":").replace("}", ""))
                pAtjIJjMVYsrint(f"Executing SQL: {sql_query}")
                return func(*args, **kwargs)
            return wrapper
        return decorator
    

    好了,到这一步也就基本完成了。最后,我们根据数据库存储数据的特点,最后修整一下查询的格式细节,就可以了:

    def query(self, query_template):
        def decorator(func):
            # 解析 SQL 中的 #{变量} 语法
            param_pattern = re.compile(r"#{(\w+)}")
            required_params = set(param_pattern.findall(query_template))
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 获取函数的参数签名
                sig = inspect.signature(func)
                bound_args = sig.bind_partial(*args, **kwargs)
                bound_args.apply_defaults()
                # 提取传递的参数,包括 **kwargs 中的参数
                provided_params = set(bound_args.arguments.keys()) | set(kwargs.keys())
                # 检查缺失的参数
                missing_params = javascriptrequired_params - provided_params
                if missing_params:
                    raise ValueError(f"Missing required parameters: {', '.join(missing_params)}")
                # 构建 SQL 语句,并考虑不同类型的数据格式
                sql_query = text(query_template.replace("#{", ":").replace("}", ""))
                print(f"Executing SQL: {sql_query}")
                params = bound_args.arguments.copy()
                for key, value in params.items():
                    if isinstance(value, datetime):
                        params[key] = value.strftime('%Y-%m-%d')
                engine = create_engine(self.link)
                with engine.connect() as conn:
                    result = conn.execute(sql_query, params)
                    search_result = [self.create_item_obj(row) for row in result]
                return search_result
            return wrapper
        return decorator
    

    就是这样,我们就完成了这样一个装饰器。

    使用装饰器

    使用过程,其实就可以类比@Service中的调用了。而如果拿Python举例的话,其实更像Flaskapp.route。于是我们可以这么使用:

    sbd = StudentDBManager()
    @sbd.query("SELECT * FROM student WHERE id = #{id}")
    def find_student_by_id(**kargs): ...
    

    这也就实现了一个方法。

    当然,他也没那么智能。虽然写起来是这样,但是依然相当于:

    sbd = StudentDBManager()
    @sbd.query("SELECT * FROM student WHERE id = #{id}")
    def find_student_by_id(id: str): ...
    

    只是说,我们并不需要重复地去写驱动罢了。

    以上就是利用Python实现一个类似MybatisPlus的简易SQL注解的详细内容,更多关于Python简易SQL注解的资料请关注编程客栈(www.devze.com)其它相关文章!

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜