From a5153e240623cca276f4dd2aefc7b4290445789e Mon Sep 17 00:00:00 2001 From: Andrei Stoica Date: Sat, 15 Oct 2022 12:32:58 -0400 Subject: [PATCH] refactor to facilitate testing --- src/recipe_graph/db.py | 15 +++++++--- src/recipe_graph/insert_sites.py | 49 +++++++++++++++++++------------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/recipe_graph/db.py b/src/recipe_graph/db.py index 87e97a9..07a36c2 100644 --- a/src/recipe_graph/db.py +++ b/src/recipe_graph/db.py @@ -9,7 +9,7 @@ from sqlalchemy import create_engine, Column, Integer, String, Boolean, \ from sqlalchemy.types import ARRAY from sqlalchemy.engine import URL from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker Base = declarative_base() @@ -99,6 +99,11 @@ def get_engine(use_dotenv = True, **kargs): database=DB_NAME) return create_engine(eng_url) +def get_session(**kargs) -> Session: + eng = get_engine(**kargs) + return sessionmaker(eng) + + def create_tables(eng): logging.info(f"Createing DB Tables: {eng.url}") @@ -210,8 +215,10 @@ def update_graph_connectivity(session = None): not_(Recipe.id.in_(graphed)))): session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) - - -if __name__ == "__main__": +def main(): eng = get_engine() create_tables(eng) + + +if __name__ == "__main__": + main() diff --git a/src/recipe_graph/insert_sites.py b/src/recipe_graph/insert_sites.py index a3a7107..886196e 100644 --- a/src/recipe_graph/insert_sites.py +++ b/src/recipe_graph/insert_sites.py @@ -1,29 +1,40 @@ +from pydoc import apropos from sqlalchemy.orm import sessionmaker -import db +from recipe_graph import db import json import argparse import logging -parser = argparse.ArgumentParser(description='Import recipes into database') -parser.add_argument('file', type=str, - help='JSON file with recipe site information') -parser.add_argument('-v', '--verbose', action='store_true') +def load_file(f_name: str): + with open(f_name) as f: + sites = json.load(f) + return sites -args = parser.parse_args() -if args.verbose: - logging.basicConfig(level=logging.INFO) - logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) +def setup_argparser() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Import recipes into database") + parser.add_argument("file", type=str, help="JSON file with recipe site information") + parser.add_argument("-v", "--verbose", action="store_true") -with open(args.file) as f: - sites = json.load(f) + args = parser.parse_args() + return args -eng = db.get_engine() -S = sessionmaker(eng) +def setup_logging(args: argparse.Namespace): + if args.verbose: + logging.basicConfig(level=logging.INFO) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) -with S.begin() as session: - for site in sites: - logging.info(f"Adding {site}") - session.add(db.RecipeSite(**site)) - - \ No newline at end of file +def main(): + args = setup_argparser() + setup_logging(args) + + S = db.get_session() + sites = load_file(args.file) + + with S.begin() as session: + for site in sites: + logging.info(f"Adding {site}") + session.add(db.RecipeSite(**site)) + +if __name__ == "__main__": + main() \ No newline at end of file