From 2f565f84aca136d9cc4e4d061f3196ddf9358ab8 Mon Sep 17 00:00:00 2001 From: David Wobrock Date: Sat, 28 Dec 2019 22:42:46 +0100 Subject: [PATCH] Fixed #31097 -- Fixed crash of ArrayAgg and StringAgg with filter when used in Subquery. --- AUTHORS | 1 + django/contrib/postgres/aggregates/mixins.py | 11 +---- tests/postgres_tests/test_aggregates.py | 46 +++++++++++++++++++- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/AUTHORS b/AUTHORS index 5b7d67d06c..4632c66a62 100644 --- a/AUTHORS +++ b/AUTHORS @@ -242,6 +242,7 @@ answer newbie questions, and generally made Django that much better: David Sanders David Schein David Tulig + David Wobrock Davide Ceretti Deepak Thukral Denis Kuzmichyov diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py index 4625738beb..3a43ca1a63 100644 --- a/django/contrib/postgres/aggregates/mixins.py +++ b/django/contrib/postgres/aggregates/mixins.py @@ -40,16 +40,7 @@ class OrderableAggMixin: return super().set_source_expressions(exprs[:self._get_ordering_expressions_index()]) def get_source_expressions(self): - return self.source_expressions + self.ordering - - def get_source_fields(self): - # Filter out fields contributed by the ordering expressions as - # these should not be used to determine which the return type of the - # expression. - return [ - e._output_field_or_none - for e in self.get_source_expressions()[:self._get_ordering_expressions_index()] - ] + return super().get_source_expressions() + self.ordering def _get_ordering_expressions_index(self): """Return the index at which the ordering expressions start.""" diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index 9a042388bd..af84f12e91 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -21,7 +21,7 @@ except ImportError: class TestGeneralAggregate(PostgreSQLTestCase): @classmethod def setUpTestData(cls): - AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) + cls.agg1 = AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1) AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2) AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0) @@ -249,6 +249,50 @@ class TestGeneralAggregate(PostgreSQLTestCase): ).order_by('char_field').values_list('char_field', 'agg') self.assertEqual(list(values), expected_result) + def test_string_agg_array_agg_filter_in_subquery(self): + StatTestModel.objects.bulk_create([ + StatTestModel(related_field=self.agg1, int1=0, int2=5), + StatTestModel(related_field=self.agg1, int1=1, int2=4), + StatTestModel(related_field=self.agg1, int1=2, int2=3), + ]) + for aggregate, expected_result in ( + ( + ArrayAgg('stattestmodel__int1', filter=Q(stattestmodel__int2__gt=3)), + [('Foo1', [0, 1]), ('Foo2', None)], + ), + ( + StringAgg( + Cast('stattestmodel__int2', CharField()), + delimiter=';', + filter=Q(stattestmodel__int1__lt=2), + ), + [('Foo1', '5;4'), ('Foo2', None)], + ), + ): + with self.subTest(aggregate=aggregate.__class__.__name__): + subquery = AggregateTestModel.objects.filter( + pk=OuterRef('pk'), + ).annotate(agg=aggregate).values('agg') + values = AggregateTestModel.objects.annotate( + agg=Subquery(subquery), + ).filter( + char_field__in=['Foo1', 'Foo2'], + ).order_by('char_field').values_list('char_field', 'agg') + self.assertEqual(list(values), expected_result) + + def test_string_agg_filter_in_subquery_with_exclude(self): + subquery = AggregateTestModel.objects.annotate( + stringagg=StringAgg( + 'char_field', + delimiter=';', + filter=Q(char_field__endswith='1'), + ) + ).exclude(stringagg='').values('id') + self.assertSequenceEqual( + AggregateTestModel.objects.filter(id__in=Subquery(subquery)), + [self.agg1], + ) + class TestAggregateDistinct(PostgreSQLTestCase): @classmethod