From 9bbf97bcdb488bb11aebb5bd405549fbec6852cd Mon Sep 17 00:00:00 2001 From: David Wobrock Date: Tue, 18 Apr 2023 10:19:06 +0200 Subject: [PATCH] Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL. --- django/db/backends/base/operations.py | 7 +++++ django/db/backends/oracle/features.py | 3 ++ django/db/backends/postgresql/operations.py | 11 +++++++ django/db/models/fields/related.py | 8 +++++ django/db/models/fields/reverse_related.py | 3 ++ django/db/models/sql/datastructures.py | 33 +++++++++++++------- tests/backends/base/test_operations.py | 15 +++++++++ tests/backends/postgresql/test_operations.py | 32 ++++++++++++++++++- tests/foreign_object/models/empty_join.py | 2 +- tests/generic_relations_regress/models.py | 2 ++ tests/generic_relations_regress/tests.py | 14 +++++++++ 11 files changed, 117 insertions(+), 13 deletions(-) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index d2bc336dd8..6f10e31cd5 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -8,6 +8,7 @@ import sqlparse from django.conf import settings from django.db import NotSupportedError, transaction from django.db.backends import utils +from django.db.models.expressions import Col from django.utils import timezone from django.utils.encoding import force_str @@ -776,3 +777,9 @@ class BaseDatabaseOperations: def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): return "" + + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr = Col(lhs_table, lhs_field) + rhs_expr = Col(rhs_table, rhs_field) + + return lhs_expr, rhs_expr diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 3d77a615c8..05dc552a98 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): "migrations.test_operations.OperationTests." "test_alter_field_pk_fk_db_collation", }, + "Oracle doesn't support comparing NCLOB to NUMBER.": { + "generic_relations_regress.tests.GenericRelationTests.test_textlink_filter", + }, } django_test_expected_failures = { # A bug in Django/cx_Oracle with respect to string handling (#23843). diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 18cfcb29cb..aa839f5634 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import ( ) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.db.models.functions import Cast from django.utils.regex_helper import _lazy_re_compile @@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations): update_fields, unique_fields, ) + + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr, rhs_expr = super().prepare_join_on_clause( + lhs_table, lhs_field, rhs_table, rhs_field + ) + + if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + rhs_expr = Cast(rhs_expr, lhs_field) + + return lhs_expr, rhs_expr diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 0efbe53a0b..7a49861164 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -785,6 +785,14 @@ class ForeignObject(RelatedField): def get_reverse_joining_columns(self): return self.get_joining_columns(reverse_join=True) + def get_joining_fields(self, reverse_join=False): + return tuple( + self.reverse_related_fields if reverse_join else self.related_fields + ) + + def get_reverse_joining_fields(self): + return self.get_joining_fields(reverse_join=True) + def get_extra_descriptor_filter(self, instance): """ Return an extra filter condition for related object fetching when diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py index b7d82f6258..f3da8f8bf2 100644 --- a/django/db/models/fields/reverse_related.py +++ b/django/db/models/fields/reverse_related.py @@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin): def get_joining_columns(self): return self.field.get_reverse_joining_columns() + def get_joining_fields(self): + return self.field.get_reverse_joining_fields() + def get_extra_restriction(self, alias, related_alias): return self.field.get_extra_restriction(related_alias, alias) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 069eb1a301..46a977188a 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -61,7 +61,15 @@ class Join: self.join_type = join_type # A list of 2-tuples to use in the ON clause of the JOIN. # Each 2-tuple will create one join condition in the ON clause. - self.join_cols = join_field.get_joining_columns() + if hasattr(join_field, "get_joining_fields"): + self.join_fields = join_field.get_joining_fields() + self.join_cols = tuple( + (lhs_field.column, rhs_field.column) + for lhs_field, rhs_field in self.join_fields + ) + else: + self.join_fields = None + self.join_cols = join_field.get_joining_columns() # Along which field (or ForeignObjectRel in the reverse join case) self.join_field = join_field # Is this join nullabled? @@ -78,18 +86,21 @@ class Join: params = [] qn = compiler.quote_name_unless_alias qn2 = connection.ops.quote_name - # Add a join condition for each pair of joining columns. - for lhs_col, rhs_col in self.join_cols: - join_conditions.append( - "%s.%s = %s.%s" - % ( - qn(self.parent_alias), - qn2(lhs_col), - qn(self.table_alias), - qn2(rhs_col), + join_fields = self.join_fields or self.join_cols + for lhs, rhs in join_fields: + if isinstance(lhs, str): + lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs)) + rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs)) + else: + lhs, rhs = connection.ops.prepare_join_on_clause( + self.parent_alias, lhs, self.table_alias, rhs ) - ) + lhs_sql, lhs_params = compiler.compile(lhs) + lhs_full_name = lhs_sql % lhs_params + rhs_sql, rhs_params = compiler.compile(rhs) + rhs_full_name = rhs_sql % rhs_params + join_conditions.append(f"{lhs_full_name} = {rhs_full_name}") # Add a single condition inside parentheses for whatever # get_extra_restriction() returns. diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index 5260344da7..9d2828c8ce 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -4,6 +4,7 @@ from django.core.management.color import no_style from django.db import NotSupportedError, connection, transaction from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import DurationField, Value +from django.db.models.expressions import Col from django.test import ( SimpleTestCase, TestCase, @@ -159,6 +160,20 @@ class SimpleDatabaseOperationTests(SimpleTestCase): ): self.ops.datetime_extract_sql(None, None, None, None) + def test_prepare_join_on_clause(self): + author_table = Author._meta.db_table + author_id_field = Author._meta.get_field("id") + book_table = Book._meta.db_table + book_fk_field = Book._meta.get_field("author") + lhs_expr, rhs_expr = self.ops.prepare_join_on_clause( + author_table, + author_id_field, + book_table, + book_fk_field, + ) + self.assertEqual(lhs_expr, Col(author_table, author_id_field)) + self.assertEqual(rhs_expr, Col(book_table, book_fk_field)) + class DatabaseOperationTests(TestCase): def setUp(self): diff --git a/tests/backends/postgresql/test_operations.py b/tests/backends/postgresql/test_operations.py index c2f2417923..632928ff87 100644 --- a/tests/backends/postgresql/test_operations.py +++ b/tests/backends/postgresql/test_operations.py @@ -2,9 +2,11 @@ import unittest from django.core.management.color import no_style from django.db import connection +from django.db.models.expressions import Col +from django.db.models.functions import Cast from django.test import SimpleTestCase -from ..models import Person, Tag +from ..models import Author, Book, Person, Tag @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.") @@ -48,3 +50,31 @@ class PostgreSQLOperationsTests(SimpleTestCase): ), ['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'], ) + + def test_prepare_join_on_clause_same_type(self): + author_table = Author._meta.db_table + author_id_field = Author._meta.get_field("id") + lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause( + author_table, + author_id_field, + author_table, + author_id_field, + ) + self.assertEqual(lhs_expr, Col(author_table, author_id_field)) + self.assertEqual(rhs_expr, Col(author_table, author_id_field)) + + def test_prepare_join_on_clause_different_types(self): + author_table = Author._meta.db_table + author_id_field = Author._meta.get_field("id") + book_table = Book._meta.db_table + book_fk_field = Book._meta.get_field("author") + lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause( + author_table, + author_id_field, + book_table, + book_fk_field, + ) + self.assertEqual(lhs_expr, Col(author_table, author_id_field)) + self.assertEqual( + rhs_expr, Cast(Col(book_table, book_fk_field), author_id_field) + ) diff --git a/tests/foreign_object/models/empty_join.py b/tests/foreign_object/models/empty_join.py index 9c0ada378c..4c3839dcc1 100644 --- a/tests/foreign_object/models/empty_join.py +++ b/tests/foreign_object/models/empty_join.py @@ -50,7 +50,7 @@ class StartsWithRelation(models.ForeignObject): from_field = self.model._meta.get_field(self.from_fields[0]) return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias)) - def get_joining_columns(self, reverse_join=False): + def get_joining_fields(self, reverse_join=False): return () def get_path_info(self, filtered_relation=None): diff --git a/tests/generic_relations_regress/models.py b/tests/generic_relations_regress/models.py index dc55b2a83b..6867747a26 100644 --- a/tests/generic_relations_regress/models.py +++ b/tests/generic_relations_regress/models.py @@ -64,12 +64,14 @@ class CharLink(models.Model): content_type = models.ForeignKey(ContentType, models.CASCADE) object_id = models.CharField(max_length=100) content_object = GenericForeignKey() + value = models.CharField(max_length=250) class TextLink(models.Model): content_type = models.ForeignKey(ContentType, models.CASCADE) object_id = models.TextField() content_object = GenericForeignKey() + value = models.CharField(max_length=250) class OddRelation1(models.Model): diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py index 9b2f21b88b..b7ecb499eb 100644 --- a/tests/generic_relations_regress/tests.py +++ b/tests/generic_relations_regress/tests.py @@ -72,6 +72,20 @@ class GenericRelationTests(TestCase): TextLink.objects.create(content_object=oddrel) oddrel.delete() + def test_charlink_filter(self): + oddrel = OddRelation1.objects.create(name="clink") + CharLink.objects.create(content_object=oddrel, value="value") + self.assertSequenceEqual( + OddRelation1.objects.filter(clinks__value="value"), [oddrel] + ) + + def test_textlink_filter(self): + oddrel = OddRelation2.objects.create(name="clink") + TextLink.objects.create(content_object=oddrel, value="value") + self.assertSequenceEqual( + OddRelation2.objects.filter(tlinks__value="value"), [oddrel] + ) + def test_coerce_object_id_remote_field_cache_persistence(self): restaurant = Restaurant.objects.create() CharLink.objects.create(content_object=restaurant)