formatting

This commit is contained in:
Andrei Stoica 2022-10-15 12:35:37 -04:00
parent a5153e2406
commit 3a45cfb02a
2 changed files with 155 additions and 116 deletions

View File

@ -3,9 +3,20 @@ import logging
from types import NoneType from types import NoneType
from dotenv import load_dotenv from dotenv import load_dotenv
from xmlrpc.client import Boolean from xmlrpc.client import Boolean
from sqlalchemy import create_engine, Column, Integer, String, Boolean, \ from sqlalchemy import (
ForeignKey, UniqueConstraint, func, select, and_, or_, \ create_engine,
not_ Column,
Integer,
String,
Boolean,
ForeignKey,
UniqueConstraint,
func,
select,
and_,
or_,
not_,
)
from sqlalchemy.types import ARRAY from sqlalchemy.types import ARRAY
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -14,40 +25,45 @@ from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base() Base = declarative_base()
class Ingredient(Base):
__tablename__ = 'Ingredient'
id = Column(Integer, primary_key = True) class Ingredient(Base):
name = Column(String, nullable = False) __tablename__ = "Ingredient"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
class RecipeSite(Base): class RecipeSite(Base):
__tablename__ = 'RecipeSite' __tablename__ = "RecipeSite"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False, unique=True)
ingredient_class = Column(String, nullable=False)
name_class = Column(String, nullable=False)
base_url = Column(String, nullable=False, unique=True)
id = Column(Integer, primary_key = True)
name = Column(String, nullable = False, unique = True)
ingredient_class = Column(String, nullable = False)
name_class = Column(String, nullable = False)
base_url = Column(String, nullable = False, unique = True)
class Recipe(Base): class Recipe(Base):
__tablename__ = 'Recipe' __tablename__ = "Recipe"
id = Column(Integer, primary_key = True) id = Column(Integer, primary_key=True)
name = Column(String) name = Column(String)
identifier = Column(String, nullable = False) identifier = Column(String, nullable=False)
recipe_site_id = Column(Integer, ForeignKey('RecipeSite.id')) recipe_site_id = Column(Integer, ForeignKey("RecipeSite.id"))
UniqueConstraint(identifier, recipe_site_id) UniqueConstraint(identifier, recipe_site_id)
class RecipeIngredient(Base):
__tablename__ = 'RecipeIngredient'
id = Column(Integer, primary_key = True) class RecipeIngredient(Base):
text = Column(String, nullable = False) __tablename__ = "RecipeIngredient"
recipe_id = Column(Integer, ForeignKey('Recipe.id'))
id = Column(Integer, primary_key=True)
text = Column(String, nullable=False)
recipe_id = Column(Integer, ForeignKey("Recipe.id"))
ingredient_id = Column(Integer, ForeignKey("Ingredient.id")) ingredient_id = Column(Integer, ForeignKey("Ingredient.id"))
class RecipeIngredientParts(Base): class RecipeIngredientParts(Base):
__tablename__ = 'RecipeIngredientParts' __tablename__ = "RecipeIngredientParts"
id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True) id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True)
quantity = Column(String) quantity = Column(String)
@ -56,144 +72,161 @@ class RecipeIngredientParts(Base):
ingredient = Column(String) ingredient = Column(String)
supplement = Column(String) supplement = Column(String)
class IngredientConnection(Base):
__tablename__ = 'IngredientConnection'
ingredient_a = Column(String, class IngredientConnection(Base):
primary_key = True) __tablename__ = "IngredientConnection"
ingredient_b = Column(String,
primary_key = True) ingredient_a = Column(String, primary_key=True)
ingredient_b = Column(String, primary_key=True)
recipe_count = Column(Integer) recipe_count = Column(Integer)
UniqueConstraint(ingredient_a, ingredient_b) UniqueConstraint(ingredient_a, ingredient_b)
class RecipeConnection(Base):
__tablename__ = 'RecipeConnection'
recipe_a = Column(Integer, class RecipeConnection(Base):
ForeignKey("Recipe.id"), __tablename__ = "RecipeConnection"
primary_key = True)
recipe_b = Column(Integer, recipe_a = Column(Integer, ForeignKey("Recipe.id"), primary_key=True)
ForeignKey("Recipe.id"), recipe_b = Column(Integer, ForeignKey("Recipe.id"), primary_key=True)
primary_key = True)
ingredient_count = Column(Integer) ingredient_count = Column(Integer)
class RecipeGraphed(Base): class RecipeGraphed(Base):
__tablename__ = "RecipeGraphed" __tablename__ = "RecipeGraphed"
recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key = True) recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key=True)
status = Column(Boolean, nullable = False, default = False) status = Column(Boolean, nullable=False, default=False)
def get_engine(use_dotenv = True, **kargs): def get_engine(use_dotenv=True, **kargs):
if use_dotenv: if use_dotenv:
load_dotenv() load_dotenv()
DB_URL = os.getenv("POSTGRES_URL") DB_URL = os.getenv("POSTGRES_URL")
DB_USER = os.getenv("POSTGRES_USER") DB_USER = os.getenv("POSTGRES_USER")
DB_PASSWORD = os.getenv("POSTGRES_PASSWORD") DB_PASSWORD = os.getenv("POSTGRES_PASSWORD")
DB_NAME = os.getenv("POSTGRES_DB") DB_NAME = os.getenv("POSTGRES_DB")
eng_url = URL.create('postgresql', eng_url = URL.create(
username=DB_USER, "postgresql",
password=DB_PASSWORD, username=DB_USER,
host=DB_URL, password=DB_PASSWORD,
database=DB_NAME) host=DB_URL,
database=DB_NAME,
)
return create_engine(eng_url) return create_engine(eng_url)
def get_session(**kargs) -> Session: def get_session(**kargs) -> Session:
eng = get_engine(**kargs) eng = get_engine(**kargs)
return sessionmaker(eng) return sessionmaker(eng)
def create_tables(eng): def create_tables(eng):
logging.info(f"Createing DB Tables: {eng.url}") logging.info(f"Createing DB Tables: {eng.url}")
Base.metadata.create_all(eng, checkfirst=True) Base.metadata.create_all(eng, checkfirst=True)
def pair_query(pairable, groupable, recipe_ids = None, pair_type = String):
pair_func= func.text_pairs
if pair_type == Integer:
pair_func=func.int_pairs
new_pairs = select(groupable, def pair_query(pairable, groupable, recipe_ids=None, pair_type=String):
pair_func(func.array_agg(pairable.distinct()), pair_func = func.text_pairs
type_=ARRAY(pair_type)).label("pair"))\ if pair_type == Integer:
.join(RecipeIngredientParts) pair_func = func.int_pairs
new_pairs = select(
groupable,
pair_func(func.array_agg(pairable.distinct()), type_=ARRAY(pair_type)).label(
"pair"
),
).join(RecipeIngredientParts)
if not type(recipe_ids) == NoneType: if not type(recipe_ids) == NoneType:
new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids)) new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids))
new_pairs = new_pairs.group_by(groupable)\ new_pairs = new_pairs.group_by(groupable).cte()
.cte()
return new_pairs return new_pairs
def pair_count_query(pairs, countable, recipe_ids = None):
def pair_count_query(pairs, countable, recipe_ids=None):
new_counts = select(pairs, func.count(func.distinct(countable))) new_counts = select(pairs, func.count(func.distinct(countable)))
if not type(recipe_ids) == NoneType: if not type(recipe_ids) == NoneType:
new_counts = new_counts.where(or_(pairs[0].in_(recipe_ids), new_counts = new_counts.where(
pairs[1].in_(recipe_ids))) or_(pairs[0].in_(recipe_ids), pairs[1].in_(recipe_ids))
)
new_counts = new_counts.group_by(pairs) new_counts = new_counts.group_by(pairs)
return new_counts return new_counts
def update_graph_connectivity(session = None):
def update_graph_connectivity(session=None):
# this is pure SQLAlchemy so it is more portable # this is pure SQLAlchemy so it is more portable
# This would have been simpler if I utilized Postgres specific feature # This would have been simpler if I utilized Postgres specific feature
if not session: if not session:
session = Session(get_engine()) session = Session(get_engine())
with session.begin(): with session.begin():
ids = select(Recipe.id)\ ids = (
.join(RecipeGraphed, isouter = True)\ select(Recipe.id)
.where(RecipeGraphed.status.is_not(True)) .join(RecipeGraphed, isouter=True)
.where(RecipeGraphed.status.is_not(True))
)
num_recipes = session.execute(select(func.count('*')).select_from(ids.cte())).fetchone()[0] num_recipes = session.execute(
select(func.count("*")).select_from(ids.cte())
).fetchone()[0]
if num_recipes <= 0: if num_recipes <= 0:
logging.info("no new recipies") logging.info("no new recipies")
return return
logging.info(f"adding {num_recipes} recipes to the graphs") logging.info(f"adding {num_recipes} recipes to the graphs")
new_pairs = pair_query(RecipeIngredientParts.ingredient, new_pairs = pair_query(
RecipeIngredient.recipe_id, RecipeIngredientParts.ingredient, RecipeIngredient.recipe_id, recipe_ids=ids
recipe_ids = ids) )
new_counts = pair_count_query(new_pairs.c.pair, new_pairs.c.recipe_id)
new_counts = pair_count_query(new_pairs.c.pair,
new_pairs.c.recipe_id)
logging.info("addeing new ingredient connections") logging.info("addeing new ingredient connections")
for pair, count in session.execute(new_counts): for pair, count in session.execute(new_counts):
connection = session.query(IngredientConnection)\ connection = (
.where(and_(IngredientConnection.ingredient_a == pair[0], session.query(IngredientConnection)
IngredientConnection.ingredient_b == pair[1]))\ .where(
.first() and_(
IngredientConnection.ingredient_a == pair[0],
IngredientConnection.ingredient_b == pair[1],
)
)
.first()
)
if connection: if connection:
connection.recipe_count += count connection.recipe_count += count
session.merge(connection) session.merge(connection)
else: else:
session.add(IngredientConnection(ingredient_a = pair[0], session.add(
ingredient_b = pair[1], IngredientConnection(
recipe_count = count)) ingredient_a=pair[0], ingredient_b=pair[1], recipe_count=count
)
)
# update RecipeConnection # update RecipeConnection
logging.info("adding new recipe connections") logging.info("adding new recipe connections")
all_pairs = pair_query(RecipeIngredient.recipe_id, all_pairs = pair_query(
RecipeIngredientParts.ingredient, RecipeIngredient.recipe_id,
pair_type=Integer) RecipeIngredientParts.ingredient,
pair_type=Integer,
)
new_counts = pair_count_query(all_pairs.c.pair, new_counts = pair_count_query(
all_pairs.c.ingredient, all_pairs.c.pair, all_pairs.c.ingredient, recipe_ids=ids
recipe_ids=ids) )
i = 0 i = 0
for pair, count in session.execute(new_counts): for pair, count in session.execute(new_counts):
session.add(RecipeConnection(recipe_a = pair[0], session.add(
recipe_b = pair[1], RecipeConnection(
ingredient_count = count)) recipe_a=pair[0], recipe_b=pair[1], ingredient_count=count
)
)
# flush often to reduce memory usage # flush often to reduce memory usage
i += 1 i += 1
if (i % 100000) == 0: if (i % 100000) == 0:
@ -201,8 +234,9 @@ def update_graph_connectivity(session = None):
# update RecipeGraphed.status # update RecipeGraphed.status
logging.info("updating existing RecipeGraphed rows") logging.info("updating existing RecipeGraphed rows")
for recipeGraphed in session.query(RecipeGraphed)\ for recipeGraphed in session.query(RecipeGraphed).where(
.where(RecipeGraphed.recipe_id.in_(ids)): RecipeGraphed.recipe_id.in_(ids)
):
recipeGraphed.status = True recipeGraphed.status = True
session.merge(recipeGraphed) session.merge(recipeGraphed)
@ -210,11 +244,12 @@ def update_graph_connectivity(session = None):
# add recipies that aren't in the table # add recipies that aren't in the table
logging.info("adding new RecipeGraphed rows") logging.info("adding new RecipeGraphed rows")
for recipe in session.query(Recipe)\ for recipe in session.query(Recipe).where(
.where(and_(Recipe.id.in_(ids), and_(Recipe.id.in_(ids), not_(Recipe.id.in_(graphed)))
not_(Recipe.id.in_(graphed)))): ):
session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) session.add(RecipeGraphed(recipe_id=recipe.id, status=True))
def main(): def main():
eng = get_engine() eng = get_engine()
create_tables(eng) create_tables(eng)

View File

@ -11,6 +11,7 @@ def load_file(f_name: str):
sites = json.load(f) sites = json.load(f)
return sites return sites
def setup_argparser() -> argparse.Namespace: def setup_argparser() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Import recipes into database") parser = argparse.ArgumentParser(description="Import recipes into database")
parser.add_argument("file", type=str, help="JSON file with recipe site information") parser.add_argument("file", type=str, help="JSON file with recipe site information")
@ -19,11 +20,13 @@ def setup_argparser() -> argparse.Namespace:
args = parser.parse_args() args = parser.parse_args()
return args return args
def setup_logging(args: argparse.Namespace): def setup_logging(args: argparse.Namespace):
if args.verbose: if args.verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
def main(): def main():
args = setup_argparser() args = setup_argparser()
setup_logging(args) setup_logging(args)
@ -36,5 +39,6 @@ def main():
logging.info(f"Adding {site}") logging.info(f"Adding {site}")
session.add(db.RecipeSite(**site)) session.add(db.RecipeSite(**site))
if __name__ == "__main__": if __name__ == "__main__":
main() main()