Skip to content
🗂️ 文章分类: Python  
🏷️ 文章标签: FastAPI  SqlAlchemy  
📝 文章创建时间: 2025-12-27
🔥 文章最后更新时间:2025-12-27

[toc]

FastAPI笔记3-FastAPI与SQLAlchemy数据库集成

本文将详细介绍如何在 FastAPI 中使用 SQLAlchemy,包括数据库配置、模型定义、依赖注入、CRUD 操作等。

数据库配置

在 database_config.py 文件中定义数据库配置。

python
# 导入sqlalchemy框架中的各个工具
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

# mysql数据库的连接URL
MYSQL_DATABASE_URL = "mysql+pymysql://root:123456@localhost:3306/shuyx_db"

# 创建数据库引擎myEngine
myEngine = create_engine(MYSQL_DATABASE_URL,
    pool_size=10,            # 连接池大小
    pool_timeout=30,        # 池中没有线程最多等待的时间,否则报错
    echo=False              # 是否在控制台打印相关语句等
    )

# 创建会话工厂对象mySessionLocal
mySessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=myEngine, expire_on_commit=False)

# 创建统一数据库模型基类
myBase = declarative_base()

# 该函数每次通过会话工厂创建新的会话session对象。确保每个请求都有独立的会话。从而避免了并发访问同一个会话对象导致的事务冲突
def get_db_session():
    mysession = mySessionLocal()  #每次通过会话工厂创建新的会话session对象
    try:
        yield mysession
    except:
        mysession.rollback() # 发生异常时回滚事务
        raise
    finally:
        mysession.close()

SQLAlchemy 数据库模型类定义 (类似entity层)

SQLAlchemy 数据库模型类与数据库表是一一对应的关系。类似Java 后端中的entity层。

以mp_user_model.py 文件为例,定义用户模型类。

python
# 导入sqlalchemy框架中的相关字段
from sqlalchemy import Column, Integer, String, DateTime, CHAR, func, Index
# 导入database_config文件中的统一数据库模型基类
from config.database_config import myBase

class MpUserModel(myBase):
    """
    用户表 mp_user
    """
    __tablename__ = 'mp_user'

    id = Column("id",Integer, primary_key=True, autoincrement=True, comment='用户id')
    name = Column("name",String(500),nullable=False, comment='用户名')
    password = Column("password",String(500),nullable=False, comment='用户密码')
    phone = Column("phone", String(20), unique=True, nullable=False, comment='用户手机号')
    create_time = Column("create_time",DateTime, comment='创建时间', default=func.now())

    # 添加索引
    __table_args__ = (
        Index('index_id', 'id'),
        Index('index_phone', 'phone'),
    )

Pydantic 模型类定义 (类似dto层)

Pydantic 模型类与请求体、响应体等数据模型是一一对应的关系。类似Java 后端中的dto层。

以mp_user_dto.py 文件为例,定义用户数据传输对象类。

python
from datetime import datetime
from pydantic import BaseModel
from typing import Optional

# 定义用户模型类型
# 注意:DTO 是数据传输对象,用于在不同层之间传递数据,而不是直接与数据库交互。
class MpUserDTO(BaseModel):
    id:Optional[int] = None          # Optional[int] = None 表示类型可以是int,也可以是 None,默认值为 None
    name:Optional[str] = None
    password:Optional[str] = None
    phone:Optional[str] = None
    create_time:Optional[datetime] = None

数据库操作类定义 (类似dao层)

基础数据库操作类 BaseDao 定义 (封装通用CRUD方法)

以base_dao.py 文件为例,里面定义了通用CRUD方法。

python
from typing import Generic, TypeVar, Type, List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import update, delete, desc, asc
from config.database_config import myBase

# 定义泛型类型变量,约束为SQLAlchemy的Base模型
ModelType = TypeVar("ModelType", bound=myBase)

class BaseDao(Generic[ModelType]):
    """通用DAO基类,封装所有表的通用CRUD方法"""
    def __init__(self, model: Type[ModelType]):
        """
        初始化BaseDAO
        :param model: 关联的SQLAlchemy数据库模型(如User、Order)
        """
        self.model = model

    def get_by_id(self, db_session: Session, id: int) -> Optional[ModelType]:
        """
        根据ID获取单条记录(注意包含字段id)
            id: 记录ID
        """
        return db_session.query(self.model).filter(self.model.id == id).first()

    def get_total_by_filters(self,db_session: Session,filters: Dict = None) -> int:
        """
        根据条件获取记录总数
            filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
        """
        # 初始化查询对象
        query = db_session.query(self.model)
        # 动态构建查询条件
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field) and value is not None:
                    query = query.filter(getattr(self.model, field) == value)

        return query.count()

    def get_one_by_filter(self, db_session: Session, filters: Dict = None) -> Optional[ModelType]:
        """
        根据条件获取单条记录
            filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
        """
        # 初始化查询对象
        query = db_session.query(self.model)
        # 动态构建查询条件
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field) and value is not None:
                    query = query.filter(getattr(self.model, field) == value)

        return query.first()

    def get_list_by_filters(self,db_session: Session,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
        """
        根据条件获取查询列表(支持条件过滤+排序)
            filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
            sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
        """
        # 初始化查询对象
        query = db_session.query(self.model)
        # 动态构建查询条件
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field) and value is not None:
                    query = query.filter(getattr(self.model, field) == value)
        # 动态构建排序条件
        if sort_by:
            # 遍历排序字段列表
            for sort_field in sort_by:
                # 判断排序方向
                if sort_field.startswith('-'):
                    # 降序
                    field_name = sort_field[1:]
                    if hasattr(self.model, field_name):
                        query = query.order_by(desc(getattr(self.model, field_name)))
                else:
                    # 升序
                    if hasattr(self.model, sort_field):
                        query = query.order_by(asc(getattr(self.model, sort_field)))

        # 执行查询并获取所有记录
        return query.all()

    def get_page_list_by_filters(self,db_session: Session,page_num: int,page_size: int,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
        """
        根据条件获取分页查询列表(支持分页+条件过滤+排序)
            page_num: 页码
            page_size: 每页大小
            filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
            sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
        """
        # 初始化查询对象
        query = db_session.query(self.model)
        # 动态构建查询条件
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field) and value is not None:
                    query = query.filter(getattr(self.model, field) == value)
        # 动态构建排序条件
        if sort_by:
            # 遍历排序字段列表
            for sort_field in sort_by:
                # 判断排序方向
                if sort_field.startswith('-'):
                    # 降序
                    field_name = sort_field[1:]
                    if hasattr(self.model, field_name):
                        query = query.order_by(desc(getattr(self.model, field_name)))
                else:
                    # 升序
                    if hasattr(self.model, sort_field):
                        query = query.order_by(asc(getattr(self.model, sort_field)))

        # 计算分页偏移量
        offset_value = (page_num - 1) * page_size
        # 获取当前分页数据
        return query.offset(offset_value).limit(page_size).all()

    def add(self, db_session: Session, dict_data: Dict = None) -> ModelType:
        """
        添加新记录
            dict_data: 新记录的字典数据
        """
        # 将字典转换为对应的model实例
        new_instance = self.model(**dict_data)
        db_session.add(new_instance)
        # 显式提交事务
        db_session.commit()
        db_session.refresh(new_instance)
        return new_instance

    def update_by_id(self,db_session: Session,id: int,update_data: Dict = None) -> bool:
        """
        根据ID更新信息
            id: 要更新的记录ID
            update_data: 更新数据字典
        """

        # 1. 先查询记录是否存在
        db_obj = self.get_by_id(db_session, id)
        # 若不存在,则返回False
        if not db_obj:
            return False
        
        # 2. 遍历更新数据字典,更新model实例的字段
        for key, value in update_data.items():
            if hasattr(db_obj, key) and key != "id":  # 禁止更新id字段
                setattr(db_obj, key, value)
        
        # 3. 提交事务
        db_session.commit()
        return True 

    def delete_by_id(self, db_session: Session, id: int) -> bool:
        """
        根据ID删除记录
            id: 要删除的记录ID
        """
        
        # 1. 先查询记录是否存在
        db_obj = self.get_by_id(db_session, id)
        if not db_obj:
            return False
        
        # 2. 删除记录
        db_session.delete(db_obj)
        db_session.commit()
        return True

用户表数据库操作类 MpUserDao 定义 (封装用户表的CRUD方法)

当定义了基础数据库操作类 BaseDao 后,用户数据库操作类 MpUserDao 就可以直接继承 BaseDao,从而获得通用CRUD方法。

以mp_user_dao.py 文件为例,里面定义了用户表的CRUD方法。

python
from module_exam.model.mp_user_model import MpUserModel
from module_exam.dao.base_dao import BaseDao

# 继承BaseDao类,专注于数据访问操作, 可添加自定义方法
class MpUserDao(BaseDao[MpUserModel]):
    def __init__(self):
        """初始化DAO实例"""
        super().__init__(model = MpUserModel)

    # 可以根据业务需求添加自定义方法

服务类定义 (类似service层)

基础服务类 BaseService 定义 (封装通用业务逻辑)

以base_service.py 文件为例,里面定义了所有表的通用业务逻辑。

python
from typing import Generic, TypeVar, Type, List, Optional, Dict, Any
from sqlalchemy.orm import Session
from fastapi import HTTPException
from config.database_config import myBase
from module_exam.dao.base_dao import BaseDao

# 定义泛型类型变量,约束为SQLAlchemy的Base模型
ModelType = TypeVar("ModelType", bound=myBase)

class BaseService(Generic[ModelType]):
    """通用Service基类,封装通用业务逻辑"""
    def __init__(self, dao: BaseDao[ModelType]):
        """
        初始化BaseService
        :param dao: 关联的DAO实例(如UserDAO)
        """
        self.dao = dao

    def get_by_id(self, db_session: Session, id: int) -> Optional[ModelType]:
        """
        根据ID获取单条记录(注意包含字段id)
            id: 记录ID
        """
        return self.dao.get_by_id(db_session, id)

    def get_total_by_filters(self, db_session: Session, filters: Dict = None) -> int:
        """
        根据条件获取记录总数
            filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
        """
        return self.dao.get_total_by_filters(db_session, filters)
    
    def get_one_by_filter(self, db_session: Session, filters: Dict = None) -> Optional[ModelType]:
        """
        根据条件获取单条记录
            filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
        """
        return self.dao.get_one_by_filter(db_session, filters)
    
    def get_list_by_filters(self,db_session: Session,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
        """
        根据条件获取查询列表(支持条件过滤+排序)
            filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
            sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
        """
        return self.dao.get_list_by_filters(db_session, filters, sort_by)

    def get_page_list_by_filters(self,db_session: Session,page_num: int,page_size: int,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
        """
        根据条件获取分页查询列表(支持分页+条件过滤+排序)
            page_num: 页码
            page_size: 每页大小
            filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
            sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
        """
        return self.dao.get_page_list_by_filters(db_session, page_num, page_size, filters, sort_by)
    
    def add(self, db_session: Session, dict_data: Dict = None) -> ModelType:
        """
        添加新记录
            dict_data: 新记录的字典数据
        """
        return self.dao.add(db_session, dict_data)
    
    def update_by_id(self,db_session: Session,id: int,update_data: Dict = None) -> bool:
        """
        根据ID更新信息
            id: 要更新的记录ID
            update_data: 更新数据字典
        """
        return self.dao.update_by_id(db_session, id, update_data)
    
    def delete_by_id(self, db_session: Session, id: int) -> bool:
        """
        根据ID删除记录
            id: 要删除的记录ID
        """
        return self.dao.delete_by_id(db_session, id)

用户服务类定义 (封装用户表的业务逻辑)

以mp_user_service.py 文件为例,里面定义了用户表的业务逻辑。

python
from module_exam.dao.mp_user_dao import MpUserDao
from module_exam.model.mp_user_model import MpUserModel
from module_exam.service.base_service import BaseService

# 继承Service类,专注于业务操作, 可添加自定义方法
class MpUserService(BaseService[MpUserModel]):
    def __init__(self):
        """
        初始化服务实例
        创建DAO实例并传递给基类
        """
        self.dao_instance = MpUserDao()
        super().__init__(dao=self.dao_instance)

    # 可以根据业务需求添加自定义方法

路由操作函数定义 (类似controller层)

以mp_user_controller.py 文件为例,里面定义了用户表的路由操作函数。

python
from fastapi import APIRouter, Body, Depends
from sqlalchemy.orm import Session

from config.database_config import get_db_session
from config.log_config import logger
from module_exam.dto.mp_user_dto import MpUserDTO
from module_exam.service.mp_user_service import MpUserService
from utils.response_util import ResponseUtil

# 创建路由实例
router = APIRouter(prefix='/mp/user', tags=['mp_user接口'])

# 创建服务实例
MpUserService_instance = MpUserService()

@router.post("/saveUserInfo")
def saveUserInfo(userInfo:MpUserDTO,db_session:Session = Depends(get_db_session)):
        logger.info(f'/mp/user/saveUserInfo, userInfo = {userInfo}')
        
        # dto 转换为 dict
        updateuser_dict = userInfo.model_dump(exclude_unset=True)

        # 调用服务层方法,更新用户信息
        result = MpUserService_instance.update_by_id(db_session,id=userInfo.id,update_data=updateuser_dict),
        if result is False:
            return ResponseUtil.error(data={"message": "更新失败"})

        return ResponseUtil.success(data={"message": "更新成功"})

@router.post("/getUserInfo")
def getUserInfo(userId:int = Body(None),db_session:Session = Depends(get_db_session)):
        logger.info(f'/mp/user/getUserInfo, userId = {userId}')
        # 调用服务层方法,查询用户信息
        result = MpUserService_instance.get_one_by_filter(db_session,filters={"id": userId})
        # 若result为空,则返回空字典。不为空则返回result
        return ResponseUtil.success(data=result)

主应用文件

先在 main.py 文件中,引入路由实例。

python
# 导入FastAPI
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

# 创建FastAPI应用实例
app = FastAPI(
    title="微信小程序服务API",
    description="微信小程序-测试系统后端API",
    version="1.0.1"
)


# 导入控制器路由
from module_exam.controller.mp_user_controller import router as mp_user_router
# 通过include_router函数,把各个路由实例加入到FastAPI应用实例中,进行统一管理
app.include_router(mp_user_router)

# 测试运行接口
@app.get("/")
async def root():
    """根路径接口"""
    return {"message": "Hello World , 服务运行正常", "version": "1.0.0"}

if __name__ == "__main__":
    uvicorn.run(
        app='main:app',
        port=39666,
        reload=True,
        log_level="info"  # 添加日志级别配置
    )

然后通过下面命令启动服务:

bash
uvicorn main:app --reload

总结

通过使用 FastAPI 和 SQLAlchemy,可以构建高效的 Web 应用程序,并轻松实现与关系型数据库的交互。

本文介绍了如何配置数据库连接、定义数据库模型和 Pydantic 模型、依赖注入数据库会话,以及如何进行基本的 CRUD 操作。

Released under the MIT License.