From 306b6875209cfedce2536a6679e69adee7c9bc6a Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 9 Sep 2019 22:58:29 -0500 Subject: [PATCH] Refs #11964 -- Removed SimpleCol in favor of Query(alias_cols). This prevent having to pass simple_col through multiple function calls by defining whether or not references should be resolved with aliases at the Query level. --- django/contrib/postgres/constraints.py | 7 +--- django/db/models/constraints.py | 4 +- django/db/models/expressions.py | 54 ++++++-------------------- django/db/models/indexes.py | 2 +- django/db/models/sql/query.py | 54 ++++++++++++-------------- tests/queries/test_query.py | 39 ++++++++++++------- 6 files changed, 66 insertions(+), 94 deletions(-) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 67e415ddcf..8cc1f58a10 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -38,10 +38,7 @@ class ExclusionConstraint(BaseConstraint): for expression, operator in self.expressions: if isinstance(expression, str): expression = F(expression) - if isinstance(expression, F): - expression = expression.resolve_expression(query=query, simple_col=True) - else: - expression = expression.resolve_expression(query=query) + expression = expression.resolve_expression(query=query) sql, params = expression.as_sql(compiler, connection) expressions.append('%s WITH %s' % (sql % params, operator)) return expressions @@ -54,7 +51,7 @@ class ExclusionConstraint(BaseConstraint): return sql % tuple(schema_editor.quote_value(p) for p in params) def constraint_sql(self, model, schema_editor): - query = Query(model) + query = Query(model, alias_cols=False) compiler = query.get_compiler(connection=schema_editor.connection) expressions = self._get_expression_sql(compiler, schema_editor.connection, query) condition = self._get_condition_sql(compiler, schema_editor, query) diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index fe0d42a168..96205b44ad 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -33,7 +33,7 @@ class CheckConstraint(BaseConstraint): super().__init__(name) def _get_check_sql(self, model, schema_editor): - query = Query(model=model) + query = Query(model=model, alias_cols=False) where = query.build_where(self.check) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) @@ -77,7 +77,7 @@ class UniqueConstraint(BaseConstraint): def _get_condition_sql(self, model, schema_editor): if self.condition is None: return None - query = Query(model=model) + query = Query(model=model, alias_cols=False) where = query.build_where(self.condition) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 13e0ef172c..ec7b0e67b9 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -529,8 +529,8 @@ class F(Combinable): return "{}({})".format(self.__class__.__name__, self.name) def resolve_expression(self, query=None, allow_joins=True, reuse=None, - summarize=False, for_save=False, simple_col=False): - return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col) + summarize=False, for_save=False): + return query.resolve_ref(self.name, allow_joins, reuse, summarize) def asc(self, **kwargs): return OrderBy(self, **kwargs) @@ -565,8 +565,7 @@ class ResolvedOuterRef(F): class OuterRef(F): - def resolve_expression(self, query=None, allow_joins=True, reuse=None, - summarize=False, for_save=False, simple_col=False): + def resolve_expression(self, *args, **kwargs): if isinstance(self.name, self.__class__): return self.name return ResolvedOuterRef(self.name) @@ -754,14 +753,19 @@ class Col(Expression): self.alias, self.target = alias, target def __repr__(self): - return "{}({}, {})".format( - self.__class__.__name__, self.alias, self.target) + alias, target = self.alias, self.target + identifiers = (alias, str(target)) if alias else (str(target),) + return '{}({})'.format(self.__class__.__name__, ', '.join(identifiers)) def as_sql(self, compiler, connection): - qn = compiler.quote_name_unless_alias - return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] + alias, column = self.alias, self.target.column + identifiers = (alias, column) if alias else (column,) + sql = '.'.join(map(compiler.quote_name_unless_alias, identifiers)) + return sql, [] def relabeled_clone(self, relabels): + if self.alias is None: + return self return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field) def get_group_by_cols(self, alias=None): @@ -774,40 +778,6 @@ class Col(Expression): self.target.get_db_converters(connection)) -class SimpleCol(Expression): - """ - Represents the SQL of a column name without the table name. - - This variant of Col doesn't include the table name (or an alias) to - avoid a syntax error in check constraints. - """ - contains_column_references = True - - def __init__(self, target, output_field=None): - if output_field is None: - output_field = target - super().__init__(output_field=output_field) - self.target = target - - def __repr__(self): - return '{}({})'.format(self.__class__.__name__, self.target) - - def as_sql(self, compiler, connection): - qn = compiler.quote_name_unless_alias - return qn(self.target.column), [] - - def get_group_by_cols(self, alias=None): - return [self] - - def get_db_converters(self, connection): - if self.target == self.output_field: - return self.output_field.get_db_converters(connection) - return ( - self.output_field.get_db_converters(connection) + - self.target.get_db_converters(connection) - ) - - class Ref(Expression): """ Reference to column alias of the query. For example, Ref('sum_cost') in diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py index 49f4989462..77a8423ef8 100644 --- a/django/db/models/indexes.py +++ b/django/db/models/indexes.py @@ -40,7 +40,7 @@ class Index: def _get_condition_sql(self, model, schema_editor): if self.condition is None: return None - query = Query(model=model) + query = Query(model=model, alias_cols=False) where = query.build_where(self.condition) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dd5889625f..e6803649e7 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -23,9 +23,7 @@ from django.core.exceptions import ( from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import ( - BaseExpression, Col, F, OuterRef, Ref, SimpleCol, -) +from django.db.models.expressions import BaseExpression, Col, F, OuterRef, Ref from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup @@ -69,12 +67,6 @@ JoinInfo = namedtuple( ) -def _get_col(target, field, alias, simple_col): - if simple_col: - return SimpleCol(target, field) - return target.get_col(alias, field) - - class RawQuery: """A single raw SQL query.""" @@ -151,7 +143,7 @@ class Query(BaseExpression): compiler = 'SQLCompiler' - def __init__(self, model, where=WhereNode): + def __init__(self, model, where=WhereNode, alias_cols=True): self.model = model self.alias_refcount = {} # alias_map is the most important data structure regarding joins. @@ -160,6 +152,8 @@ class Query(BaseExpression): # the table name) and the value is a Join-like object (see # sql.datastructures.Join for more information). self.alias_map = {} + # Whether to provide alias to columns during reference resolving. + self.alias_cols = alias_cols # Sometimes the query contains references to aliases in outer queries (as # a result of split_exclude). Correct alias quoting needs to know these # aliases too. @@ -360,6 +354,11 @@ class Query(BaseExpression): clone.change_aliases(change_map) return clone + def _get_col(self, target, field, alias): + if not self.alias_cols: + alias = None + return target.get_col(alias, field) + def rewrite_cols(self, annotation, col_cnt): # We must make sure the inner query has the referred columns in it. # If we are aggregating over an annotation, then Django uses Ref() @@ -1050,17 +1049,16 @@ class Query(BaseExpression): sql = '(%s)' % sql return sql, params - def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col): + def resolve_lookup_value(self, value, can_reuse, allow_joins): if hasattr(value, 'resolve_expression'): - kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins} - if isinstance(value, F): - kwargs['simple_col'] = simple_col - value = value.resolve_expression(self, **kwargs) + value = value.resolve_expression( + self, reuse=can_reuse, allow_joins=allow_joins, + ) elif isinstance(value, (list, tuple)): # The items of the iterable may be expressions and therefore need # to be resolved independently. return type(value)( - self.resolve_lookup_value(sub_value, can_reuse, allow_joins, simple_col) + self.resolve_lookup_value(sub_value, can_reuse, allow_joins) for sub_value in value ) return value @@ -1192,7 +1190,7 @@ class Query(BaseExpression): def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, allow_joins=True, split_subq=True, - reuse_with_filtered_relation=False, simple_col=False): + reuse_with_filtered_relation=False): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1244,7 +1242,7 @@ class Query(BaseExpression): raise FieldError("Joined field references are not permitted in this query") pre_joins = self.alias_refcount.copy() - value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col) + value = self.resolve_lookup_value(value, can_reuse, allow_joins) used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} self.check_filterable(value) @@ -1289,11 +1287,11 @@ class Query(BaseExpression): if num_lookups > 1: raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) if len(targets) == 1: - col = _get_col(targets[0], join_info.final_field, alias, simple_col) + col = self._get_col(targets[0], join_info.final_field, alias) else: col = MultiColSource(alias, targets, join_info.targets, join_info.final_field) else: - col = _get_col(targets[0], join_info.final_field, alias, simple_col) + col = self._get_col(targets[0], join_info.final_field, alias) condition = self.build_lookup(lookups, col, value) lookup_type = condition.lookup_name @@ -1315,7 +1313,7 @@ class Query(BaseExpression): # <=> # NOT (col IS NOT NULL AND col = someval). lookup_class = targets[0].get_lookup('isnull') - col = _get_col(targets[0], join_info.targets[0], alias, simple_col) + col = self._get_col(targets[0], join_info.targets[0], alias) clause.add(lookup_class(col, False), AND) return clause, used_joins if not require_outer else () @@ -1340,11 +1338,10 @@ class Query(BaseExpression): self.demote_joins(existing_inner) def build_where(self, q_object): - return self._add_q(q_object, used_aliases=set(), allow_joins=False, simple_col=True)[0] + return self._add_q(q_object, used_aliases=set(), allow_joins=False)[0] def _add_q(self, q_object, used_aliases, branch_negated=False, - current_negated=False, allow_joins=True, split_subq=True, - simple_col=False): + current_negated=False, allow_joins=True, split_subq=True): """Add a Q-object to the current filter.""" connector = q_object.connector current_negated = current_negated ^ q_object.negated @@ -1356,13 +1353,13 @@ class Query(BaseExpression): if isinstance(child, Node): child_clause, needed_inner = self._add_q( child, used_aliases, branch_negated, - current_negated, allow_joins, split_subq, simple_col) + current_negated, allow_joins, split_subq) joinpromoter.add_votes(needed_inner) else: child_clause, needed_inner = self.build_filter( child, can_reuse=used_aliases, branch_negated=branch_negated, current_negated=current_negated, allow_joins=allow_joins, - split_subq=split_subq, simple_col=simple_col, + split_subq=split_subq, ) joinpromoter.add_votes(needed_inner) if child_clause: @@ -1639,7 +1636,7 @@ class Query(BaseExpression): else: yield from cls._gen_col_aliases(expr.get_source_expressions()) - def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False): + def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False): if not allow_joins and LOOKUP_SEP in name: raise FieldError("Joined field references are not permitted in this query") annotation = self.annotations.get(name) @@ -1673,8 +1670,7 @@ class Query(BaseExpression): join_info.transform_function(targets[0], final_alias) if reuse is not None: reuse.update(join_list) - col = _get_col(targets[0], join_info.targets[0], join_list[-1], simple_col) - return col + return self._get_col(targets[0], join_info.targets[0], join_list[-1]) def split_exclude(self, filter_expr, can_reuse, names_with_path): """ diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 012d56a02f..ecd9c96d8c 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -2,7 +2,7 @@ from datetime import datetime from django.core.exceptions import FieldError from django.db.models import CharField, F, Q -from django.db.models.expressions import SimpleCol +from django.db.models.expressions import Col from django.db.models.fields.related_lookups import RelatedIsNull from django.db.models.functions import Lower from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan @@ -23,20 +23,24 @@ class TestQuery(SimpleTestCase): self.assertEqual(lookup.rhs, 2) self.assertEqual(lookup.lhs.target, Author._meta.get_field('num')) - def test_simplecol_query(self): - query = Query(Author) + def test_non_alias_cols_query(self): + query = Query(Author, alias_cols=False) where = query.build_where(Q(num__gt=2, name__isnull=False) | Q(num__lt=F('id'))) name_isnull_lookup, num_gt_lookup = where.children[0].children self.assertIsInstance(num_gt_lookup, GreaterThan) - self.assertIsInstance(num_gt_lookup.lhs, SimpleCol) + self.assertIsInstance(num_gt_lookup.lhs, Col) + self.assertIsNone(num_gt_lookup.lhs.alias) self.assertIsInstance(name_isnull_lookup, IsNull) - self.assertIsInstance(name_isnull_lookup.lhs, SimpleCol) + self.assertIsInstance(name_isnull_lookup.lhs, Col) + self.assertIsNone(name_isnull_lookup.lhs.alias) num_lt_lookup = where.children[1] self.assertIsInstance(num_lt_lookup, LessThan) - self.assertIsInstance(num_lt_lookup.rhs, SimpleCol) - self.assertIsInstance(num_lt_lookup.lhs, SimpleCol) + self.assertIsInstance(num_lt_lookup.rhs, Col) + self.assertIsNone(num_lt_lookup.rhs.alias) + self.assertIsInstance(num_lt_lookup.lhs, Col) + self.assertIsNone(num_lt_lookup.lhs.alias) def test_complex_query(self): query = Query(Author) @@ -54,23 +58,26 @@ class TestQuery(SimpleTestCase): self.assertEqual(lookup.lhs.target, Author._meta.get_field('num')) def test_multiple_fields(self): - query = Query(Item) + query = Query(Item, alias_cols=False) where = query.build_where(Q(modified__gt=F('created'))) lookup = where.children[0] self.assertIsInstance(lookup, GreaterThan) - self.assertIsInstance(lookup.rhs, SimpleCol) - self.assertIsInstance(lookup.lhs, SimpleCol) + self.assertIsInstance(lookup.rhs, Col) + self.assertIsNone(lookup.rhs.alias) + self.assertIsInstance(lookup.lhs, Col) + self.assertIsNone(lookup.lhs.alias) self.assertEqual(lookup.rhs.target, Item._meta.get_field('created')) self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified')) def test_transform(self): - query = Query(Author) + query = Query(Author, alias_cols=False) with register_lookup(CharField, Lower): where = query.build_where(~Q(name__lower='foo')) lookup = where.children[0] self.assertIsInstance(lookup, Exact) self.assertIsInstance(lookup.lhs, Lower) - self.assertIsInstance(lookup.lhs.lhs, SimpleCol) + self.assertIsInstance(lookup.lhs.lhs, Col) + self.assertIsNone(lookup.lhs.lhs.alias) self.assertEqual(lookup.lhs.lhs.target, Author._meta.get_field('name')) def test_negated_nullable(self): @@ -96,15 +103,17 @@ class TestQuery(SimpleTestCase): query.build_where(Q(rank__gt=F('author__num'))) def test_foreign_key_exclusive(self): - query = Query(ObjectC) + query = Query(ObjectC, alias_cols=False) where = query.build_where(Q(objecta=None) | Q(objectb=None)) a_isnull = where.children[0] self.assertIsInstance(a_isnull, RelatedIsNull) - self.assertIsInstance(a_isnull.lhs, SimpleCol) + self.assertIsInstance(a_isnull.lhs, Col) + self.assertIsNone(a_isnull.lhs.alias) self.assertEqual(a_isnull.lhs.target, ObjectC._meta.get_field('objecta')) b_isnull = where.children[1] self.assertIsInstance(b_isnull, RelatedIsNull) - self.assertIsInstance(b_isnull.lhs, SimpleCol) + self.assertIsInstance(b_isnull.lhs, Col) + self.assertIsNone(b_isnull.lhs.alias) self.assertEqual(b_isnull.lhs.target, ObjectC._meta.get_field('objectb')) def test_clone_select_related(self):