177 lines
4.1 KiB
Python
177 lines
4.1 KiB
Python
from pydantic import BaseModel
|
|
import psycopg2 as pg
|
|
from psycopg2 import sql
|
|
from psycopg2.extras import RealDictCursor
|
|
from os import environ as env, stat
|
|
from functools import wraps
|
|
from inspect import getfullargspec
|
|
from typing import Callable, Type
|
|
from data import Recipe, Ingredient
|
|
|
|
|
|
class NoConnectionException(Exception):
|
|
pass
|
|
|
|
|
|
def connect():
|
|
return pg.connect(
|
|
dbname=env.get("POSTGRES_DB"),
|
|
user=env.get("POSTGRES_USER"),
|
|
password=env.get("POSTGRES_PASSWORD"),
|
|
host=env.get("POSTGRES_HOST"),
|
|
port=env.get("POSTGRES_PORT"),
|
|
)
|
|
|
|
|
|
def connect_by_default(
|
|
func: Callable,
|
|
kword: str = "conn",
|
|
conn_func: Callable[..., pg.extensions.connection] = connect,
|
|
) -> Callable:
|
|
@wraps(func)
|
|
def wrapped(*args, **kargs):
|
|
for i, arg in enumerate(args):
|
|
kargs[getfullargspec(func).args[i]] = arg
|
|
if not kargs.get(kword):
|
|
kargs[kword] = conn_func()
|
|
return func(**kargs)
|
|
|
|
return wrapped
|
|
|
|
|
|
@connect_by_default
|
|
def select_by_id(
|
|
schema_name: str,
|
|
table_name: str,
|
|
id: int,
|
|
columns: list[str] | None = None,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> dict | None:
|
|
if not conn:
|
|
raise NoConnectionException()
|
|
|
|
query = sql.SQL(
|
|
f"""
|
|
SELECT {'{}' if columns else '*'}
|
|
FROM {{}}.{{}}
|
|
WHERE id = %s;"""
|
|
)
|
|
if columns:
|
|
cols = sql.SQL(",").join([sql.Identifier(col) for col in columns])
|
|
query = query.format(
|
|
cols, sql.Identifier(schema_name), sql.Identifier(table_name)
|
|
)
|
|
else:
|
|
query = query.format(
|
|
sql.Identifier(schema_name), sql.Identifier(table_name)
|
|
)
|
|
|
|
with conn.cursor(cursor_factory=RealDictCursor) as curs:
|
|
curs.execute(query, (id,))
|
|
data = curs.fetchone()
|
|
|
|
if not data:
|
|
return
|
|
|
|
return data
|
|
|
|
|
|
@connect_by_default
|
|
def insert_model(
|
|
schema_name: str,
|
|
table_name: str,
|
|
model: BaseModel,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> dict | None:
|
|
if not conn:
|
|
raise NoConnectionException()
|
|
|
|
data = model.model_dump(exclude_none=True)
|
|
cols = data.keys()
|
|
vals = data.values()
|
|
|
|
query = sql.SQL(
|
|
f"""
|
|
INSERT INTO {{schema}}.{{table}} (
|
|
{{fields}}
|
|
) VALUES (
|
|
{", ".join(["%s" for _ in vals])}
|
|
)
|
|
RETURNING id;
|
|
"""
|
|
)
|
|
|
|
with conn.cursor() as curs:
|
|
try:
|
|
curs.execute(
|
|
query.format(
|
|
schema=sql.Identifier(schema_name),
|
|
table=sql.Identifier(table_name),
|
|
fields=sql.SQL(",").join(
|
|
[sql.Identifier(col) for col in cols]
|
|
),
|
|
),
|
|
(*vals,),
|
|
)
|
|
row = curs.fetchone()
|
|
conn.commit()
|
|
print(row)
|
|
except pg.errors.UniqueViolation as error:
|
|
conn.rollback()
|
|
raise error
|
|
|
|
if not row:
|
|
return
|
|
id = row[0]
|
|
|
|
data["id"] = id
|
|
return data
|
|
|
|
|
|
@connect_by_default
|
|
def insert_recipe(
|
|
recipe: Recipe,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> Recipe | None:
|
|
data = insert_model("recipe", "recipe", recipe, conn)
|
|
if not data:
|
|
return
|
|
return Recipe(**data)
|
|
|
|
|
|
@connect_by_default
|
|
def get_recipe(
|
|
id: int,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> Recipe | None:
|
|
data = select_by_id("recipe", "recipe", id, conn=conn)
|
|
if not data:
|
|
return
|
|
|
|
recipe = Recipe(**data)
|
|
return recipe
|
|
|
|
|
|
@connect_by_default
|
|
def insert_ingredient(
|
|
ingredient: Ingredient,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> Ingredient | None:
|
|
data = insert_model("recipe", "ingredient", ingredient, conn)
|
|
if not data:
|
|
return
|
|
return Ingredient(**data)
|
|
|
|
|
|
@connect_by_default
|
|
def get_ingredient(
|
|
id: int,
|
|
conn: pg.extensions.connection | None = None,
|
|
) -> Ingredient | None:
|
|
data = select_by_id("recipe", "ingredient", id, conn=conn)
|
|
if not data:
|
|
return
|
|
|
|
ingredient = Ingredient(**data)
|
|
return ingredient
|