dbms.py
import aiomysql
import os
from dotenv import load_dotenv
load_dotenv(verbose=True)
class Database:
def __init__(self):
self.pool = None
async def create_pool(self):
self.pool = await aiomysql.create_pool(
host=os.getenv("MYSQL_HOST"),
port=int(os.getenv("MYSQL_PORT")),
user=os.getenv("MYSQL_USER"),
password=os.getenv("MYSQL_PASSWORD"),
db=os.getenv("MYSQL_DATABASE"),
autocommit=True,
cursorclass=aiomysql.DictCursor,
minsize=3,
maxsize=10,
)
async def disconnect(self):
if self.pool:
self.pool.close()
await self.pool.wait_closed()
async def execute_write_sql(self, sql, sql_data): # insert db
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(sql, sql_data)
async def execute_read_sql(self, sql, sql_data=None, fetchone=False): # select db
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(sql, sql_data)
if fetchone:
result = await cursor.fetchone()
else:
result = await cursor.fetchall()
return result
main.py
from dbms import Database
@app.on_event("startup")
async def startup():
app.state.db_pool = Database()
await app.state.db_pool.create_pool()
print("Created db pool")
@app.on_event("shutdown")
async def shutdown():
await app.state.db_pool.disconnect()
print("Disconnected db pool")
@app.middleware("http")
async def state_insert(request: Request, call_next):
request.state.db_pool = app.state.db_pool
response = await call_next(request)
return response
routers/test.py
import aiomysql
from fastapi import APIRouter, Depends, HTTPException
from fastapi.requests import Request
router = APIRouter(
prefix="/community",
tags=["community"],
responses={404: {"description": "Not found"}},
)
# middleware의 pool 종속성 주입을 위한 함수
def get_db_pool(request: Request):
return request.state.db_pool
@router.get("/{site_name}")
async def get_posts(site_name: str, db: aiomysql.Pool = Depends(get_db_pool)):
sql = "SELECT * FROM table WHERE `site_name` = %s order by no;"
sql_data = (site_name,)
rows = await db.execute_read_sql(sql, sql_data)
if len(rows) < 1:
raise HTTPException(status_code=404, detail="Posts not found")
return {"result": True, "data": rows}