import json import os from recipe_graph import insert_sites, db from sqlalchemy import select import sqlalchemy import logging import pytest from test_db import engine, init_db @pytest.fixture def json_data() -> list[dict[str, any]]: return [{"key": "value"}, {"test": "value1", "test2": "value2"}] @pytest.fixture def db_initialized(engine) -> sqlalchemy.engine.Engine: init_db(engine) return engine @pytest.fixture def mock_sites() -> list[dict[str, any]]: return [ { "name": "example-site", "ingredient_class": "example-item-name", "name_class": "example-content", "base_url": "https://www.example.com/recipe/", }, { "name": "test-site", "ingredient_class": "test-item-name", "name_class": "test-content", "base_url": "https://www.test.com/recipe/", }, ] @pytest.fixture def json_file(json_data: list[dict]) -> str: f_path = "test.json" with open(f_path, "w") as f: json.dump(json_data, f) yield f_path if os.path.exists(f_path): os.remove(f_path) def test_load_file(json_file: str, json_data): test_data = insert_sites.load_file(json_file) assert test_data == json_data def test_setup_argparser(): file_name = "test" args = insert_sites.setup_argparser([file_name]) assert len(vars(args)) == 2 assert args.file == file_name assert args.verbose == False args = insert_sites.setup_argparser([file_name, "-v"]) assert args.file == file_name assert args.verbose == True args = insert_sites.setup_argparser([file_name, "--verbose"]) assert args.file == file_name assert args.verbose == True def test_setup_logging(): args = insert_sites.setup_argparser(["test"]) logger = insert_sites.setup_logging(args) assert logger.level == logging.WARNING args = insert_sites.setup_argparser(["test", "-v"]) logger = insert_sites.setup_logging(args) assert logger.level == logging.INFO args = insert_sites.setup_argparser(["test", "--verbose"]) logger = insert_sites.setup_logging(args) assert logger.level == logging.INFO def test_add_sites(mock_sites, db_initialized): db_session = db.get_session() insert_sites.add_sites(db_session, mock_sites) results = [] with db_session.begin() as session: results = session.execute(select(db.RecipeSite)).all() assert len(results) > 0 assert len(results) == 2 print(db.RecipeSite(name="a")) for i, (site,) in enumerate(results): site.name == mock_sites[i]["name"] site.ingredient_class == mock_sites[i]["ingredient_class"] site.name_class == mock_sites[i]["name_class"] site.base_url == mock_sites[i]["base_url"]