diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index eeb042b361..20268ee660 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -534,7 +534,7 @@ class Func(Expression): c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save) return c - def as_sql(self, compiler, connection, function=None, template=None): + def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context): connection.ops.check_expression_support(self) sql_parts = [] params = [] @@ -542,13 +542,19 @@ class Func(Expression): arg_sql, arg_params = compiler.compile(arg) sql_parts.append(arg_sql) params.extend(arg_params) - if function is None: - self.extra['function'] = self.extra.get('function', self.function) + data = self.extra.copy() + data.update(**extra_context) + # Use the first supplied value in this order: the parameter to this + # method, a value supplied in __init__()'s **extra (the value in + # `data`), or the value defined on the class. + if function is not None: + data['function'] = function else: - self.extra['function'] = function - self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts) - template = template or self.extra.get('template', self.template) - return template % self.extra, params + data.setdefault('function', self.function) + template = template or data.get('template', self.template) + arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner) + data['expressions'] = data['field'] = arg_joiner.join(sql_parts) + return template % data, params def as_sqlite(self, compiler, connection): sql, params = self.as_sql(compiler, connection) @@ -778,9 +784,9 @@ class When(Expression): c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save) return c - def as_sql(self, compiler, connection, template=None): + def as_sql(self, compiler, connection, template=None, **extra_context): connection.ops.check_expression_support(self) - template_params = {} + template_params = extra_context sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) template_params['condition'] = condition_sql @@ -822,6 +828,7 @@ class Case(Expression): super(Case, self).__init__(output_field) self.cases = list(cases) self.default = self._parse_expressions(default)[0] + self.extra = extra def __str__(self): return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) @@ -849,22 +856,24 @@ class Case(Expression): c.cases = c.cases[:] return c - def as_sql(self, compiler, connection, template=None, extra=None): + def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context): connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) - template_params = dict(extra) if extra else {} + template_params = self.extra.copy() + template_params.update(extra_context) case_parts = [] sql_params = [] for case in self.cases: case_sql, case_params = compiler.compile(case) case_parts.append(case_sql) sql_params.extend(case_params) - template_params['cases'] = self.case_joiner.join(case_parts) + case_joiner = case_joiner or self.case_joiner + template_params['cases'] = case_joiner.join(case_parts) default_sql, default_params = compiler.compile(self.default) template_params['default'] = default_sql sql_params.extend(default_params) - template = template or self.template + template = template or template_params.get('template', self.template) sql = template % template_params if self._output_field_or_none is not None: sql = connection.ops.unification_cast_sql(self.output_field) % sql @@ -995,14 +1004,16 @@ class OrderBy(BaseExpression): def get_source_expressions(self): return [self.expression] - def as_sql(self, compiler, connection): + def as_sql(self, compiler, connection, template=None, **extra_context): connection.ops.check_expression_support(self) expression_sql, params = compiler.compile(self.expression) placeholders = { 'expression': expression_sql, 'ordering': 'DESC' if self.descending else 'ASC', } - return (self.template % placeholders).rstrip(), params + placeholders.update(extra_context) + template = template or self.template + return (template % placeholders).rstrip(), params def get_group_by_cols(self): cols = [] diff --git a/django/db/models/functions.py b/django/db/models/functions.py index 8aecc99633..021535f017 100644 --- a/django/db/models/functions.py +++ b/django/db/models/functions.py @@ -43,9 +43,8 @@ class ConcatPair(Func): def as_sqlite(self, compiler, connection): coalesced = self.coalesce() - coalesced.arg_joiner = ' || ' return super(ConcatPair, coalesced).as_sql( - compiler, connection, template='%(expressions)s', + compiler, connection, template='%(expressions)s', arg_joiner=' || ' ) def as_mysql(self, compiler, connection): diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 8dc30a91c4..d7c4105c4f 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -261,12 +261,13 @@ The ``Func`` API is as follows: different number of expressions, ``TypeError`` will be raised. Defaults to ``None``. - .. method:: as_sql(compiler, connection, function=None, template=None) + .. method:: as_sql(compiler, connection, function=None, template=None, arg_joiner=None, **extra_context) Generates the SQL for the database function. - The ``as_vendor()`` methods should use the ``function`` and - ``template`` parameters to customize the SQL as needed. For example: + The ``as_vendor()`` methods should use the ``function``, ``template``, + ``arg_joiner``, and any other ``**extra_context`` parameters to + customize the SQL as needed. For example: .. snippet:: :filename: django/db/models/functions.py @@ -283,6 +284,11 @@ The ``Func`` API is as follows: template="%(function)s('', %(expressions)s)", ) + .. versionchanged:: 1.10 + + Support for the ``arg_joiner`` and ``**extra_context`` parameters + was added. + The ``*expressions`` argument is a list of positional expressions that the function will be applied to. The expressions will be converted to strings, joined together with ``arg_joiner``, and then interpolated into the ``template`` @@ -293,10 +299,10 @@ assumed to be column references and will be wrapped in ``F()`` expressions while other values will be wrapped in ``Value()`` expressions. The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated -into the ``template`` attribute. Note that the keywords ``function`` and -``template`` can be used to replace the ``function`` and ``template`` -attributes respectively, without having to define your own class. -``output_field`` can be used to define the expected return type. +into the ``template`` attribute. The ``function``, ``template``, and +``arg_joiner`` keywords can be used to replace the attributes of the same name +without having to define your own class. ``output_field`` can be used to define +the expected return type. ``Aggregate()`` expressions --------------------------- diff --git a/docs/releases/1.10.txt b/docs/releases/1.10.txt index ebce5b06e5..48a262362e 100644 --- a/docs/releases/1.10.txt +++ b/docs/releases/1.10.txt @@ -212,6 +212,13 @@ Database backends ``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys on objects created using ``QuerySet.bulk_create()``. +* Added keyword arguments to the ``as_sql()`` methods of various expressions + (``Func``, ``When``, ``Case``, and ``OrderBy``) to allow database backends to + customize them without mutating ``self``, which isn't safe when using + different database backends. See the ``arg_joiner`` and ``**extra_context`` + parameters of :meth:`Func.as_sql() ` for an + example. + Email ~~~~~