From 3a45cfb02a354a7e6539717f038caed7a27e1bdf Mon Sep 17 00:00:00 2001 From: Andrei Stoica Date: Sat, 15 Oct 2022 12:35:37 -0400 Subject: [PATCH] formatting --- src/recipe_graph/db.py | 263 +++++++++++++++++-------------- src/recipe_graph/insert_sites.py | 8 +- 2 files changed, 155 insertions(+), 116 deletions(-) diff --git a/src/recipe_graph/db.py b/src/recipe_graph/db.py index 07a36c2..73ccd6c 100644 --- a/src/recipe_graph/db.py +++ b/src/recipe_graph/db.py @@ -3,9 +3,20 @@ import logging from types import NoneType from dotenv import load_dotenv from xmlrpc.client import Boolean -from sqlalchemy import create_engine, Column, Integer, String, Boolean, \ - ForeignKey, UniqueConstraint, func, select, and_, or_, \ - not_ +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + Boolean, + ForeignKey, + UniqueConstraint, + func, + select, + and_, + or_, + not_, +) from sqlalchemy.types import ARRAY from sqlalchemy.engine import URL from sqlalchemy.ext.declarative import declarative_base @@ -14,41 +25,46 @@ from sqlalchemy.orm import Session, sessionmaker Base = declarative_base() + class Ingredient(Base): - __tablename__ = 'Ingredient' - - id = Column(Integer, primary_key = True) - name = Column(String, nullable = False) + __tablename__ = "Ingredient" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False) + class RecipeSite(Base): - __tablename__ = 'RecipeSite' - - id = Column(Integer, primary_key = True) - name = Column(String, nullable = False, unique = True) - ingredient_class = Column(String, nullable = False) - name_class = Column(String, nullable = False) - base_url = Column(String, nullable = False, unique = True) + __tablename__ = "RecipeSite" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False, unique=True) + ingredient_class = Column(String, nullable=False) + name_class = Column(String, nullable=False) + base_url = Column(String, nullable=False, unique=True) + class Recipe(Base): - __tablename__ = 'Recipe' - - id = Column(Integer, primary_key = True) + __tablename__ = "Recipe" + + id = Column(Integer, primary_key=True) name = Column(String) - identifier = Column(String, nullable = False) - recipe_site_id = Column(Integer, ForeignKey('RecipeSite.id')) + identifier = Column(String, nullable=False) + recipe_site_id = Column(Integer, ForeignKey("RecipeSite.id")) UniqueConstraint(identifier, recipe_site_id) + class RecipeIngredient(Base): - __tablename__ = 'RecipeIngredient' - - id = Column(Integer, primary_key = True) - text = Column(String, nullable = False) - recipe_id = Column(Integer, ForeignKey('Recipe.id')) + __tablename__ = "RecipeIngredient" + + id = Column(Integer, primary_key=True) + text = Column(String, nullable=False) + recipe_id = Column(Integer, ForeignKey("Recipe.id")) ingredient_id = Column(Integer, ForeignKey("Ingredient.id")) + class RecipeIngredientParts(Base): - __tablename__ = 'RecipeIngredientParts' - + __tablename__ = "RecipeIngredientParts" + id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True) quantity = Column(String) unit = Column(String) @@ -56,169 +72,188 @@ class RecipeIngredientParts(Base): ingredient = Column(String) supplement = Column(String) + class IngredientConnection(Base): - __tablename__ = 'IngredientConnection' - - ingredient_a = Column(String, - primary_key = True) - ingredient_b = Column(String, - primary_key = True) + __tablename__ = "IngredientConnection" + + ingredient_a = Column(String, primary_key=True) + ingredient_b = Column(String, primary_key=True) recipe_count = Column(Integer) UniqueConstraint(ingredient_a, ingredient_b) + class RecipeConnection(Base): - __tablename__ = 'RecipeConnection' - - recipe_a = Column(Integer, - ForeignKey("Recipe.id"), - primary_key = True) - recipe_b = Column(Integer, - ForeignKey("Recipe.id"), - primary_key = True) + __tablename__ = "RecipeConnection" + + recipe_a = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) + recipe_b = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) ingredient_count = Column(Integer) + class RecipeGraphed(Base): __tablename__ = "RecipeGraphed" - - recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key = True) - status = Column(Boolean, nullable = False, default = False) - -def get_engine(use_dotenv = True, **kargs): + recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) + status = Column(Boolean, nullable=False, default=False) + + +def get_engine(use_dotenv=True, **kargs): if use_dotenv: - load_dotenv() + load_dotenv() DB_URL = os.getenv("POSTGRES_URL") - DB_USER = os.getenv("POSTGRES_USER") + DB_USER = os.getenv("POSTGRES_USER") DB_PASSWORD = os.getenv("POSTGRES_PASSWORD") - DB_NAME = os.getenv("POSTGRES_DB") + DB_NAME = os.getenv("POSTGRES_DB") - eng_url = URL.create('postgresql', - username=DB_USER, - password=DB_PASSWORD, - host=DB_URL, - database=DB_NAME) + eng_url = URL.create( + "postgresql", + username=DB_USER, + password=DB_PASSWORD, + host=DB_URL, + database=DB_NAME, + ) return create_engine(eng_url) + def get_session(**kargs) -> Session: eng = get_engine(**kargs) return sessionmaker(eng) - def create_tables(eng): logging.info(f"Createing DB Tables: {eng.url}") Base.metadata.create_all(eng, checkfirst=True) -def pair_query(pairable, groupable, recipe_ids = None, pair_type = String): - pair_func= func.text_pairs - if pair_type == Integer: - pair_func=func.int_pairs - new_pairs = select(groupable, - pair_func(func.array_agg(pairable.distinct()), - type_=ARRAY(pair_type)).label("pair"))\ - .join(RecipeIngredientParts) +def pair_query(pairable, groupable, recipe_ids=None, pair_type=String): + pair_func = func.text_pairs + if pair_type == Integer: + pair_func = func.int_pairs + + new_pairs = select( + groupable, + pair_func(func.array_agg(pairable.distinct()), type_=ARRAY(pair_type)).label( + "pair" + ), + ).join(RecipeIngredientParts) if not type(recipe_ids) == NoneType: new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids)) - new_pairs = new_pairs.group_by(groupable)\ - .cte() + new_pairs = new_pairs.group_by(groupable).cte() return new_pairs -def pair_count_query(pairs, countable, recipe_ids = None): + +def pair_count_query(pairs, countable, recipe_ids=None): new_counts = select(pairs, func.count(func.distinct(countable))) - + if not type(recipe_ids) == NoneType: - new_counts = new_counts.where(or_(pairs[0].in_(recipe_ids), - pairs[1].in_(recipe_ids))) - - + new_counts = new_counts.where( + or_(pairs[0].in_(recipe_ids), pairs[1].in_(recipe_ids)) + ) + new_counts = new_counts.group_by(pairs) - + return new_counts -def update_graph_connectivity(session = None): + +def update_graph_connectivity(session=None): # this is pure SQLAlchemy so it is more portable # This would have been simpler if I utilized Postgres specific feature if not session: session = Session(get_engine()) with session.begin(): - ids = select(Recipe.id)\ - .join(RecipeGraphed, isouter = True)\ - .where(RecipeGraphed.status.is_not(True)) + ids = ( + select(Recipe.id) + .join(RecipeGraphed, isouter=True) + .where(RecipeGraphed.status.is_not(True)) + ) - num_recipes = session.execute(select(func.count('*')).select_from(ids.cte())).fetchone()[0] + num_recipes = session.execute( + select(func.count("*")).select_from(ids.cte()) + ).fetchone()[0] if num_recipes <= 0: logging.info("no new recipies") return logging.info(f"adding {num_recipes} recipes to the graphs") - - new_pairs = pair_query(RecipeIngredientParts.ingredient, - RecipeIngredient.recipe_id, - recipe_ids = ids) - - new_counts = pair_count_query(new_pairs.c.pair, - new_pairs.c.recipe_id) - + new_pairs = pair_query( + RecipeIngredientParts.ingredient, RecipeIngredient.recipe_id, recipe_ids=ids + ) + + new_counts = pair_count_query(new_pairs.c.pair, new_pairs.c.recipe_id) + logging.info("addeing new ingredient connections") for pair, count in session.execute(new_counts): - connection = session.query(IngredientConnection)\ - .where(and_(IngredientConnection.ingredient_a == pair[0], - IngredientConnection.ingredient_b == pair[1]))\ - .first() + connection = ( + session.query(IngredientConnection) + .where( + and_( + IngredientConnection.ingredient_a == pair[0], + IngredientConnection.ingredient_b == pair[1], + ) + ) + .first() + ) if connection: connection.recipe_count += count session.merge(connection) else: - session.add(IngredientConnection(ingredient_a = pair[0], - ingredient_b = pair[1], - recipe_count = count)) - + session.add( + IngredientConnection( + ingredient_a=pair[0], ingredient_b=pair[1], recipe_count=count + ) + ) + # update RecipeConnection logging.info("adding new recipe connections") - all_pairs = pair_query(RecipeIngredient.recipe_id, - RecipeIngredientParts.ingredient, - pair_type=Integer) - - new_counts = pair_count_query(all_pairs.c.pair, - all_pairs.c.ingredient, - recipe_ids=ids) - + all_pairs = pair_query( + RecipeIngredient.recipe_id, + RecipeIngredientParts.ingredient, + pair_type=Integer, + ) + + new_counts = pair_count_query( + all_pairs.c.pair, all_pairs.c.ingredient, recipe_ids=ids + ) + i = 0 for pair, count in session.execute(new_counts): - session.add(RecipeConnection(recipe_a = pair[0], - recipe_b = pair[1], - ingredient_count = count)) + session.add( + RecipeConnection( + recipe_a=pair[0], recipe_b=pair[1], ingredient_count=count + ) + ) # flush often to reduce memory usage i += 1 if (i % 100000) == 0: session.flush() - + # update RecipeGraphed.status logging.info("updating existing RecipeGraphed rows") - for recipeGraphed in session.query(RecipeGraphed)\ - .where(RecipeGraphed.recipe_id.in_(ids)): + for recipeGraphed in session.query(RecipeGraphed).where( + RecipeGraphed.recipe_id.in_(ids) + ): recipeGraphed.status = True session.merge(recipeGraphed) - + graphed = select(RecipeGraphed.recipe_id) - - # add recipies that aren't in the table + + # add recipies that aren't in the table logging.info("adding new RecipeGraphed rows") - for recipe in session.query(Recipe)\ - .where(and_(Recipe.id.in_(ids), - not_(Recipe.id.in_(graphed)))): + for recipe in session.query(Recipe).where( + and_(Recipe.id.in_(ids), not_(Recipe.id.in_(graphed))) + ): session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) - + + def main(): eng = get_engine() create_tables(eng) - + if __name__ == "__main__": main() diff --git a/src/recipe_graph/insert_sites.py b/src/recipe_graph/insert_sites.py index 886196e..50db413 100644 --- a/src/recipe_graph/insert_sites.py +++ b/src/recipe_graph/insert_sites.py @@ -11,6 +11,7 @@ def load_file(f_name: str): sites = json.load(f) return sites + def setup_argparser() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Import recipes into database") parser.add_argument("file", type=str, help="JSON file with recipe site information") @@ -19,11 +20,13 @@ def setup_argparser() -> argparse.Namespace: args = parser.parse_args() return args + def setup_logging(args: argparse.Namespace): if args.verbose: logging.basicConfig(level=logging.INFO) logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + def main(): args = setup_argparser() setup_logging(args) @@ -35,6 +38,7 @@ def main(): for site in sites: logging.info(f"Adding {site}") session.add(db.RecipeSite(**site)) - + + if __name__ == "__main__": - main() \ No newline at end of file + main()