refactor to facilitate testing

This commit is contained in:
Andrei Stoica 2022-10-15 12:32:58 -04:00
parent 4c96bd8a28
commit a5153e2406
2 changed files with 41 additions and 23 deletions

View File

@ -9,7 +9,7 @@ from sqlalchemy import create_engine, Column, Integer, String, Boolean, \
from sqlalchemy.types import ARRAY from sqlalchemy.types import ARRAY
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base() Base = declarative_base()
@ -99,6 +99,11 @@ def get_engine(use_dotenv = True, **kargs):
database=DB_NAME) database=DB_NAME)
return create_engine(eng_url) return create_engine(eng_url)
def get_session(**kargs) -> Session:
eng = get_engine(**kargs)
return sessionmaker(eng)
def create_tables(eng): def create_tables(eng):
logging.info(f"Createing DB Tables: {eng.url}") logging.info(f"Createing DB Tables: {eng.url}")
@ -210,8 +215,10 @@ def update_graph_connectivity(session = None):
not_(Recipe.id.in_(graphed)))): not_(Recipe.id.in_(graphed)))):
session.add(RecipeGraphed(recipe_id=recipe.id, status=True)) session.add(RecipeGraphed(recipe_id=recipe.id, status=True))
def main():
eng = get_engine()
create_tables(eng)
if __name__ == "__main__": if __name__ == "__main__":
eng = get_engine() main()
create_tables(eng)

View File

@ -1,29 +1,40 @@
from pydoc import apropos
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import db from recipe_graph import db
import json import json
import argparse import argparse
import logging import logging
parser = argparse.ArgumentParser(description='Import recipes into database')
parser.add_argument('file', type=str,
help='JSON file with recipe site information')
parser.add_argument('-v', '--verbose', action='store_true')
def load_file(f_name: str):
args = parser.parse_args() with open(f_name) as f:
if args.verbose:
logging.basicConfig(level=logging.INFO)
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
with open(args.file) as f:
sites = json.load(f) sites = json.load(f)
return sites
eng = db.get_engine() def setup_argparser() -> argparse.Namespace:
S = sessionmaker(eng) parser = argparse.ArgumentParser(description="Import recipes into database")
parser.add_argument("file", type=str, help="JSON file with recipe site information")
parser.add_argument("-v", "--verbose", action="store_true")
with S.begin() as session: args = parser.parse_args()
return args
def setup_logging(args: argparse.Namespace):
if args.verbose:
logging.basicConfig(level=logging.INFO)
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
def main():
args = setup_argparser()
setup_logging(args)
S = db.get_session()
sites = load_file(args.file)
with S.begin() as session:
for site in sites: for site in sites:
logging.info(f"Adding {site}") logging.info(f"Adding {site}")
session.add(db.RecipeSite(**site)) session.add(db.RecipeSite(**site))
if __name__ == "__main__":
main()