testing db connection and table creation
This commit is contained in:
parent
082c342256
commit
754cb1235c
|
|
@ -60,10 +60,8 @@ class IngredientConnection(Base):
|
||||||
__tablename__ = 'IngredientConnection'
|
__tablename__ = 'IngredientConnection'
|
||||||
|
|
||||||
ingredient_a = Column(String,
|
ingredient_a = Column(String,
|
||||||
ForeignKey("RecipeIngredientParts.ingredient"),
|
|
||||||
primary_key = True)
|
primary_key = True)
|
||||||
ingredient_b = Column(String,
|
ingredient_b = Column(String,
|
||||||
ForeignKey("RecipeIngredientParts.ingredient"),
|
|
||||||
primary_key = True)
|
primary_key = True)
|
||||||
recipe_count = Column(Integer)
|
recipe_count = Column(Integer)
|
||||||
UniqueConstraint(ingredient_a, ingredient_b)
|
UniqueConstraint(ingredient_a, ingredient_b)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,29 @@
|
||||||
from sqlite3 import connect
|
from sqlite3 import connect
|
||||||
|
import inspect
|
||||||
from recipe_graph import db
|
from recipe_graph import db
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def engine() -> sqlalchemy.engine.Engine:
|
def engine() -> sqlalchemy.engine.Engine:
|
||||||
return db.get_engine()
|
engine = db.get_engine()
|
||||||
|
yield engine
|
||||||
|
db.Base.metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tables() -> list[db.Base]:
|
||||||
|
tables = []
|
||||||
|
for _, obj in inspect.getmembers(db):
|
||||||
|
if inspect.isclass(obj) and issubclass(obj, db.Base) and obj != db.Base:
|
||||||
|
tables.append(obj)
|
||||||
|
return tables
|
||||||
|
|
||||||
|
|
||||||
def test_db_connection(engine):
|
def test_db_connection(engine):
|
||||||
connected = False
|
connected = False
|
||||||
|
|
@ -17,4 +33,26 @@ def test_db_connection(engine):
|
||||||
except (SQLAlchemyError):
|
except (SQLAlchemyError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
assert(connected == True)
|
assert connected == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_classes(tables):
|
||||||
|
assert len(tables) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_class_creation(tables: list[db.Base], engine: sqlalchemy.engine.Engine):
|
||||||
|
inspector = sqlalchemy.inspect(engine)
|
||||||
|
schemas = inspector.get_schema_names()
|
||||||
|
assert len(list(schemas)) == 2
|
||||||
|
assert schemas[0] == "information_schema"
|
||||||
|
assert schemas[1] == "public"
|
||||||
|
|
||||||
|
db_tables = inspector.get_table_names(schema="public")
|
||||||
|
assert len(db_tables) == 0
|
||||||
|
|
||||||
|
db.create_tables(engine)
|
||||||
|
inspector = sqlalchemy.inspect(engine)
|
||||||
|
db_tables = inspector.get_table_names(schema="public")
|
||||||
|
assert len(db_tables) == len(tables)
|
||||||
|
table_names = [table.__tablename__ for table in tables]
|
||||||
|
assert sorted(db_tables) == sorted(table_names)
|
||||||
Loading…
Reference in New Issue