diff --git a/src/recipe_graph/db.py b/src/recipe_graph/db.py index 7e2265d..87e97a9 100644 --- a/src/recipe_graph/db.py +++ b/src/recipe_graph/db.py @@ -60,10 +60,8 @@ class IngredientConnection(Base): __tablename__ = 'IngredientConnection' ingredient_a = Column(String, - ForeignKey("RecipeIngredientParts.ingredient"), primary_key = True) ingredient_b = Column(String, - ForeignKey("RecipeIngredientParts.ingredient"), primary_key = True) recipe_count = Column(Integer) UniqueConstraint(ingredient_a, ingredient_b) diff --git a/test/test_db.py b/test/test_db.py index 6d6bfdd..bcda081 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,13 +1,29 @@ from sqlite3 import connect +import inspect from recipe_graph import db +from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session import sqlalchemy import pytest + @pytest.fixture def engine() -> sqlalchemy.engine.Engine: - return db.get_engine() + engine = db.get_engine() + yield engine + db.Base.metadata.drop_all(engine) + + +@pytest.fixture +def tables() -> list[db.Base]: + tables = [] + for _, obj in inspect.getmembers(db): + if inspect.isclass(obj) and issubclass(obj, db.Base) and obj != db.Base: + tables.append(obj) + return tables + def test_db_connection(engine): connected = False @@ -17,4 +33,26 @@ def test_db_connection(engine): except (SQLAlchemyError): pass finally: - assert(connected == True) + assert connected == True + + +def test_db_classes(tables): + assert len(tables) > 0 + + +def test_db_class_creation(tables: list[db.Base], engine: sqlalchemy.engine.Engine): + inspector = sqlalchemy.inspect(engine) + schemas = inspector.get_schema_names() + assert len(list(schemas)) == 2 + assert schemas[0] == "information_schema" + assert schemas[1] == "public" + + db_tables = inspector.get_table_names(schema="public") + assert len(db_tables) == 0 + + db.create_tables(engine) + inspector = sqlalchemy.inspect(engine) + db_tables = inspector.get_table_names(schema="public") + assert len(db_tables) == len(tables) + table_names = [table.__tablename__ for table in tables] + assert sorted(db_tables) == sorted(table_names) \ No newline at end of file