1
0
mirror of https://github.com/django/django.git synced 2025-08-21 09:19:12 +00:00

Fixed #35972 -- Fixed lookup crashes after subquery annotations.

This commit is contained in:
Jacob Walls 2024-12-29 01:13:48 -08:00 committed by Sarah Boyce
parent 079d31e698
commit 8914f4703c
11 changed files with 128 additions and 23 deletions

View File

@ -55,7 +55,7 @@ class GISLookup(Lookup):
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class # get_db_prep_lookup is called by process_rhs from super class
return ("%s", [connection.ops.Adapter(value)]) return ("%s", (connection.ops.Adapter(value),))
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
if isinstance(self.rhs, Query): if isinstance(self.rhs, Query):
@ -284,7 +284,7 @@ class RelateLookup(GISLookup):
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern): elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
sql, params = super().process_rhs(compiler, connection) sql, params = super().process_rhs(compiler, connection)
return sql, [*params, pattern] return sql, (*params, pattern)
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
@ -352,7 +352,7 @@ class DWithinLookup(DistanceLookupBase):
dist_sql, dist_params = self.process_distance(compiler, connection) dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params["value"] = dist_sql self.template_params["value"] = dist_sql
rhs_sql, params = super().process_rhs(compiler, connection) rhs_sql, params = super().process_rhs(compiler, connection)
return rhs_sql, params + dist_params return rhs_sql, (*params, *dist_params)
class DistanceLookupFromFunction(DistanceLookupBase): class DistanceLookupFromFunction(DistanceLookupBase):
@ -367,7 +367,7 @@ class DistanceLookupFromFunction(DistanceLookupBase):
dist_sql, dist_params = self.process_distance(compiler, connection) dist_sql, dist_params = self.process_distance(compiler, connection)
return ( return (
"%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql}, "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
params + dist_params, (*params, *dist_params),
) )

View File

@ -27,7 +27,7 @@ class SearchVectorExact(Lookup):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params params = (*lhs_params, *rhs_params)
return "%s @@ %s" % (lhs, rhs), params return "%s @@ %s" % (lhs, rhs), params
@ -148,7 +148,7 @@ class SearchVector(SearchVectorCombinable, Func):
weight_sql, extra_params = compiler.compile(clone.weight) weight_sql, extra_params = compiler.compile(clone.weight)
sql = "setweight({}, {})".format(sql, weight_sql) sql = "setweight({}, {})".format(sql, weight_sql)
return sql, config_params + params + extra_params return sql, (*config_params, *params, *extra_params)
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression): class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
@ -318,13 +318,13 @@ class SearchHeadline(Func):
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None):
options_sql = "" options_sql = ""
options_params = [] options_params = ()
if self.options: if self.options:
options_params.append( options_params = (
", ".join( ", ".join(
connection.ops.compose_sql(f"{option}=%s", [value]) connection.ops.compose_sql(f"{option}=%s", [value])
for option, value in self.options.items() for option, value in self.options.items()
) ),
) )
options_sql = ", %s" options_sql = ", %s"
sql, params = super().as_sql( sql, params = super().as_sql(

View File

@ -1,6 +1,7 @@
import logging import logging
import operator import operator
from datetime import datetime from datetime import datetime
from itertools import chain
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
@ -1160,7 +1161,7 @@ class BaseDatabaseSchemaEditor:
# Combine actions together if we can (e.g. postgres) # Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters and actions: if self.connection.features.supports_combined_alters and actions:
sql, params = tuple(zip(*actions)) sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), sum(params, []))] actions = [(", ".join(sql), tuple(chain(*params)))]
# Apply those actions # Apply those actions
for sql, params in actions: for sql, params in actions:
self.execute( self.execute(

View File

@ -1127,7 +1127,7 @@ class Func(SQLiteNumericMixin, Expression):
template = template or data.get("template", self.template) template = template or data.get("template", self.template)
arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
data["expressions"] = data["field"] = arg_joiner.join(sql_parts) data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
return template % data, params return template % data, tuple(params)
def copy(self): def copy(self):
copy = super().copy() copy = super().copy()
@ -1323,7 +1323,7 @@ class Col(Expression):
alias, column = self.alias, self.target.column alias, column = self.alias, self.target.column
identifiers = (alias, column) if alias else (column,) identifiers = (alias, column) if alias else (column,)
sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
return sql, [] return sql, ()
def relabeled_clone(self, relabels): def relabeled_clone(self, relabels):
if self.alias is None: if self.alias is None:

View File

@ -26,7 +26,7 @@ class Cast(Func):
compiler, connection, template=template, **extra_context compiler, connection, template=template, **extra_context
) )
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string) params = (format_string, *params)
return sql, params return sql, params
elif db_type == "date": elif db_type == "date":
template = "date(%(expressions)s)" template = "date(%(expressions)s)"

View File

@ -101,7 +101,7 @@ class Lookup(Expression):
return Value(self.lhs) return Value(self.lhs)
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
return ("%s", [value]) return ("%s", (value,))
def process_lhs(self, compiler, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs lhs = lhs or self.lhs
@ -415,7 +415,7 @@ class IExact(BuiltinLookup):
def process_rhs(self, qn, connection): def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection) rhs, params = super().process_rhs(qn, connection)
if params: if params:
params[0] = connection.ops.prep_for_iexact_query(params[0]) params = (connection.ops.prep_for_iexact_query(params[0]), *params[1:])
return rhs, params return rhs, params
@ -603,8 +603,9 @@ class PatternLookup(BuiltinLookup):
def process_rhs(self, qn, connection): def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection) rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms: if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query( params = (
params[0] self.param_pattern % connection.ops.prep_for_like_query(params[0]),
*params[1:],
) )
return rhs, params return rhs, params
@ -686,8 +687,9 @@ class Regex(BuiltinLookup):
else: else:
lhs, lhs_params = self.process_lhs(compiler, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = (*lhs_params, *rhs_params)
sql_template = connection.ops.regex_lookup(self.lookup_name) sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params return sql_template % (lhs, rhs), params
@Field.register_lookup @Field.register_lookup

View File

@ -1202,6 +1202,8 @@ calling the appropriate methods on the wrapped expression.
:meth:`~django.db.models.query.QuerySet.reverse()` is called on a :meth:`~django.db.models.query.QuerySet.reverse()` is called on a
queryset. queryset.
.. _writing-your-own-query-expressions:
Writing your own Query Expressions Writing your own Query Expressions
---------------------------------- ----------------------------------
@ -1262,7 +1264,7 @@ Next, we write the method responsible for generating the SQL::
sql_params.extend(params) sql_params.extend(params)
template = template or self.template template = template or self.template
data = {"expressions": ",".join(sql_expressions)} data = {"expressions": ",".join(sql_expressions)}
return template % data, sql_params return template % data, tuple(sql_params)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):

View File

@ -427,6 +427,24 @@ Email
significantly, closely examine any custom subclasses that rely on overriding significantly, closely examine any custom subclasses that rely on overriding
undocumented, internal underscore methods. undocumented, internal underscore methods.
Custom ORM expressions should return params as a tuple
------------------------------------------------------
Prior to Django 6.0, :doc:`custom lookups </howto/custom-lookups>` and
:ref:`custom expressions <writing-your-own-query-expressions>` implementing the
``as_sql()`` method (and its supporting methods ``process_lhs()`` and
``process_rhs()``) were allowed to return a sequence of params in either a list
or a tuple. To address the interoperability problems that resulted, the second
return element of the ``as_sql()`` method should now be a tuple::
def as_sql(self, compiler, connection) -> tuple[str, tuple]: ...
If your custom expressions support multiple versions of Django, you should
adjust any pre-processing of parameters to be resilient against either tuples
or lists. For instance, prefer unpacking like this::
params = (*lhs_params, *rhs_params)
Miscellaneous Miscellaneous
------------- -------------

View File

@ -5,6 +5,7 @@ from datetime import date, datetime
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, models from django.db import connection, models
from django.db.models.fields.related_lookups import RelatedGreaterThan from django.db.models.fields.related_lookups import RelatedGreaterThan
from django.db.models.functions import Lower
from django.db.models.lookups import EndsWith, StartsWith from django.db.models.lookups import EndsWith, StartsWith
from django.test import SimpleTestCase, TestCase, override_settings from django.test import SimpleTestCase, TestCase, override_settings
from django.test.utils import register_lookup from django.test.utils import register_lookup
@ -17,15 +18,15 @@ class Div3Lookup(models.Lookup):
lookup_name = "div3" lookup_name = "div3"
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params = (*lhs_params, *rhs_params)
return "(%s) %%%% 3 = %s" % (lhs, rhs), params return "(%s) %%%% 3 = %s" % (lhs, rhs), params
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params = (*lhs_params, *rhs_params)
return "mod(%s, 3) = %s" % (lhs, rhs), params return "mod(%s, 3) = %s" % (lhs, rhs), params
@ -249,6 +250,39 @@ class LookupTests(TestCase):
self.assertSequenceEqual(qs1, [a1]) self.assertSequenceEqual(qs1, [a1])
self.assertSequenceEqual(qs2, [a1]) self.assertSequenceEqual(qs2, [a1])
def test_custom_lookup_with_subquery(self):
class NotEqual(models.Lookup):
lookup_name = "ne"
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
# Although combining via (*lhs_params, *rhs_params) would be
# more resilient, the "simple" way works too.
params = lhs_params + rhs_params
return "%s <> %s" % (lhs, rhs), params
author = Author.objects.create(name="Isabella")
with register_lookup(models.Field, NotEqual):
qs = Author.objects.annotate(
unknown_age=models.Subquery(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="Plato")
self.assertSequenceEqual(qs, [author])
qs = Author.objects.annotate(
unknown_age=Lower(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="plato")
self.assertSequenceEqual(qs, [author])
def test_custom_exact_lookup_none_rhs(self): def test_custom_exact_lookup_none_rhs(self):
""" """
__exact=None is transformed to __isnull=True if a custom lookup class __exact=None is transformed to __isnull=True if a custom lookup class

View File

@ -749,6 +749,24 @@ class BasicExpressionsTests(TestCase):
) )
self.assertCountEqual(subquery_test2, [self.foobar_ltd]) self.assertCountEqual(subquery_test2, [self.foobar_ltd])
def test_lookups_subquery(self):
smallest_company = Company.objects.order_by("num_employees").values("name")[:1]
for lookup in CharField.get_lookups():
if lookup == "isnull":
continue # not allowed, rhs must be a literal boolean.
if (
lookup == "in"
and not connection.features.allow_sliced_subqueries_with_in
):
continue
if lookup == "range":
rhs = (Subquery(smallest_company), Subquery(smallest_company))
else:
rhs = Subquery(smallest_company)
with self.subTest(lookup=lookup):
qs = Company.objects.filter(**{f"name__{lookup}": rhs})
self.assertGreater(len(qs), 0)
def test_uuid_pk_subquery(self): def test_uuid_pk_subquery(self):
u = UUIDPK.objects.create() u = UUIDPK.objects.create()
UUID.objects.create(uuid_fk=u) UUID.objects.create(uuid_fk=u)

View File

@ -2340,6 +2340,36 @@ class OperationTests(OperationTestBase):
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.pink, 3) self.assertEqual(pony.pink, 3)
@skipUnlessDBFeature("supports_expression_defaults")
def test_alter_field_add_database_default_func(self):
app_label = "test_alfladdf"
project_state = self.set_up_test_model(app_label)
operation = migrations.AlterField(
"Pony", "weight", models.FloatField(db_default=Pi())
)
new_state = project_state.clone()
operation.state_forwards(app_label, new_state)
old_weight = project_state.models[app_label, "pony"].fields["weight"]
self.assertIs(old_weight.default, models.NOT_PROVIDED)
self.assertIs(old_weight.db_default, models.NOT_PROVIDED)
new_weight = new_state.models[app_label, "pony"].fields["weight"]
self.assertIs(new_weight.default, models.NOT_PROVIDED)
self.assertIsInstance(new_weight.db_default, Pi)
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.weight, 1)
# Alter field.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
pony = new_state.apps.get_model(app_label, "pony").objects.create()
if not connection.features.can_return_columns_from_insert:
pony.refresh_from_db()
self.assertAlmostEqual(pony.weight, math.pi)
# Reversal.
with connection.schema_editor() as editor:
operation.database_backwards(app_label, editor, new_state, project_state)
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
self.assertEqual(pony.weight, 1)
def test_alter_field_change_nullable_to_database_default_not_null(self): def test_alter_field_change_nullable_to_database_default_not_null(self):
""" """
The AlterField operation changing a null field to db_default. The AlterField operation changing a null field to db_default.