from pydantic import BaseModel import psycopg2 as pg from psycopg2 import sql from psycopg2.extras import RealDictCursor from os import environ as env, stat from functools import wraps from inspect import getfullargspec from typing import Callable, Type from data import Recipe, Ingredient class NoConnectionException(Exception): pass def connect(): return pg.connect( dbname=env.get("POSTGRES_DB"), user=env.get("POSTGRES_USER"), password=env.get("POSTGRES_PASSWORD"), host=env.get("POSTGRES_HOST"), port=env.get("POSTGRES_PORT"), ) def connect_by_default( func: Callable, kword: str = "conn", conn_func: Callable[..., pg.extensions.connection] = connect, ) -> Callable: @wraps(func) def wrapped(*args, **kargs): for i, arg in enumerate(args): kargs[getfullargspec(func).args[i]] = arg if not kargs.get(kword): kargs[kword] = conn_func() return func(**kargs) return wrapped @connect_by_default def select_by_id( schema_name: str, table_name: str, id: int, columns: list[str] | None = None, conn: pg.extensions.connection | None = None, ) -> dict | None: if not conn: raise NoConnectionException() query = sql.SQL( f""" SELECT {'{}' if columns else '*'} FROM {{}}.{{}} WHERE id = %s;""" ) if columns: cols = sql.SQL(",").join([sql.Identifier(col) for col in columns]) query = query.format( cols, sql.Identifier(schema_name), sql.Identifier(table_name) ) else: query = query.format( sql.Identifier(schema_name), sql.Identifier(table_name) ) with conn.cursor(cursor_factory=RealDictCursor) as curs: curs.execute(query, (id,)) data = curs.fetchone() if not data: return return data @connect_by_default def insert_model( schema_name: str, table_name: str, model: BaseModel, conn: pg.extensions.connection | None = None, ) -> dict | None: if not conn: raise NoConnectionException() data = model.model_dump(exclude_none=True) cols = data.keys() vals = data.values() query = sql.SQL( f""" INSERT INTO {{schema}}.{{table}} ( {{fields}} ) VALUES ( {", ".join(["%s" for _ in vals])} ) RETURNING id; """ ) with conn.cursor() as curs: try: curs.execute( query.format( schema=sql.Identifier(schema_name), table=sql.Identifier(table_name), fields=sql.SQL(",").join( [sql.Identifier(col) for col in cols] ), ), (*vals,), ) row = curs.fetchone() conn.commit() print(row) except pg.errors.UniqueViolation as error: conn.rollback() raise error if not row: return id = row[0] data["id"] = id return data @connect_by_default def insert_recipe( recipe: Recipe, conn: pg.extensions.connection | None = None, ) -> Recipe | None: data = insert_model("recipe", "recipe", recipe, conn) if not data: return return Recipe(**data) @connect_by_default def get_recipe( id: int, conn: pg.extensions.connection | None = None, ) -> Recipe | None: data = select_by_id("recipe", "recipe", id, conn=conn) if not data: return recipe = Recipe(**data) return recipe @connect_by_default def insert_ingredient( ingredient: Ingredient, conn: pg.extensions.connection | None = None, ) -> Ingredient | None: data = insert_model("recipe", "ingredient", ingredient, conn) if not data: return return Ingredient(**data) @connect_by_default def get_ingredient( id: int, conn: pg.extensions.connection | None = None, ) -> Ingredient | None: data = select_by_id("recipe", "ingredient", id, conn=conn) if not data: return ingredient = Ingredient(**data) return ingredient