1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #25367 -- Moved Oracle Exists() handling to contextual methods.

Oracle requires the EXISTS expression to be wrapped in a CASE WHEN in
the following cases.

1. When part of a SELECT clause.
2. When part of a ORDER BY clause.
3. When compared against another expression in the WHERE clause.

This commit moves the systematic CASE WHEN wrapping of Exists.as_oracle
to contextual .select_format, Lookup.as_oracle, and OrderBy.as_oracle
methods in order to avoid unnecessary wrapping.
This commit is contained in:
Simon Charette 2019-08-11 02:12:11 -04:00 committed by Mariusz Felisiak
parent fff5186d32
commit efa1908f66
4 changed files with 40 additions and 7 deletions

View File

@ -288,6 +288,9 @@ class BaseDatabaseFeatures:
# field(s)?
allows_multiple_constraints_on_same_fields = True
# Does the backend support boolean expressions in the SELECT clause?
supports_boolean_expr_in_select_clause = True
def __init__(self, connection):
self.connection = connection

View File

@ -58,3 +58,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_partial_indexes = False
supports_slicing_ordering_in_compound = True
allows_multiple_constraints_on_same_fields = False
supports_boolean_expr_in_select_clause = False

View File

@ -1075,12 +1075,11 @@ class Exists(Subquery):
sql = 'NOT {}'.format(sql)
return sql, params
def as_oracle(self, compiler, connection, template=None, **extra_context):
# Oracle doesn't allow EXISTS() in the SELECT list, so wrap it with a
# CASE WHEN expression. Change the template since the When expression
# requires a left hand side (column) to compare against.
sql, params = self.as_sql(compiler, connection, template, **extra_context)
sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
def select_format(self, compiler, sql, params):
# Wrap EXISTS() with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in the SELECT list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
return sql, params
@ -1140,6 +1139,19 @@ class OrderBy(BaseExpression):
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
return self.as_sql(compiler, connection, template=template)
def as_oracle(self, compiler, connection):
# Oracle doesn't allow ORDER BY EXISTS() unless it's wrapped in
# a CASE WHEN.
if isinstance(self.expression, Exists):
copy = self.copy()
# XXX: Use Case(When(self.lhs)) once support for boolean
# expressions is added to When.
exists_sql, params = compiler.compile(self.expression)
case_sql = 'CASE WHEN %s THEN 1 ELSE 0 END' % exists_sql
copy.expression = RawSQL(case_sql, params)
return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection)
def get_group_by_cols(self, alias=None):
cols = []
for source in self.get_source_expressions():

View File

@ -3,7 +3,7 @@ import math
from copy import copy
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Func, Value
from django.db.models.expressions import Exists, Func, RawSQL, Value
from django.db.models.fields import DateTimeField, Field, IntegerField
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
@ -112,6 +112,23 @@ class Lookup:
def as_sql(self, compiler, connection):
raise NotImplementedError
def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() to be compared to another expression
# unless it's wrapped in a CASE WHEN.
wrapped = False
exprs = []
for expr in (self.lhs, self.rhs):
if isinstance(expr, Exists):
# XXX: Use Case(When(self.lhs)) once support for boolean
# expressions is added to When.
sql, params = compiler.compile(expr)
sql = 'CASE WHEN %s THEN 1 ELSE 0 END' % sql
expr = RawSQL(sql, params)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
return lookup.as_sql(compiler, connection)
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)