mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #26067 -- Added ordering support to ArrayAgg and StringAgg.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							2a0116266c
						
					
				
				
					commit
					96199e562d
				
			| @@ -1,14 +1,16 @@ | ||||
| from django.contrib.postgres.fields import ArrayField, JSONField | ||||
| from django.db.models.aggregates import Aggregate | ||||
|  | ||||
| from .mixins import OrderableAggMixin | ||||
|  | ||||
| __all__ = [ | ||||
|     'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg', | ||||
| ] | ||||
|  | ||||
|  | ||||
| class ArrayAgg(Aggregate): | ||||
| class ArrayAgg(OrderableAggMixin, Aggregate): | ||||
|     function = 'ARRAY_AGG' | ||||
|     template = '%(function)s(%(distinct)s%(expressions)s)' | ||||
|     template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' | ||||
|  | ||||
|     @property | ||||
|     def output_field(self): | ||||
| @@ -49,9 +51,9 @@ class JSONBAgg(Aggregate): | ||||
|         return value | ||||
|  | ||||
|  | ||||
| class StringAgg(Aggregate): | ||||
| class StringAgg(OrderableAggMixin, Aggregate): | ||||
|     function = 'STRING_AGG' | ||||
|     template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')" | ||||
|     template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)" | ||||
|  | ||||
|     def __init__(self, expression, delimiter, distinct=False, **extra): | ||||
|         distinct = 'DISTINCT ' if distinct else '' | ||||
|   | ||||
							
								
								
									
										47
									
								
								django/contrib/postgres/aggregates/mixins.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								django/contrib/postgres/aggregates/mixins.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| from django.db.models.expressions import F, OrderBy | ||||
|  | ||||
|  | ||||
| class OrderableAggMixin: | ||||
|  | ||||
|     def __init__(self, expression, ordering=(), **extra): | ||||
|         if not isinstance(ordering, (list, tuple)): | ||||
|             ordering = [ordering] | ||||
|         ordering = ordering or [] | ||||
|         # Transform minus sign prefixed strings into an OrderBy() expression. | ||||
|         ordering = ( | ||||
|             (OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o) | ||||
|             for o in ordering | ||||
|         ) | ||||
|         super().__init__(expression, **extra) | ||||
|         self.ordering = self._parse_expressions(*ordering) | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering] | ||||
|         return super().resolve_expression(*args, **kwargs) | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if self.ordering: | ||||
|             self.extra['ordering'] = 'ORDER BY ' + ', '.join(( | ||||
|                 ordering_element.as_sql(compiler, connection)[0] | ||||
|                 for ordering_element in self.ordering | ||||
|             )) | ||||
|         else: | ||||
|             self.extra['ordering'] = '' | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|     def get_source_expressions(self): | ||||
|         return self.source_expressions + self.ordering | ||||
|  | ||||
|     def get_source_fields(self): | ||||
|         # Filter out fields contributed by the ordering expressions as | ||||
|         # these should not be used to determine which the return type of the | ||||
|         # expression. | ||||
|         return [ | ||||
|             e._output_field_or_none | ||||
|             for e in self.get_source_expressions()[:self._get_ordering_expressions_index()] | ||||
|         ] | ||||
|  | ||||
|     def _get_ordering_expressions_index(self): | ||||
|         """Return the index at which the ordering expressions start.""" | ||||
|         source_expressions = self.get_source_expressions() | ||||
|         return len(source_expressions) - len(self.ordering) | ||||
| @@ -22,7 +22,7 @@ General-purpose aggregation functions | ||||
| ``ArrayAgg`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: ArrayAgg(expression, distinct=False, filter=None, **extra) | ||||
| .. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra) | ||||
|  | ||||
|     Returns a list of values, including nulls, concatenated into an array. | ||||
|  | ||||
| @@ -31,6 +31,22 @@ General-purpose aggregation functions | ||||
|         An optional boolean argument that determines if array values | ||||
|         will be distinct. Defaults to ``False``. | ||||
|  | ||||
|     .. attribute:: ordering | ||||
|  | ||||
|         .. versionadded:: 2.2 | ||||
|  | ||||
|         An optional string of a field name (with an optional ``"-"`` prefix | ||||
|         which indicates descending order) or an expression (or a tuple or list | ||||
|         of strings and/or expressions) that specifies the ordering of the | ||||
|         elements in the result list. | ||||
|  | ||||
|         Examples:: | ||||
|  | ||||
|             'some_field' | ||||
|             '-some_field' | ||||
|             from django.db.models import F | ||||
|             F('some_field').desc() | ||||
|  | ||||
| ``BitAnd`` | ||||
| ---------- | ||||
|  | ||||
| @@ -73,7 +89,7 @@ General-purpose aggregation functions | ||||
| ``StringAgg`` | ||||
| ------------- | ||||
|  | ||||
| .. class:: StringAgg(expression, delimiter, distinct=False, filter=None) | ||||
| .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=()) | ||||
|  | ||||
|     Returns the input values concatenated into a string, separated by | ||||
|     the ``delimiter`` string. | ||||
| @@ -87,6 +103,17 @@ General-purpose aggregation functions | ||||
|         An optional boolean argument that determines if concatenated values | ||||
|         will be distinct. Defaults to ``False``. | ||||
|  | ||||
|     .. attribute:: ordering | ||||
|  | ||||
|         .. versionadded:: 2.2 | ||||
|  | ||||
|         An optional string of a field name (with an optional ``"-"`` prefix | ||||
|         which indicates descending order) or an expression (or a tuple or list | ||||
|         of strings and/or expressions) that specifies the ordering of the | ||||
|         elements in the result string. | ||||
|  | ||||
|         Examples are the same as for :attr:`ArrayAgg.ordering`. | ||||
|  | ||||
| Aggregate functions for statistics | ||||
| ================================== | ||||
|  | ||||
|   | ||||
| @@ -70,7 +70,10 @@ Minor features | ||||
| :mod:`django.contrib.postgres` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| * ... | ||||
| * The new ``ordering`` argument for | ||||
|   :class:`~django.contrib.postgres.aggregates.ArrayAgg` and | ||||
|   :class:`~django.contrib.postgres.aggregates.StringAgg` determines the | ||||
|   ordering of the aggregated elements. | ||||
|  | ||||
| :mod:`django.contrib.redirects` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|   | ||||
| @@ -22,21 +22,57 @@ class TestGeneralAggregate(PostgreSQLTestCase): | ||||
|     def setUpTestData(cls): | ||||
|         AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) | ||||
|         AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1) | ||||
|         AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2) | ||||
|         AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0) | ||||
|         AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2) | ||||
|         AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0) | ||||
|  | ||||
|     def test_array_agg_charfield(self): | ||||
|         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) | ||||
|         self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) | ||||
|         self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) | ||||
|  | ||||
|     def test_array_agg_charfield_ordering(self): | ||||
|         ordering_test_cases = ( | ||||
|             (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), | ||||
|             (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), | ||||
|             (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), | ||||
|             ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']), | ||||
|             ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']), | ||||
|             ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']), | ||||
|             ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']), | ||||
|         ) | ||||
|         for ordering, expected_output in ordering_test_cases: | ||||
|             with self.subTest(ordering=ordering, expected_output=expected_output): | ||||
|                 values = AggregateTestModel.objects.aggregate( | ||||
|                     arrayagg=ArrayAgg('char_field', ordering=ordering) | ||||
|                 ) | ||||
|                 self.assertEqual(values, {'arrayagg': expected_output}) | ||||
|  | ||||
|     def test_array_agg_integerfield(self): | ||||
|         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field')) | ||||
|         self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]}) | ||||
|  | ||||
|     def test_array_agg_integerfield_ordering(self): | ||||
|         values = AggregateTestModel.objects.aggregate( | ||||
|             arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc()) | ||||
|         ) | ||||
|         self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]}) | ||||
|  | ||||
|     def test_array_agg_booleanfield(self): | ||||
|         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) | ||||
|         self.assertEqual(values, {'arrayagg': [True, False, False, True]}) | ||||
|  | ||||
|     def test_array_agg_booleanfield_ordering(self): | ||||
|         ordering_test_cases = ( | ||||
|             (F('boolean_field').asc(), [False, False, True, True]), | ||||
|             (F('boolean_field').desc(), [True, True, False, False]), | ||||
|             (F('boolean_field'), [False, False, True, True]), | ||||
|         ) | ||||
|         for ordering, expected_output in ordering_test_cases: | ||||
|             with self.subTest(ordering=ordering, expected_output=expected_output): | ||||
|                 values = AggregateTestModel.objects.aggregate( | ||||
|                     arrayagg=ArrayAgg('boolean_field', ordering=ordering) | ||||
|                 ) | ||||
|                 self.assertEqual(values, {'arrayagg': expected_output}) | ||||
|  | ||||
|     def test_array_agg_empty_result(self): | ||||
|         AggregateTestModel.objects.all().delete() | ||||
|         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) | ||||
| @@ -122,17 +158,36 @@ class TestGeneralAggregate(PostgreSQLTestCase): | ||||
|  | ||||
|     def test_string_agg_charfield(self): | ||||
|         values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) | ||||
|         self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'}) | ||||
|         self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'}) | ||||
|  | ||||
|     def test_string_agg_charfield_ordering(self): | ||||
|         ordering_test_cases = ( | ||||
|             (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'), | ||||
|             (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'), | ||||
|             (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'), | ||||
|         ) | ||||
|         for ordering, expected_output in ordering_test_cases: | ||||
|             with self.subTest(ordering=ordering, expected_output=expected_output): | ||||
|                 values = AggregateTestModel.objects.aggregate( | ||||
|                     stringagg=StringAgg('char_field', delimiter=';', ordering=ordering) | ||||
|                 ) | ||||
|                 self.assertEqual(values, {'stringagg': expected_output}) | ||||
|  | ||||
|     def test_string_agg_empty_result(self): | ||||
|         AggregateTestModel.objects.all().delete() | ||||
|         values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) | ||||
|         self.assertEqual(values, {'stringagg': ''}) | ||||
|  | ||||
|     def test_orderable_agg_alternative_fields(self): | ||||
|         values = AggregateTestModel.objects.aggregate( | ||||
|             arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc()) | ||||
|         ) | ||||
|         self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]}) | ||||
|  | ||||
|     @skipUnlessDBFeature('has_jsonb_agg') | ||||
|     def test_json_agg(self): | ||||
|         values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field')) | ||||
|         self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) | ||||
|         self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) | ||||
|  | ||||
|     @skipUnlessDBFeature('has_jsonb_agg') | ||||
|     def test_json_agg_empty(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user