mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #19500 -- Solved a regression in join reuse
The ORM didn't reuse joins for direct foreign key traversals when using
chained filters. For example:
    qs.filter(fk__somefield=1).filter(fk__somefield=2))
produced two joins.
As a bonus, reverse onetoone filters can now reuse joins correctly
The regression was caused by the join() method refactor in commit
68847135bc
Thanks for Simon Charette for spotting some issues with the first draft
of the patch.
			
			
This commit is contained in:
		| @@ -6,7 +6,7 @@ from django.db.backends.util import truncate_name | |||||||
| from django.db.models.constants import LOOKUP_SEP | from django.db.models.constants import LOOKUP_SEP | ||||||
| from django.db.models.query_utils import select_related_descend | from django.db.models.query_utils import select_related_descend | ||||||
| from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, | from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, | ||||||
|         GET_ITERATOR_CHUNK_SIZE, REUSE_ALL, SelectInfo) |         GET_ITERATOR_CHUNK_SIZE, SelectInfo) | ||||||
| from django.db.models.sql.datastructures import EmptyResultSet | from django.db.models.sql.datastructures import EmptyResultSet | ||||||
| from django.db.models.sql.expressions import SQLEvaluator | from django.db.models.sql.expressions import SQLEvaluator | ||||||
| from django.db.models.sql.query import get_order_dir, Query | from django.db.models.sql.query import get_order_dir, Query | ||||||
| @@ -317,7 +317,7 @@ class SQLCompiler(object): | |||||||
|  |  | ||||||
|         for name in self.query.distinct_fields: |         for name in self.query.distinct_fields: | ||||||
|             parts = name.split(LOOKUP_SEP) |             parts = name.split(LOOKUP_SEP) | ||||||
|             field, col, alias, _, _ = self._setup_joins(parts, opts, None) |             field, col, alias, _, _ = self._setup_joins(parts, opts) | ||||||
|             col, alias = self._final_join_removal(col, alias) |             col, alias = self._final_join_removal(col, alias) | ||||||
|             result.append("%s.%s" % (qn(alias), qn2(col))) |             result.append("%s.%s" % (qn(alias), qn2(col))) | ||||||
|         return result |         return result | ||||||
| @@ -450,7 +450,7 @@ class SQLCompiler(object): | |||||||
|         if not alias: |         if not alias: | ||||||
|             alias = self.query.get_initial_alias() |             alias = self.query.get_initial_alias() | ||||||
|         field, target, opts, joins, _ = self.query.setup_joins( |         field, target, opts, joins, _ = self.query.setup_joins( | ||||||
|             pieces, opts, alias, REUSE_ALL) |             pieces, opts, alias) | ||||||
|         # We will later on need to promote those joins that were added to the |         # We will later on need to promote those joins that were added to the | ||||||
|         # query afresh above. |         # query afresh above. | ||||||
|         joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] |         joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] | ||||||
| @@ -688,7 +688,7 @@ class SQLCompiler(object): | |||||||
|                         int_opts = int_model._meta |                         int_opts = int_model._meta | ||||||
|                         alias = self.query.join( |                         alias = self.query.join( | ||||||
|                             (alias, int_opts.db_table, lhs_col, int_opts.pk.column), |                             (alias, int_opts.db_table, lhs_col, int_opts.pk.column), | ||||||
|                             promote=True, |                             promote=True | ||||||
|                         ) |                         ) | ||||||
|                         alias_chain.append(alias) |                         alias_chain.append(alias) | ||||||
|                 alias = self.query.join( |                 alias = self.query.join( | ||||||
|   | |||||||
| @@ -44,6 +44,3 @@ ORDER_DIR = { | |||||||
|     'ASC': ('ASC', 'DESC'), |     'ASC': ('ASC', 'DESC'), | ||||||
|     'DESC': ('DESC', 'ASC'), |     'DESC': ('DESC', 'ASC'), | ||||||
| } | } | ||||||
|  |  | ||||||
| # A marker for join-reusability. |  | ||||||
| REUSE_ALL = object() |  | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db.models.constants import LOOKUP_SEP | from django.db.models.constants import LOOKUP_SEP | ||||||
| from django.db.models.fields import FieldDoesNotExist | from django.db.models.fields import FieldDoesNotExist | ||||||
| from django.db.models.sql.constants import REUSE_ALL |  | ||||||
|  |  | ||||||
| class SQLEvaluator(object): | class SQLEvaluator(object): | ||||||
|     def __init__(self, expression, query, allow_joins=True, reuse=REUSE_ALL): |     def __init__(self, expression, query, allow_joins=True, reuse=None): | ||||||
|         self.expression = expression |         self.expression = expression | ||||||
|         self.opts = query.get_meta() |         self.opts = query.get_meta() | ||||||
|         self.cols = [] |         self.cols = [] | ||||||
| @@ -54,7 +53,7 @@ class SQLEvaluator(object): | |||||||
|                     field_list, query.get_meta(), |                     field_list, query.get_meta(), | ||||||
|                     query.get_initial_alias(), self.reuse) |                     query.get_initial_alias(), self.reuse) | ||||||
|                 col, _, join_list = query.trim_joins(source, join_list, path) |                 col, _, join_list = query.trim_joins(source, join_list, path) | ||||||
|                 if self.reuse is not None and self.reuse != REUSE_ALL: |                 if self.reuse is not None: | ||||||
|                     self.reuse.update(join_list) |                     self.reuse.update(join_list) | ||||||
|                 self.cols.append((node, (join_list[-1], col))) |                 self.cols.append((node, (join_list[-1], col))) | ||||||
|             except FieldDoesNotExist: |             except FieldDoesNotExist: | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ from django.db.models.fields import FieldDoesNotExist | |||||||
| from django.db.models.loading import get_model | from django.db.models.loading import get_model | ||||||
| from django.db.models.sql import aggregates as base_aggregates_module | from django.db.models.sql import aggregates as base_aggregates_module | ||||||
| from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, | from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, | ||||||
|         ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo) |         ORDER_PATTERN, JoinInfo, SelectInfo, PathInfo) | ||||||
| from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin | from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin | ||||||
| from django.db.models.sql.expressions import SQLEvaluator | from django.db.models.sql.expressions import SQLEvaluator | ||||||
| from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, | from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, | ||||||
| @@ -891,7 +891,7 @@ class Query(object): | |||||||
|         """ |         """ | ||||||
|         return len([1 for count in self.alias_refcount.values() if count]) |         return len([1 for count in self.alias_refcount.values() if count]) | ||||||
|  |  | ||||||
|     def join(self, connection, reuse=REUSE_ALL, promote=False, |     def join(self, connection, reuse=None, promote=False, | ||||||
|              outer_if_first=False, nullable=False, join_field=None): |              outer_if_first=False, nullable=False, join_field=None): | ||||||
|         """ |         """ | ||||||
|         Returns an alias for the join in 'connection', either reusing an |         Returns an alias for the join in 'connection', either reusing an | ||||||
| @@ -902,10 +902,9 @@ class Query(object): | |||||||
|  |  | ||||||
|             lhs.lhs_col = table.col |             lhs.lhs_col = table.col | ||||||
|  |  | ||||||
|         The 'reuse' parameter can be used in three ways: it can be REUSE_ALL |         The 'reuse' parameter can be either None which means all joins | ||||||
|         which means all joins (matching the connection) are reusable, it can |         (matching the connection) are reusable, or it can be a set containing | ||||||
|         be a set containing the aliases that can be reused, or it can be None |         the aliases that can be reused. | ||||||
|         which means a new join is always created. |  | ||||||
|  |  | ||||||
|         If 'promote' is True, the join type for the alias will be LOUTER (if |         If 'promote' is True, the join type for the alias will be LOUTER (if | ||||||
|         the alias previously existed, the join type will be promoted from INNER |         the alias previously existed, the join type will be promoted from INNER | ||||||
| @@ -926,10 +925,8 @@ class Query(object): | |||||||
|         """ |         """ | ||||||
|         lhs, table, lhs_col, col = connection |         lhs, table, lhs_col, col = connection | ||||||
|         existing = self.join_map.get(connection, ()) |         existing = self.join_map.get(connection, ()) | ||||||
|         if reuse == REUSE_ALL: |         if reuse is None: | ||||||
|             reuse = existing |             reuse = existing | ||||||
|         elif reuse is None: |  | ||||||
|             reuse = set() |  | ||||||
|         else: |         else: | ||||||
|             reuse = [a for a in existing if a in reuse] |             reuse = [a for a in existing if a in reuse] | ||||||
|         for alias in reuse: |         for alias in reuse: | ||||||
| @@ -1040,7 +1037,7 @@ class Query(object): | |||||||
|             # then we need to explore the joins that are required. |             # then we need to explore the joins that are required. | ||||||
|  |  | ||||||
|             field, source, opts, join_list, path = self.setup_joins( |             field, source, opts, join_list, path = self.setup_joins( | ||||||
|                 field_list, opts, self.get_initial_alias(), REUSE_ALL) |                 field_list, opts, self.get_initial_alias()) | ||||||
|  |  | ||||||
|             # Process the join chain to see if it can be trimmed |             # Process the join chain to see if it can be trimmed | ||||||
|             col, _, join_list = self.trim_joins(source, join_list, path) |             col, _, join_list = self.trim_joins(source, join_list, path) | ||||||
| @@ -1441,7 +1438,7 @@ class Query(object): | |||||||
|             raise MultiJoin(multijoin_pos + 1) |             raise MultiJoin(multijoin_pos + 1) | ||||||
|         return path, final_field, target |         return path, final_field, target | ||||||
|  |  | ||||||
|     def setup_joins(self, names, opts, alias, can_reuse, allow_many=True, |     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, | ||||||
|                     allow_explicit_fk=False): |                     allow_explicit_fk=False): | ||||||
|         """ |         """ | ||||||
|         Compute the necessary table joins for the passage through the fields |         Compute the necessary table joins for the passage through the fields | ||||||
| @@ -1450,9 +1447,9 @@ class Query(object): | |||||||
|         the table to start the joining from. |         the table to start the joining from. | ||||||
|  |  | ||||||
|         The 'can_reuse' defines the reverse foreign key joins we can reuse. It |         The 'can_reuse' defines the reverse foreign key joins we can reuse. It | ||||||
|         can be sql.constants.REUSE_ALL in which case all joins are reusable |         can be None in which case all joins are reusable or a set of aliases | ||||||
|         or a set of aliases that can be reused. Note that Non-reverse foreign |         that can be reused. Note that non-reverse foreign keys are always | ||||||
|         keys are always reusable. |         reusable when using setup_joins(). | ||||||
|  |  | ||||||
|         If 'allow_many' is False, then any reverse foreign key seen will |         If 'allow_many' is False, then any reverse foreign key seen will | ||||||
|         generate a MultiJoin exception. |         generate a MultiJoin exception. | ||||||
| @@ -1485,8 +1482,9 @@ class Query(object): | |||||||
|             else: |             else: | ||||||
|                 nullable = True |                 nullable = True | ||||||
|             connection = alias, opts.db_table, from_field.column, to_field.column |             connection = alias, opts.db_table, from_field.column, to_field.column | ||||||
|             alias = self.join(connection, reuse=can_reuse, nullable=nullable, |             reuse = None if direct or to_field.unique else can_reuse | ||||||
|                               join_field=join_field) |             alias = self.join(connection, reuse=reuse, | ||||||
|  |                               nullable=nullable, join_field=join_field) | ||||||
|             joins.append(alias) |             joins.append(alias) | ||||||
|         return final_field, target, opts, joins, path |         return final_field, target, opts, joins, path | ||||||
|  |  | ||||||
| @@ -1643,7 +1641,7 @@ class Query(object): | |||||||
|         try: |         try: | ||||||
|             for name in field_names: |             for name in field_names: | ||||||
|                 field, target, u2, joins, u3 = self.setup_joins( |                 field, target, u2, joins, u3 = self.setup_joins( | ||||||
|                         name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m, |                         name.split(LOOKUP_SEP), opts, alias, None, allow_m2m, | ||||||
|                         True) |                         True) | ||||||
|                 final_alias = joins[-1] |                 final_alias = joins[-1] | ||||||
|                 col = target.column |                 col = target.column | ||||||
| @@ -1729,7 +1727,8 @@ class Query(object): | |||||||
|         else: |         else: | ||||||
|             opts = self.model._meta |             opts = self.model._meta | ||||||
|             if not self.select: |             if not self.select: | ||||||
|                 count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column), |                 count = self.aggregates_module.Count( | ||||||
|  |                     (self.join((None, opts.db_table, None, None)), opts.pk.column), | ||||||
|                     is_summary=True, distinct=True) |                     is_summary=True, distinct=True) | ||||||
|             else: |             else: | ||||||
|                 # Because of SQL portability issues, multi-column, distinct |                 # Because of SQL portability issues, multi-column, distinct | ||||||
| @@ -1934,7 +1933,7 @@ class Query(object): | |||||||
|         opts = self.model._meta |         opts = self.model._meta | ||||||
|         alias = self.get_initial_alias() |         alias = self.get_initial_alias() | ||||||
|         field, col, opts, joins, extra = self.setup_joins( |         field, col, opts, joins, extra = self.setup_joins( | ||||||
|                 start.split(LOOKUP_SEP), opts, alias, REUSE_ALL) |                 start.split(LOOKUP_SEP), opts, alias) | ||||||
|         select_col = self.alias_map[joins[1]].lhs_join_col |         select_col = self.alias_map[joins[1]].lhs_join_col | ||||||
|         select_alias = alias |         select_alias = alias | ||||||
|  |  | ||||||
|   | |||||||
| @@ -232,7 +232,6 @@ class DateQuery(Query): | |||||||
|                 field_name.split(LOOKUP_SEP), |                 field_name.split(LOOKUP_SEP), | ||||||
|                 self.get_meta(), |                 self.get_meta(), | ||||||
|                 self.get_initial_alias(), |                 self.get_initial_alias(), | ||||||
|                 False |  | ||||||
|             ) |             ) | ||||||
|         except FieldError: |         except FieldError: | ||||||
|             raise FieldDoesNotExist("%s has no field named '%s'" % ( |             raise FieldDoesNotExist("%s has no field named '%s'" % ( | ||||||
|   | |||||||
| @@ -72,5 +72,11 @@ class GenericRelationTests(TestCase): | |||||||
|             Q(notes__note__icontains=r'other note')) |             Q(notes__note__icontains=r'other note')) | ||||||
|         self.assertTrue(org_contact in qs) |         self.assertTrue(org_contact in qs) | ||||||
|  |  | ||||||
|  |     def test_join_reuse(self): | ||||||
|  |         qs = Person.objects.filter( | ||||||
|  |             addresses__street='foo' | ||||||
|  |         ).filter( | ||||||
|  |             addresses__street='bar' | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 2) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2418,3 +2418,36 @@ class ReverseJoinTrimmingTest(TestCase): | |||||||
|         qs = Tag.objects.filter(annotation__tag=t.pk) |         qs = Tag.objects.filter(annotation__tag=t.pk) | ||||||
|         self.assertIn('INNER JOIN', str(qs.query)) |         self.assertIn('INNER JOIN', str(qs.query)) | ||||||
|         self.assertEquals(list(qs), []) |         self.assertEquals(list(qs), []) | ||||||
|  |  | ||||||
|  | class JoinReuseTest(TestCase): | ||||||
|  |     """ | ||||||
|  |     Test that the queries reuse joins sensibly (for example, direct joins | ||||||
|  |     are always reused). | ||||||
|  |     """ | ||||||
|  |     def test_fk_reuse(self): | ||||||
|  |         qs = Annotation.objects.filter(tag__name='foo').filter(tag__name='bar') | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_fk_reuse_select_related(self): | ||||||
|  |         qs = Annotation.objects.filter(tag__name='foo').select_related('tag') | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_fk_reuse_annotation(self): | ||||||
|  |         qs = Annotation.objects.filter(tag__name='foo').annotate(cnt=Count('tag__name')) | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_fk_reuse_disjunction(self): | ||||||
|  |         qs = Annotation.objects.filter(Q(tag__name='foo') | Q(tag__name='bar')) | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_fk_reuse_order_by(self): | ||||||
|  |         qs = Annotation.objects.filter(tag__name='foo').order_by('tag__name') | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_revo2o_reuse(self): | ||||||
|  |         qs = Detail.objects.filter(member__name='foo').filter(member__name='foo') | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 1) | ||||||
|  |  | ||||||
|  |     def test_revfk_noreuse(self): | ||||||
|  |         qs = Author.objects.filter(report__name='r4').filter(report__name='r1') | ||||||
|  |         self.assertEqual(str(qs.query).count('JOIN'), 2) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user