mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							bf26f66029
						
					
				
				
					commit
					2162f0983d
				
			| @@ -158,7 +158,7 @@ class BaseDatabaseOperations: | |||||||
|         """ |         """ | ||||||
|         return '' |         return '' | ||||||
|  |  | ||||||
|     def distinct_sql(self, fields): |     def distinct_sql(self, fields, params): | ||||||
|         """ |         """ | ||||||
|         Return an SQL DISTINCT clause which removes duplicate rows from the |         Return an SQL DISTINCT clause which removes duplicate rows from the | ||||||
|         result set. If any fields are given, only check the given fields for |         result set. If any fields are given, only check the given fields for | ||||||
| @@ -167,7 +167,7 @@ class BaseDatabaseOperations: | |||||||
|         if fields: |         if fields: | ||||||
|             raise NotSupportedError('DISTINCT ON fields is not supported by this database backend') |             raise NotSupportedError('DISTINCT ON fields is not supported by this database backend') | ||||||
|         else: |         else: | ||||||
|             return 'DISTINCT' |             return ['DISTINCT'], [] | ||||||
|  |  | ||||||
|     def fetch_returned_insert_id(self, cursor): |     def fetch_returned_insert_id(self, cursor): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         """ |         """ | ||||||
|         return 63 |         return 63 | ||||||
|  |  | ||||||
|     def distinct_sql(self, fields): |     def distinct_sql(self, fields, params): | ||||||
|         if fields: |         if fields: | ||||||
|             return 'DISTINCT ON (%s)' % ', '.join(fields) |             params = [param for param_list in params for param in param_list] | ||||||
|  |             return (['DISTINCT ON (%s)' % ', '.join(fields)], params) | ||||||
|         else: |         else: | ||||||
|             return 'DISTINCT' |             return ['DISTINCT'], [] | ||||||
|  |  | ||||||
|     def last_executed_query(self, cursor, sql, params): |     def last_executed_query(self, cursor, sql, params): | ||||||
|         # http://initd.org/psycopg/docs/cursor.html#cursor.query |         # http://initd.org/psycopg/docs/cursor.html#cursor.query | ||||||
|   | |||||||
| @@ -451,7 +451,7 @@ class SQLCompiler: | |||||||
|                     raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) |                     raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) | ||||||
|                 result, params = self.get_combinator_sql(combinator, self.query.combinator_all) |                 result, params = self.get_combinator_sql(combinator, self.query.combinator_all) | ||||||
|             else: |             else: | ||||||
|                 distinct_fields = self.get_distinct() |                 distinct_fields, distinct_params = self.get_distinct() | ||||||
|                 # This must come after 'select', 'ordering', and 'distinct' |                 # This must come after 'select', 'ordering', and 'distinct' | ||||||
|                 # (see docstring of get_from_clause() for details). |                 # (see docstring of get_from_clause() for details). | ||||||
|                 from_, f_params = self.get_from_clause() |                 from_, f_params = self.get_from_clause() | ||||||
| @@ -461,7 +461,12 @@ class SQLCompiler: | |||||||
|                 params = [] |                 params = [] | ||||||
|  |  | ||||||
|                 if self.query.distinct: |                 if self.query.distinct: | ||||||
|                     result.append(self.connection.ops.distinct_sql(distinct_fields)) |                     distinct_result, distinct_params = self.connection.ops.distinct_sql( | ||||||
|  |                         distinct_fields, | ||||||
|  |                         distinct_params, | ||||||
|  |                     ) | ||||||
|  |                     result += distinct_result | ||||||
|  |                     params += distinct_params | ||||||
|  |  | ||||||
|                 out_cols = [] |                 out_cols = [] | ||||||
|                 col_idx = 1 |                 col_idx = 1 | ||||||
| @@ -621,21 +626,22 @@ class SQLCompiler: | |||||||
|         This method can alter the tables in the query, and thus it must be |         This method can alter the tables in the query, and thus it must be | ||||||
|         called before get_from_clause(). |         called before get_from_clause(). | ||||||
|         """ |         """ | ||||||
|         qn = self.quote_name_unless_alias |  | ||||||
|         qn2 = self.connection.ops.quote_name |  | ||||||
|         result = [] |         result = [] | ||||||
|  |         params = [] | ||||||
|         opts = self.query.get_meta() |         opts = self.query.get_meta() | ||||||
|  |  | ||||||
|         for name in self.query.distinct_fields: |         for name in self.query.distinct_fields: | ||||||
|             parts = name.split(LOOKUP_SEP) |             parts = name.split(LOOKUP_SEP) | ||||||
|             _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None) |             _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None) | ||||||
|             targets, alias, _ = self.query.trim_joins(targets, joins, path) |             targets, alias, _ = self.query.trim_joins(targets, joins, path) | ||||||
|             for target in targets: |             for target in targets: | ||||||
|                 if name in self.query.annotation_select: |                 if name in self.query.annotation_select: | ||||||
|                     result.append(name) |                     result.append(name) | ||||||
|                 else: |                 else: | ||||||
|                     result.append("%s.%s" % (qn(alias), qn2(target.column))) |                     r, p = self.compile(transform_function(target, alias)) | ||||||
|         return result |                     result.append(r) | ||||||
|  |                     params.append(p) | ||||||
|  |         return result, params | ||||||
|  |  | ||||||
|     def find_ordering_name(self, name, opts, alias=None, default_order='ASC', |     def find_ordering_name(self, name, opts, alias=None, default_order='ASC', | ||||||
|                            already_seen=None): |                            already_seen=None): | ||||||
| @@ -647,7 +653,7 @@ class SQLCompiler: | |||||||
|         name, order = get_order_dir(name, default_order) |         name, order = get_order_dir(name, default_order) | ||||||
|         descending = order == 'DESC' |         descending = order == 'DESC' | ||||||
|         pieces = name.split(LOOKUP_SEP) |         pieces = name.split(LOOKUP_SEP) | ||||||
|         field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias) |         field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias) | ||||||
|  |  | ||||||
|         # If we get to this point and the field is a relation to another model, |         # If we get to this point and the field is a relation to another model, | ||||||
|         # append the default ordering for that model unless the attribute name |         # append the default ordering for that model unless the attribute name | ||||||
| @@ -666,7 +672,7 @@ class SQLCompiler: | |||||||
|                                                        order, already_seen)) |                                                        order, already_seen)) | ||||||
|             return results |             return results | ||||||
|         targets, alias, _ = self.query.trim_joins(targets, joins, path) |         targets, alias, _ = self.query.trim_joins(targets, joins, path) | ||||||
|         return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets] |         return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets] | ||||||
|  |  | ||||||
|     def _setup_joins(self, pieces, opts, alias): |     def _setup_joins(self, pieces, opts, alias): | ||||||
|         """ |         """ | ||||||
| @@ -677,10 +683,9 @@ class SQLCompiler: | |||||||
|         match. Executing SQL where this is not true is an error. |         match. Executing SQL where this is not true is an error. | ||||||
|         """ |         """ | ||||||
|         alias = alias or self.query.get_initial_alias() |         alias = alias or self.query.get_initial_alias() | ||||||
|         field, targets, opts, joins, path = self.query.setup_joins( |         field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias) | ||||||
|             pieces, opts, alias) |  | ||||||
|         alias = joins[-1] |         alias = joins[-1] | ||||||
|         return field, targets, alias, joins, path, opts |         return field, targets, alias, joins, path, opts, transform_function | ||||||
|  |  | ||||||
|     def get_from_clause(self): |     def get_from_clause(self): | ||||||
|         """ |         """ | ||||||
| @@ -786,7 +791,7 @@ class SQLCompiler: | |||||||
|             } |             } | ||||||
|             related_klass_infos.append(klass_info) |             related_klass_infos.append(klass_info) | ||||||
|             select_fields = [] |             select_fields = [] | ||||||
|             _, _, _, joins, _ = self.query.setup_joins( |             _, _, _, joins, _, _ = self.query.setup_joins( | ||||||
|                 [f.name], opts, root_alias) |                 [f.name], opts, root_alias) | ||||||
|             alias = joins[-1] |             alias = joins[-1] | ||||||
|             columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta) |             columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta) | ||||||
| @@ -843,7 +848,7 @@ class SQLCompiler: | |||||||
|                     break |                     break | ||||||
|                 if name in self.query._filtered_relations: |                 if name in self.query._filtered_relations: | ||||||
|                     fields_found.add(name) |                     fields_found.add(name) | ||||||
|                     f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) |                     f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias) | ||||||
|                     model = join_opts.model |                     model = join_opts.model | ||||||
|                     alias = joins[-1] |                     alias = joins[-1] | ||||||
|                     from_parent = issubclass(model, opts.model) and model is not opts.model |                     from_parent = issubclass(model, opts.model) and model is not opts.model | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL | |||||||
| databases). The abstraction barrier only works one way: this module has to know | databases). The abstraction barrier only works one way: this module has to know | ||||||
| all about the internals of models in order to get the information it needs. | all about the internals of models in order to get the information it needs. | ||||||
| """ | """ | ||||||
|  | import functools | ||||||
| from collections import Counter, OrderedDict, namedtuple | from collections import Counter, OrderedDict, namedtuple | ||||||
| from collections.abc import Iterator, Mapping | from collections.abc import Iterator, Mapping | ||||||
| from itertools import chain, count, product | from itertools import chain, count, product | ||||||
| @@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections | |||||||
| from django.db.models.aggregates import Count | from django.db.models.aggregates import Count | ||||||
| from django.db.models.constants import LOOKUP_SEP | from django.db.models.constants import LOOKUP_SEP | ||||||
| from django.db.models.expressions import Col, Ref | from django.db.models.expressions import Col, Ref | ||||||
|  | from django.db.models.fields import Field | ||||||
| from django.db.models.fields.related_lookups import MultiColSource | from django.db.models.fields.related_lookups import MultiColSource | ||||||
| from django.db.models.lookups import Lookup | from django.db.models.lookups import Lookup | ||||||
| from django.db.models.query_utils import ( | from django.db.models.query_utils import ( | ||||||
| @@ -56,7 +58,7 @@ def get_children_from_q(q): | |||||||
|  |  | ||||||
| JoinInfo = namedtuple( | JoinInfo = namedtuple( | ||||||
|     'JoinInfo', |     'JoinInfo', | ||||||
|     ('final_field', 'targets', 'opts', 'joins', 'path') |     ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function') | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -1429,8 +1431,11 @@ class Query: | |||||||
|         generate a MultiJoin exception. |         generate a MultiJoin exception. | ||||||
|  |  | ||||||
|         Return the final field involved in the joins, the target field (used |         Return the final field involved in the joins, the target field (used | ||||||
|         for any 'where' constraint), the final 'opts' value, the joins and the |         for any 'where' constraint), the final 'opts' value, the joins, the | ||||||
|         field path travelled to generate the joins. |         field path traveled to generate the joins, and a transform function | ||||||
|  |         that takes a field and alias and is equivalent to `field.get_col(alias)` | ||||||
|  |         in the simple case but wraps field transforms if they were included in | ||||||
|  |         names. | ||||||
|  |  | ||||||
|         The target field is the field containing the concrete value. Final |         The target field is the field containing the concrete value. Final | ||||||
|         field can be something different, for example foreign key pointing to |         field can be something different, for example foreign key pointing to | ||||||
| @@ -1439,10 +1444,46 @@ class Query: | |||||||
|         key field for example). |         key field for example). | ||||||
|         """ |         """ | ||||||
|         joins = [alias] |         joins = [alias] | ||||||
|         # First, generate the path for the names |         # The transform can't be applied yet, as joins must be trimmed later. | ||||||
|         path, final_field, targets, rest = self.names_to_path( |         # To avoid making every caller of this method look up transforms | ||||||
|             names, opts, allow_many, fail_on_missing=True) |         # directly, compute transforms here and and create a partial that | ||||||
|  |         # converts fields to the appropriate wrapped version. | ||||||
|  |  | ||||||
|  |         def final_transformer(field, alias): | ||||||
|  |             return field.get_col(alias) | ||||||
|  |  | ||||||
|  |         # Try resolving all the names as fields first. If there's an error, | ||||||
|  |         # treat trailing names as lookups until a field can be resolved. | ||||||
|  |         last_field_exception = None | ||||||
|  |         for pivot in range(len(names), 0, -1): | ||||||
|  |             try: | ||||||
|  |                 path, final_field, targets, rest = self.names_to_path( | ||||||
|  |                     names[:pivot], opts, allow_many, fail_on_missing=True, | ||||||
|  |                 ) | ||||||
|  |             except FieldError as exc: | ||||||
|  |                 if pivot == 1: | ||||||
|  |                     # The first item cannot be a lookup, so it's safe | ||||||
|  |                     # to raise the field error here. | ||||||
|  |                     raise | ||||||
|  |                 else: | ||||||
|  |                     last_field_exception = exc | ||||||
|  |             else: | ||||||
|  |                 # The transforms are the remaining items that couldn't be | ||||||
|  |                 # resolved into fields. | ||||||
|  |                 transforms = names[pivot:] | ||||||
|  |                 break | ||||||
|  |         for name in transforms: | ||||||
|  |             def transform(field, alias, *, name, previous): | ||||||
|  |                 try: | ||||||
|  |                     wrapped = previous(field, alias) | ||||||
|  |                     return self.try_transform(wrapped, name) | ||||||
|  |                 except FieldError: | ||||||
|  |                     # FieldError is raised if the transform doesn't exist. | ||||||
|  |                     if isinstance(final_field, Field) and last_field_exception: | ||||||
|  |                         raise last_field_exception | ||||||
|  |                     else: | ||||||
|  |                         raise | ||||||
|  |             final_transformer = functools.partial(transform, name=name, previous=final_transformer) | ||||||
|         # Then, add the path to the query's joins. Note that we can't trim |         # Then, add the path to the query's joins. Note that we can't trim | ||||||
|         # joins at this stage - we will need the information about join type |         # joins at this stage - we will need the information about join type | ||||||
|         # of the trimmed joins. |         # of the trimmed joins. | ||||||
| @@ -1470,7 +1511,7 @@ class Query: | |||||||
|             joins.append(alias) |             joins.append(alias) | ||||||
|             if filtered_relation: |             if filtered_relation: | ||||||
|                 filtered_relation.path = joins[:] |                 filtered_relation.path = joins[:] | ||||||
|         return JoinInfo(final_field, targets, opts, joins, path) |         return JoinInfo(final_field, targets, opts, joins, path, final_transformer) | ||||||
|  |  | ||||||
|     def trim_joins(self, targets, joins, path): |     def trim_joins(self, targets, joins, path): | ||||||
|         """ |         """ | ||||||
| @@ -1683,7 +1724,7 @@ class Query: | |||||||
|                     join_info.path, |                     join_info.path, | ||||||
|                 ) |                 ) | ||||||
|                 for target in targets: |                 for target in targets: | ||||||
|                     cols.append(target.get_col(final_alias)) |                     cols.append(join_info.transform_function(target, final_alias)) | ||||||
|             if cols: |             if cols: | ||||||
|                 self.set_select(cols) |                 self.set_select(cols) | ||||||
|         except MultiJoin: |         except MultiJoin: | ||||||
|   | |||||||
| @@ -138,6 +138,21 @@ SQL:: | |||||||
| Note that in case there is no other lookup specified, Django interprets | Note that in case there is no other lookup specified, Django interprets | ||||||
| ``change__abs=27`` as ``change__abs__exact=27``. | ``change__abs=27`` as ``change__abs__exact=27``. | ||||||
|  |  | ||||||
|  | This also allows the result to be used in ``ORDER BY`` and ``DISTINCT ON`` | ||||||
|  | clauses. For example ``Experiment.objects.order_by('change__abs')`` generates:: | ||||||
|  |  | ||||||
|  |     SELECT ... ORDER BY ABS("experiments"."change") ASC | ||||||
|  |  | ||||||
|  | And on databases that support distinct on fields (such as PostgreSQL), | ||||||
|  | ``Experiment.objects.distinct('change__abs')`` generates:: | ||||||
|  |  | ||||||
|  |     SELECT ... DISTINCT ON ABS("experiments"."change") | ||||||
|  |  | ||||||
|  | .. versionchanged:: 2.1 | ||||||
|  |  | ||||||
|  |     Ordering and distinct support as described in the last two paragraphs was | ||||||
|  |     added. | ||||||
|  |  | ||||||
| When looking for which lookups are allowable after the ``Transform`` has been | When looking for which lookups are allowable after the ``Transform`` has been | ||||||
| applied, Django uses the ``output_field`` attribute. We didn't need to specify | applied, Django uses the ``output_field`` attribute. We didn't need to specify | ||||||
| this here as it didn't change, but supposing we were applying ``AbsoluteValue`` | this here as it didn't change, but supposing we were applying ``AbsoluteValue`` | ||||||
|   | |||||||
| @@ -64,10 +64,14 @@ Some examples | |||||||
|     # Aggregates can contain complex computations also |     # Aggregates can contain complex computations also | ||||||
|     Company.objects.annotate(num_offerings=Count(F('products') + F('services'))) |     Company.objects.annotate(num_offerings=Count(F('products') + F('services'))) | ||||||
|  |  | ||||||
|     # Expressions can also be used in order_by() |     # Expressions can also be used in order_by(), either directly | ||||||
|     Company.objects.order_by(Length('name').asc()) |     Company.objects.order_by(Length('name').asc()) | ||||||
|     Company.objects.order_by(Length('name').desc()) |     Company.objects.order_by(Length('name').desc()) | ||||||
|  |     # or using the double underscore lookup syntax. | ||||||
|  |     from django.db.models import CharField | ||||||
|  |     from django.db.models.functions import Length | ||||||
|  |     CharField.register_lookup(Length) | ||||||
|  |     Company.objects.order_by('name__length') | ||||||
|  |  | ||||||
| Built-in Expressions | Built-in Expressions | ||||||
| ==================== | ==================== | ||||||
|   | |||||||
| @@ -535,6 +535,19 @@ The ``values()`` method also takes optional keyword arguments, | |||||||
|     >>> Blog.objects.values(lower_name=Lower('name')) |     >>> Blog.objects.values(lower_name=Lower('name')) | ||||||
|     <QuerySet [{'lower_name': 'beatles blog'}]> |     <QuerySet [{'lower_name': 'beatles blog'}]> | ||||||
|  |  | ||||||
|  | You can use built-in and :doc:`custom lookups </howto/custom-lookups>` in | ||||||
|  | ordering. For example:: | ||||||
|  |  | ||||||
|  |     >>> from django.db.models import CharField | ||||||
|  |     >>> from django.db.models.functions import Lower | ||||||
|  |     >>> CharField.register_lookup(Lower, 'lower') | ||||||
|  |     >>> Blog.objects.values('name__lower') | ||||||
|  |     <QuerySet [{'name__lower': 'beatles blog'}]> | ||||||
|  |  | ||||||
|  | .. versionchanged:: 2.1 | ||||||
|  |  | ||||||
|  |     Support for lookups was added. | ||||||
|  |  | ||||||
| An aggregate within a ``values()`` clause is applied before other arguments | An aggregate within a ``values()`` clause is applied before other arguments | ||||||
| within the same ``values()`` clause. If you need to group by another value, | within the same ``values()`` clause. If you need to group by another value, | ||||||
| add it to an earlier ``values()`` clause instead. For example:: | add it to an earlier ``values()`` clause instead. For example:: | ||||||
| @@ -580,6 +593,25 @@ A few subtleties that are worth mentioning: | |||||||
| * Calling :meth:`only()` and :meth:`defer()` after ``values()`` doesn't make | * Calling :meth:`only()` and :meth:`defer()` after ``values()`` doesn't make | ||||||
|   sense, so doing so will raise a ``NotImplementedError``. |   sense, so doing so will raise a ``NotImplementedError``. | ||||||
|  |  | ||||||
|  | * Combining transforms and aggregates requires the use of two :meth:`annotate` | ||||||
|  |   calls, either explicitly or as keyword arguments to :meth:`values`. As above, | ||||||
|  |   if the transform has been registered on the relevant field type the first | ||||||
|  |   :meth:`annotate` can be omitted, thus the following examples are equivalent:: | ||||||
|  |  | ||||||
|  |     >>> from django.db.models import CharField, Count | ||||||
|  |     >>> from django.db.models.functions import Lower | ||||||
|  |     >>> CharField.register_lookup(Lower, 'lower') | ||||||
|  |     >>> Blog.objects.values('entry__authors__name__lower').annotate(entries=Count('entry')) | ||||||
|  |     <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]> | ||||||
|  |     >>> Blog.objects.values( | ||||||
|  |     ...     entry__authors__name__lower=Lower('entry__authors__name') | ||||||
|  |     ... ).annotate(entries=Count('entry')) | ||||||
|  |     <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]> | ||||||
|  |     >>> Blog.objects.annotate( | ||||||
|  |     ...     entry__authors__name__lower=Lower('entry__authors__name') | ||||||
|  |     ... ).values('entry__authors__name__lower').annotate(entries=Count('entry')) | ||||||
|  |     <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]> | ||||||
|  |  | ||||||
| It is useful when you know you're only going to need values from a small number | It is useful when you know you're only going to need values from a small number | ||||||
| of the available fields and you won't need the functionality of a model | of the available fields and you won't need the functionality of a model | ||||||
| instance object. It's more efficient to select only the fields you need to use. | instance object. It's more efficient to select only the fields you need to use. | ||||||
|   | |||||||
| @@ -187,6 +187,9 @@ Models | |||||||
|  |  | ||||||
| * Query expressions can now be negated using a minus sign. | * Query expressions can now be negated using a minus sign. | ||||||
|  |  | ||||||
|  | * :meth:`.QuerySet.order_by` and :meth:`distinct(*fields) <.QuerySet.distinct>` | ||||||
|  |   now support using field transforms. | ||||||
|  |  | ||||||
| Requests and Responses | Requests and Responses | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
| @@ -242,6 +245,9 @@ Database backend API | |||||||
| * Renamed the ``allow_sliced_subqueries`` database feature flag to | * Renamed the ``allow_sliced_subqueries`` database feature flag to | ||||||
|   ``allow_sliced_subqueries_with_in``. |   ``allow_sliced_subqueries_with_in``. | ||||||
|  |  | ||||||
|  | * ``DatabaseOperations.distinct_sql()`` now requires an additional ``params`` | ||||||
|  |   argument and returns a tuple of SQL and parameters instead of a SQL string. | ||||||
|  |  | ||||||
| :mod:`django.contrib.gis` | :mod:`django.contrib.gis` | ||||||
| ------------------------- | ------------------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ class DatabaseOperationTests(SimpleTestCase): | |||||||
|     def test_distinct_on_fields(self): |     def test_distinct_on_fields(self): | ||||||
|         msg = 'DISTINCT ON fields is not supported by this database backend' |         msg = 'DISTINCT ON fields is not supported by this database backend' | ||||||
|         with self.assertRaisesMessage(NotSupportedError, msg): |         with self.assertRaisesMessage(NotSupportedError, msg): | ||||||
|             self.ops.distinct_sql(['a', 'b']) |             self.ops.distinct_sql(['a', 'b'], None) | ||||||
|  |  | ||||||
|     def test_deferrable_sql(self): |     def test_deferrable_sql(self): | ||||||
|         self.assertEqual(self.ops.deferrable_sql(), '') |         self.assertEqual(self.ops.deferrable_sql(), '') | ||||||
|   | |||||||
| @@ -63,6 +63,14 @@ class Mult3BilateralTransform(models.Transform): | |||||||
|         return '3 * (%s)' % lhs, lhs_params |         return '3 * (%s)' % lhs, lhs_params | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LastDigitTransform(models.Transform): | ||||||
|  |     lookup_name = 'lastdigit' | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         lhs, lhs_params = compiler.compile(self.lhs) | ||||||
|  |         return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params | ||||||
|  |  | ||||||
|  |  | ||||||
| class UpperBilateralTransform(models.Transform): | class UpperBilateralTransform(models.Transform): | ||||||
|     bilateral = True |     bilateral = True | ||||||
|     lookup_name = 'upper' |     lookup_name = 'upper' | ||||||
| @@ -379,6 +387,15 @@ class BilateralTransformTests(TestCase): | |||||||
|             self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]) |             self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]) | ||||||
|             self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3]) |             self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3]) | ||||||
|  |  | ||||||
|  |     def test_transform_order_by(self): | ||||||
|  |         with register_lookup(models.IntegerField, LastDigitTransform): | ||||||
|  |             a1 = Author.objects.create(name='a1', age=11) | ||||||
|  |             a2 = Author.objects.create(name='a2', age=23) | ||||||
|  |             a3 = Author.objects.create(name='a3', age=32) | ||||||
|  |             a4 = Author.objects.create(name='a4', age=40) | ||||||
|  |             qs = Author.objects.order_by('age__lastdigit') | ||||||
|  |             self.assertSequenceEqual(qs, [a4, a1, a3, a2]) | ||||||
|  |  | ||||||
|     def test_bilateral_fexpr(self): |     def test_bilateral_fexpr(self): | ||||||
|         with register_lookup(models.IntegerField, Mult3BilateralTransform): |         with register_lookup(models.IntegerField, Mult3BilateralTransform): | ||||||
|             a1 = Author.objects.create(name='a1', age=1, average_rating=3.2) |             a1 = Author.objects.create(name='a1', age=1, average_rating=3.2) | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from django.db.models import Max | from django.db.models import CharField, Max | ||||||
|  | from django.db.models.functions import Lower | ||||||
| from django.test import TestCase, skipUnlessDBFeature | from django.test import TestCase, skipUnlessDBFeature | ||||||
|  |  | ||||||
| from .models import Celebrity, Fan, Staff, StaffTag, Tag | from .models import Celebrity, Fan, Staff, StaffTag, Tag | ||||||
| @@ -8,19 +9,19 @@ from .models import Celebrity, Fan, Staff, StaffTag, Tag | |||||||
| @skipUnlessDBFeature('supports_nullable_unique_constraints') | @skipUnlessDBFeature('supports_nullable_unique_constraints') | ||||||
| class DistinctOnTests(TestCase): | class DistinctOnTests(TestCase): | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         t1 = Tag.objects.create(name='t1') |         self.t1 = Tag.objects.create(name='t1') | ||||||
|         Tag.objects.create(name='t2', parent=t1) |         self.t2 = Tag.objects.create(name='t2', parent=self.t1) | ||||||
|         t3 = Tag.objects.create(name='t3', parent=t1) |         self.t3 = Tag.objects.create(name='t3', parent=self.t1) | ||||||
|         Tag.objects.create(name='t4', parent=t3) |         self.t4 = Tag.objects.create(name='t4', parent=self.t3) | ||||||
|         Tag.objects.create(name='t5', parent=t3) |         self.t5 = Tag.objects.create(name='t5', parent=self.t3) | ||||||
|  |  | ||||||
|         self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1") |         self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1") | ||||||
|         self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1") |         self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1") | ||||||
|         self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1") |         self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1") | ||||||
|         self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2") |         self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2") | ||||||
|         self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1) |         self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1) | ||||||
|         StaffTag.objects.create(staff=self.p1_o1, tag=t1) |         StaffTag.objects.create(staff=self.p1_o1, tag=self.t1) | ||||||
|         StaffTag.objects.create(staff=self.p1_o1, tag=t1) |         StaffTag.objects.create(staff=self.p1_o1, tag=self.t1) | ||||||
|  |  | ||||||
|         celeb1 = Celebrity.objects.create(name="c1") |         celeb1 = Celebrity.objects.create(name="c1") | ||||||
|         celeb2 = Celebrity.objects.create(name="c2") |         celeb2 = Celebrity.objects.create(name="c2") | ||||||
| @@ -95,6 +96,19 @@ class DistinctOnTests(TestCase): | |||||||
|         c2 = c1.distinct('pk') |         c2 = c1.distinct('pk') | ||||||
|         self.assertNotIn('OUTER JOIN', str(c2.query)) |         self.assertNotIn('OUTER JOIN', str(c2.query)) | ||||||
|  |  | ||||||
|  |     def test_transform(self): | ||||||
|  |         new_name = self.t1.name.upper() | ||||||
|  |         self.assertNotEqual(self.t1.name, new_name) | ||||||
|  |         Tag.objects.create(name=new_name) | ||||||
|  |         CharField.register_lookup(Lower) | ||||||
|  |         try: | ||||||
|  |             self.assertCountEqual( | ||||||
|  |                 Tag.objects.order_by().distinct('name__lower'), | ||||||
|  |                 [self.t1, self.t2, self.t3, self.t4, self.t5], | ||||||
|  |             ) | ||||||
|  |         finally: | ||||||
|  |             CharField._unregister_lookup(Lower) | ||||||
|  |  | ||||||
|     def test_distinct_not_implemented_checks(self): |     def test_distinct_not_implemented_checks(self): | ||||||
|         # distinct + annotate not allowed |         # distinct + annotate not allowed | ||||||
|         msg = 'annotate() + distinct(fields) is not implemented.' |         msg = 'annotate() + distinct(fields) is not implemented.' | ||||||
|   | |||||||
| @@ -1363,6 +1363,40 @@ class ValueTests(TestCase): | |||||||
|             ExpressionList() |             ExpressionList() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FieldTransformTests(TestCase): | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         cls.sday = sday = datetime.date(2010, 6, 25) | ||||||
|  |         cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000) | ||||||
|  |         cls.ex1 = Experiment.objects.create( | ||||||
|  |             name='Experiment 1', | ||||||
|  |             assigned=sday, | ||||||
|  |             completed=sday + datetime.timedelta(2), | ||||||
|  |             estimated_time=datetime.timedelta(2), | ||||||
|  |             start=stime, | ||||||
|  |             end=stime + datetime.timedelta(2), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_month_aggregation(self): | ||||||
|  |         self.assertEqual( | ||||||
|  |             Experiment.objects.aggregate(month_count=Count('assigned__month')), | ||||||
|  |             {'month_count': 1} | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_transform_in_values(self): | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Experiment.objects.values('assigned__month'), | ||||||
|  |             ["{'assigned__month': 6}"] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_multiple_transforms_in_values(self): | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Experiment.objects.values('end__date__month'), | ||||||
|  |             ["{'end__date__month': 6}"] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReprTests(TestCase): | class ReprTests(TestCase): | ||||||
|  |  | ||||||
|     def test_expressions(self): |     def test_expressions(self): | ||||||
|   | |||||||
| @@ -309,6 +309,22 @@ class TestQuerying(PostgreSQLTestCase): | |||||||
|             self.objs[2:3] |             self.objs[2:3] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_order_by_slice(self): | ||||||
|  |         more_objs = ( | ||||||
|  |             NullableIntegerArrayModel.objects.create(field=[1, 637]), | ||||||
|  |             NullableIntegerArrayModel.objects.create(field=[2, 1]), | ||||||
|  |             NullableIntegerArrayModel.objects.create(field=[3, -98123]), | ||||||
|  |             NullableIntegerArrayModel.objects.create(field=[4, 2]), | ||||||
|  |         ) | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             NullableIntegerArrayModel.objects.order_by('field__1'), | ||||||
|  |             [ | ||||||
|  |                 more_objs[2], more_objs[1], more_objs[3], self.objs[2], | ||||||
|  |                 self.objs[3], more_objs[0], self.objs[4], self.objs[1], | ||||||
|  |                 self.objs[0], | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @unittest.expectedFailure |     @unittest.expectedFailure | ||||||
|     def test_slice_nested(self): |     def test_slice_nested(self): | ||||||
|         instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) |         instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]]) | ||||||
|   | |||||||
| @@ -148,6 +148,18 @@ class TestQuerying(HStoreTestCase): | |||||||
|             self.objs[:2] |             self.objs[:2] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_order_by_field(self): | ||||||
|  |         more_objs = ( | ||||||
|  |             HStoreModel.objects.create(field={'g': '637'}), | ||||||
|  |             HStoreModel.objects.create(field={'g': '002'}), | ||||||
|  |             HStoreModel.objects.create(field={'g': '042'}), | ||||||
|  |             HStoreModel.objects.create(field={'g': '981'}), | ||||||
|  |         ) | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             HStoreModel.objects.filter(field__has_key='g').order_by('field__g'), | ||||||
|  |             [more_objs[1], more_objs[2], more_objs[0], more_objs[3]] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_keys_contains(self): |     def test_keys_contains(self): | ||||||
|         self.assertSequenceEqual( |         self.assertSequenceEqual( | ||||||
|             HStoreModel.objects.filter(field__keys__contains=['a']), |             HStoreModel.objects.filter(field__keys__contains=['a']), | ||||||
|   | |||||||
| @@ -141,6 +141,31 @@ class TestQuerying(PostgreSQLTestCase): | |||||||
|             [self.objs[0]] |             [self.objs[0]] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_ordering_by_transform(self): | ||||||
|  |         objs = [ | ||||||
|  |             JSONModel.objects.create(field={'ord': 93, 'name': 'bar'}), | ||||||
|  |             JSONModel.objects.create(field={'ord': 22.1, 'name': 'foo'}), | ||||||
|  |             JSONModel.objects.create(field={'ord': -1, 'name': 'baz'}), | ||||||
|  |             JSONModel.objects.create(field={'ord': 21.931902, 'name': 'spam'}), | ||||||
|  |             JSONModel.objects.create(field={'ord': -100291029, 'name': 'eggs'}), | ||||||
|  |         ] | ||||||
|  |         query = JSONModel.objects.filter(field__name__isnull=False).order_by('field__ord') | ||||||
|  |         self.assertSequenceEqual(query, [objs[4], objs[2], objs[3], objs[1], objs[0]]) | ||||||
|  |  | ||||||
|  |     def test_deep_values(self): | ||||||
|  |         query = JSONModel.objects.values_list('field__k__l') | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             query, | ||||||
|  |             [ | ||||||
|  |                 (None,), (None,), (None,), (None,), (None,), (None,), | ||||||
|  |                 (None,), (None,), ('m',), (None,), (None,), (None,), | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_deep_distinct(self): | ||||||
|  |         query = JSONModel.objects.distinct('field__k__l').values_list('field__k__l') | ||||||
|  |         self.assertSequenceEqual(query, [('m',), (None,)]) | ||||||
|  |  | ||||||
|     def test_isnull_key(self): |     def test_isnull_key(self): | ||||||
|         # key__isnull works the same as has_key='key'. |         # key__isnull works the same as has_key='key'. | ||||||
|         self.assertSequenceEqual( |         self.assertSequenceEqual( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user