diff --git a/AUTHORS b/AUTHORS index 2915761b96..faf6420618 100644 --- a/AUTHORS +++ b/AUTHORS @@ -152,6 +152,7 @@ answer newbie questions, and generally made Django that much better: Ben Lomax Ben Slavin Ben Sturmfels + Bendegúz Csirmaz Berker Peksag Bernd Schlapsi Bernhard Essl diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index ffb9f3c816..4a242012ee 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1295,6 +1295,52 @@ class Col(Expression): ) + self.target.get_db_converters(connection) +class ColPairs(Expression): + def __init__(self, alias, targets, sources, output_field): + super().__init__(output_field=output_field) + self.alias, self.targets, self.sources = alias, targets, sources + + def __len__(self): + return len(self.targets) + + def __iter__(self): + return iter(self.get_cols()) + + def get_cols(self): + return [ + Col(self.alias, target, source) + for target, source in zip(self.targets, self.sources) + ] + + def get_source_expressions(self): + return self.get_cols() + + def set_source_expressions(self, exprs): + assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs) + self.targets = [col.target for col in exprs] + self.sources = [col.field for col in exprs] + + def as_sql(self, compiler, connection): + cols_sql = [] + cols_params = [] + cols = self.get_cols() + + for col in cols: + sql, params = col.as_sql(compiler, connection) + cols_sql.append(sql) + cols_params.extend(params) + + return ", ".join(cols_sql), cols_params + + def relabeled_clone(self, relabels): + return self.__class__( + relabels.get(self.alias, self.alias), self.targets, self.sources, self.field + ) + + def resolve_expression(self, *args, **kwargs): + return self + + class Ref(Expression): """ Reference to column alias of the query. For example, Ref('sum_cost') in diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 07a06e1686..22fd17ab4f 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -1,3 +1,6 @@ +from django.db import NotSupportedError +from django.db.models.expressions import ColPairs +from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups from django.db.models.lookups import ( Exact, GreaterThan, @@ -9,34 +12,6 @@ from django.db.models.lookups import ( ) -class MultiColSource: - contains_aggregate = False - contains_over_clause = False - - def __init__(self, alias, targets, sources, field): - self.targets, self.sources, self.field, self.alias = ( - targets, - sources, - field, - alias, - ) - self.output_field = self.field - - def __repr__(self): - return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field) - - def relabeled_clone(self, relabels): - return self.__class__( - relabels.get(self.alias, self.alias), self.targets, self.sources, self.field - ) - - def get_lookup(self, lookup): - return self.output_field.get_lookup(lookup) - - def resolve_expression(self, *args, **kwargs): - return self - - def get_normalized_value(value, lhs): from django.db.models import Model @@ -64,7 +39,7 @@ def get_normalized_value(value, lhs): class RelatedIn(In): def get_prep_lookup(self): - if not isinstance(self.lhs, MultiColSource): + if not isinstance(self.lhs, ColPairs): if self.rhs_is_direct_value(): # If we get here, we are dealing with single-column relations. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] @@ -98,49 +73,33 @@ class RelatedIn(In): return super().get_prep_lookup() def as_sql(self, compiler, connection): - if isinstance(self.lhs, MultiColSource): + if isinstance(self.lhs, ColPairs): # For multicolumn lookups we need to build a multicolumn where clause. # This clause is either a SubqueryConstraint (for values that need # to be compiled to SQL) or an OR-combined list of # (col1 = val1 AND col2 = val2 AND ...) clauses. - from django.db.models.sql.where import ( - AND, - OR, - SubqueryConstraint, - WhereNode, - ) + from django.db.models.sql.where import SubqueryConstraint - root_constraint = WhereNode(connector=OR) if self.rhs_is_direct_value(): values = [get_normalized_value(value, self.lhs) for value in self.rhs] - for value in values: - value_constraint = WhereNode() - for source, target, val in zip( - self.lhs.sources, self.lhs.targets, value - ): - lookup_class = target.get_lookup("exact") - lookup = lookup_class( - target.get_col(self.lhs.alias, source), val - ) - value_constraint.add(lookup, AND) - root_constraint.add(value_constraint, OR) + lookup = TupleIn(self.lhs, values) + return compiler.compile(lookup) else: - root_constraint.add( + return compiler.compile( SubqueryConstraint( self.lhs.alias, [target.column for target in self.lhs.targets], [source.name for source in self.lhs.sources], self.rhs, ), - AND, ) - return root_constraint.as_sql(compiler, connection) + return super().as_sql(compiler, connection) class RelatedLookupMixin: def get_prep_lookup(self): - if not isinstance(self.lhs, MultiColSource) and not hasattr( + if not isinstance(self.lhs, ColPairs) and not hasattr( self.rhs, "resolve_expression" ): # If we get here, we are dealing with single-column relations. @@ -158,20 +117,16 @@ class RelatedLookupMixin: return super().get_prep_lookup() def as_sql(self, compiler, connection): - if isinstance(self.lhs, MultiColSource): - assert self.rhs_is_direct_value() - self.rhs = get_normalized_value(self.rhs, self.lhs) - from django.db.models.sql.where import AND, WhereNode - - root_constraint = WhereNode() - for target, source, val in zip( - self.lhs.targets, self.lhs.sources, self.rhs - ): - lookup_class = target.get_lookup(self.lookup_name) - root_constraint.add( - lookup_class(target.get_col(self.lhs.alias, source), val), AND + if isinstance(self.lhs, ColPairs): + if not self.rhs_is_direct_value(): + raise NotSupportedError( + f"'{self.lookup_name}' doesn't support multi-column subqueries." ) - return root_constraint.as_sql(compiler, connection) + self.rhs = get_normalized_value(self.rhs, self.lhs) + lookup_class = tuple_lookups[self.lookup_name] + lookup = lookup_class(self.lhs, self.rhs) + return compiler.compile(lookup) + return super().as_sql(compiler, connection) diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py new file mode 100644 index 0000000000..468826f224 --- /dev/null +++ b/django/db/models/fields/tuple_lookups.py @@ -0,0 +1,244 @@ +import itertools + +from django.core.exceptions import EmptyResultSet +from django.db.models.expressions import ColPairs, Func, Value +from django.db.models.lookups import ( + Exact, + GreaterThan, + GreaterThanOrEqual, + In, + IsNull, + LessThan, + LessThanOrEqual, +) +from django.db.models.sql.where import AND, OR, WhereNode + + +class Tuple(Func): + function = "" + + +class TupleLookupMixin: + def get_prep_lookup(self): + self.check_tuple_lookup() + return super().get_prep_lookup() + + def check_tuple_lookup(self): + assert isinstance(self.lhs, ColPairs) + self.check_rhs_is_tuple_or_list() + self.check_rhs_length_equals_lhs_length() + + def check_rhs_is_tuple_or_list(self): + if not isinstance(self.rhs, (tuple, list)): + raise ValueError( + f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + "must be a tuple or a list" + ) + + def check_rhs_length_equals_lhs_length(self): + if len(self.lhs) != len(self.rhs): + raise ValueError( + f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + f"must have {len(self.lhs)} elements" + ) + + def check_rhs_is_collection_of_tuples_or_lists(self): + if not all(isinstance(vals, (tuple, list)) for vals in self.rhs): + raise ValueError( + f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + f"must be a collection of tuples or lists" + ) + + def check_rhs_elements_length_equals_lhs_length(self): + if not all(len(self.lhs) == len(vals) for vals in self.rhs): + raise ValueError( + f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + f"must have {len(self.lhs)} elements each" + ) + + def as_sql(self, compiler, connection): + # e.g.: (a, b, c) == (x, y, z) as SQL: + # WHERE (a, b, c) = (x, y, z) + vals = [ + Value(val, output_field=col.output_field) + for col, val in zip(self.lhs, self.rhs) + ] + lookup_class = self.__class__.__bases__[-1] + lookup = lookup_class(Tuple(self.lhs), Tuple(*vals)) + return lookup.as_sql(compiler, connection) + + +class TupleExact(TupleLookupMixin, Exact): + def as_oracle(self, compiler, connection): + # e.g.: (a, b, c) == (x, y, z) as SQL: + # WHERE a = x AND b = y AND c = z + cols = self.lhs.get_cols() + lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)] + root = WhereNode(lookups, connector=AND) + + return root.as_sql(compiler, connection) + + +class TupleIsNull(IsNull): + def as_sql(self, compiler, connection): + # e.g.: (a, b, c) is None as SQL: + # WHERE a IS NULL AND b IS NULL AND c IS NULL + vals = self.rhs + if isinstance(vals, bool): + vals = [vals] * len(self.lhs) + + cols = self.lhs.get_cols() + lookups = [IsNull(col, val) for col, val in zip(cols, vals)] + root = WhereNode(lookups, connector=AND) + + return root.as_sql(compiler, connection) + + +class TupleGreaterThan(TupleLookupMixin, GreaterThan): + def as_oracle(self, compiler, connection): + # e.g.: (a, b, c) > (x, y, z) as SQL: + # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z))) + cols = self.lhs.get_cols() + lookups = itertools.cycle([GreaterThan, Exact]) + connectors = itertools.cycle([OR, AND]) + cols_list = [col for col in cols for _ in range(2)] + vals_list = [val for val in self.rhs for _ in range(2)] + cols_iter = iter(cols_list[:-1]) + vals_iter = iter(vals_list[:-1]) + col, val = next(cols_iter), next(vals_iter) + lookup, connector = next(lookups), next(connectors) + root = node = WhereNode([lookup(col, val)], connector=connector) + + for col, val in zip(cols_iter, vals_iter): + lookup, connector = next(lookups), next(connectors) + child = WhereNode([lookup(col, val)], connector=connector) + node.children.append(child) + node = child + + return root.as_sql(compiler, connection) + + +class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual): + def as_oracle(self, compiler, connection): + # e.g.: (a, b, c) >= (x, y, z) as SQL: + # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z)))) + cols = self.lhs.get_cols() + lookups = itertools.cycle([GreaterThan, Exact]) + connectors = itertools.cycle([OR, AND]) + cols_list = [col for col in cols for _ in range(2)] + vals_list = [val for val in self.rhs for _ in range(2)] + cols_iter = iter(cols_list) + vals_iter = iter(vals_list) + col, val = next(cols_iter), next(vals_iter) + lookup, connector = next(lookups), next(connectors) + root = node = WhereNode([lookup(col, val)], connector=connector) + + for col, val in zip(cols_iter, vals_iter): + lookup, connector = next(lookups), next(connectors) + child = WhereNode([lookup(col, val)], connector=connector) + node.children.append(child) + node = child + + return root.as_sql(compiler, connection) + + +class TupleLessThan(TupleLookupMixin, LessThan): + def as_oracle(self, compiler, connection): + # e.g.: (a, b, c) < (x, y, z) as SQL: + # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z))) + cols = self.lhs.get_cols() + lookups = itertools.cycle([LessThan, Exact]) + connectors = itertools.cycle([OR, AND]) + cols_list = [col for col in cols for _ in range(2)] + vals_list = [val for val in self.rhs for _ in range(2)] + cols_iter = iter(cols_list[:-1]) + vals_iter = iter(vals_list[:-1]) + col, val = next(cols_iter), next(vals_iter) + lookup, connector = next(lookups), next(connectors) + root = node = WhereNode([lookup(col, val)], connector=connector) + + for col, val in zip(cols_iter, vals_iter): + lookup, connector = next(lookups), next(connectors) + child = WhereNode([lookup(col, val)], connector=connector) + node.children.append(child) + node = child + + return root.as_sql(compiler, connection) + + +class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): + def as_oracle(self, compiler, connection): + # e.g.: (a, b, c) <= (x, y, z) as SQL: + # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z)))) + cols = self.lhs.get_cols() + lookups = itertools.cycle([LessThan, Exact]) + connectors = itertools.cycle([OR, AND]) + cols_list = [col for col in cols for _ in range(2)] + vals_list = [val for val in self.rhs for _ in range(2)] + cols_iter = iter(cols_list) + vals_iter = iter(vals_list) + col, val = next(cols_iter), next(vals_iter) + lookup, connector = next(lookups), next(connectors) + root = node = WhereNode([lookup(col, val)], connector=connector) + + for col, val in zip(cols_iter, vals_iter): + lookup, connector = next(lookups), next(connectors) + child = WhereNode([lookup(col, val)], connector=connector) + node.children.append(child) + node = child + + return root.as_sql(compiler, connection) + + +class TupleIn(TupleLookupMixin, In): + def check_tuple_lookup(self): + assert isinstance(self.lhs, ColPairs) + self.check_rhs_is_tuple_or_list() + self.check_rhs_is_collection_of_tuples_or_lists() + self.check_rhs_elements_length_equals_lhs_length() + + def as_sql(self, compiler, connection): + if not self.rhs: + raise EmptyResultSet + + # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: + # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2)) + rhs = [] + for vals in self.rhs: + rhs.append( + Tuple( + *[ + Value(val, output_field=col.output_field) + for col, val in zip(self.lhs, vals) + ] + ) + ) + + lookup = In(Tuple(self.lhs), Tuple(*rhs)) + return lookup.as_sql(compiler, connection) + + def as_sqlite(self, compiler, connection): + if not self.rhs: + raise EmptyResultSet + + # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: + # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) + root = WhereNode([], connector=OR) + cols = self.lhs.get_cols() + + for vals in self.rhs: + lookups = [Exact(col, val) for col, val in zip(cols, vals)] + root.children.append(WhereNode(lookups, connector=AND)) + + return root.as_sql(compiler, connection) + + +tuple_lookups = { + "exact": TupleExact, + "gt": TupleGreaterThan, + "gte": TupleGreaterThanOrEqual, + "lt": TupleLessThan, + "lte": TupleLessThanOrEqual, + "in": TupleIn, + "isnull": TupleIsNull, +} diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9a57af2bf3..09916277bc 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -23,6 +23,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import ( BaseExpression, Col, + ColPairs, Exists, F, OuterRef, @@ -32,7 +33,6 @@ from django.db.models.expressions import ( Value, ) from django.db.models.fields import Field -from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup from django.db.models.query_utils import ( Q, @@ -1549,9 +1549,7 @@ class Query(BaseExpression): if len(targets) == 1: col = self._get_col(targets[0], join_info.final_field, alias) else: - col = MultiColSource( - alias, targets, join_info.targets, join_info.final_field - ) + col = ColPairs(alias, targets, join_info.targets, join_info.final_field) else: col = self._get_col(targets[0], join_info.final_field, alias) diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py new file mode 100644 index 0000000000..cf080d084b --- /dev/null +++ b/tests/foreign_object/test_tuple_lookups.py @@ -0,0 +1,242 @@ +import unittest + +from django.db import NotSupportedError, connection +from django.test import TestCase + +from .models import Contact, Customer + + +class TupleLookupsTests(TestCase): + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.customer_1 = Customer.objects.create(customer_id=1, company="a") + cls.customer_2 = Customer.objects.create(customer_id=1, company="b") + cls.customer_3 = Customer.objects.create(customer_id=2, company="c") + cls.customer_4 = Customer.objects.create(customer_id=3, company="d") + cls.customer_5 = Customer.objects.create(customer_id=1, company="e") + cls.contact_1 = Contact.objects.create(customer=cls.customer_1) + cls.contact_2 = Contact.objects.create(customer=cls.customer_1) + cls.contact_3 = Contact.objects.create(customer=cls.customer_2) + cls.contact_4 = Contact.objects.create(customer=cls.customer_3) + cls.contact_5 = Contact.objects.create(customer=cls.customer_1) + cls.contact_6 = Contact.objects.create(customer=cls.customer_5) + + def test_exact(self): + test_cases = ( + (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)), + (self.customer_2, (self.contact_3,)), + (self.customer_3, (self.contact_4,)), + (self.customer_4, ()), + (self.customer_5, (self.contact_6,)), + ) + + for customer, contacts in test_cases: + with self.subTest(customer=customer, contacts=contacts): + self.assertSequenceEqual( + Contact.objects.filter(customer=customer).order_by("id"), contacts + ) + + def test_exact_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'exact' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer=subquery).order_by("id"), () + ) + + def test_in(self): + cust_1, cust_2, cust_3, cust_4, cust_5 = ( + self.customer_1, + self.customer_2, + self.customer_3, + self.customer_4, + self.customer_5, + ) + c1, c2, c3, c4, c5, c6 = ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ) + test_cases = ( + ((), ()), + ((cust_1,), (c1, c2, c5)), + ((cust_1, cust_2), (c1, c2, c3, c5)), + ((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)), + ((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)), + ((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)), + ) + + for contacts, customers in test_cases: + with self.subTest(contacts=contacts, customers=customers): + self.assertSequenceEqual( + Contact.objects.filter(customer__in=contacts).order_by("id"), + customers, + ) + + @unittest.skipIf( + connection.vendor == "mysql", + "MySQL doesn't support LIMIT & IN/ALL/ANY/SOME subquery", + ) + def test_in_subquery(self): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__in=subquery).order_by("id"), + (self.contact_1, self.contact_2, self.contact_5), + ) + + def test_lt(self): + c1, c2, c3, c4, c5, c6 = ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ) + test_cases = ( + (self.customer_1, ()), + (self.customer_2, (c1, c2, c5)), + (self.customer_5, (c1, c2, c3, c5)), + (self.customer_3, (c1, c2, c3, c5, c6)), + (self.customer_4, (c1, c2, c3, c4, c5, c6)), + ) + + for customer, contacts in test_cases: + with self.subTest(customer=customer, contacts=contacts): + self.assertSequenceEqual( + Contact.objects.filter(customer__lt=customer).order_by("id"), + contacts, + ) + + def test_lt_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'lt' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__lt=subquery).order_by("id"), () + ) + + def test_lte(self): + c1, c2, c3, c4, c5, c6 = ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ) + test_cases = ( + (self.customer_1, (c1, c2, c5)), + (self.customer_2, (c1, c2, c3, c5)), + (self.customer_5, (c1, c2, c3, c5, c6)), + (self.customer_3, (c1, c2, c3, c4, c5, c6)), + (self.customer_4, (c1, c2, c3, c4, c5, c6)), + ) + + for customer, contacts in test_cases: + with self.subTest(customer=customer, contacts=contacts): + self.assertSequenceEqual( + Contact.objects.filter(customer__lte=customer).order_by("id"), + contacts, + ) + + def test_lte_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'lte' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__lte=subquery).order_by("id"), () + ) + + def test_gt(self): + test_cases = ( + (self.customer_1, (self.contact_3, self.contact_4, self.contact_6)), + (self.customer_2, (self.contact_4, self.contact_6)), + (self.customer_5, (self.contact_4,)), + (self.customer_3, ()), + (self.customer_4, ()), + ) + + for customer, contacts in test_cases: + with self.subTest(customer=customer, contacts=contacts): + self.assertSequenceEqual( + Contact.objects.filter(customer__gt=customer).order_by("id"), + contacts, + ) + + def test_gt_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'gt' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__gt=subquery).order_by("id"), () + ) + + def test_gte(self): + c1, c2, c3, c4, c5, c6 = ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ) + test_cases = ( + (self.customer_1, (c1, c2, c3, c4, c5, c6)), + (self.customer_2, (c3, c4, c6)), + (self.customer_5, (c4, c6)), + (self.customer_3, (c4,)), + (self.customer_4, ()), + ) + + for customer, contacts in test_cases: + with self.subTest(customer=customer, contacts=contacts): + self.assertSequenceEqual( + Contact.objects.filter(customer__gte=customer).order_by("pk"), + contacts, + ) + + def test_gte_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'gte' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=self.customer_1.id)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__gte=subquery).order_by("id"), () + ) + + def test_isnull(self): + with self.subTest("customer__isnull=True"): + self.assertSequenceEqual( + Contact.objects.filter(customer__isnull=True).order_by("id"), + (), + ) + with self.subTest("customer__isnull=False"): + self.assertSequenceEqual( + Contact.objects.filter(customer__isnull=False).order_by("id"), + ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ), + ) + + def test_isnull_subquery(self): + with self.assertRaisesMessage( + NotSupportedError, "'isnull' doesn't support multi-column subqueries." + ): + subquery = Customer.objects.filter(id=0)[:1] + self.assertSequenceEqual( + Contact.objects.filter(customer__isnull=subquery).order_by("id"), () + )