recipe-graph/recipe_graph/db.py

220 lines
7.9 KiB
Python

import os
import logging
from types import NoneType
from dotenv import load_dotenv
from xmlrpc.client import Boolean
from sqlalchemy import create_engine, Column, Integer, String, Boolean, \
ForeignKey, UniqueConstraint, func, select, and_, or_, \
not_
from sqlalchemy.types import ARRAY
from sqlalchemy.engine import URL
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
Base = declarative_base()
class Ingredient(Base):
__tablename__ = 'Ingredient'
id = Column(Integer, primary_key = True)
name = Column(String, nullable = False)
class RecipeSite(Base):
__tablename__ = 'RecipeSite'
id = Column(Integer, primary_key = True)
name = Column(String, nullable = False, unique = True)
ingredient_class = Column(String, nullable = False)
name_class = Column(String, nullable = False)
base_url = Column(String, nullable = False, unique = True)
class Recipe(Base):
__tablename__ = 'Recipe'
id = Column(Integer, primary_key = True)
name = Column(String)
identifier = Column(String, nullable = False)
recipe_site_id = Column(Integer, ForeignKey('RecipeSite.id'))
UniqueConstraint(identifier, recipe_site_id)
class RecipeIngredient(Base):
__tablename__ = 'RecipeIngredient'
id = Column(Integer, primary_key = True)
text = Column(String, nullable = False)
recipe_id = Column(Integer, ForeignKey('Recipe.id'))
ingredient_id = Column(Integer, ForeignKey("Ingredient.id"))
class RecipeIngredientParts(Base):
__tablename__ = 'RecipeIngredientParts'
id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True)
quantity = Column(String)
unit = Column(String)
instruction = Column(String)
ingredient = Column(String)
supplement = Column(String)
class IngredientConnection(Base):
__tablename__ = 'IngredientConnection'
ingredient_a = Column(String,
ForeignKey("RecipeIngredientParts.ingredient"),
primary_key = True)
ingredient_b = Column(String,
ForeignKey("RecipeIngredientParts.ingredient"),
primary_key = True)
recipe_count = Column(Integer)
UniqueConstraint(ingredient_a, ingredient_b)
class RecipeConnection(Base):
__tablename__ = 'RecipeConnection'
recipe_a = Column(Integer,
ForeignKey("Recipe.id"),
primary_key = True)
recipe_b = Column(Integer,
ForeignKey("Recipe.id"),
primary_key = True)
ingredient_count = Column(Integer)
class RecipeGraphed(Base):
__tablename__ = "RecipeGraphed"
recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key = True)
status = Column(Boolean, nullable = False, default = False)
def get_engine(use_dotenv = True, **kargs):
if use_dotenv:
load_dotenv()
DB_URL = os.getenv("POSTGRES_URL")
DB_USER = os.getenv("POSTGRES_USER")
DB_PASSWORD = os.getenv("POSTGRES_PASSWORD")
DB_NAME = os.getenv("POSTGRES_DB")
eng_url = URL.create('postgresql',
username=DB_USER,
password=DB_PASSWORD,
host=DB_URL,
database=DB_NAME)
return create_engine(eng_url)
def create_tables(eng):
logging.info(f"Createing DB Tables: {eng.url}")
Base.metadata.create_all(eng, checkfirst=True)
def pair_query(pairable, groupable, recipe_ids = None, pair_type = String):
pair_func= func.text_pairs
if pair_type == Integer:
pair_func=func.int_pairs
new_pairs = select(groupable,
pair_func(func.array_agg(pairable.distinct()),
type_=ARRAY(pair_type)).label("pair"))\
.join(RecipeIngredientParts)
if not type(recipe_ids) == NoneType:
new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids))
new_pairs = new_pairs.group_by(groupable)\
.cte()
return new_pairs
def pair_count_query(pairs, countable, recipe_ids = None):
new_counts = select(pairs, func.count(func.distinct(countable)))
if not type(recipe_ids) == NoneType:
new_counts = new_counts.where(or_(pairs[0].in_(recipe_ids),
pairs[1].in_(recipe_ids)))
new_counts = new_counts.group_by(pairs)
return new_counts
def update_graph_connectivity(session = None):
# this is pure SQLAlchemy so it is more portable
# This would have been simpler if I utilized Postgres specific feature
if not session:
session = Session(get_engine())
with session.begin():
ids = select(Recipe.id)\
.join(RecipeGraphed, isouter = True)\
.where(RecipeGraphed.status.is_not(True))
num_recipes = session.execute(select(func.count('*')).select_from(ids.cte())).fetchone()[0]
if num_recipes <= 0:
logging.info("no new recipies")
return
logging.info(f"adding {num_recipes} recipes to the graphs")
new_pairs = pair_query(RecipeIngredientParts.ingredient,
RecipeIngredient.recipe_id,
recipe_ids = ids)
new_counts = pair_count_query(new_pairs.c.pair,
new_pairs.c.recipe_id)
logging.info("addeing new ingredient connections")
for pair, count in session.execute(new_counts):
connection = session.query(IngredientConnection)\
.where(and_(IngredientConnection.ingredient_a == pair[0],
IngredientConnection.ingredient_b == pair[1]))\
.first()
if connection:
connection.recipe_count += count
session.merge(connection)
else:
session.add(IngredientConnection(ingredient_a = pair[0],
ingredient_b = pair[1],
recipe_count = count))
# update RecipeConnection
logging.info("adding new recipe connections")
all_pairs = pair_query(RecipeIngredient.recipe_id,
RecipeIngredientParts.ingredient,
pair_type=Integer)
new_counts = pair_count_query(all_pairs.c.pair,
all_pairs.c.ingredient,
recipe_ids=ids)
i = 0
for pair, count in session.execute(new_counts):
session.add(RecipeConnection(recipe_a = pair[0],
recipe_b = pair[1],
ingredient_count = count))
# flush often to reduce memory usage
i += 1
if (i % 100000) == 0:
session.flush()
# update RecipeGraphed.status
logging.info("updating existing RecipeGraphed rows")
for recipeGraphed in session.query(RecipeGraphed)\
.where(RecipeGraphed.recipe_id.in_(ids)):
recipeGraphed.status = True
session.merge(recipeGraphed)
graphed = select(RecipeGraphed.recipe_id)
# add recipies that aren't in the table
logging.info("adding new RecipeGraphed rows")
for recipe in session.query(Recipe)\
.where(and_(Recipe.id.in_(ids),
not_(Recipe.id.in_(graphed)))):
session.add(RecipeGraphed(recipe_id=recipe.id, status=True))
if __name__ == "__main__":
eng = get_engine()
create_tables(eng)