diff --git a/.drone.yml b/.drone.yml new file mode 100644 index 0000000..caf5ab6 --- /dev/null +++ b/.drone.yml @@ -0,0 +1,65 @@ +--- +kind: pipeline +name: test +environment: + project_name: rgraph +trigger: + event: + include: + - pull_request + +steps: + - name: db-up + image: docker/compose:alpine-1.29.2 + environment: + POSTGRES_USER: + from_secret: TESTING_USER + POSTGRES_PASSWORD: + from_secret: TESTING_PASSWORD + POSTGRES_DB: + from_secret: TESTING_DB + volumes: + - name: docker_sock + path: /var/run/docker.sock + commands: + - docker-compose -p rgraph-test up -d + + - name: requirements + image: python:3.10-alpine + commands: + - python -m venv .venv + - . .venv/bin/activate + - pip install -r requirements.txt + + - name: build + image: python:3.10-alpine + commands: + - . .venv/bin/activate + - pip install . + + - name: test + image: python:3.10-alpine + environment: + POSTGRES_USER: + from_secret: TESTING_USER + POSTGRES_PASSWORD: + from_secret: TESTING_PASSWORD + POSTGRES_DB: + from_secret: TESTING_DB + commands: + - hostip=$(ip route show | awk '/default/ {print $3}') + - export POSTGRES_URL=$hostip + - . .venv/bin/activate + - pytest + - name: db-cleanup + image: docker/compose:alpine-1.29.2 + volumes: + - name: docker_sock + path: /var/run/docker.sock + commands: + - docker-compose -p rgraph-test down + - docker volume rm rgraph-test_dbdata +volumes: + - name: docker_sock + host: + path: /var/run/docker.sock diff --git a/.gitignore b/.gitignore index 64b7803..44010e0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,12 @@ data/ *__pycache__ *env *.code-workspace -sandbox/ \ No newline at end of file +sandbox/ +htmlcov +.coverage +*.lcov +.vscode/ +*.pytest_cache/ +*.egg-info/ +dist/ +build/ diff --git a/README.md b/README.md index 006be47..28a7c8e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,19 @@ POSTGRES_DB=rgraph Start database ```sh -docker-compose up +docker-compose -p recipe-dev up +``` + +Example `sites.json` +```json +[ + { + "name": "Example Site Name", + "ingredient_class": "example-ingredients-item-name", + "name_class" : "example-heading-content", + "base_url" : "https://www.example.com/recipe/" + } +] ``` Initialize database and recipe sites @@ -29,6 +41,11 @@ python src/db.py python src/insert_sites.py data/sites.json ``` +Shutdown database +```sh +docker-compose -p recipe-dev down +``` + ## Usage ### Scrape import new recipes @@ -60,6 +77,44 @@ options: -v, --verbose ``` +## Testing +For testing create a new set up docker containers. Tests will fail if +the database is already initiated. + +Starting testing db +```sh +docker-compose -p recipe-test up +``` + +running tests +```sh +pytest +``` + +**WARNINING**: If you get `ERROR at setup of test_db_connection` and +`ERROR at setup of test_db_class_creation`, please check if testing database is +already initiated. Testing is destructive and should be done on a fresh database. + + +Shutting down testing db +```sh +docker-compose -p recipe-test down +``` + + +Test are written in pytest framework. Currently focused on unittest. +Integration tests to come. + +To run test use: +``` +pytest --cov=src/recipe_graph --cov-report lcov --cov-report html +``` + +The html report is under `htmlcov/` and can be viewed through any browser. +The `lcov` file can be used for the [Coverage Gutters](https://marketplace.visualstudio.com/items?itemName=ryanluker.vscode-coverage-gutters) +plugin for VS Code to view coverage in your editor. + + ## TODO > ☑ automate scraping\ > ☑ extracting quantity and name (via regex)\ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a19e8bb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[metadata] +name = "recepie_graph" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b4b5e06..9bce4cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,17 @@ +attrs==22.1.0 beautifulsoup4==4.11.1 +coverage==6.5.0 greenlet==1.1.2 +iniconfig==1.1.1 +packaging==21.3 +pluggy==1.0.0 psycopg2-binary==2.9.3 +py==1.11.0 PyMySQL==1.0.2 +pyparsing==3.0.9 +pytest==7.1.3 +pytest-cov==4.0.0 python-dotenv==0.20.0 soupsieve==2.3.2.post1 SQLAlchemy==1.4.39 +tomli==2.0.1 diff --git a/src/db.py b/src/db.py deleted file mode 100644 index 7e2265d..0000000 --- a/src/db.py +++ /dev/null @@ -1,219 +0,0 @@ -import os -import logging -from types import NoneType -from dotenv import load_dotenv -from xmlrpc.client import Boolean -from sqlalchemy import create_engine, Column, Integer, String, Boolean, \ - ForeignKey, UniqueConstraint, func, select, and_, or_, \ - not_ -from sqlalchemy.types import ARRAY -from sqlalchemy.engine import URL -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session - - -Base = declarative_base() - -class Ingredient(Base): - __tablename__ = 'Ingredient' - - id = Column(Integer, primary_key = True) - name = Column(String, nullable = False) - -class RecipeSite(Base): - __tablename__ = 'RecipeSite' - - id = Column(Integer, primary_key = True) - name = Column(String, nullable = False, unique = True) - ingredient_class = Column(String, nullable = False) - name_class = Column(String, nullable = False) - base_url = Column(String, nullable = False, unique = True) - -class Recipe(Base): - __tablename__ = 'Recipe' - - id = Column(Integer, primary_key = True) - name = Column(String) - identifier = Column(String, nullable = False) - recipe_site_id = Column(Integer, ForeignKey('RecipeSite.id')) - UniqueConstraint(identifier, recipe_site_id) - -class RecipeIngredient(Base): - __tablename__ = 'RecipeIngredient' - - id = Column(Integer, primary_key = True) - text = Column(String, nullable = False) - recipe_id = Column(Integer, ForeignKey('Recipe.id')) - ingredient_id = Column(Integer, ForeignKey("Ingredient.id")) - -class RecipeIngredientParts(Base): - __tablename__ = 'RecipeIngredientParts' - - id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True) - quantity = Column(String) - unit = Column(String) - instruction = Column(String) - ingredient = Column(String) - supplement = Column(String) - -class IngredientConnection(Base): - __tablename__ = 'IngredientConnection' - - ingredient_a = Column(String, - ForeignKey("RecipeIngredientParts.ingredient"), - primary_key = True) - ingredient_b = Column(String, - ForeignKey("RecipeIngredientParts.ingredient"), - primary_key = True) - recipe_count = Column(Integer) - UniqueConstraint(ingredient_a, ingredient_b) - -class RecipeConnection(Base): - __tablename__ = 'RecipeConnection' - - recipe_a = Column(Integer, - ForeignKey("Recipe.id"), - primary_key = True) - recipe_b = Column(Integer, - ForeignKey("Recipe.id"), - primary_key = True) - ingredient_count = Column(Integer) - -class RecipeGraphed(Base): - __tablename__ = "RecipeGraphed" - - recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key = True) - status = Column(Boolean, nullable = False, default = False) - - -def get_engine(use_dotenv = True, **kargs): - if use_dotenv: - load_dotenv() - DB_URL = os.getenv("POSTGRES_URL") - DB_USER = os.getenv("POSTGRES_USER") - DB_PASSWORD = os.getenv("POSTGRES_PASSWORD") - DB_NAME = os.getenv("POSTGRES_DB") - - eng_url = URL.create('postgresql', - username=DB_USER, - password=DB_PASSWORD, - host=DB_URL, - database=DB_NAME) - return create_engine(eng_url) - - -def create_tables(eng): - logging.info(f"Createing DB Tables: {eng.url}") - Base.metadata.create_all(eng, checkfirst=True) - -def pair_query(pairable, groupable, recipe_ids = None, pair_type = String): - pair_func= func.text_pairs - if pair_type == Integer: - pair_func=func.int_pairs - - new_pairs = select(groupable, - pair_func(func.array_agg(pairable.distinct()), - type_=ARRAY(pair_type)).label("pair"))\ - .join(RecipeIngredientParts) - - if not type(recipe_ids) == NoneType: - new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids)) - - new_pairs = new_pairs.group_by(groupable)\ - .cte() - - return new_pairs - -def pair_count_query(pairs, countable, recipe_ids = None): - new_counts = select(pairs, func.count(func.distinct(countable))) - - if not type(recipe_ids) == NoneType: - new_counts = new_counts.where(or_(pairs[0].in_(recipe_ids), - pairs[1].in_(recipe_ids))) - - - new_counts = new_counts.group_by(pairs) - - return new_counts - -def update_graph_connectivity(session = None): - # this is pure SQLAlchemy so it is more portable - # This would have been simpler if I utilized Postgres specific feature - if not session: - session = Session(get_engine()) - - with session.begin(): - ids = select(Recipe.id)\ - .join(RecipeGraphed, isouter = True)\ - .where(RecipeGraphed.status.is_not(True)) - - num_recipes = session.execute(select(func.count('*')).select_from(ids.cte())).fetchone()[0] - if num_recipes <= 0: - logging.info("no new recipies") - return - - logging.info(f"adding {num_recipes} recipes to the graphs") - - new_pairs = pair_query(RecipeIngredientParts.ingredient, - RecipeIngredient.recipe_id, - recipe_ids = ids) - - - new_counts = pair_count_query(new_pairs.c.pair, - new_pairs.c.recipe_id) - - logging.info("addeing new ingredient connections") - for pair, count in session.execute(new_counts): - connection = session.query(IngredientConnection)\ - .where(and_(IngredientConnection.ingredient_a == pair[0], - IngredientConnection.ingredient_b == pair[1]))\ - .first() - if connection: - connection.recipe_count += count - session.merge(connection) - else: - session.add(IngredientConnection(ingredient_a = pair[0], - ingredient_b = pair[1], - recipe_count = count)) - - # update RecipeConnection - logging.info("adding new recipe connections") - all_pairs = pair_query(RecipeIngredient.recipe_id, - RecipeIngredientParts.ingredient, - pair_type=Integer) - - new_counts = pair_count_query(all_pairs.c.pair, - all_pairs.c.ingredient, - recipe_ids=ids) - - i = 0 - for pair, count in session.execute(new_counts): - session.add(RecipeConnection(recipe_a = pair[0], - recipe_b = pair[1], - ingredient_count = count)) - # flush often to reduce memory usage - i += 1 - if (i % 100000) == 0: - session.flush() - - # update RecipeGraphed.status - logging.info("updating existing RecipeGraphed rows") - for recipeGraphed in session.query(RecipeGraphed)\ - .where(RecipeGraphed.recipe_id.in_(ids)): - recipeGraphed.status = True - session.merge(recipeGraphed) - - graphed = select(RecipeGraphed.recipe_id) - - # add recipies that aren't in the table - logging.info("adding new RecipeGraphed rows") - for recipe in session.query(Recipe)\ - .where(and_(Recipe.id.in_(ids), - not_(Recipe.id.in_(graphed)))): - session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) - - - -if __name__ == "__main__": - eng = get_engine() - create_tables(eng) diff --git a/src/insert_sites.py b/src/insert_sites.py deleted file mode 100644 index a3a7107..0000000 --- a/src/insert_sites.py +++ /dev/null @@ -1,29 +0,0 @@ -from sqlalchemy.orm import sessionmaker -import db -import json -import argparse -import logging - -parser = argparse.ArgumentParser(description='Import recipes into database') -parser.add_argument('file', type=str, - help='JSON file with recipe site information') -parser.add_argument('-v', '--verbose', action='store_true') - - -args = parser.parse_args() -if args.verbose: - logging.basicConfig(level=logging.INFO) - logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) - -with open(args.file) as f: - sites = json.load(f) - -eng = db.get_engine() -S = sessionmaker(eng) - -with S.begin() as session: - for site in sites: - logging.info(f"Adding {site}") - session.add(db.RecipeSite(**site)) - - \ No newline at end of file diff --git a/src/recipe_graph/__init__.py b/src/recipe_graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/recipe_graph/db.py b/src/recipe_graph/db.py new file mode 100644 index 0000000..8dc7ad6 --- /dev/null +++ b/src/recipe_graph/db.py @@ -0,0 +1,259 @@ +import os +import logging +from types import NoneType +from dotenv import load_dotenv +from xmlrpc.client import Boolean +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + Boolean, + ForeignKey, + UniqueConstraint, + func, + select, + and_, + or_, + not_, +) +from sqlalchemy.types import ARRAY +from sqlalchemy.engine import URL +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker + + +Base = declarative_base() + + +class Ingredient(Base): + __tablename__ = "Ingredient" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False) + + +class RecipeSite(Base): + __tablename__ = "RecipeSite" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False, unique=True) + ingredient_class = Column(String, nullable=False) + name_class = Column(String, nullable=False) + base_url = Column(String, nullable=False, unique=True) + + +class Recipe(Base): + __tablename__ = "Recipe" + + id = Column(Integer, primary_key=True) + name = Column(String) + identifier = Column(String, nullable=False) + recipe_site_id = Column(Integer, ForeignKey("RecipeSite.id")) + UniqueConstraint(identifier, recipe_site_id) + + +class RecipeIngredient(Base): + __tablename__ = "RecipeIngredient" + + id = Column(Integer, primary_key=True) + text = Column(String, nullable=False) + recipe_id = Column(Integer, ForeignKey("Recipe.id")) + ingredient_id = Column(Integer, ForeignKey("Ingredient.id")) + + +class RecipeIngredientParts(Base): + __tablename__ = "RecipeIngredientParts" + + id = Column(Integer, ForeignKey("RecipeIngredient.id"), primary_key=True) + quantity = Column(String) + unit = Column(String) + instruction = Column(String) + ingredient = Column(String) + supplement = Column(String) + + +class IngredientConnection(Base): + __tablename__ = "IngredientConnection" + + ingredient_a = Column(String, primary_key=True) + ingredient_b = Column(String, primary_key=True) + recipe_count = Column(Integer) + UniqueConstraint(ingredient_a, ingredient_b) + + +class RecipeConnection(Base): + __tablename__ = "RecipeConnection" + + recipe_a = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) + recipe_b = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) + ingredient_count = Column(Integer) + + +class RecipeGraphed(Base): + __tablename__ = "RecipeGraphed" + + recipe_id = Column(Integer, ForeignKey("Recipe.id"), primary_key=True) + status = Column(Boolean, nullable=False, default=False) + + +def get_engine(use_dotenv=True, **kargs): + if use_dotenv: + load_dotenv() + DB_URL = os.getenv("POSTGRES_URL") + DB_USER = os.getenv("POSTGRES_USER") + DB_PASSWORD = os.getenv("POSTGRES_PASSWORD") + DB_NAME = os.getenv("POSTGRES_DB") + + eng_url = URL.create( + "postgresql", + username=DB_USER, + password=DB_PASSWORD, + host=DB_URL, + database=DB_NAME, + ) + return create_engine(eng_url) + + +def get_session(**kargs) -> Session: + eng = get_engine(**kargs) + return sessionmaker(eng) + + +def create_tables(eng): + logging.info(f"Createing DB Tables: {eng.url}") + Base.metadata.create_all(eng, checkfirst=True) + + +def pair_query(pairable, groupable, recipe_ids=None, pair_type=String): + pair_func = func.text_pairs + if pair_type == Integer: + pair_func = func.int_pairs + + new_pairs = select( + groupable, + pair_func(func.array_agg(pairable.distinct()), type_=ARRAY(pair_type)).label( + "pair" + ), + ).join(RecipeIngredientParts) + + if not type(recipe_ids) == NoneType: + new_pairs = new_pairs.where(RecipeIngredient.recipe_id.in_(recipe_ids)) + + new_pairs = new_pairs.group_by(groupable).cte() + + return new_pairs + + +def pair_count_query(pairs, countable, recipe_ids=None): + new_counts = select(pairs, func.count(func.distinct(countable))) + + if not type(recipe_ids) == NoneType: + new_counts = new_counts.where( + or_(pairs[0].in_(recipe_ids), pairs[1].in_(recipe_ids)) + ) + + new_counts = new_counts.group_by(pairs) + + return new_counts + + +def update_graph_connectivity(session=None): + # this is pure SQLAlchemy so it is more portable + # This would have been simpler if I utilized Postgres specific feature + if not session: + session = Session(get_engine()) + + with session.begin(): + ids = ( + select(Recipe.id) + .join(RecipeGraphed, isouter=True) + .where(RecipeGraphed.status.is_not(True)) + ) + + num_recipes = session.execute( + select(func.count("*")).select_from(ids.cte()) + ).fetchone()[0] + if num_recipes <= 0: + logging.info("no new recipies") + return + + logging.info(f"adding {num_recipes} recipes to the graphs") + + new_pairs = pair_query( + RecipeIngredientParts.ingredient, RecipeIngredient.recipe_id, recipe_ids=ids + ) + + new_counts = pair_count_query(new_pairs.c.pair, new_pairs.c.recipe_id) + + logging.info("addeing new ingredient connections") + for pair, count in session.execute(new_counts): + connection = ( + session.query(IngredientConnection) + .where( + and_( + IngredientConnection.ingredient_a == pair[0], + IngredientConnection.ingredient_b == pair[1], + ) + ) + .first() + ) + if connection: + connection.recipe_count += count + session.merge(connection) + else: + session.add( + IngredientConnection( + ingredient_a=pair[0], ingredient_b=pair[1], recipe_count=count + ) + ) + + # update RecipeConnection + logging.info("adding new recipe connections") + all_pairs = pair_query( + RecipeIngredient.recipe_id, + RecipeIngredientParts.ingredient, + pair_type=Integer, + ) + + new_counts = pair_count_query( + all_pairs.c.pair, all_pairs.c.ingredient, recipe_ids=ids + ) + + i = 0 + for pair, count in session.execute(new_counts): + session.add( + RecipeConnection( + recipe_a=pair[0], recipe_b=pair[1], ingredient_count=count + ) + ) + # flush often to reduce memory usage + i += 1 + if (i % 100000) == 0: + session.flush() + + # update RecipeGraphed.status + logging.info("updating existing RecipeGraphed rows") + for recipeGraphed in session.query(RecipeGraphed).where( + RecipeGraphed.recipe_id.in_(ids) + ): + recipeGraphed.status = True + session.merge(recipeGraphed) + + graphed = select(RecipeGraphed.recipe_id) + + # add recipies that aren't in the table + logging.info("adding new RecipeGraphed rows") + for recipe in session.query(Recipe).where( + and_(Recipe.id.in_(ids), not_(Recipe.id.in_(graphed))) + ): + session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) + + +def main(): # pragma: no cover + eng = get_engine() + create_tables(eng) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/src/recipe_graph/insert_sites.py b/src/recipe_graph/insert_sites.py new file mode 100644 index 0000000..812bd76 --- /dev/null +++ b/src/recipe_graph/insert_sites.py @@ -0,0 +1,46 @@ +from pydoc import apropos +from sqlalchemy.orm import sessionmaker +from recipe_graph import db +import json +import argparse +import logging +import sys + + +def load_file(f_name: str): + with open(f_name) as f: + sites = json.load(f) + return sites + + +def setup_argparser(args) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Import recipes into database") + parser.add_argument("file", type=str, help="JSON file with recipe site information") + parser.add_argument("-v", "--verbose", action="store_true") + + return parser.parse_args(args) + + +def setup_logging(args: argparse.Namespace) -> logging.Logger: + logger = logging.Logger("insert_sites", logging.WARNING) + if args.verbose: + logger.setLevel(logging.INFO) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + return logger + + +def main(): # pragma: no cover + args = setup_argparser(sys.argv[1:]) + logger = setup_logging(args) + + S = db.get_session() + sites = load_file(args.file) + + with S.begin() as session: + for site in sites: + logger.info(f"Adding {site}") + session.add(db.RecipeSite(**site)) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/src/scrape.py b/src/recipe_graph/scrape.py similarity index 69% rename from src/scrape.py rename to src/recipe_graph/scrape.py index 06ca61f..bdc4519 100644 --- a/src/scrape.py +++ b/src/recipe_graph/scrape.py @@ -1,9 +1,7 @@ -from ast import alias -from dis import Instruction -import db +import sys +from recipe_graph import db import re from sqlalchemy import select, desc, exists, not_, except_ -from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker import bs4 from urllib.request import urlopen @@ -129,62 +127,64 @@ def parse_recipe(session, recipe, site): return recipe +def main(): + parser = ArgumentParser(description="Scrape a recipe site for recipies") + parser.add_argument('site', + help='Name of site') + parser.add_argument('-id', '--identifier', dest='id', + help='url of recipe(reletive to base url of site) or commma seperated list') + parser.add_argument('-a', '--auto', action='store', dest='n', + help='automaticaly generate identifier(must supply number of recipies to scrape)') + parser.add_argument('-v', '--verbose', action='store_true') -parser = ArgumentParser(description="Scrape a recipe site for recipies") -parser.add_argument('site', - help='Name of site') -parser.add_argument('-id', '--identifier', dest='id', - help='url of recipe(reletive to base url of site) or commma seperated list') -parser.add_argument('-a', '--auto', action='store', dest='n', - help='automaticaly generate identifier(must supply number of recipies to scrape)') -parser.add_argument('-v', '--verbose', action='store_true') + args = parser.parse_args(sys.argv) + if args.verbose: + logging.basicConfig(level=logging.INFO) + logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) -args = parser.parse_args() -if args.verbose: - logging.basicConfig(level=logging.INFO) - logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + eng = db.get_engine() + S = sessionmaker(eng) -eng = db.get_engine() -S = sessionmaker(eng) - -with S.begin() as sess: - site = sess.query(db.RecipeSite).where(db.RecipeSite.name == args.site).one() - site_id = site.id + with S.begin() as sess: + site = sess.query(db.RecipeSite).where(db.RecipeSite.name == args.site).one() + site_id = site.id + + recipe_ids = [] + starting_id = 0 + if args.id and not args.n: + recipe_ids.append(args.id) + logging.info(f'Retreiving single recipe: {args.id}') + elif args.n: + if not args.id: + last_recipe = sess.query(db.Recipe).\ + where(db.Recipe.recipe_site_id == site.id).\ + order_by(desc(db.Recipe.identifier)).\ + limit(1).\ + scalar() + starting_id = int(last_recipe.identifier) + 1 + else: + starting_id = int(args.id) + recipe_ids = range(starting_id, starting_id+int(args.n)) + logging.info(f'Retreving {args.n} recipes from {site.base_url} starting at {starting_id}') - recipe_ids = [] - starting_id = 0 - if args.id and not args.n: - recipe_ids.append(args.id) - logging.info(f'Retreiving single recipe: {args.id}') - elif args.n: - if not args.id: - last_recipe = sess.query(db.Recipe).\ - where(db.Recipe.recipe_site_id == site.id).\ - order_by(desc(db.Recipe.identifier)).\ - limit(1).\ - scalar() - starting_id = int(last_recipe.identifier) + 1 - else: - starting_id = int(args.id) - recipe_ids = range(starting_id, starting_id+int(args.n)) - logging.info(f'Retreving {args.n} recipes from {site.base_url} starting at {starting_id}') - - - - for recipe_id in recipe_ids: - try: - savepoint = sess.begin_nested() + + + for recipe_id in recipe_ids: + try: + savepoint = sess.begin_nested() - recipe = db.Recipe(identifier = recipe_id, recipe_site_id = site.id) - parse_recipe(sess, recipe, site) + recipe = db.Recipe(identifier = recipe_id, recipe_site_id = site.id) + parse_recipe(sess, recipe, site) - savepoint.commit() - except KeyboardInterrupt as e: - savepoint.rollback() - break - except Exception as e: - savepoint.rollback() - logging.error(e) - continue + savepoint.commit() + except KeyboardInterrupt as e: + savepoint.rollback() + break + except Exception as e: + savepoint.rollback() + logging.error(e) + continue - \ No newline at end of file + +if __name__ == "__main__": # pragma: no cover + main() \ No newline at end of file diff --git a/test/test_db.py b/test/test_db.py new file mode 100644 index 0000000..f40c5ce --- /dev/null +++ b/test/test_db.py @@ -0,0 +1,69 @@ +import inspect +from recipe_graph import db +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker +import sqlalchemy + +import pytest + + +@pytest.fixture +def engine() -> sqlalchemy.engine.Engine: + engine = db.get_engine() + # make sure db is empty otherwise might be testing a live db + # this makes sure that we don't drop all on a live db + inspector = sqlalchemy.inspect(engine) + schemas = inspector.get_schema_names() + assert len(list(schemas)) == 2 + assert schemas[0] == "information_schema" + assert schemas[1] == "public" + + db_tables = inspector.get_table_names(schema="public") + assert len(db_tables) == 0 + + yield engine + db.Base.metadata.drop_all(engine) + + +def init_db(engine) -> sqlalchemy.engine.Engine: + db.create_tables(engine) + return engine + + +@pytest.fixture +def tables() -> list[db.Base]: + tables = [] + for _, obj in inspect.getmembers(db): + if inspect.isclass(obj) and issubclass(obj, db.Base) and obj != db.Base: + tables.append(obj) + return tables + + +def test_db_connection(engine): + connected = False + try: + engine.connect() + connected = True + except (SQLAlchemyError): + pass + finally: + assert connected == True + +def test_get_session(): + session = db.get_session() + eng = db.get_engine() + session == sessionmaker(eng) + + +def test_db_classes(tables): + assert len(tables) > 0 + + +def test_db_class_creation(tables: list[db.Base], engine: sqlalchemy.engine.Engine): + db.create_tables(engine) + inspector = sqlalchemy.inspect(engine) + db_tables = inspector.get_table_names(schema="public") + assert len(db_tables) == len(tables) + table_names = [table.__tablename__ for table in tables] + assert sorted(db_tables) == sorted(table_names) diff --git a/test/test_insert_sites.py b/test/test_insert_sites.py new file mode 100644 index 0000000..2b01be0 --- /dev/null +++ b/test/test_insert_sites.py @@ -0,0 +1,57 @@ +import json +import os +from recipe_graph import insert_sites +from sqlalchemy import select +import logging + +import pytest + + +@pytest.fixture +def json_data() -> list[dict]: + return [{"key": "value"}, {"test": "value1", "test2": "value2"}] + + +@pytest.fixture +def json_file(json_data: list[dict]) -> str: + f_path = "test.json" + with open(f_path, 'w') as f: + json.dump(json_data, f) + yield f_path + if os.path.exists(f_path): + os.remove(f_path) + +def test_load_file(json_file: str, json_data): + test_data = insert_sites.load_file(json_file) + assert test_data == json_data + + +def test_setup_argparser(): + file_name = "test" + args = insert_sites.setup_argparser([file_name]) + assert len(vars(args)) == 2 + assert args.file == file_name + assert args.verbose == False + + + args = insert_sites.setup_argparser([file_name, "-v"]) + assert args.file == file_name + assert args.verbose == True + + args = insert_sites.setup_argparser([file_name, "--verbose"]) + assert args.file == file_name + assert args.verbose == True + +def test_setup_logging(): + args = insert_sites.setup_argparser(["test"]) + logger = insert_sites.setup_logging(args) + assert logger.level == logging.WARNING + + args = insert_sites.setup_argparser(["test", "-v"]) + logger = insert_sites.setup_logging(args) + assert logger.level == logging.INFO + + args = insert_sites.setup_argparser(["test", "--verbose"]) + logger = insert_sites.setup_logging(args) + assert logger.level == logging.INFO + \ No newline at end of file diff --git a/test/test_scrape.py b/test/test_scrape.py new file mode 100644 index 0000000..e7ac014 --- /dev/null +++ b/test/test_scrape.py @@ -0,0 +1,3 @@ +from recipe_graph import scrape + +import pytest \ No newline at end of file