diff --git a/docker/psql/init.sql b/docker/psql/init.sql index 7d242d5..0559369 100644 --- a/docker/psql/init.sql +++ b/docker/psql/init.sql @@ -1 +1,42 @@ -CREATE EXTENSION plpython3u \ No newline at end of file +CREATE EXTENSION plpython3u; + +DROP FUNCTION text_pairs; +CREATE OR REPLACE FUNCTION text_pairs(list TEXT[]) +RETURNS SETOF TEXT[] +LANGUAGE plpgsql AS +$$ +DECLARE + x TEXT; + y TEXT; +BEGIN + for x, y in( + SELECT * + FROM unnest(list) a + CROSS JOIN unnest(list) b + WHERE a > b) + LOOP + RETURN NEXT ARRAY [x, y]; + END LOOP; +END; +$$; + + +DROP FUNCTION int_pairs; +CREATE OR REPLACE FUNCTION int_pairs(list INTEGER[]) +RETURNS SETOF INTEGER[] +LANGUAGE plpgsql AS +$$ +DECLARE + x INTEGER; + y INTEGER; +BEGIN + for x, y in( + SELECT * + FROM unnest(list) a + CROSS JOIN unnest(list) b + WHERE a > b) + LOOP + RETURN NEXT ARRAY [x, y]; + END LOOP; +END; +$$; \ No newline at end of file diff --git a/src/db.py b/src/db.py index 81c90b1..7e2265d 100644 --- a/src/db.py +++ b/src/db.py @@ -1,11 +1,15 @@ -from typing import Text -from sqlalchemy import create_engine, Column, Integer, String, \ - ForeignKey, UniqueConstraint +import os +import logging +from types import NoneType +from dotenv import load_dotenv +from xmlrpc.client import Boolean +from sqlalchemy import create_engine, Column, Integer, String, Boolean, \ + ForeignKey, UniqueConstraint, func, select, and_, or_, \ + not_ +from sqlalchemy.types import ARRAY from sqlalchemy.engine import URL from sqlalchemy.ext.declarative import declarative_base -import os -from dotenv import load_dotenv -import logging +from sqlalchemy.orm import Session Base = declarative_base() @@ -55,11 +59,31 @@ class RecipeIngredientParts(Base): class IngredientConnection(Base): __tablename__ = 'IngredientConnection' - ingredient_a = Column(String, primary_key = True) - ingredient_b = Column(String, primary_key = True) + ingredient_a = Column(String, + ForeignKey("RecipeIngredientParts.ingredient"), + primary_key = True) + ingredient_b = Column(String, + ForeignKey("RecipeIngredientParts.ingredient"), + primary_key = True) recipe_count = Column(Integer) UniqueConstraint(ingredient_a, ingredient_b) +class RecipeConnection(Base): + __tablename__ = 'RecipeConnection' + + recipe_a = Column(Integer, + ForeignKey("Recipe.id"), + primary_key = True) + recipe_b = Column(Integer, + ForeignKey("Recipe.id"), + primary_key = True) + ingredient_count = Column(Integer) + +class RecipeGraphed(Base): + __tablename__ = "RecipeGraphed" + + recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key = True) + status = Column(Boolean, nullable = False, default = False) def get_engine(use_dotenv = True, **kargs): @@ -82,6 +106,113 @@ def create_tables(eng): logging.info(f"Createing DB Tables: {eng.url}") Base.metadata.create_all(eng, checkfirst=True) +def pair_query(pairable, groupable, recipe_ids = None, pair_type = String): + pair_func= func.text_pairs + if pair_type == Integer: + pair_func=func.int_pairs + + new_pairs = select(groupable, + pair_func(func.array_agg(pairable.distinct()), + type_=ARRAY(pair_type)).label("pair"))\ + .join(RecipeIngredientParts) + + if not type(recipe_ids) == NoneType: + new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids)) + + new_pairs = new_pairs.group_by(groupable)\ + .cte() + + return new_pairs + +def pair_count_query(pairs, countable, recipe_ids = None): + new_counts = select(pairs, func.count(func.distinct(countable))) + + if not type(recipe_ids) == NoneType: + new_counts = new_counts.where(or_(pairs[0].in_(recipe_ids), + pairs[1].in_(recipe_ids))) + + + new_counts = new_counts.group_by(pairs) + + return new_counts + +def update_graph_connectivity(session = None): + # this is pure SQLAlchemy so it is more portable + # This would have been simpler if I utilized Postgres specific feature + if not session: + session = Session(get_engine()) + + with session.begin(): + ids = select(Recipe.id)\ + .join(RecipeGraphed, isouter = True)\ + .where(RecipeGraphed.status.is_not(True)) + + num_recipes = session.execute(select(func.count('*')).select_from(ids.cte())).fetchone()[0] + if num_recipes <= 0: + logging.info("no new recipies") + return + + logging.info(f"adding {num_recipes} recipes to the graphs") + + new_pairs = pair_query(RecipeIngredientParts.ingredient, + RecipeIngredient.recipe_id, + recipe_ids = ids) + + + new_counts = pair_count_query(new_pairs.c.pair, + new_pairs.c.recipe_id) + + logging.info("addeing new ingredient connections") + for pair, count in session.execute(new_counts): + connection = session.query(IngredientConnection)\ + .where(and_(IngredientConnection.ingredient_a == pair[0], + IngredientConnection.ingredient_b == pair[1]))\ + .first() + if connection: + connection.recipe_count += count + session.merge(connection) + else: + session.add(IngredientConnection(ingredient_a = pair[0], + ingredient_b = pair[1], + recipe_count = count)) + + # update RecipeConnection + logging.info("adding new recipe connections") + all_pairs = pair_query(RecipeIngredient.recipe_id, + RecipeIngredientParts.ingredient, + pair_type=Integer) + + new_counts = pair_count_query(all_pairs.c.pair, + all_pairs.c.ingredient, + recipe_ids=ids) + + i = 0 + for pair, count in session.execute(new_counts): + session.add(RecipeConnection(recipe_a = pair[0], + recipe_b = pair[1], + ingredient_count = count)) + # flush often to reduce memory usage + i += 1 + if (i % 100000) == 0: + session.flush() + + # update RecipeGraphed.status + logging.info("updating existing RecipeGraphed rows") + for recipeGraphed in session.query(RecipeGraphed)\ + .where(RecipeGraphed.recipe_id.in_(ids)): + recipeGraphed.status = True + session.merge(recipeGraphed) + + graphed = select(RecipeGraphed.recipe_id) + + # add recipies that aren't in the table + logging.info("adding new RecipeGraphed rows") + for recipe in session.query(Recipe)\ + .where(and_(Recipe.id.in_(ids), + not_(Recipe.id.in_(graphed)))): + session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) + + if __name__ == "__main__": eng = get_engine()