mirror of
https://github.com/django/django.git
synced 2025-08-21 01:09:13 +00:00
Fixed #35972 -- Fixed lookup crashes after subquery annotations.
This commit is contained in:
parent
079d31e698
commit
8914f4703c
@ -55,7 +55,7 @@ class GISLookup(Lookup):
|
|||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
# get_db_prep_lookup is called by process_rhs from super class
|
# get_db_prep_lookup is called by process_rhs from super class
|
||||||
return ("%s", [connection.ops.Adapter(value)])
|
return ("%s", (connection.ops.Adapter(value),))
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs(self, compiler, connection):
|
||||||
if isinstance(self.rhs, Query):
|
if isinstance(self.rhs, Query):
|
||||||
@ -284,7 +284,7 @@ class RelateLookup(GISLookup):
|
|||||||
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
|
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
|
||||||
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
|
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
|
||||||
sql, params = super().process_rhs(compiler, connection)
|
sql, params = super().process_rhs(compiler, connection)
|
||||||
return sql, [*params, pattern]
|
return sql, (*params, pattern)
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
@ -352,7 +352,7 @@ class DWithinLookup(DistanceLookupBase):
|
|||||||
dist_sql, dist_params = self.process_distance(compiler, connection)
|
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||||
self.template_params["value"] = dist_sql
|
self.template_params["value"] = dist_sql
|
||||||
rhs_sql, params = super().process_rhs(compiler, connection)
|
rhs_sql, params = super().process_rhs(compiler, connection)
|
||||||
return rhs_sql, params + dist_params
|
return rhs_sql, (*params, *dist_params)
|
||||||
|
|
||||||
|
|
||||||
class DistanceLookupFromFunction(DistanceLookupBase):
|
class DistanceLookupFromFunction(DistanceLookupBase):
|
||||||
@ -367,7 +367,7 @@ class DistanceLookupFromFunction(DistanceLookupBase):
|
|||||||
dist_sql, dist_params = self.process_distance(compiler, connection)
|
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||||
return (
|
return (
|
||||||
"%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
|
"%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
|
||||||
params + dist_params,
|
(*params, *dist_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ class SearchVectorExact(Lookup):
|
|||||||
def as_sql(self, qn, connection):
|
def as_sql(self, qn, connection):
|
||||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||||
params = lhs_params + rhs_params
|
params = (*lhs_params, *rhs_params)
|
||||||
return "%s @@ %s" % (lhs, rhs), params
|
return "%s @@ %s" % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class SearchVector(SearchVectorCombinable, Func):
|
|||||||
weight_sql, extra_params = compiler.compile(clone.weight)
|
weight_sql, extra_params = compiler.compile(clone.weight)
|
||||||
sql = "setweight({}, {})".format(sql, weight_sql)
|
sql = "setweight({}, {})".format(sql, weight_sql)
|
||||||
|
|
||||||
return sql, config_params + params + extra_params
|
return sql, (*config_params, *params, *extra_params)
|
||||||
|
|
||||||
|
|
||||||
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
|
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
|
||||||
@ -318,13 +318,13 @@ class SearchHeadline(Func):
|
|||||||
|
|
||||||
def as_sql(self, compiler, connection, function=None, template=None):
|
def as_sql(self, compiler, connection, function=None, template=None):
|
||||||
options_sql = ""
|
options_sql = ""
|
||||||
options_params = []
|
options_params = ()
|
||||||
if self.options:
|
if self.options:
|
||||||
options_params.append(
|
options_params = (
|
||||||
", ".join(
|
", ".join(
|
||||||
connection.ops.compose_sql(f"{option}=%s", [value])
|
connection.ops.compose_sql(f"{option}=%s", [value])
|
||||||
for option, value in self.options.items()
|
for option, value in self.options.items()
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
options_sql = ", %s"
|
options_sql = ", %s"
|
||||||
sql, params = super().as_sql(
|
sql, params = super().as_sql(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
@ -1160,7 +1161,7 @@ class BaseDatabaseSchemaEditor:
|
|||||||
# Combine actions together if we can (e.g. postgres)
|
# Combine actions together if we can (e.g. postgres)
|
||||||
if self.connection.features.supports_combined_alters and actions:
|
if self.connection.features.supports_combined_alters and actions:
|
||||||
sql, params = tuple(zip(*actions))
|
sql, params = tuple(zip(*actions))
|
||||||
actions = [(", ".join(sql), sum(params, []))]
|
actions = [(", ".join(sql), tuple(chain(*params)))]
|
||||||
# Apply those actions
|
# Apply those actions
|
||||||
for sql, params in actions:
|
for sql, params in actions:
|
||||||
self.execute(
|
self.execute(
|
||||||
|
@ -1127,7 +1127,7 @@ class Func(SQLiteNumericMixin, Expression):
|
|||||||
template = template or data.get("template", self.template)
|
template = template or data.get("template", self.template)
|
||||||
arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
|
arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
|
||||||
data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
|
data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
|
||||||
return template % data, params
|
return template % data, tuple(params)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
copy = super().copy()
|
copy = super().copy()
|
||||||
@ -1323,7 +1323,7 @@ class Col(Expression):
|
|||||||
alias, column = self.alias, self.target.column
|
alias, column = self.alias, self.target.column
|
||||||
identifiers = (alias, column) if alias else (column,)
|
identifiers = (alias, column) if alias else (column,)
|
||||||
sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
|
sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
|
||||||
return sql, []
|
return sql, ()
|
||||||
|
|
||||||
def relabeled_clone(self, relabels):
|
def relabeled_clone(self, relabels):
|
||||||
if self.alias is None:
|
if self.alias is None:
|
||||||
|
@ -26,7 +26,7 @@ class Cast(Func):
|
|||||||
compiler, connection, template=template, **extra_context
|
compiler, connection, template=template, **extra_context
|
||||||
)
|
)
|
||||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||||
params.insert(0, format_string)
|
params = (format_string, *params)
|
||||||
return sql, params
|
return sql, params
|
||||||
elif db_type == "date":
|
elif db_type == "date":
|
||||||
template = "date(%(expressions)s)"
|
template = "date(%(expressions)s)"
|
||||||
|
@ -101,7 +101,7 @@ class Lookup(Expression):
|
|||||||
return Value(self.lhs)
|
return Value(self.lhs)
|
||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
return ("%s", [value])
|
return ("%s", (value,))
|
||||||
|
|
||||||
def process_lhs(self, compiler, connection, lhs=None):
|
def process_lhs(self, compiler, connection, lhs=None):
|
||||||
lhs = lhs or self.lhs
|
lhs = lhs or self.lhs
|
||||||
@ -415,7 +415,7 @@ class IExact(BuiltinLookup):
|
|||||||
def process_rhs(self, qn, connection):
|
def process_rhs(self, qn, connection):
|
||||||
rhs, params = super().process_rhs(qn, connection)
|
rhs, params = super().process_rhs(qn, connection)
|
||||||
if params:
|
if params:
|
||||||
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
params = (connection.ops.prep_for_iexact_query(params[0]), *params[1:])
|
||||||
return rhs, params
|
return rhs, params
|
||||||
|
|
||||||
|
|
||||||
@ -603,8 +603,9 @@ class PatternLookup(BuiltinLookup):
|
|||||||
def process_rhs(self, qn, connection):
|
def process_rhs(self, qn, connection):
|
||||||
rhs, params = super().process_rhs(qn, connection)
|
rhs, params = super().process_rhs(qn, connection)
|
||||||
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
||||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
|
params = (
|
||||||
params[0]
|
self.param_pattern % connection.ops.prep_for_like_query(params[0]),
|
||||||
|
*params[1:],
|
||||||
)
|
)
|
||||||
return rhs, params
|
return rhs, params
|
||||||
|
|
||||||
@ -686,8 +687,9 @@ class Regex(BuiltinLookup):
|
|||||||
else:
|
else:
|
||||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
|
params = (*lhs_params, *rhs_params)
|
||||||
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
||||||
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
return sql_template % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
@Field.register_lookup
|
@Field.register_lookup
|
||||||
|
@ -1202,6 +1202,8 @@ calling the appropriate methods on the wrapped expression.
|
|||||||
:meth:`~django.db.models.query.QuerySet.reverse()` is called on a
|
:meth:`~django.db.models.query.QuerySet.reverse()` is called on a
|
||||||
queryset.
|
queryset.
|
||||||
|
|
||||||
|
.. _writing-your-own-query-expressions:
|
||||||
|
|
||||||
Writing your own Query Expressions
|
Writing your own Query Expressions
|
||||||
----------------------------------
|
----------------------------------
|
||||||
|
|
||||||
@ -1262,7 +1264,7 @@ Next, we write the method responsible for generating the SQL::
|
|||||||
sql_params.extend(params)
|
sql_params.extend(params)
|
||||||
template = template or self.template
|
template = template or self.template
|
||||||
data = {"expressions": ",".join(sql_expressions)}
|
data = {"expressions": ",".join(sql_expressions)}
|
||||||
return template % data, sql_params
|
return template % data, tuple(sql_params)
|
||||||
|
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
|
@ -427,6 +427,24 @@ Email
|
|||||||
significantly, closely examine any custom subclasses that rely on overriding
|
significantly, closely examine any custom subclasses that rely on overriding
|
||||||
undocumented, internal underscore methods.
|
undocumented, internal underscore methods.
|
||||||
|
|
||||||
|
Custom ORM expressions should return params as a tuple
|
||||||
|
------------------------------------------------------
|
||||||
|
|
||||||
|
Prior to Django 6.0, :doc:`custom lookups </howto/custom-lookups>` and
|
||||||
|
:ref:`custom expressions <writing-your-own-query-expressions>` implementing the
|
||||||
|
``as_sql()`` method (and its supporting methods ``process_lhs()`` and
|
||||||
|
``process_rhs()``) were allowed to return a sequence of params in either a list
|
||||||
|
or a tuple. To address the interoperability problems that resulted, the second
|
||||||
|
return element of the ``as_sql()`` method should now be a tuple::
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection) -> tuple[str, tuple]: ...
|
||||||
|
|
||||||
|
If your custom expressions support multiple versions of Django, you should
|
||||||
|
adjust any pre-processing of parameters to be resilient against either tuples
|
||||||
|
or lists. For instance, prefer unpacking like this::
|
||||||
|
|
||||||
|
params = (*lhs_params, *rhs_params)
|
||||||
|
|
||||||
Miscellaneous
|
Miscellaneous
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from datetime import date, datetime
|
|||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import connection, models
|
from django.db import connection, models
|
||||||
from django.db.models.fields.related_lookups import RelatedGreaterThan
|
from django.db.models.fields.related_lookups import RelatedGreaterThan
|
||||||
|
from django.db.models.functions import Lower
|
||||||
from django.db.models.lookups import EndsWith, StartsWith
|
from django.db.models.lookups import EndsWith, StartsWith
|
||||||
from django.test import SimpleTestCase, TestCase, override_settings
|
from django.test import SimpleTestCase, TestCase, override_settings
|
||||||
from django.test.utils import register_lookup
|
from django.test.utils import register_lookup
|
||||||
@ -17,15 +18,15 @@ class Div3Lookup(models.Lookup):
|
|||||||
lookup_name = "div3"
|
lookup_name = "div3"
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
lhs, params = self.process_lhs(compiler, connection)
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
params.extend(rhs_params)
|
params = (*lhs_params, *rhs_params)
|
||||||
return "(%s) %%%% 3 = %s" % (lhs, rhs), params
|
return "(%s) %%%% 3 = %s" % (lhs, rhs), params
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
lhs, params = self.process_lhs(compiler, connection)
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
params.extend(rhs_params)
|
params = (*lhs_params, *rhs_params)
|
||||||
return "mod(%s, 3) = %s" % (lhs, rhs), params
|
return "mod(%s, 3) = %s" % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
@ -249,6 +250,39 @@ class LookupTests(TestCase):
|
|||||||
self.assertSequenceEqual(qs1, [a1])
|
self.assertSequenceEqual(qs1, [a1])
|
||||||
self.assertSequenceEqual(qs2, [a1])
|
self.assertSequenceEqual(qs2, [a1])
|
||||||
|
|
||||||
|
def test_custom_lookup_with_subquery(self):
|
||||||
|
class NotEqual(models.Lookup):
|
||||||
|
lookup_name = "ne"
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
|
# Although combining via (*lhs_params, *rhs_params) would be
|
||||||
|
# more resilient, the "simple" way works too.
|
||||||
|
params = lhs_params + rhs_params
|
||||||
|
return "%s <> %s" % (lhs, rhs), params
|
||||||
|
|
||||||
|
author = Author.objects.create(name="Isabella")
|
||||||
|
|
||||||
|
with register_lookup(models.Field, NotEqual):
|
||||||
|
qs = Author.objects.annotate(
|
||||||
|
unknown_age=models.Subquery(
|
||||||
|
Author.objects.filter(age__isnull=True)
|
||||||
|
.order_by("name")
|
||||||
|
.values("name")[:1]
|
||||||
|
)
|
||||||
|
).filter(unknown_age__ne="Plato")
|
||||||
|
self.assertSequenceEqual(qs, [author])
|
||||||
|
|
||||||
|
qs = Author.objects.annotate(
|
||||||
|
unknown_age=Lower(
|
||||||
|
Author.objects.filter(age__isnull=True)
|
||||||
|
.order_by("name")
|
||||||
|
.values("name")[:1]
|
||||||
|
)
|
||||||
|
).filter(unknown_age__ne="plato")
|
||||||
|
self.assertSequenceEqual(qs, [author])
|
||||||
|
|
||||||
def test_custom_exact_lookup_none_rhs(self):
|
def test_custom_exact_lookup_none_rhs(self):
|
||||||
"""
|
"""
|
||||||
__exact=None is transformed to __isnull=True if a custom lookup class
|
__exact=None is transformed to __isnull=True if a custom lookup class
|
||||||
|
@ -749,6 +749,24 @@ class BasicExpressionsTests(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertCountEqual(subquery_test2, [self.foobar_ltd])
|
self.assertCountEqual(subquery_test2, [self.foobar_ltd])
|
||||||
|
|
||||||
|
def test_lookups_subquery(self):
|
||||||
|
smallest_company = Company.objects.order_by("num_employees").values("name")[:1]
|
||||||
|
for lookup in CharField.get_lookups():
|
||||||
|
if lookup == "isnull":
|
||||||
|
continue # not allowed, rhs must be a literal boolean.
|
||||||
|
if (
|
||||||
|
lookup == "in"
|
||||||
|
and not connection.features.allow_sliced_subqueries_with_in
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if lookup == "range":
|
||||||
|
rhs = (Subquery(smallest_company), Subquery(smallest_company))
|
||||||
|
else:
|
||||||
|
rhs = Subquery(smallest_company)
|
||||||
|
with self.subTest(lookup=lookup):
|
||||||
|
qs = Company.objects.filter(**{f"name__{lookup}": rhs})
|
||||||
|
self.assertGreater(len(qs), 0)
|
||||||
|
|
||||||
def test_uuid_pk_subquery(self):
|
def test_uuid_pk_subquery(self):
|
||||||
u = UUIDPK.objects.create()
|
u = UUIDPK.objects.create()
|
||||||
UUID.objects.create(uuid_fk=u)
|
UUID.objects.create(uuid_fk=u)
|
||||||
|
@ -2340,6 +2340,36 @@ class OperationTests(OperationTestBase):
|
|||||||
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
self.assertEqual(pony.pink, 3)
|
self.assertEqual(pony.pink, 3)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_alter_field_add_database_default_func(self):
|
||||||
|
app_label = "test_alfladdf"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AlterField(
|
||||||
|
"Pony", "weight", models.FloatField(db_default=Pi())
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
old_weight = project_state.models[app_label, "pony"].fields["weight"]
|
||||||
|
self.assertIs(old_weight.default, models.NOT_PROVIDED)
|
||||||
|
self.assertIs(old_weight.db_default, models.NOT_PROVIDED)
|
||||||
|
new_weight = new_state.models[app_label, "pony"].fields["weight"]
|
||||||
|
self.assertIs(new_weight.default, models.NOT_PROVIDED)
|
||||||
|
self.assertIsInstance(new_weight.db_default, Pi)
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
self.assertEqual(pony.weight, 1)
|
||||||
|
# Alter field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
pony = new_state.apps.get_model(app_label, "pony").objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertAlmostEqual(pony.weight, math.pi)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
self.assertEqual(pony.weight, 1)
|
||||||
|
|
||||||
def test_alter_field_change_nullable_to_database_default_not_null(self):
|
def test_alter_field_change_nullable_to_database_default_not_null(self):
|
||||||
"""
|
"""
|
||||||
The AlterField operation changing a null field to db_default.
|
The AlterField operation changing a null field to db_default.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user