From 8914f4703cf03e2a01683c4ba00f5ae7d3fa449d Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sun, 29 Dec 2024 01:13:48 -0800 Subject: [PATCH] Fixed #35972 -- Fixed lookup crashes after subquery annotations. --- django/contrib/gis/db/models/lookups.py | 8 ++--- django/contrib/postgres/search.py | 10 +++--- django/db/backends/base/schema.py | 3 +- django/db/models/expressions.py | 4 +-- django/db/models/functions/comparison.py | 2 +- django/db/models/lookups.py | 12 ++++--- docs/ref/models/expressions.txt | 4 ++- docs/releases/6.0.txt | 18 ++++++++++ tests/custom_lookups/tests.py | 42 +++++++++++++++++++++--- tests/expressions/tests.py | 18 ++++++++++ tests/migrations/test_operations.py | 30 +++++++++++++++++ 11 files changed, 128 insertions(+), 23 deletions(-) diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 3d30ffed5c..b9e5e47b27 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -55,7 +55,7 @@ class GISLookup(Lookup): def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class - return ("%s", [connection.ops.Adapter(value)]) + return ("%s", (connection.ops.Adapter(value),)) def process_rhs(self, compiler, connection): if isinstance(self.rhs, Query): @@ -284,7 +284,7 @@ class RelateLookup(GISLookup): elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern): raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) sql, params = super().process_rhs(compiler, connection) - return sql, [*params, pattern] + return sql, (*params, pattern) @BaseSpatialField.register_lookup @@ -352,7 +352,7 @@ class DWithinLookup(DistanceLookupBase): dist_sql, dist_params = self.process_distance(compiler, connection) self.template_params["value"] = dist_sql rhs_sql, params = super().process_rhs(compiler, connection) - return rhs_sql, params + dist_params + return rhs_sql, (*params, *dist_params) class DistanceLookupFromFunction(DistanceLookupBase): @@ -367,7 +367,7 @@ class DistanceLookupFromFunction(DistanceLookupBase): dist_sql, dist_params = self.process_distance(compiler, connection) return ( "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql}, - params + dist_params, + (*params, *dist_params), ) diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 24f7a497da..4ab27605cb 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -27,7 +27,7 @@ class SearchVectorExact(Lookup): def as_sql(self, qn, connection): lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) - params = lhs_params + rhs_params + params = (*lhs_params, *rhs_params) return "%s @@ %s" % (lhs, rhs), params @@ -148,7 +148,7 @@ class SearchVector(SearchVectorCombinable, Func): weight_sql, extra_params = compiler.compile(clone.weight) sql = "setweight({}, {})".format(sql, weight_sql) - return sql, config_params + params + extra_params + return sql, (*config_params, *params, *extra_params) class CombinedSearchVector(SearchVectorCombinable, CombinedExpression): @@ -318,13 +318,13 @@ class SearchHeadline(Func): def as_sql(self, compiler, connection, function=None, template=None): options_sql = "" - options_params = [] + options_params = () if self.options: - options_params.append( + options_params = ( ", ".join( connection.ops.compose_sql(f"{option}=%s", [value]) for option, value in self.options.items() - ) + ), ) options_sql = ", %s" sql, params = super().as_sql( diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 5262864e7f..cc33740195 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -1,6 +1,7 @@ import logging import operator from datetime import datetime +from itertools import chain from django.conf import settings from django.core.exceptions import FieldError @@ -1160,7 +1161,7 @@ class BaseDatabaseSchemaEditor: # Combine actions together if we can (e.g. postgres) if self.connection.features.supports_combined_alters and actions: sql, params = tuple(zip(*actions)) - actions = [(", ".join(sql), sum(params, []))] + actions = [(", ".join(sql), tuple(chain(*params)))] # Apply those actions for sql, params in actions: self.execute( diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 012a7c346b..1168c7ddbd 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1127,7 +1127,7 @@ class Func(SQLiteNumericMixin, Expression): template = template or data.get("template", self.template) arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) data["expressions"] = data["field"] = arg_joiner.join(sql_parts) - return template % data, params + return template % data, tuple(params) def copy(self): copy = super().copy() @@ -1323,7 +1323,7 @@ class Col(Expression): alias, column = self.alias, self.target.column identifiers = (alias, column) if alias else (column,) sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) - return sql, [] + return sql, () def relabeled_clone(self, relabels): if self.alias is None: diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index 11af0c0750..8f7493f2fd 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -26,7 +26,7 @@ class Cast(Func): compiler, connection, template=template, **extra_context ) format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" - params.insert(0, format_string) + params = (format_string, *params) return sql, params elif db_type == "date": template = "date(%(expressions)s)" diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 51817710e9..65128732fd 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -101,7 +101,7 @@ class Lookup(Expression): return Value(self.lhs) def get_db_prep_lookup(self, value, connection): - return ("%s", [value]) + return ("%s", (value,)) def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs @@ -415,7 +415,7 @@ class IExact(BuiltinLookup): def process_rhs(self, qn, connection): rhs, params = super().process_rhs(qn, connection) if params: - params[0] = connection.ops.prep_for_iexact_query(params[0]) + params = (connection.ops.prep_for_iexact_query(params[0]), *params[1:]) return rhs, params @@ -603,8 +603,9 @@ class PatternLookup(BuiltinLookup): def process_rhs(self, qn, connection): rhs, params = super().process_rhs(qn, connection) if self.rhs_is_direct_value() and params and not self.bilateral_transforms: - params[0] = self.param_pattern % connection.ops.prep_for_like_query( - params[0] + params = ( + self.param_pattern % connection.ops.prep_for_like_query(params[0]), + *params[1:], ) return rhs, params @@ -686,8 +687,9 @@ class Regex(BuiltinLookup): else: lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) + params = (*lhs_params, *rhs_params) sql_template = connection.ops.regex_lookup(self.lookup_name) - return sql_template % (lhs, rhs), lhs_params + rhs_params + return sql_template % (lhs, rhs), params @Field.register_lookup diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 4e8bf266b5..9e0eb3ba65 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -1202,6 +1202,8 @@ calling the appropriate methods on the wrapped expression. :meth:`~django.db.models.query.QuerySet.reverse()` is called on a queryset. +.. _writing-your-own-query-expressions: + Writing your own Query Expressions ---------------------------------- @@ -1262,7 +1264,7 @@ Next, we write the method responsible for generating the SQL:: sql_params.extend(params) template = template or self.template data = {"expressions": ",".join(sql_expressions)} - return template % data, sql_params + return template % data, tuple(sql_params) def as_oracle(self, compiler, connection): diff --git a/docs/releases/6.0.txt b/docs/releases/6.0.txt index 1e45fe9aec..828e2d0354 100644 --- a/docs/releases/6.0.txt +++ b/docs/releases/6.0.txt @@ -427,6 +427,24 @@ Email significantly, closely examine any custom subclasses that rely on overriding undocumented, internal underscore methods. +Custom ORM expressions should return params as a tuple +------------------------------------------------------ + +Prior to Django 6.0, :doc:`custom lookups ` and +:ref:`custom expressions ` implementing the +``as_sql()`` method (and its supporting methods ``process_lhs()`` and +``process_rhs()``) were allowed to return a sequence of params in either a list +or a tuple. To address the interoperability problems that resulted, the second +return element of the ``as_sql()`` method should now be a tuple:: + + def as_sql(self, compiler, connection) -> tuple[str, tuple]: ... + +If your custom expressions support multiple versions of Django, you should +adjust any pre-processing of parameters to be resilient against either tuples +or lists. For instance, prefer unpacking like this:: + + params = (*lhs_params, *rhs_params) + Miscellaneous ------------- diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 6728c49afd..db985240d7 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -5,6 +5,7 @@ from datetime import date, datetime from django.core.exceptions import FieldError from django.db import connection, models from django.db.models.fields.related_lookups import RelatedGreaterThan +from django.db.models.functions import Lower from django.db.models.lookups import EndsWith, StartsWith from django.test import SimpleTestCase, TestCase, override_settings from django.test.utils import register_lookup @@ -17,15 +18,15 @@ class Div3Lookup(models.Lookup): lookup_name = "div3" def as_sql(self, compiler, connection): - lhs, params = self.process_lhs(compiler, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) - params.extend(rhs_params) + params = (*lhs_params, *rhs_params) return "(%s) %%%% 3 = %s" % (lhs, rhs), params def as_oracle(self, compiler, connection): - lhs, params = self.process_lhs(compiler, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) - params.extend(rhs_params) + params = (*lhs_params, *rhs_params) return "mod(%s, 3) = %s" % (lhs, rhs), params @@ -249,6 +250,39 @@ class LookupTests(TestCase): self.assertSequenceEqual(qs1, [a1]) self.assertSequenceEqual(qs2, [a1]) + def test_custom_lookup_with_subquery(self): + class NotEqual(models.Lookup): + lookup_name = "ne" + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + # Although combining via (*lhs_params, *rhs_params) would be + # more resilient, the "simple" way works too. + params = lhs_params + rhs_params + return "%s <> %s" % (lhs, rhs), params + + author = Author.objects.create(name="Isabella") + + with register_lookup(models.Field, NotEqual): + qs = Author.objects.annotate( + unknown_age=models.Subquery( + Author.objects.filter(age__isnull=True) + .order_by("name") + .values("name")[:1] + ) + ).filter(unknown_age__ne="Plato") + self.assertSequenceEqual(qs, [author]) + + qs = Author.objects.annotate( + unknown_age=Lower( + Author.objects.filter(age__isnull=True) + .order_by("name") + .values("name")[:1] + ) + ).filter(unknown_age__ne="plato") + self.assertSequenceEqual(qs, [author]) + def test_custom_exact_lookup_none_rhs(self): """ __exact=None is transformed to __isnull=True if a custom lookup class diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 5f61f65ac0..ac06c212ea 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -749,6 +749,24 @@ class BasicExpressionsTests(TestCase): ) self.assertCountEqual(subquery_test2, [self.foobar_ltd]) + def test_lookups_subquery(self): + smallest_company = Company.objects.order_by("num_employees").values("name")[:1] + for lookup in CharField.get_lookups(): + if lookup == "isnull": + continue # not allowed, rhs must be a literal boolean. + if ( + lookup == "in" + and not connection.features.allow_sliced_subqueries_with_in + ): + continue + if lookup == "range": + rhs = (Subquery(smallest_company), Subquery(smallest_company)) + else: + rhs = Subquery(smallest_company) + with self.subTest(lookup=lookup): + qs = Company.objects.filter(**{f"name__{lookup}": rhs}) + self.assertGreater(len(qs), 0) + def test_uuid_pk_subquery(self): u = UUIDPK.objects.create() UUID.objects.create(uuid_fk=u) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 0bf16cb481..ec4b772c13 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -2340,6 +2340,36 @@ class OperationTests(OperationTestBase): pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) self.assertEqual(pony.pink, 3) + @skipUnlessDBFeature("supports_expression_defaults") + def test_alter_field_add_database_default_func(self): + app_label = "test_alfladdf" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "weight", models.FloatField(db_default=Pi()) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_weight = project_state.models[app_label, "pony"].fields["weight"] + self.assertIs(old_weight.default, models.NOT_PROVIDED) + self.assertIs(old_weight.db_default, models.NOT_PROVIDED) + new_weight = new_state.models[app_label, "pony"].fields["weight"] + self.assertIs(new_weight.default, models.NOT_PROVIDED) + self.assertIsInstance(new_weight.db_default, Pi) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.weight, 1) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + pony = new_state.apps.get_model(app_label, "pony").objects.create() + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertAlmostEqual(pony.weight, math.pi) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.weight, 1) + def test_alter_field_change_nullable_to_database_default_not_null(self): """ The AlterField operation changing a null field to db_default.