From e8a39da396b3ed8e469f569e4e865a54ae5dad0b Mon Sep 17 00:00:00 2001 From: Nils VAN ZUIJLEN Date: Mon, 6 Feb 2023 22:46:44 +0100 Subject: [PATCH] [4.2.x] Fixed #34285 -- Fixed index/slice lookups on filtered aggregates with ArrayField. Thanks Simon Charette for the review. Backport of ae1fe72e9b1f5fe3b05e5b670bd0c205cd305e71 from main --- django/contrib/postgres/fields/array.py | 8 +++-- tests/postgres_tests/test_aggregates.py | 43 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 8477dd9fff..c8e8e132e0 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -325,7 +325,9 @@ class IndexTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return "%s[%%s]" % lhs, params + [self.index] + if not lhs.endswith("]"): + lhs = "(%s)" % lhs + return "%s[%%s]" % lhs, (*params, self.index) @property def output_field(self): @@ -349,7 +351,9 @@ class SliceTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] + if not lhs.endswith("]"): + lhs = "(%s)" % lhs + return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end) class SliceTransformFactory: diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index b5474d361e..e781fd87ee 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -365,6 +365,49 @@ class TestGeneralAggregate(PostgreSQLTestCase): ) self.assertCountEqual(qs.get(), [1, 2]) + def test_array_agg_filter_index(self): + aggr1 = AggregateTestModel.objects.create(integer_field=1) + aggr2 = AggregateTestModel.objects.create(integer_field=2) + StatTestModel.objects.bulk_create( + [ + StatTestModel(related_field=aggr1, int1=1, int2=0), + StatTestModel(related_field=aggr1, int1=2, int2=1), + StatTestModel(related_field=aggr2, int1=3, int2=0), + StatTestModel(related_field=aggr2, int1=4, int2=1), + ] + ) + qs = ( + AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk]) + .annotate( + array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0)) + ) + .annotate(array_value=F("array__0")) + .values_list("array_value", flat=True) + ) + self.assertCountEqual(qs, [1, 3]) + + def test_array_agg_filter_slice(self): + aggr1 = AggregateTestModel.objects.create(integer_field=1) + aggr2 = AggregateTestModel.objects.create(integer_field=2) + StatTestModel.objects.bulk_create( + [ + StatTestModel(related_field=aggr1, int1=1, int2=0), + StatTestModel(related_field=aggr1, int1=2, int2=1), + StatTestModel(related_field=aggr2, int1=3, int2=0), + StatTestModel(related_field=aggr2, int1=4, int2=1), + StatTestModel(related_field=aggr2, int1=5, int2=0), + ] + ) + qs = ( + AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk]) + .annotate( + array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0)) + ) + .annotate(array_value=F("array__1_2")) + .values_list("array_value", flat=True) + ) + self.assertCountEqual(qs, [[], [5]]) + def test_bit_and_general(self): values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate( bitand=BitAnd("integer_field")