from sqlite3 import connect import inspect from recipe_graph import db from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session import sqlalchemy import pytest @pytest.fixture def engine() -> sqlalchemy.engine.Engine: engine = db.get_engine() yield engine db.Base.metadata.drop_all(engine) @pytest.fixture def tables() -> list[db.Base]: tables = [] for _, obj in inspect.getmembers(db): if inspect.isclass(obj) and issubclass(obj, db.Base) and obj != db.Base: tables.append(obj) return tables def test_db_connection(engine): connected = False try: engine.connect() connected = True except (SQLAlchemyError): pass finally: assert connected == True def test_db_classes(tables): assert len(tables) > 0 def test_db_class_creation(tables: list[db.Base], engine: sqlalchemy.engine.Engine): inspector = sqlalchemy.inspect(engine) schemas = inspector.get_schema_names() assert len(list(schemas)) == 2 assert schemas[0] == "information_schema" assert schemas[1] == "public" db_tables = inspector.get_table_names(schema="public") assert len(db_tables) == 0 db.create_tables(engine) inspector = sqlalchemy.inspect(engine) db_tables = inspector.get_table_names(schema="public") assert len(db_tables) == len(tables) table_names = [table.__tablename__ for table in tables] assert sorted(db_tables) == sorted(table_names)