mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Refs #373 -- Added TupleIn subqueries.
This commit is contained in:
parent
611bf6c2e2
commit
f7601aed51
@ -12,6 +12,7 @@ from django.db.models.lookups import (
|
|||||||
LessThan,
|
LessThan,
|
||||||
LessThanOrEqual,
|
LessThanOrEqual,
|
||||||
)
|
)
|
||||||
|
from django.db.models.sql import Query
|
||||||
from django.db.models.sql.where import AND, OR, WhereNode
|
from django.db.models.sql.where import AND, OR, WhereNode
|
||||||
|
|
||||||
|
|
||||||
@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
|||||||
|
|
||||||
class TupleIn(TupleLookupMixin, In):
|
class TupleIn(TupleLookupMixin, In):
|
||||||
def get_prep_lookup(self):
|
def get_prep_lookup(self):
|
||||||
|
if self.rhs_is_direct_value():
|
||||||
self.check_rhs_is_tuple_or_list()
|
self.check_rhs_is_tuple_or_list()
|
||||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||||
self.check_rhs_elements_length_equals_lhs_length()
|
self.check_rhs_elements_length_equals_lhs_length()
|
||||||
|
else:
|
||||||
|
self.check_rhs_is_query()
|
||||||
|
self.check_rhs_select_length_equals_lhs_length()
|
||||||
|
|
||||||
return self.rhs # skip checks from mixin
|
return self.rhs # skip checks from mixin
|
||||||
|
|
||||||
def check_rhs_is_collection_of_tuples_or_lists(self):
|
def check_rhs_is_collection_of_tuples_or_lists(self):
|
||||||
@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In):
|
|||||||
f"must have {len_lhs} elements each"
|
f"must have {len_lhs} elements each"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_rhs_is_query(self):
|
||||||
|
if not isinstance(self.rhs, Query):
|
||||||
|
lhs_str = self.get_lhs_str()
|
||||||
|
rhs_cls = self.rhs.__class__.__name__
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||||
|
f"must be a Query object (received {rhs_cls!r})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_rhs_select_length_equals_lhs_length(self):
|
||||||
|
len_rhs = len(self.rhs.select)
|
||||||
|
len_lhs = len(self.lhs)
|
||||||
|
if len_rhs != len_lhs:
|
||||||
|
lhs_str = self.get_lhs_str()
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||||
|
f"must have {len_lhs} fields (received {len_rhs})"
|
||||||
|
)
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs(self, compiler, connection):
|
||||||
rhs = self.rhs
|
rhs = self.rhs
|
||||||
if not rhs:
|
if not rhs:
|
||||||
@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In):
|
|||||||
|
|
||||||
return Tuple(*result).as_sql(compiler, connection)
|
return Tuple(*result).as_sql(compiler, connection)
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
if not self.rhs_is_direct_value():
|
||||||
|
return self.as_subquery(compiler, connection)
|
||||||
|
return super().as_sql(compiler, connection)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
rhs = self.rhs
|
rhs = self.rhs
|
||||||
if not rhs:
|
if not rhs:
|
||||||
raise EmptyResultSet
|
raise EmptyResultSet
|
||||||
|
if not self.rhs_is_direct_value():
|
||||||
|
return self.as_subquery(compiler, connection)
|
||||||
|
|
||||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||||
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
||||||
@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In):
|
|||||||
|
|
||||||
return root.as_sql(compiler, connection)
|
return root.as_sql(compiler, connection)
|
||||||
|
|
||||||
|
def as_subquery(self, compiler, connection):
|
||||||
|
return compiler.compile(In(self.lhs, self.rhs))
|
||||||
|
|
||||||
|
|
||||||
tuple_lookups = {
|
tuple_lookups = {
|
||||||
"exact": TupleExact,
|
"exact": TupleExact,
|
||||||
|
@ -11,6 +11,7 @@ from django.db.models.fields.tuple_lookups import (
|
|||||||
TupleLessThan,
|
TupleLessThan,
|
||||||
TupleLessThanOrEqual,
|
TupleLessThanOrEqual,
|
||||||
)
|
)
|
||||||
|
from django.db.models.lookups import In
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
from .models import Contact, Customer
|
from .models import Contact, Customer
|
||||||
@ -126,6 +127,46 @@ class TupleLookupsTests(TestCase):
|
|||||||
(self.contact_1, self.contact_2, self.contact_5),
|
(self.contact_1, self.contact_2, self.contact_5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tuple_in_subquery_must_be_query(self):
|
||||||
|
lhs = (F("customer_code"), F("company_code"))
|
||||||
|
# If rhs is any non-Query object with an as_sql() function.
|
||||||
|
rhs = In(F("customer_code"), [1, 2, 3])
|
||||||
|
with self.assertRaisesMessage(
|
||||||
|
ValueError,
|
||||||
|
"'in' subquery lookup of ('customer_code', 'company_code') "
|
||||||
|
"must be a Query object (received 'In')",
|
||||||
|
):
|
||||||
|
TupleIn(lhs, rhs)
|
||||||
|
|
||||||
|
def test_tuple_in_subquery_must_have_2_fields(self):
|
||||||
|
lhs = (F("customer_code"), F("company_code"))
|
||||||
|
rhs = Customer.objects.values_list("customer_id").query
|
||||||
|
with self.assertRaisesMessage(
|
||||||
|
ValueError,
|
||||||
|
"'in' subquery lookup of ('customer_code', 'company_code') "
|
||||||
|
"must have 2 fields (received 1)",
|
||||||
|
):
|
||||||
|
TupleIn(lhs, rhs)
|
||||||
|
|
||||||
|
def test_tuple_in_subquery(self):
|
||||||
|
customers = Customer.objects.values_list("customer_id", "company")
|
||||||
|
test_cases = (
|
||||||
|
(self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
|
||||||
|
(self.customer_2, (self.contact_3,)),
|
||||||
|
(self.customer_3, (self.contact_4,)),
|
||||||
|
(self.customer_4, ()),
|
||||||
|
(self.customer_5, (self.contact_6,)),
|
||||||
|
)
|
||||||
|
|
||||||
|
for customer, contacts in test_cases:
|
||||||
|
lhs = (F("customer_code"), F("company_code"))
|
||||||
|
rhs = customers.filter(id=customer.id).query
|
||||||
|
lookup = TupleIn(lhs, rhs)
|
||||||
|
qs = Contact.objects.filter(lookup).order_by("id")
|
||||||
|
|
||||||
|
with self.subTest(customer=customer.id, query=str(qs.query)):
|
||||||
|
self.assertSequenceEqual(qs, contacts)
|
||||||
|
|
||||||
def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
|
def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
|
||||||
test_cases = (
|
test_cases = (
|
||||||
(1, 2, 3),
|
(1, 2, 3),
|
||||||
|
Loading…
Reference in New Issue
Block a user