본문 바로가기
Python

[FastAPI] aiomysql connection pool 사용하기

by lllIIl 2023. 4. 11.

dbms.py

import aiomysql
import os
from dotenv import load_dotenv

load_dotenv(verbose=True)


class Database:
    def __init__(self):
        self.pool = None

    async def create_pool(self):
        self.pool = await aiomysql.create_pool(
            host=os.getenv("MYSQL_HOST"),
            port=int(os.getenv("MYSQL_PORT")),
            user=os.getenv("MYSQL_USER"),
            password=os.getenv("MYSQL_PASSWORD"),
            db=os.getenv("MYSQL_DATABASE"),
            autocommit=True,
            cursorclass=aiomysql.DictCursor,
            minsize=3,
            maxsize=10,
        )

    async def disconnect(self):
        if self.pool:
            self.pool.close()
            await self.pool.wait_closed()

    async def execute_write_sql(self, sql, sql_data):  # insert db
        async with self.pool.acquire() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute(sql, sql_data)

    async def execute_read_sql(self, sql, sql_data=None, fetchone=False):  # select db
        async with self.pool.acquire() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute(sql, sql_data)
                if fetchone:
                    result = await cursor.fetchone()
                else:
                    result = await cursor.fetchall()
                return result

 

 

main.py

from dbms import Database

@app.on_event("startup")
async def startup():
    app.state.db_pool = Database()
    await app.state.db_pool.create_pool()
    print("Created db pool")


@app.on_event("shutdown")
async def shutdown():
    await app.state.db_pool.disconnect()
    print("Disconnected db pool")


@app.middleware("http")
async def state_insert(request: Request, call_next):
    request.state.db_pool = app.state.db_pool
    response = await call_next(request)
    return response

 

 

routers/test.py

import aiomysql
from fastapi import APIRouter, Depends, HTTPException
from fastapi.requests import Request

router = APIRouter(
    prefix="/community",
    tags=["community"],
    responses={404: {"description": "Not found"}},
)


# middleware의 pool 종속성 주입을 위한 함수
def get_db_pool(request: Request):
    return request.state.db_pool


@router.get("/{site_name}")
async def get_posts(site_name: str, db: aiomysql.Pool = Depends(get_db_pool)):
    sql = "SELECT * FROM table WHERE `site_name` = %s order by no;"
    sql_data = (site_name,)
    rows = await db.execute_read_sql(sql, sql_data)
    if len(rows) < 1:
        raise HTTPException(status_code=404, detail="Posts not found")
    return {"result": True, "data": rows}