ai_sandbox/db.py

177 lines
4.1 KiB
Python

from pydantic import BaseModel
import psycopg2 as pg
from psycopg2 import sql
from psycopg2.extras import RealDictCursor
from os import environ as env, stat
from functools import wraps
from inspect import getfullargspec
from typing import Callable, Type
from data import Recipe, Ingredient
class NoConnectionException(Exception):
pass
def connect():
return pg.connect(
dbname=env.get("POSTGRES_DB"),
user=env.get("POSTGRES_USER"),
password=env.get("POSTGRES_PASSWORD"),
host=env.get("POSTGRES_HOST"),
port=env.get("POSTGRES_PORT"),
)
def connect_by_default(
func: Callable,
kword: str = "conn",
conn_func: Callable[..., pg.extensions.connection] = connect,
) -> Callable:
@wraps(func)
def wrapped(*args, **kargs):
for i, arg in enumerate(args):
kargs[getfullargspec(func).args[i]] = arg
if not kargs.get(kword):
kargs[kword] = conn_func()
return func(**kargs)
return wrapped
@connect_by_default
def select_by_id(
schema_name: str,
table_name: str,
id: int,
columns: list[str] | None = None,
conn: pg.extensions.connection | None = None,
) -> dict | None:
if not conn:
raise NoConnectionException()
query = sql.SQL(
f"""
SELECT {'{}' if columns else '*'}
FROM {{}}.{{}}
WHERE id = %s;"""
)
if columns:
cols = sql.SQL(",").join([sql.Identifier(col) for col in columns])
query = query.format(
cols, sql.Identifier(schema_name), sql.Identifier(table_name)
)
else:
query = query.format(
sql.Identifier(schema_name), sql.Identifier(table_name)
)
with conn.cursor(cursor_factory=RealDictCursor) as curs:
curs.execute(query, (id,))
data = curs.fetchone()
if not data:
return
return data
@connect_by_default
def insert_model(
schema_name: str,
table_name: str,
model: BaseModel,
conn: pg.extensions.connection | None = None,
) -> dict | None:
if not conn:
raise NoConnectionException()
data = model.model_dump(exclude_none=True)
cols = data.keys()
vals = data.values()
query = sql.SQL(
f"""
INSERT INTO {{schema}}.{{table}} (
{{fields}}
) VALUES (
{", ".join(["%s" for _ in vals])}
)
RETURNING id;
"""
)
with conn.cursor() as curs:
try:
curs.execute(
query.format(
schema=sql.Identifier(schema_name),
table=sql.Identifier(table_name),
fields=sql.SQL(",").join(
[sql.Identifier(col) for col in cols]
),
),
(*vals,),
)
row = curs.fetchone()
conn.commit()
print(row)
except pg.errors.UniqueViolation as error:
conn.rollback()
raise error
if not row:
return
id = row[0]
data["id"] = id
return data
@connect_by_default
def insert_recipe(
recipe: Recipe,
conn: pg.extensions.connection | None = None,
) -> Recipe | None:
data = insert_model("recipe", "recipe", recipe, conn)
if not data:
return
return Recipe(**data)
@connect_by_default
def get_recipe(
id: int,
conn: pg.extensions.connection | None = None,
) -> Recipe | None:
data = select_by_id("recipe", "recipe", id, conn=conn)
if not data:
return
recipe = Recipe(**data)
return recipe
@connect_by_default
def insert_ingredient(
ingredient: Ingredient,
conn: pg.extensions.connection | None = None,
) -> Ingredient | None:
data = insert_model("recipe", "ingredient", ingredient, conn)
if not data:
return
return Ingredient(**data)
@connect_by_default
def get_ingredient(
id: int,
conn: pg.extensions.connection | None = None,
) -> Ingredient | None:
data = select_by_id("recipe", "ingredient", id, conn=conn)
if not data:
return
ingredient = Ingredient(**data)
return ingredient