diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 70878f92a2..0009de7a60 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -288,6 +288,9 @@ class BaseDatabaseFeatures: # field(s)? allows_multiple_constraints_on_same_fields = True + # Does the backend support boolean expressions in the SELECT clause? + supports_boolean_expr_in_select_clause = True + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 2da11355ed..73a6e86686 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -58,3 +58,4 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_partial_indexes = False supports_slicing_ordering_in_compound = True allows_multiple_constraints_on_same_fields = False + supports_boolean_expr_in_select_clause = False diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 16924be9f6..5f85b47423 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1075,12 +1075,11 @@ class Exists(Subquery): sql = 'NOT {}'.format(sql) return sql, params - def as_oracle(self, compiler, connection, template=None, **extra_context): - # Oracle doesn't allow EXISTS() in the SELECT list, so wrap it with a - # CASE WHEN expression. Change the template since the When expression - # requires a left hand side (column) to compare against. - sql, params = self.as_sql(compiler, connection, template, **extra_context) - sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql) + def select_format(self, compiler, sql, params): + # Wrap EXISTS() with a CASE WHEN expression if a database backend + # (e.g. Oracle) doesn't support boolean expression in the SELECT list. + if not compiler.connection.features.supports_boolean_expr_in_select_clause: + sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql) return sql, params @@ -1140,6 +1139,19 @@ class OrderBy(BaseExpression): template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s ' return self.as_sql(compiler, connection, template=template) + def as_oracle(self, compiler, connection): + # Oracle doesn't allow ORDER BY EXISTS() unless it's wrapped in + # a CASE WHEN. + if isinstance(self.expression, Exists): + copy = self.copy() + # XXX: Use Case(When(self.lhs)) once support for boolean + # expressions is added to When. + exists_sql, params = compiler.compile(self.expression) + case_sql = 'CASE WHEN %s THEN 1 ELSE 0 END' % exists_sql + copy.expression = RawSQL(case_sql, params) + return copy.as_sql(compiler, connection) + return self.as_sql(compiler, connection) + def get_group_by_cols(self, alias=None): cols = [] for source in self.get_source_expressions(): diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 186e217083..f76c1e391b 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -3,7 +3,7 @@ import math from copy import copy from django.core.exceptions import EmptyResultSet -from django.db.models.expressions import Func, Value +from django.db.models.expressions import Exists, Func, RawSQL, Value from django.db.models.fields import DateTimeField, Field, IntegerField from django.db.models.query_utils import RegisterLookupMixin from django.utils.datastructures import OrderedSet @@ -112,6 +112,23 @@ class Lookup: def as_sql(self, compiler, connection): raise NotImplementedError + def as_oracle(self, compiler, connection): + # Oracle doesn't allow EXISTS() to be compared to another expression + # unless it's wrapped in a CASE WHEN. + wrapped = False + exprs = [] + for expr in (self.lhs, self.rhs): + if isinstance(expr, Exists): + # XXX: Use Case(When(self.lhs)) once support for boolean + # expressions is added to When. + sql, params = compiler.compile(expr) + sql = 'CASE WHEN %s THEN 1 ELSE 0 END' % sql + expr = RawSQL(sql, params) + wrapped = True + exprs.append(expr) + lookup = type(self)(*exprs) if wrapped else self + return lookup.as_sql(compiler, connection) + @cached_property def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)