mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Refs #373 -- Added tuple lookups.
This commit is contained in:
parent
3dac3271d2
commit
1eac690d25
1
AUTHORS
1
AUTHORS
@ -152,6 +152,7 @@ answer newbie questions, and generally made Django that much better:
|
||||
Ben Lomax <lomax.on.the.run@gmail.com>
|
||||
Ben Slavin <benjamin.slavin@gmail.com>
|
||||
Ben Sturmfels <ben@sturm.com.au>
|
||||
Bendegúz Csirmaz <csirmazbendeguz@gmail.com>
|
||||
Berker Peksag <berker.peksag@gmail.com>
|
||||
Bernd Schlapsi
|
||||
Bernhard Essl <me@bernhardessl.com>
|
||||
|
@ -1295,6 +1295,52 @@ class Col(Expression):
|
||||
) + self.target.get_db_converters(connection)
|
||||
|
||||
|
||||
class ColPairs(Expression):
|
||||
def __init__(self, alias, targets, sources, output_field):
|
||||
super().__init__(output_field=output_field)
|
||||
self.alias, self.targets, self.sources = alias, targets, sources
|
||||
|
||||
def __len__(self):
|
||||
return len(self.targets)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.get_cols())
|
||||
|
||||
def get_cols(self):
|
||||
return [
|
||||
Col(self.alias, target, source)
|
||||
for target, source in zip(self.targets, self.sources)
|
||||
]
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.get_cols()
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs)
|
||||
self.targets = [col.target for col in exprs]
|
||||
self.sources = [col.field for col in exprs]
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
cols_sql = []
|
||||
cols_params = []
|
||||
cols = self.get_cols()
|
||||
|
||||
for col in cols:
|
||||
sql, params = col.as_sql(compiler, connection)
|
||||
cols_sql.append(sql)
|
||||
cols_params.extend(params)
|
||||
|
||||
return ", ".join(cols_sql), cols_params
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(
|
||||
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
|
||||
)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class Ref(Expression):
|
||||
"""
|
||||
Reference to column alias of the query. For example, Ref('sum_cost') in
|
||||
|
@ -1,3 +1,6 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
@ -9,34 +12,6 @@ from django.db.models.lookups import (
|
||||
)
|
||||
|
||||
|
||||
class MultiColSource:
|
||||
contains_aggregate = False
|
||||
contains_over_clause = False
|
||||
|
||||
def __init__(self, alias, targets, sources, field):
|
||||
self.targets, self.sources, self.field, self.alias = (
|
||||
targets,
|
||||
sources,
|
||||
field,
|
||||
alias,
|
||||
)
|
||||
self.output_field = self.field
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(
|
||||
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
|
||||
)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
@ -64,7 +39,7 @@ def get_normalized_value(value, lhs):
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource):
|
||||
if not isinstance(self.lhs, ColPairs):
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
@ -98,49 +73,33 @@ class RelatedIn(In):
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
# For multicolumn lookups we need to build a multicolumn where clause.
|
||||
# This clause is either a SubqueryConstraint (for values that need
|
||||
# to be compiled to SQL) or an OR-combined list of
|
||||
# (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
from django.db.models.sql.where import (
|
||||
AND,
|
||||
OR,
|
||||
SubqueryConstraint,
|
||||
WhereNode,
|
||||
)
|
||||
from django.db.models.sql.where import SubqueryConstraint
|
||||
|
||||
root_constraint = WhereNode(connector=OR)
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
for value in values:
|
||||
value_constraint = WhereNode()
|
||||
for source, target, val in zip(
|
||||
self.lhs.sources, self.lhs.targets, value
|
||||
):
|
||||
lookup_class = target.get_lookup("exact")
|
||||
lookup = lookup_class(
|
||||
target.get_col(self.lhs.alias, source), val
|
||||
)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
lookup = TupleIn(self.lhs, values)
|
||||
return compiler.compile(lookup)
|
||||
else:
|
||||
root_constraint.add(
|
||||
return compiler.compile(
|
||||
SubqueryConstraint(
|
||||
self.lhs.alias,
|
||||
[target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources],
|
||||
self.rhs,
|
||||
),
|
||||
AND,
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(
|
||||
if not isinstance(self.lhs, ColPairs) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
@ -158,20 +117,16 @@ class RelatedLookupMixin:
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
assert self.rhs_is_direct_value()
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
from django.db.models.sql.where import AND, WhereNode
|
||||
|
||||
root_constraint = WhereNode()
|
||||
for target, source, val in zip(
|
||||
self.lhs.targets, self.lhs.sources, self.rhs
|
||||
):
|
||||
lookup_class = target.get_lookup(self.lookup_name)
|
||||
root_constraint.add(
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if not self.rhs_is_direct_value():
|
||||
raise NotSupportedError(
|
||||
f"'{self.lookup_name}' doesn't support multi-column subqueries."
|
||||
)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
lookup_class = tuple_lookups[self.lookup_name]
|
||||
lookup = lookup_class(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
|
244
django/db/models/fields/tuple_lookups.py
Normal file
244
django/db/models/fields/tuple_lookups.py
Normal file
@ -0,0 +1,244 @@
|
||||
import itertools
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models.expressions import ColPairs, Func, Value
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.db.models.sql.where import AND, OR, WhereNode
|
||||
|
||||
|
||||
class Tuple(Func):
|
||||
function = ""
|
||||
|
||||
|
||||
class TupleLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
self.check_tuple_lookup()
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def check_tuple_lookup(self):
|
||||
assert isinstance(self.lhs, ColPairs)
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_length_equals_lhs_length()
|
||||
|
||||
def check_rhs_is_tuple_or_list(self):
|
||||
if not isinstance(self.rhs, (tuple, list)):
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
|
||||
"must be a tuple or a list"
|
||||
)
|
||||
|
||||
def check_rhs_length_equals_lhs_length(self):
|
||||
if len(self.lhs) != len(self.rhs):
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
|
||||
f"must have {len(self.lhs)} elements"
|
||||
)
|
||||
|
||||
def check_rhs_is_collection_of_tuples_or_lists(self):
|
||||
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
|
||||
f"must be a collection of tuples or lists"
|
||||
)
|
||||
|
||||
def check_rhs_elements_length_equals_lhs_length(self):
|
||||
if not all(len(self.lhs) == len(vals) for vals in self.rhs):
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
|
||||
f"must have {len(self.lhs)} elements each"
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
||||
# WHERE (a, b, c) = (x, y, z)
|
||||
vals = [
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(self.lhs, self.rhs)
|
||||
]
|
||||
lookup_class = self.__class__.__bases__[-1]
|
||||
lookup = lookup_class(Tuple(self.lhs), Tuple(*vals))
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleExact(TupleLookupMixin, Exact):
|
||||
def as_oracle(self, compiler, connection):
|
||||
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
||||
# WHERE a = x AND b = y AND c = z
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)]
|
||||
root = WhereNode(lookups, connector=AND)
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIsNull(IsNull):
|
||||
def as_sql(self, compiler, connection):
|
||||
# e.g.: (a, b, c) is None as SQL:
|
||||
# WHERE a IS NULL AND b IS NULL AND c IS NULL
|
||||
vals = self.rhs
|
||||
if isinstance(vals, bool):
|
||||
vals = [vals] * len(self.lhs)
|
||||
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = [IsNull(col, val) for col, val in zip(cols, vals)]
|
||||
root = WhereNode(lookups, connector=AND)
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
|
||||
def as_oracle(self, compiler, connection):
|
||||
# e.g.: (a, b, c) > (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in cols for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col, val = next(cols_iter), next(vals_iter)
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
|
||||
def as_oracle(self, compiler, connection):
|
||||
# e.g.: (a, b, c) >= (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in cols for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col, val = next(cols_iter), next(vals_iter)
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThan(TupleLookupMixin, LessThan):
|
||||
def as_oracle(self, compiler, connection):
|
||||
# e.g.: (a, b, c) < (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in cols for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col, val = next(cols_iter), next(vals_iter)
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
||||
def as_oracle(self, compiler, connection):
|
||||
# e.g.: (a, b, c) <= (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
|
||||
cols = self.lhs.get_cols()
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in cols for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col, val = next(cols_iter), next(vals_iter)
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup, connector = next(lookups), next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIn(TupleLookupMixin, In):
|
||||
def check_tuple_lookup(self):
|
||||
assert isinstance(self.lhs, ColPairs)
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||
self.check_rhs_elements_length_equals_lhs_length()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not self.rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
|
||||
rhs = []
|
||||
for vals in self.rhs:
|
||||
rhs.append(
|
||||
Tuple(
|
||||
*[
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(self.lhs, vals)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
lookup = In(Tuple(self.lhs), Tuple(*rhs))
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
if not self.rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
||||
root = WhereNode([], connector=OR)
|
||||
cols = self.lhs.get_cols()
|
||||
|
||||
for vals in self.rhs:
|
||||
lookups = [Exact(col, val) for col, val in zip(cols, vals)]
|
||||
root.children.append(WhereNode(lookups, connector=AND))
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
tuple_lookups = {
|
||||
"exact": TupleExact,
|
||||
"gt": TupleGreaterThan,
|
||||
"gte": TupleGreaterThanOrEqual,
|
||||
"lt": TupleLessThan,
|
||||
"lte": TupleLessThanOrEqual,
|
||||
"in": TupleIn,
|
||||
"isnull": TupleIsNull,
|
||||
}
|
@ -23,6 +23,7 @@ from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import (
|
||||
BaseExpression,
|
||||
Col,
|
||||
ColPairs,
|
||||
Exists,
|
||||
F,
|
||||
OuterRef,
|
||||
@ -32,7 +33,6 @@ from django.db.models.expressions import (
|
||||
Value,
|
||||
)
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.fields.related_lookups import MultiColSource
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.query_utils import (
|
||||
Q,
|
||||
@ -1549,9 +1549,7 @@ class Query(BaseExpression):
|
||||
if len(targets) == 1:
|
||||
col = self._get_col(targets[0], join_info.final_field, alias)
|
||||
else:
|
||||
col = MultiColSource(
|
||||
alias, targets, join_info.targets, join_info.final_field
|
||||
)
|
||||
col = ColPairs(alias, targets, join_info.targets, join_info.final_field)
|
||||
else:
|
||||
col = self._get_col(targets[0], join_info.final_field, alias)
|
||||
|
||||
|
242
tests/foreign_object/test_tuple_lookups.py
Normal file
242
tests/foreign_object/test_tuple_lookups.py
Normal file
@ -0,0 +1,242 @@
|
||||
import unittest
|
||||
|
||||
from django.db import NotSupportedError, connection
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Contact, Customer
|
||||
|
||||
|
||||
class TupleLookupsTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
super().setUpTestData()
|
||||
cls.customer_1 = Customer.objects.create(customer_id=1, company="a")
|
||||
cls.customer_2 = Customer.objects.create(customer_id=1, company="b")
|
||||
cls.customer_3 = Customer.objects.create(customer_id=2, company="c")
|
||||
cls.customer_4 = Customer.objects.create(customer_id=3, company="d")
|
||||
cls.customer_5 = Customer.objects.create(customer_id=1, company="e")
|
||||
cls.contact_1 = Contact.objects.create(customer=cls.customer_1)
|
||||
cls.contact_2 = Contact.objects.create(customer=cls.customer_1)
|
||||
cls.contact_3 = Contact.objects.create(customer=cls.customer_2)
|
||||
cls.contact_4 = Contact.objects.create(customer=cls.customer_3)
|
||||
cls.contact_5 = Contact.objects.create(customer=cls.customer_1)
|
||||
cls.contact_6 = Contact.objects.create(customer=cls.customer_5)
|
||||
|
||||
def test_exact(self):
|
||||
test_cases = (
|
||||
(self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
|
||||
(self.customer_2, (self.contact_3,)),
|
||||
(self.customer_3, (self.contact_4,)),
|
||||
(self.customer_4, ()),
|
||||
(self.customer_5, (self.contact_6,)),
|
||||
)
|
||||
|
||||
for customer, contacts in test_cases:
|
||||
with self.subTest(customer=customer, contacts=contacts):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer=customer).order_by("id"), contacts
|
||||
)
|
||||
|
||||
def test_exact_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'exact' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer=subquery).order_by("id"), ()
|
||||
)
|
||||
|
||||
def test_in(self):
|
||||
cust_1, cust_2, cust_3, cust_4, cust_5 = (
|
||||
self.customer_1,
|
||||
self.customer_2,
|
||||
self.customer_3,
|
||||
self.customer_4,
|
||||
self.customer_5,
|
||||
)
|
||||
c1, c2, c3, c4, c5, c6 = (
|
||||
self.contact_1,
|
||||
self.contact_2,
|
||||
self.contact_3,
|
||||
self.contact_4,
|
||||
self.contact_5,
|
||||
self.contact_6,
|
||||
)
|
||||
test_cases = (
|
||||
((), ()),
|
||||
((cust_1,), (c1, c2, c5)),
|
||||
((cust_1, cust_2), (c1, c2, c3, c5)),
|
||||
((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)),
|
||||
((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)),
|
||||
((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)),
|
||||
)
|
||||
|
||||
for contacts, customers in test_cases:
|
||||
with self.subTest(contacts=contacts, customers=customers):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__in=contacts).order_by("id"),
|
||||
customers,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
connection.vendor == "mysql",
|
||||
"MySQL doesn't support LIMIT & IN/ALL/ANY/SOME subquery",
|
||||
)
|
||||
def test_in_subquery(self):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__in=subquery).order_by("id"),
|
||||
(self.contact_1, self.contact_2, self.contact_5),
|
||||
)
|
||||
|
||||
def test_lt(self):
|
||||
c1, c2, c3, c4, c5, c6 = (
|
||||
self.contact_1,
|
||||
self.contact_2,
|
||||
self.contact_3,
|
||||
self.contact_4,
|
||||
self.contact_5,
|
||||
self.contact_6,
|
||||
)
|
||||
test_cases = (
|
||||
(self.customer_1, ()),
|
||||
(self.customer_2, (c1, c2, c5)),
|
||||
(self.customer_5, (c1, c2, c3, c5)),
|
||||
(self.customer_3, (c1, c2, c3, c5, c6)),
|
||||
(self.customer_4, (c1, c2, c3, c4, c5, c6)),
|
||||
)
|
||||
|
||||
for customer, contacts in test_cases:
|
||||
with self.subTest(customer=customer, contacts=contacts):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__lt=customer).order_by("id"),
|
||||
contacts,
|
||||
)
|
||||
|
||||
def test_lt_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'lt' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__lt=subquery).order_by("id"), ()
|
||||
)
|
||||
|
||||
def test_lte(self):
|
||||
c1, c2, c3, c4, c5, c6 = (
|
||||
self.contact_1,
|
||||
self.contact_2,
|
||||
self.contact_3,
|
||||
self.contact_4,
|
||||
self.contact_5,
|
||||
self.contact_6,
|
||||
)
|
||||
test_cases = (
|
||||
(self.customer_1, (c1, c2, c5)),
|
||||
(self.customer_2, (c1, c2, c3, c5)),
|
||||
(self.customer_5, (c1, c2, c3, c5, c6)),
|
||||
(self.customer_3, (c1, c2, c3, c4, c5, c6)),
|
||||
(self.customer_4, (c1, c2, c3, c4, c5, c6)),
|
||||
)
|
||||
|
||||
for customer, contacts in test_cases:
|
||||
with self.subTest(customer=customer, contacts=contacts):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__lte=customer).order_by("id"),
|
||||
contacts,
|
||||
)
|
||||
|
||||
def test_lte_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'lte' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__lte=subquery).order_by("id"), ()
|
||||
)
|
||||
|
||||
def test_gt(self):
|
||||
test_cases = (
|
||||
(self.customer_1, (self.contact_3, self.contact_4, self.contact_6)),
|
||||
(self.customer_2, (self.contact_4, self.contact_6)),
|
||||
(self.customer_5, (self.contact_4,)),
|
||||
(self.customer_3, ()),
|
||||
(self.customer_4, ()),
|
||||
)
|
||||
|
||||
for customer, contacts in test_cases:
|
||||
with self.subTest(customer=customer, contacts=contacts):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__gt=customer).order_by("id"),
|
||||
contacts,
|
||||
)
|
||||
|
||||
def test_gt_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'gt' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__gt=subquery).order_by("id"), ()
|
||||
)
|
||||
|
||||
def test_gte(self):
|
||||
c1, c2, c3, c4, c5, c6 = (
|
||||
self.contact_1,
|
||||
self.contact_2,
|
||||
self.contact_3,
|
||||
self.contact_4,
|
||||
self.contact_5,
|
||||
self.contact_6,
|
||||
)
|
||||
test_cases = (
|
||||
(self.customer_1, (c1, c2, c3, c4, c5, c6)),
|
||||
(self.customer_2, (c3, c4, c6)),
|
||||
(self.customer_5, (c4, c6)),
|
||||
(self.customer_3, (c4,)),
|
||||
(self.customer_4, ()),
|
||||
)
|
||||
|
||||
for customer, contacts in test_cases:
|
||||
with self.subTest(customer=customer, contacts=contacts):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__gte=customer).order_by("pk"),
|
||||
contacts,
|
||||
)
|
||||
|
||||
def test_gte_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'gte' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__gte=subquery).order_by("id"), ()
|
||||
)
|
||||
|
||||
def test_isnull(self):
|
||||
with self.subTest("customer__isnull=True"):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__isnull=True).order_by("id"),
|
||||
(),
|
||||
)
|
||||
with self.subTest("customer__isnull=False"):
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__isnull=False).order_by("id"),
|
||||
(
|
||||
self.contact_1,
|
||||
self.contact_2,
|
||||
self.contact_3,
|
||||
self.contact_4,
|
||||
self.contact_5,
|
||||
self.contact_6,
|
||||
),
|
||||
)
|
||||
|
||||
def test_isnull_subquery(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotSupportedError, "'isnull' doesn't support multi-column subqueries."
|
||||
):
|
||||
subquery = Customer.objects.filter(id=0)[:1]
|
||||
self.assertSequenceEqual(
|
||||
Contact.objects.filter(customer__isnull=subquery).order_by("id"), ()
|
||||
)
|
Loading…
Reference in New Issue
Block a user