1
0
mirror of https://github.com/django/django.git synced 2025-10-25 14:46:09 +00:00

Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL.

This commit is contained in:
David Wobrock
2023-04-18 10:19:06 +02:00
committed by Mariusz Felisiak
parent 594fcc2b74
commit 9bbf97bcdb
11 changed files with 117 additions and 13 deletions

View File

@@ -8,6 +8,7 @@ import sqlparse
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.encoding import force_str
@@ -776,3 +777,9 @@ class BaseDatabaseOperations:
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
return ""
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr = Col(lhs_table, lhs_field)
rhs_expr = Col(rhs_table, rhs_field)
return lhs_expr, rhs_expr

View File

@@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"migrations.test_operations.OperationTests."
"test_alter_field_pk_fk_db_collation",
},
"Oracle doesn't support comparing NCLOB to NUMBER.": {
"generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
},
}
django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843).

View File

@@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile
@@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
update_fields,
unique_fields,
)
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr, rhs_expr = super().prepare_join_on_clause(
lhs_table, lhs_field, rhs_table, rhs_field
)
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr

View File

@@ -785,6 +785,14 @@ class ForeignObject(RelatedField):
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
def get_joining_fields(self, reverse_join=False):
return tuple(
self.reverse_related_fields if reverse_join else self.related_fields
)
def get_reverse_joining_fields(self):
return self.get_joining_fields(reverse_join=True)
def get_extra_descriptor_filter(self, instance):
"""
Return an extra filter condition for related object fetching when

View File

@@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin):
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_joining_fields(self):
return self.field.get_reverse_joining_fields()
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)

View File

@@ -61,7 +61,15 @@ class Join:
self.join_type = join_type
# A list of 2-tuples to use in the ON clause of the JOIN.
# Each 2-tuple will create one join condition in the ON clause.
self.join_cols = join_field.get_joining_columns()
if hasattr(join_field, "get_joining_fields"):
self.join_fields = join_field.get_joining_fields()
self.join_cols = tuple(
(lhs_field.column, rhs_field.column)
for lhs_field, rhs_field in self.join_fields
)
else:
self.join_fields = None
self.join_cols = join_field.get_joining_columns()
# Along which field (or ForeignObjectRel in the reverse join case)
self.join_field = join_field
# Is this join nullabled?
@@ -78,18 +86,21 @@ class Join:
params = []
qn = compiler.quote_name_unless_alias
qn2 = connection.ops.quote_name
# Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols:
join_conditions.append(
"%s.%s = %s.%s"
% (
qn(self.parent_alias),
qn2(lhs_col),
qn(self.table_alias),
qn2(rhs_col),
join_fields = self.join_fields or self.join_cols
for lhs, rhs in join_fields:
if isinstance(lhs, str):
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
else:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, self.table_alias, rhs
)
)
lhs_sql, lhs_params = compiler.compile(lhs)
lhs_full_name = lhs_sql % lhs_params
rhs_sql, rhs_params = compiler.compile(rhs)
rhs_full_name = rhs_sql % rhs_params
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.