1
0
mirror of https://github.com/django/django.git synced 2025-08-21 01:09:13 +00:00

Fixed #35972 -- Fixed lookup crashes after subquery annotations.

This commit is contained in:
Jacob Walls 2024-12-29 01:13:48 -08:00 committed by Sarah Boyce
parent 079d31e698
commit 8914f4703c
11 changed files with 128 additions and 23 deletions

View File

@ -55,7 +55,7 @@ class GISLookup(Lookup):
def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class
return ("%s", [connection.ops.Adapter(value)])
return ("%s", (connection.ops.Adapter(value),))
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, Query):
@ -284,7 +284,7 @@ class RelateLookup(GISLookup):
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
sql, params = super().process_rhs(compiler, connection)
return sql, [*params, pattern]
return sql, (*params, pattern)
@BaseSpatialField.register_lookup
@ -352,7 +352,7 @@ class DWithinLookup(DistanceLookupBase):
dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params["value"] = dist_sql
rhs_sql, params = super().process_rhs(compiler, connection)
return rhs_sql, params + dist_params
return rhs_sql, (*params, *dist_params)
class DistanceLookupFromFunction(DistanceLookupBase):
@ -367,7 +367,7 @@ class DistanceLookupFromFunction(DistanceLookupBase):
dist_sql, dist_params = self.process_distance(compiler, connection)
return (
"%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
params + dist_params,
(*params, *dist_params),
)

View File

@ -27,7 +27,7 @@ class SearchVectorExact(Lookup):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
params = (*lhs_params, *rhs_params)
return "%s @@ %s" % (lhs, rhs), params
@ -148,7 +148,7 @@ class SearchVector(SearchVectorCombinable, Func):
weight_sql, extra_params = compiler.compile(clone.weight)
sql = "setweight({}, {})".format(sql, weight_sql)
return sql, config_params + params + extra_params
return sql, (*config_params, *params, *extra_params)
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
@ -318,13 +318,13 @@ class SearchHeadline(Func):
def as_sql(self, compiler, connection, function=None, template=None):
options_sql = ""
options_params = []
options_params = ()
if self.options:
options_params.append(
options_params = (
", ".join(
connection.ops.compose_sql(f"{option}=%s", [value])
for option, value in self.options.items()
)
),
)
options_sql = ", %s"
sql, params = super().as_sql(

View File

@ -1,6 +1,7 @@
import logging
import operator
from datetime import datetime
from itertools import chain
from django.conf import settings
from django.core.exceptions import FieldError
@ -1160,7 +1161,7 @@ class BaseDatabaseSchemaEditor:
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters and actions:
sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), sum(params, []))]
actions = [(", ".join(sql), tuple(chain(*params)))]
# Apply those actions
for sql, params in actions:
self.execute(

View File

@ -1127,7 +1127,7 @@ class Func(SQLiteNumericMixin, Expression):
template = template or data.get("template", self.template)
arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
return template % data, params
return template % data, tuple(params)
def copy(self):
copy = super().copy()
@ -1323,7 +1323,7 @@ class Col(Expression):
alias, column = self.alias, self.target.column
identifiers = (alias, column) if alias else (column,)
sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
return sql, []
return sql, ()
def relabeled_clone(self, relabels):
if self.alias is None:

View File

@ -26,7 +26,7 @@ class Cast(Func):
compiler, connection, template=template, **extra_context
)
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string)
params = (format_string, *params)
return sql, params
elif db_type == "date":
template = "date(%(expressions)s)"

View File

@ -101,7 +101,7 @@ class Lookup(Expression):
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
return ("%s", [value])
return ("%s", (value,))
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
@ -415,7 +415,7 @@ class IExact(BuiltinLookup):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
params = (connection.ops.prep_for_iexact_query(params[0]), *params[1:])
return rhs, params
@ -603,8 +603,9 @@ class PatternLookup(BuiltinLookup):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
params[0]
params = (
self.param_pattern % connection.ops.prep_for_like_query(params[0]),
*params[1:],
)
return rhs, params
@ -686,8 +687,9 @@ class Regex(BuiltinLookup):
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = (*lhs_params, *rhs_params)
sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params
return sql_template % (lhs, rhs), params
@Field.register_lookup

View File

@ -1202,6 +1202,8 @@ calling the appropriate methods on the wrapped expression.
:meth:`~django.db.models.query.QuerySet.reverse()` is called on a
queryset.
.. _writing-your-own-query-expressions:
Writing your own Query Expressions
----------------------------------
@ -1262,7 +1264,7 @@ Next, we write the method responsible for generating the SQL::
sql_params.extend(params)
template = template or self.template
data = {"expressions": ",".join(sql_expressions)}
return template % data, sql_params
return template % data, tuple(sql_params)
def as_oracle(self, compiler, connection):

View File

@ -427,6 +427,24 @@ Email
significantly, closely examine any custom subclasses that rely on overriding
undocumented, internal underscore methods.
Custom ORM expressions should return params as a tuple
------------------------------------------------------
Prior to Django 6.0, :doc:`custom lookups </howto/custom-lookups>` and
:ref:`custom expressions <writing-your-own-query-expressions>` implementing the
``as_sql()`` method (and its supporting methods ``process_lhs()`` and
``process_rhs()``) were allowed to return a sequence of params in either a list
or a tuple. To address the interoperability problems that resulted, the second
return element of the ``as_sql()`` method should now be a tuple::
def as_sql(self, compiler, connection) -> tuple[str, tuple]: ...
If your custom expressions support multiple versions of Django, you should
adjust any pre-processing of parameters to be resilient against either tuples
or lists. For instance, prefer unpacking like this::
params = (*lhs_params, *rhs_params)
Miscellaneous
-------------

View File

@ -5,6 +5,7 @@ from datetime import date, datetime
from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models.fields.related_lookups import RelatedGreaterThan
from django.db.models.functions import Lower
from django.db.models.lookups import EndsWith, StartsWith
from django.test import SimpleTestCase, TestCase, override_settings
from django.test.utils import register_lookup
@ -17,15 +18,15 @@ class Div3Lookup(models.Lookup):
lookup_name = "div3"
def as_sql(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*lhs_params, *rhs_params)
return "(%s) %%%% 3 = %s" % (lhs, rhs), params
def as_oracle(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*lhs_params, *rhs_params)
return "mod(%s, 3) = %s" % (lhs, rhs), params
@ -249,6 +250,39 @@ class LookupTests(TestCase):
self.assertSequenceEqual(qs1, [a1])
self.assertSequenceEqual(qs2, [a1])
def test_custom_lookup_with_subquery(self):
class NotEqual(models.Lookup):
lookup_name = "ne"
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
# Although combining via (*lhs_params, *rhs_params) would be
# more resilient, the "simple" way works too.
params = lhs_params + rhs_params
return "%s <> %s" % (lhs, rhs), params
author = Author.objects.create(name="Isabella")
with register_lookup(models.Field, NotEqual):
qs = Author.objects.annotate(
unknown_age=models.Subquery(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="Plato")
self.assertSequenceEqual(qs, [author])
qs = Author.objects.annotate(
unknown_age=Lower(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="plato")
self.assertSequenceEqual(qs, [author])
def test_custom_exact_lookup_none_rhs(self):
"""
__exact=None is transformed to __isnull=True if a custom lookup class

View File

@ -749,6 +749,24 @@ class BasicExpressionsTests(TestCase):
)
self.assertCountEqual(subquery_test2, [self.foobar_ltd])
def test_lookups_subquery(self):
smallest_company = Company.objects.order_by("num_employees").values("name")[:1]
for lookup in CharField.get_lookups():
if lookup == "isnull":
continue # not allowed, rhs must be a literal boolean.
if (
lookup == "in"
and not connection.features.allow_sliced_subqueries_with_in
):
continue
if lookup == "range":
rhs = (Subquery(smallest_company), Subquery(smallest_company))
else:
rhs = Subquery(smallest_company)
with self.subTest(lookup=lookup):
qs = Company.objects.filter(**{f"name__{lookup}": rhs})
self.assertGreater(len(qs), 0)
def test_uuid_pk_subquery(self):
u = UUIDPK.objects.create()
UUID.objects.create(uuid_fk=u)

View File

@ -2340,6 +2340,36 @@ class OperationTests(OperationTestBase):
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.pink, 3)
@skipUnlessDBFeature("supports_expression_defaults")
def test_alter_field_add_database_default_func(self):
app_label = "test_alfladdf"
project_state = self.set_up_test_model(app_label)
operation = migrations.AlterField(
"Pony", "weight", models.FloatField(db_default=Pi())
)
new_state = project_state.clone()
operation.state_forwards(app_label, new_state)
old_weight = project_state.models[app_label, "pony"].fields["weight"]
self.assertIs(old_weight.default, models.NOT_PROVIDED)
self.assertIs(old_weight.db_default, models.NOT_PROVIDED)
new_weight = new_state.models[app_label, "pony"].fields["weight"]
self.assertIs(new_weight.default, models.NOT_PROVIDED)
self.assertIsInstance(new_weight.db_default, Pi)
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.weight, 1)
# Alter field.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
pony = new_state.apps.get_model(app_label, "pony").objects.create()
if not connection.features.can_return_columns_from_insert:
pony.refresh_from_db()
self.assertAlmostEqual(pony.weight, math.pi)
# Reversal.
with connection.schema_editor() as editor:
operation.database_backwards(app_label, editor, new_state, project_state)
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.weight, 1)
def test_alter_field_change_nullable_to_database_default_not_null(self):
"""
The AlterField operation changing a null field to db_default.