[toc]
FastAPI笔记3-FastAPI与SQLAlchemy数据库集成
本文将详细介绍如何在 FastAPI 中使用 SQLAlchemy,包括数据库配置、模型定义、依赖注入、CRUD 操作等。
数据库配置
在 database_config.py 文件中定义数据库配置。
python
# 导入sqlalchemy框架中的各个工具
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
# mysql数据库的连接URL
MYSQL_DATABASE_URL = "mysql+pymysql://root:123456@localhost:3306/shuyx_db"
# 创建数据库引擎myEngine
myEngine = create_engine(MYSQL_DATABASE_URL,
pool_size=10, # 连接池大小
pool_timeout=30, # 池中没有线程最多等待的时间,否则报错
echo=False # 是否在控制台打印相关语句等
)
# 创建会话工厂对象mySessionLocal
mySessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=myEngine, expire_on_commit=False)
# 创建统一数据库模型基类
myBase = declarative_base()
# 该函数每次通过会话工厂创建新的会话session对象。确保每个请求都有独立的会话。从而避免了并发访问同一个会话对象导致的事务冲突
def get_db_session():
mysession = mySessionLocal() #每次通过会话工厂创建新的会话session对象
try:
yield mysession
except:
mysession.rollback() # 发生异常时回滚事务
raise
finally:
mysession.close()SQLAlchemy 数据库模型类定义 (类似entity层)
SQLAlchemy 数据库模型类与数据库表是一一对应的关系。类似Java 后端中的entity层。
以mp_user_model.py 文件为例,定义用户模型类。
python
# 导入sqlalchemy框架中的相关字段
from sqlalchemy import Column, Integer, String, DateTime, CHAR, func, Index
# 导入database_config文件中的统一数据库模型基类
from config.database_config import myBase
class MpUserModel(myBase):
"""
用户表 mp_user
"""
__tablename__ = 'mp_user'
id = Column("id",Integer, primary_key=True, autoincrement=True, comment='用户id')
name = Column("name",String(500),nullable=False, comment='用户名')
password = Column("password",String(500),nullable=False, comment='用户密码')
phone = Column("phone", String(20), unique=True, nullable=False, comment='用户手机号')
create_time = Column("create_time",DateTime, comment='创建时间', default=func.now())
# 添加索引
__table_args__ = (
Index('index_id', 'id'),
Index('index_phone', 'phone'),
)Pydantic 模型类定义 (类似dto层)
Pydantic 模型类与请求体、响应体等数据模型是一一对应的关系。类似Java 后端中的dto层。
以mp_user_dto.py 文件为例,定义用户数据传输对象类。
python
from datetime import datetime
from pydantic import BaseModel
from typing import Optional
# 定义用户模型类型
# 注意:DTO 是数据传输对象,用于在不同层之间传递数据,而不是直接与数据库交互。
class MpUserDTO(BaseModel):
id:Optional[int] = None # Optional[int] = None 表示类型可以是int,也可以是 None,默认值为 None
name:Optional[str] = None
password:Optional[str] = None
phone:Optional[str] = None
create_time:Optional[datetime] = None数据库操作类定义 (类似dao层)
基础数据库操作类 BaseDao 定义 (封装通用CRUD方法)
以base_dao.py 文件为例,里面定义了通用CRUD方法。
python
from typing import Generic, TypeVar, Type, List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import update, delete, desc, asc
from config.database_config import myBase
# 定义泛型类型变量,约束为SQLAlchemy的Base模型
ModelType = TypeVar("ModelType", bound=myBase)
class BaseDao(Generic[ModelType]):
"""通用DAO基类,封装所有表的通用CRUD方法"""
def __init__(self, model: Type[ModelType]):
"""
初始化BaseDAO
:param model: 关联的SQLAlchemy数据库模型(如User、Order)
"""
self.model = model
def get_by_id(self, db_session: Session, id: int) -> Optional[ModelType]:
"""
根据ID获取单条记录(注意包含字段id)
id: 记录ID
"""
return db_session.query(self.model).filter(self.model.id == id).first()
def get_total_by_filters(self,db_session: Session,filters: Dict = None) -> int:
"""
根据条件获取记录总数
filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
"""
# 初始化查询对象
query = db_session.query(self.model)
# 动态构建查询条件
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
return query.count()
def get_one_by_filter(self, db_session: Session, filters: Dict = None) -> Optional[ModelType]:
"""
根据条件获取单条记录
filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
"""
# 初始化查询对象
query = db_session.query(self.model)
# 动态构建查询条件
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
return query.first()
def get_list_by_filters(self,db_session: Session,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
"""
根据条件获取查询列表(支持条件过滤+排序)
filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
"""
# 初始化查询对象
query = db_session.query(self.model)
# 动态构建查询条件
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
# 动态构建排序条件
if sort_by:
# 遍历排序字段列表
for sort_field in sort_by:
# 判断排序方向
if sort_field.startswith('-'):
# 降序
field_name = sort_field[1:]
if hasattr(self.model, field_name):
query = query.order_by(desc(getattr(self.model, field_name)))
else:
# 升序
if hasattr(self.model, sort_field):
query = query.order_by(asc(getattr(self.model, sort_field)))
# 执行查询并获取所有记录
return query.all()
def get_page_list_by_filters(self,db_session: Session,page_num: int,page_size: int,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
"""
根据条件获取分页查询列表(支持分页+条件过滤+排序)
page_num: 页码
page_size: 每页大小
filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
"""
# 初始化查询对象
query = db_session.query(self.model)
# 动态构建查询条件
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
# 动态构建排序条件
if sort_by:
# 遍历排序字段列表
for sort_field in sort_by:
# 判断排序方向
if sort_field.startswith('-'):
# 降序
field_name = sort_field[1:]
if hasattr(self.model, field_name):
query = query.order_by(desc(getattr(self.model, field_name)))
else:
# 升序
if hasattr(self.model, sort_field):
query = query.order_by(asc(getattr(self.model, sort_field)))
# 计算分页偏移量
offset_value = (page_num - 1) * page_size
# 获取当前分页数据
return query.offset(offset_value).limit(page_size).all()
def add(self, db_session: Session, dict_data: Dict = None) -> ModelType:
"""
添加新记录
dict_data: 新记录的字典数据
"""
# 将字典转换为对应的model实例
new_instance = self.model(**dict_data)
db_session.add(new_instance)
# 显式提交事务
db_session.commit()
db_session.refresh(new_instance)
return new_instance
def update_by_id(self,db_session: Session,id: int,update_data: Dict = None) -> bool:
"""
根据ID更新信息
id: 要更新的记录ID
update_data: 更新数据字典
"""
# 1. 先查询记录是否存在
db_obj = self.get_by_id(db_session, id)
# 若不存在,则返回False
if not db_obj:
return False
# 2. 遍历更新数据字典,更新model实例的字段
for key, value in update_data.items():
if hasattr(db_obj, key) and key != "id": # 禁止更新id字段
setattr(db_obj, key, value)
# 3. 提交事务
db_session.commit()
return True
def delete_by_id(self, db_session: Session, id: int) -> bool:
"""
根据ID删除记录
id: 要删除的记录ID
"""
# 1. 先查询记录是否存在
db_obj = self.get_by_id(db_session, id)
if not db_obj:
return False
# 2. 删除记录
db_session.delete(db_obj)
db_session.commit()
return True用户表数据库操作类 MpUserDao 定义 (封装用户表的CRUD方法)
当定义了基础数据库操作类 BaseDao 后,用户数据库操作类 MpUserDao 就可以直接继承 BaseDao,从而获得通用CRUD方法。
以mp_user_dao.py 文件为例,里面定义了用户表的CRUD方法。
python
from module_exam.model.mp_user_model import MpUserModel
from module_exam.dao.base_dao import BaseDao
# 继承BaseDao类,专注于数据访问操作, 可添加自定义方法
class MpUserDao(BaseDao[MpUserModel]):
def __init__(self):
"""初始化DAO实例"""
super().__init__(model = MpUserModel)
# 可以根据业务需求添加自定义方法服务类定义 (类似service层)
基础服务类 BaseService 定义 (封装通用业务逻辑)
以base_service.py 文件为例,里面定义了所有表的通用业务逻辑。
python
from typing import Generic, TypeVar, Type, List, Optional, Dict, Any
from sqlalchemy.orm import Session
from fastapi import HTTPException
from config.database_config import myBase
from module_exam.dao.base_dao import BaseDao
# 定义泛型类型变量,约束为SQLAlchemy的Base模型
ModelType = TypeVar("ModelType", bound=myBase)
class BaseService(Generic[ModelType]):
"""通用Service基类,封装通用业务逻辑"""
def __init__(self, dao: BaseDao[ModelType]):
"""
初始化BaseService
:param dao: 关联的DAO实例(如UserDAO)
"""
self.dao = dao
def get_by_id(self, db_session: Session, id: int) -> Optional[ModelType]:
"""
根据ID获取单条记录(注意包含字段id)
id: 记录ID
"""
return self.dao.get_by_id(db_session, id)
def get_total_by_filters(self, db_session: Session, filters: Dict = None) -> int:
"""
根据条件获取记录总数
filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
"""
return self.dao.get_total_by_filters(db_session, filters)
def get_one_by_filter(self, db_session: Session, filters: Dict = None) -> Optional[ModelType]:
"""
根据条件获取单条记录
filters: 查询条件,字典类型。例如 {"filed1": value1, "filed2": value2}
"""
return self.dao.get_one_by_filter(db_session, filters)
def get_list_by_filters(self,db_session: Session,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
"""
根据条件获取查询列表(支持条件过滤+排序)
filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
"""
return self.dao.get_list_by_filters(db_session, filters, sort_by)
def get_page_list_by_filters(self,db_session: Session,page_num: int,page_size: int,filters: Dict = None,sort_by: List[str] = None) -> List[ModelType]:
"""
根据条件获取分页查询列表(支持分页+条件过滤+排序)
page_num: 页码
page_size: 每页大小
filters: 查询条件字典。例如 {"filed1": value1, "filed2": value2}
sort_by: 排序字段,是一个字符串列表。例如 ["field1", "-field2"] 表示按field1升序,按field2降序排序。
"""
return self.dao.get_page_list_by_filters(db_session, page_num, page_size, filters, sort_by)
def add(self, db_session: Session, dict_data: Dict = None) -> ModelType:
"""
添加新记录
dict_data: 新记录的字典数据
"""
return self.dao.add(db_session, dict_data)
def update_by_id(self,db_session: Session,id: int,update_data: Dict = None) -> bool:
"""
根据ID更新信息
id: 要更新的记录ID
update_data: 更新数据字典
"""
return self.dao.update_by_id(db_session, id, update_data)
def delete_by_id(self, db_session: Session, id: int) -> bool:
"""
根据ID删除记录
id: 要删除的记录ID
"""
return self.dao.delete_by_id(db_session, id)用户服务类定义 (封装用户表的业务逻辑)
以mp_user_service.py 文件为例,里面定义了用户表的业务逻辑。
python
from module_exam.dao.mp_user_dao import MpUserDao
from module_exam.model.mp_user_model import MpUserModel
from module_exam.service.base_service import BaseService
# 继承Service类,专注于业务操作, 可添加自定义方法
class MpUserService(BaseService[MpUserModel]):
def __init__(self):
"""
初始化服务实例
创建DAO实例并传递给基类
"""
self.dao_instance = MpUserDao()
super().__init__(dao=self.dao_instance)
# 可以根据业务需求添加自定义方法路由操作函数定义 (类似controller层)
以mp_user_controller.py 文件为例,里面定义了用户表的路由操作函数。
python
from fastapi import APIRouter, Body, Depends
from sqlalchemy.orm import Session
from config.database_config import get_db_session
from config.log_config import logger
from module_exam.dto.mp_user_dto import MpUserDTO
from module_exam.service.mp_user_service import MpUserService
from utils.response_util import ResponseUtil
# 创建路由实例
router = APIRouter(prefix='/mp/user', tags=['mp_user接口'])
# 创建服务实例
MpUserService_instance = MpUserService()
@router.post("/saveUserInfo")
def saveUserInfo(userInfo:MpUserDTO,db_session:Session = Depends(get_db_session)):
logger.info(f'/mp/user/saveUserInfo, userInfo = {userInfo}')
# dto 转换为 dict
updateuser_dict = userInfo.model_dump(exclude_unset=True)
# 调用服务层方法,更新用户信息
result = MpUserService_instance.update_by_id(db_session,id=userInfo.id,update_data=updateuser_dict),
if result is False:
return ResponseUtil.error(data={"message": "更新失败"})
return ResponseUtil.success(data={"message": "更新成功"})
@router.post("/getUserInfo")
def getUserInfo(userId:int = Body(None),db_session:Session = Depends(get_db_session)):
logger.info(f'/mp/user/getUserInfo, userId = {userId}')
# 调用服务层方法,查询用户信息
result = MpUserService_instance.get_one_by_filter(db_session,filters={"id": userId})
# 若result为空,则返回空字典。不为空则返回result
return ResponseUtil.success(data=result)主应用文件
先在 main.py 文件中,引入路由实例。
python
# 导入FastAPI
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# 创建FastAPI应用实例
app = FastAPI(
title="微信小程序服务API",
description="微信小程序-测试系统后端API",
version="1.0.1"
)
# 导入控制器路由
from module_exam.controller.mp_user_controller import router as mp_user_router
# 通过include_router函数,把各个路由实例加入到FastAPI应用实例中,进行统一管理
app.include_router(mp_user_router)
# 测试运行接口
@app.get("/")
async def root():
"""根路径接口"""
return {"message": "Hello World , 服务运行正常", "version": "1.0.0"}
if __name__ == "__main__":
uvicorn.run(
app='main:app',
port=39666,
reload=True,
log_level="info" # 添加日志级别配置
)然后通过下面命令启动服务:
bash
uvicorn main:app --reload总结
通过使用 FastAPI 和 SQLAlchemy,可以构建高效的 Web 应用程序,并轻松实现与关系型数据库的交互。
本文介绍了如何配置数据库连接、定义数据库模型和 Pydantic 模型、依赖注入数据库会话,以及如何进行基本的 CRUD 操作。
