From 126feccc1ca3fc14b8a59505f05a646c3e88a77f Mon Sep 17 00:00:00 2001 From: Tony Wang Date: Sun, 21 Oct 2018 23:29:23 +0800 Subject: [PATCH 1/3] add tests to demonstrate queries using alias - multiple models are in one subquery - aggregation functions are used in queries --- gino/loader.py | 3 +++ tests/test_loader.py | 50 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/gino/loader.py b/gino/loader.py index 9cec1114..b565c700 100644 --- a/gino/loader.py +++ b/gino/loader.py @@ -2,6 +2,7 @@ from sqlalchemy import select from sqlalchemy.schema import Column +from sqlalchemy.sql.elements import Label from .declarative import Model @@ -19,6 +20,8 @@ def get(cls, value): rv = AliasLoader(value) elif isinstance(value, Column): rv = ColumnLoader(value) + elif isinstance(value, Label): + rv = ColumnLoader(value.name) elif isinstance(value, tuple): rv = TupleLoader(value) elif callable(value): diff --git a/tests/test_loader.py b/tests/test_loader.py index 3d001ce1..0d17d510 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,10 +1,12 @@ import random from datetime import datetime -import pytest from async_generator import yield_, async_generator +import pytest +from sqlalchemy import select +from sqlalchemy.sql.functions import count -from gino.loader import AliasLoader +from gino.loader import AliasLoader, ColumnLoader from .models import db, User, Team, Company pytestmark = pytest.mark.asyncio @@ -12,11 +14,11 @@ @pytest.fixture @async_generator -async def user(bind, random_name): +async def user(bind): c = await Company.create() t1 = await Team.create(company_id=c.id) t2 = await Team.create(company_id=c.id, parent_id=t1.id) - u = await User.create(nickname=random_name, team_id=t2.id) + u = await User.create(team_id=t2.id) u.team = t2 t2.parent = t1 t2.company = c @@ -161,6 +163,46 @@ async def test_alias_loader_columns(user): assert u.id is not None +async def test_multiple_models_in_one_query(bind): + for _ in range(3): + await User.create() + + ua1 = User.alias() + ua2 = User.alias() + join_query = select([ua1, ua2]).where(ua1.id < ua2.id) + result = await join_query.gino.load((ua1.load('id'), ua2.load('id'))).all() + assert len(result) == 3 + for u1, u2 in result: + assert u1.id is not None + assert u2.id is not None + assert u1.id < u2.id + + +async def test_loader_with_aggregation(user): + user_count = select( + [User.team_id, count().label('count')] + ).group_by( + User.team_id + ).alias() + query = Team.outerjoin(user_count).select() + result = await query.gino.load( + (Team.id, Team.name, user_count.columns.team_id, ColumnLoader('count')) + ).all() + assert len(result) == 2 + # team 1 doesn't have users, team 2 has 1 user + # third and forth columns are None for team 1 + for team_id, team_name, user_team_id, user_count in result: + if team_id == user.team_id: + assert team_name == user.team.name + assert user_team_id == user.team_id + assert user_count == 1 + else: + assert team_id is not None + assert team_name is not None + assert user_team_id is None + assert user_count is None + + async def test_adjacency_list_query_builder(user): group = Team.alias() u = await User.load(team=Team.load(parent=group.on( From a282cd2d1f41e292470fb6e096835a160997ef97 Mon Sep 17 00:00:00 2001 From: Tony Wang Date: Tue, 23 Oct 2018 00:13:38 +0800 Subject: [PATCH 2/3] update to use Label to cover changed code in test --- tests/test_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_loader.py b/tests/test_loader.py index 0d17d510..b6a492bc 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -179,14 +179,15 @@ async def test_multiple_models_in_one_query(bind): async def test_loader_with_aggregation(user): + count_col = count().label('count') user_count = select( - [User.team_id, count().label('count')] + [User.team_id, count_col] ).group_by( User.team_id ).alias() query = Team.outerjoin(user_count).select() result = await query.gino.load( - (Team.id, Team.name, user_count.columns.team_id, ColumnLoader('count')) + (Team.id, Team.name, user_count.columns.team_id, count_col) ).all() assert len(result) == 2 # team 1 doesn't have users, team 2 has 1 user From be1f37c8c6471791d7e7b13ba81f096df015c023 Mon Sep 17 00:00:00 2001 From: Tony Wang Date: Tue, 23 Oct 2018 11:30:26 +0800 Subject: [PATCH 3/3] remove unused import --- tests/test_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_loader.py b/tests/test_loader.py index b6a492bc..30b6b2ae 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.sql.functions import count -from gino.loader import AliasLoader, ColumnLoader +from gino.loader import AliasLoader from .models import db, User, Team, Company pytestmark = pytest.mark.asyncio