import inspect from recipe_graph import db from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError import sqlalchemy import pytest @pytest.fixture def engine() -> sqlalchemy.engine.Engine: engine = db.get_engine() # make sure db is empty otherwise might be testing a live db # this makes sure that we don't drop all on a live db inspector = sqlalchemy.inspect(engine) schemas = inspector.get_schema_names() assert len(list(schemas)) == 2 assert schemas[0] == "information_schema" assert schemas[1] == "public" db_tables = inspector.get_table_names(schema="public") assert len(db_tables) == 0 yield engine db.Base.metadata.drop_all(engine) def init_db(engine) -> sqlalchemy.engine.Engine: db.create_tables(engine) return engine @pytest.fixture def tables() -> list[db.Base]: tables = [] for _, obj in inspect.getmembers(db): if inspect.isclass(obj) and issubclass(obj, db.Base) and obj != db.Base: tables.append(obj) return tables def test_db_connection(engine): connected = False try: engine.connect() connected = True except (SQLAlchemyError): pass finally: assert connected == True def test_db_classes(tables): assert len(tables) > 0 def test_db_class_creation(tables: list[db.Base], engine: sqlalchemy.engine.Engine): db.create_tables(engine) inspector = sqlalchemy.inspect(engine) db_tables = inspector.get_table_names(schema="public") assert len(db_tables) == len(tables) table_names = [table.__tablename__ for table in tables] assert sorted(db_tables) == sorted(table_names)