From 65246de7b1d70d25831ab394c4f4a75813f629fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Modzelewski?= Date: Fri, 2 Jan 2015 02:39:31 +0100 Subject: [PATCH] Fixed #24031 -- Added CASE expressions to the ORM. --- AUTHORS | 1 + django/db/backends/__init__.py | 8 + .../postgresql_psycopg2/operations.py | 12 + django/db/models/__init__.py | 2 +- django/db/models/aggregates.py | 4 +- django/db/models/expressions.py | 160 ++- django/db/models/query_utils.py | 21 + django/db/models/sql/compiler.py | 4 +- django/db/models/sql/query.py | 26 +- django/db/models/sql/where.py | 25 + docs/index.txt | 1 + docs/ref/models/conditional-expressions.txt | 212 ++++ docs/ref/models/expressions.txt | 9 + docs/ref/models/index.txt | 1 + docs/releases/1.8.txt | 10 +- tests/expressions/models.py | 16 + tests/expressions/tests.py | 17 +- tests/expressions_case/__init__.py | 0 tests/expressions_case/models.py | 80 ++ tests/expressions_case/tests.py | 1083 +++++++++++++++++ 20 files changed, 1659 insertions(+), 33 deletions(-) create mode 100644 docs/ref/models/conditional-expressions.txt create mode 100644 tests/expressions_case/__init__.py create mode 100644 tests/expressions_case/models.py create mode 100644 tests/expressions_case/tests.py diff --git a/AUTHORS b/AUTHORS index a151f530e9..b4863849dc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -475,6 +475,7 @@ answer newbie questions, and generally made Django that much better: Michael Thornhill Michal Chruszcz michal@plovarna.cz + Michał Modzelewski Mihai Damian Mihai Preda Mikaël Barbero diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 1c4904a893..fdb4d35d04 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -821,6 +821,14 @@ class BaseDatabaseOperations(object): """ return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s" + def unification_cast_sql(self, output_field): + """ + Given a field instance, returns the SQL necessary to cast the result of + a union to that type. Note that the resulting string should contain a + '%s' placeholder for the expression being cast. + """ + return '%s' + def date_extract_sql(self, lookup_type, field_name): """ Given a lookup_type of 'year', 'month' or 'day', returns the SQL that diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index 9d5b10d01a..31cbe5919f 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -5,6 +5,18 @@ from django.db.backends import BaseDatabaseOperations class DatabaseOperations(BaseDatabaseOperations): + def unification_cast_sql(self, output_field): + internal_type = output_field.get_internal_type() + if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"): + # PostgreSQL will resolve a union as type 'text' if input types are + # 'unknown'. + # http://www.postgresql.org/docs/9.4/static/typeconv-union-case.html + # These fields cannot be implicitly cast back in the default + # PostgreSQL configuration so we need to explicitly cast them. + # We must also remove components of the type within brackets: + # varchar(255) -> varchar. + return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0] + return '%s' def date_extract_sql(self, lookup_type, field_name): # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 6fe8fe8aae..9348529625 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -4,7 +4,7 @@ import warnings from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA from django.db.models.query import Q, QuerySet, Prefetch # NOQA -from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA +from django.db.models.expressions import ExpressionNode, F, Value, Func, Case, When # NOQA from django.db.models.manager import Manager # NOQA from django.db.models.base import Model # NOQA from django.db.models.aggregates import * # NOQA diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 4b57890cf8..06220123ca 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -14,8 +14,9 @@ class Aggregate(Func): contains_aggregate = True name = None - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): assert len(self.source_expressions) == 1 + # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize) if c.source_expressions[0].contains_aggregate and not summarize: name = self.source_expressions[0].name @@ -101,7 +102,6 @@ class Count(Aggregate): def __init__(self, expression, distinct=False, **extra): if expression == '*': expression = Value(expression) - expression._output_field = IntegerField() super(Count, self).__init__( expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 0f55545961..97a2a9071d 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -6,7 +6,7 @@ from django.core.exceptions import FieldError from django.db.backends import utils as backend_utils from django.db.models import fields from django.db.models.constants import LOOKUP_SEP -from django.db.models.query_utils import refs_aggregate +from django.db.models.query_utils import refs_aggregate, Q from django.utils import timezone from django.utils.functional import cached_property @@ -173,7 +173,7 @@ class BaseExpression(object): return True return False - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): """ Provides the chance to do any preprocessing or validation before being added to the query. @@ -380,11 +380,11 @@ class Expression(ExpressionNode): sql = connection.ops.combine_expression(self.connector, expressions) return expression_wrapper % sql, expression_params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = self.copy() c.is_summary = summarize - c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize) - c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize) + c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) return c @@ -426,7 +426,7 @@ class F(CombinableMixin): """ self.name = name - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): return query.resolve_ref(self.name, allow_joins, reuse, summarize) def refs_aggregate(self, existing_aggregates): @@ -465,11 +465,11 @@ class Func(ExpressionNode): for arg in expressions ] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = self.copy() c.is_summary = summarize for pos, arg in enumerate(c.source_expressions): - c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize) + c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save) return c def as_sql(self, compiler, connection, function=None, template=None): @@ -511,12 +511,24 @@ class Value(ExpressionNode): self.value = value def as_sql(self, compiler, connection): - if self.value is None: + val = self.value + # check _output_field to avoid triggering an exception + if self._output_field is not None: + if self.for_save: + val = self.output_field.get_db_prep_save(val, connection=connection) + else: + val = self.output_field.get_db_prep_value(val, connection=connection) + if val is None: # cx_Oracle does not always convert None to the appropriate # NULL type (like in case expressions using numbers), so we # use a literal SQL NULL return 'NULL', [] - return '%s', [self.value] + return '%s', [val] + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + c = super(Value, self).resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.for_save = for_save + return c def get_group_by_cols(self): return [] @@ -599,6 +611,130 @@ class Ref(ExpressionNode): return [self] +class When(ExpressionNode): + template = 'WHEN %(condition)s THEN %(result)s' + + def __init__(self, condition=None, then=Value(None), **lookups): + if lookups and condition is None: + condition, lookups = Q(**lookups), None + if condition is None or not isinstance(condition, Q) or lookups: + raise TypeError("__init__() takes either a Q object or lookups as keyword arguments") + super(When, self).__init__(output_field=None) + self.condition = condition + self.result = self._parse_expression(then) + + def __str__(self): + return "WHEN %r THEN %r" % (self.condition, self.result) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self) + + def get_source_expressions(self): + return [self.condition, self.result] + + def set_source_expressions(self, exprs): + self.condition, self.result = exprs + + def get_source_fields(self): + # We're only interested in the fields of the result expressions. + return [self.result._output_field_or_none] + + def _parse_expression(self, expression): + return expression if hasattr(expression, 'resolve_expression') else F(expression) + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + c = self.copy() + c.is_summary = summarize + c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False) + c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save) + return c + + def as_sql(self, compiler, connection, template=None): + template_params = {} + sql_params = [] + condition_sql, condition_params = compiler.compile(self.condition) + template_params['condition'] = condition_sql + sql_params.extend(condition_params) + result_sql, result_params = compiler.compile(self.result) + template_params['result'] = result_sql + sql_params.extend(result_params) + template = template or self.template + return template % template_params, sql_params + + def get_group_by_cols(self): + # This is not a complete expression and cannot be used in GROUP BY. + cols = [] + for source in self.get_source_expressions(): + cols.extend(source.get_group_by_cols()) + return cols + + +class Case(ExpressionNode): + """ + An SQL searched CASE expression: + + CASE + WHEN n > 0 + THEN 'positive' + WHEN n < 0 + THEN 'negative' + ELSE 'zero' + END + """ + template = 'CASE %(cases)s ELSE %(default)s END' + case_joiner = ' ' + + def __init__(self, *cases, **extra): + if not all(isinstance(case, When) for case in cases): + raise TypeError("Positional arguments must all be When objects.") + default = extra.pop('default', Value(None)) + output_field = extra.pop('output_field', None) + super(Case, self).__init__(output_field) + self.cases = list(cases) + self.default = default if hasattr(default, 'resolve_expression') else F(default) + + def __str__(self): + return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self) + + def get_source_expressions(self): + return self.cases + [self.default] + + def set_source_expressions(self, exprs): + self.cases = exprs[:-1] + self.default = exprs[-1] + + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + c = self.copy() + c.is_summary = summarize + for pos, case in enumerate(c.cases): + c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save) + return c + + def as_sql(self, compiler, connection, template=None, extra=None): + if not self.cases: + return compiler.compile(self.default) + template_params = dict(extra) if extra else {} + case_parts = [] + sql_params = [] + for case in self.cases: + case_sql, case_params = compiler.compile(case) + case_parts.append(case_sql) + sql_params.extend(case_params) + template_params['cases'] = self.case_joiner.join(case_parts) + default_sql, default_params = compiler.compile(self.default) + template_params['default'] = default_sql + sql_params.extend(default_params) + template = template or self.template + sql = template % template_params + if self._output_field_or_none is not None: + sql = connection.ops.unification_cast_sql(self.output_field) % sql + return sql, sql_params + + class Date(ExpressionNode): """ Add a date selection column. @@ -615,7 +751,7 @@ class Date(ExpressionNode): def set_source_expressions(self, exprs): self.col, = exprs - def resolve_expression(self, query, allow_joins, reuse, summarize): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): copy = self.copy() copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) field = copy.col.output_field @@ -664,7 +800,7 @@ class DateTime(ExpressionNode): def set_source_expressions(self, exprs): self.col, = exprs - def resolve_expression(self, query, allow_joins, reuse, summarize): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): copy = self.copy() copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) field = copy.col.output_field diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 3eae98ee65..9e4ddfec58 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -86,6 +86,27 @@ class Q(tree.Node): clone.children.append(child) return clone + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + clause, _ = query._add_q(self, reuse, allow_joins=allow_joins) + return clause + + def refs_aggregate(self, existing_aggregates): + def _refs_aggregate(obj, existing_aggregates): + if not isinstance(obj, tree.Node): + aggregate, aggregate_lookups = refs_aggregate(obj[0].split(LOOKUP_SEP), existing_aggregates) + if not aggregate and hasattr(obj[1], 'refs_aggregate'): + return obj[1].refs_aggregate(existing_aggregates) + return aggregate, aggregate_lookups + for c in obj.children: + aggregate, aggregate_lookups = _refs_aggregate(c, existing_aggregates) + if aggregate: + return aggregate, aggregate_lookups + return False, () + + if not existing_aggregates: + return False + return _refs_aggregate(self, existing_aggregates) + class DeferredAttribute(object): """ diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b091571835..ebda3be96f 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -998,7 +998,9 @@ class SQLUpdateCompiler(SQLCompiler): values, update_params = [], [] for field, model, val in self.query.values: if hasattr(val, 'resolve_expression'): - val = val.resolve_expression(self.query, allow_joins=False) + val = val.resolve_expression(self.query, allow_joins=False, for_save=True) + if val.contains_aggregate: + raise FieldError("Aggregate functions are not allowed in this query") elif hasattr(val, 'prepare_database_save'): if field.rel: val = val.prepare_database_save(field) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 5d4538b2d9..c5e7eab28c 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -961,7 +961,7 @@ class Query(object): self.append_annotation_mask([alias]) self.annotations[alias] = annotation - def prepare_lookup_value(self, value, lookups, can_reuse): + def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True): # Default lookup if none given is exact. used_joins = [] if len(lookups) == 0: @@ -980,7 +980,7 @@ class Query(object): value = value() elif hasattr(value, 'resolve_expression'): pre_joins = self.alias_refcount.copy() - value = value.resolve_expression(self, reuse=can_reuse) + value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] # Subqueries need to use a different set of aliases than the # outer query. Call bump_prefix to change aliases of the inner @@ -1095,7 +1095,7 @@ class Query(object): (name, lhs.output_field.__class__.__name__)) def build_filter(self, filter_expr, branch_negated=False, current_negated=False, - can_reuse=None, connector=AND): + can_reuse=None, connector=AND, allow_joins=True): """ Builds a WhereNode for a single filter clause, but doesn't add it to this Query. Query.add_q() will then add this filter to the where @@ -1125,10 +1125,12 @@ class Query(object): if not arg: raise FieldError("Cannot parse keyword query %r" % arg) lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) + if not allow_joins and len(parts) > 1: + raise FieldError("Joined field references are not permitted in this query") # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse) + value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins) clause = self.where_class() if reffed_aggregate: @@ -1225,11 +1227,11 @@ class Query(object): """ if not self._annotations: return False - if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] - or (hasattr(obj[1], 'refs_aggregate') - and obj[1].refs_aggregate(self.annotations)[0])) - return any(self.need_having(c) for c in obj.children) + if hasattr(obj, 'refs_aggregate'): + return obj.refs_aggregate(self.annotations)[0] + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] + or (hasattr(obj[1], 'refs_aggregate') + and obj[1].refs_aggregate(self.annotations)[0])) def split_having_parts(self, q_object, negated=False): """ @@ -1287,7 +1289,7 @@ class Query(object): self.demote_joins(existing_inner) def _add_q(self, q_object, used_aliases, branch_negated=False, - current_negated=False): + current_negated=False, allow_joins=True): """ Adds a Q-object to the current filter. """ @@ -1301,12 +1303,12 @@ class Query(object): if isinstance(child, Node): child_clause, needed_inner = self._add_q( child, used_aliases, branch_negated, - current_negated) + current_negated, allow_joins) joinpromoter.add_votes(needed_inner) else: child_clause, needed_inner = self.build_filter( child, can_reuse=used_aliases, branch_negated=branch_negated, - current_negated=current_negated, connector=connector) + current_negated=current_negated, connector=connector, allow_joins=allow_joins) joinpromoter.add_votes(needed_inner) target_clause.add(child_clause, connector) needed_inner = joinpromoter.update_join_types(self) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 10445555e0..cbb709dfc8 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -11,6 +11,7 @@ from django.conf import settings from django.db.models.fields import DateTimeField, Field from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.utils.deprecation import RemovedInDjango19Warning +from django.utils.functional import cached_property from django.utils.six.moves import range from django.utils import timezone from django.utils import tree @@ -309,6 +310,30 @@ class WhereNode(tree.Node): clone.children.append(child) return clone + def relabeled_clone(self, change_map): + clone = self.clone() + clone.relabel_aliases(change_map) + return clone + + @cached_property + def contains_aggregate(self): + def _contains_aggregate(obj): + if not isinstance(obj, tree.Node): + return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) + return any(_contains_aggregate(c) for c in obj.children) + + return _contains_aggregate(self) + + def refs_field(self, aggregate_types, field_types): + def _refs_field(obj, aggregate_types, field_types): + if not isinstance(obj, tree.Node): + if hasattr(obj.rhs, 'refs_field'): + return obj.rhs.refs_field(aggregate_types, field_types) + return False + return any(_refs_field(c, aggregate_types, field_types) for c in obj.children) + + return _refs_field(self, aggregate_types, field_types) + class EmptyWhere(WhereNode): def add(self, data, connector): diff --git a/docs/index.txt b/docs/index.txt index 5204f00085..f4074ca5f2 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -87,6 +87,7 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Multiple databases ` | :doc:`Custom lookups ` | :doc:`Query Expressions ` | + :doc:`Conditional Expressions ` | :doc:`Database Functions ` * **Other:** diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt new file mode 100644 index 0000000000..5692bc1968 --- /dev/null +++ b/docs/ref/models/conditional-expressions.txt @@ -0,0 +1,212 @@ +======================= +Conditional Expressions +======================= + +.. currentmodule:: django.db.models.expressions + +.. versionadded:: 1.8 + +Conditional expressions let you use :keyword:`if` ... :keyword:`elif` ... +:keyword:`else` logic within filters, annotations, aggregations, and updates. A +conditional expression evaluates a series of conditions for each row of a +table and returns the matching result expression. Conditional expressions can +also be combined and nested like other :doc:`expressions `. + +The conditional expression classes +================================== + +We'll be using the following model in the subsequent examples:: + + from django.db import models + + class Client(models.Model): + REGULAR = 'R' + GOLD = 'G' + PLATINUM = 'P' + ACCOUNT_TYPE_CHOICES = ( + (REGULAR, 'Regular'), + (GOLD, 'Gold'), + (PLATINUM, 'Platinum'), + ) + name = models.CharField(max_length=50) + registered_on = models.DateField() + account_type = models.CharField( + max_length=1, + choices=ACCOUNT_TYPE_CHOICES, + default=REGULAR, + ) + +When +---- + +.. class:: When(condition=None, then=Value(None), **lookups) + +A ``When()`` object is used to encapsulate a condition and its result for use +in the conditional expression. Using a ``When()`` object is similar to using +the :meth:`~django.db.models.query.QuerySet.filter` method. The condition can +be specified using :ref:`field lookups ` or +:class:`~django.db.models.Q` objects. The result is provided using the ``then`` +keyword. + +Some examples:: + + >>> from django.db.models import When, F, Q + >>> # String arguments refer to fields; the following two examples are equivalent: + >>> When(account_type=Client.GOLD, then='name') + >>> When(account_type=Client.GOLD, then=F('name')) + >>> # You can use field lookups in the condition + >>> from datetime import date + >>> When(registered_on__gt=date(2014, 1, 1), + ... registered_on__lt=date(2015, 1, 1), + ... then='account_type') + >>> # Complex conditions can be created using Q objects + >>> When(Q(name__startswith="John") | Q(name__startswith="Paul"), + ... then='name') + +Keep in mind that each of these values can be an expression. + +.. note:: + + Since the ``then`` keyword argument is reserved for the result of the + ``When()``, there is a potential conflict if a + :class:`~django.db.models.Model` has a field named ``then``. This can be + resolved in two ways:: + + >>> from django.db.models import Value + >>> When(then__exact=0, then=Value(1)) + >>> When(Q(then=0), then=Value(1)) + +Case +---- + +.. class:: Case(*cases, **extra) + +A ``Case()`` expression is like the :keyword:`if` ... :keyword:`elif` ... +:keyword:`else` statement in ``Python``. Each ``condition`` in the provided +``When()`` objects is evaluated in order, until one evaluates to a +truthful value. The ``result`` expression from the matching ``When()`` object +is returned. + +A simple example:: + + >>> + >>> from datetime import date, timedelta + >>> from django.db.models import CharField, Case, Value, When + >>> Client.objects.create( + ... name='Jane Doe', + ... account_type=Client.REGULAR, + ... registered_on=date.today() - timedelta(days=36)) + >>> Client.objects.create( + ... name='James Smith', + ... account_type=Client.GOLD, + ... registered_on=date.today() - timedelta(days=5)) + >>> Client.objects.create( + ... name='Jack Black', + ... account_type=Client.PLATINUM, + ... registered_on=date.today() - timedelta(days=10 * 365)) + >>> # Get the discount for each Client based on the account type + >>> Client.objects.annotate( + ... discount=Case( + ... When(account_type=Client.GOLD, then=Value('5%')), + ... When(account_type=Client.PLATINUM, then=Value('10%')), + ... default=Value('0%'), + ... output_field=CharField(), + ... ), + ... ).values_list('name', 'discount') + [('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')] + +``Case()`` accepts any number of ``When()`` objects as individual arguments. +Other options are provided using keyword arguments. If none of the conditions +evaluate to ``TRUE``, then the expression given with the ``default`` keyword +argument is returned. If no ``default`` argument is provided, ``Value(None)`` +is used. + +If we wanted to change our previous query to get the discount based on how long +the ``Client`` has been with us, we could do so using lookups:: + + >>> a_month_ago = date.today() - timedelta(days=30) + >>> a_year_ago = date.today() - timedelta(days=365) + >>> # Get the discount for each Client based on the registration date + >>> Client.objects.annotate( + ... discount=Case( + ... When(registered_on__lte=a_year_ago, then=Value('10%')), + ... When(registered_on__lte=a_month_ago, then=Value('5%')), + ... default=Value('0%'), + ... output_field=CharField(), + ... ) + ... ).values_list('name', 'discount') + [('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')] + +.. note:: + + Remember that the conditions are evaluated in order, so in the above + example we get the correct result even though the second condition matches + both Jane Doe and Jack Black. This works just like an :keyword:`if` ... + :keyword:`elif` ... :keyword:`else` statement in ``Python``. + +Advanced queries +================ + +Conditional expressions can be used in annotations, aggregations, lookups, and +updates. They can also be combined and nested with other expressions. This +allows you to make powerful conditional queries. + +Conditional update +------------------ + +Let's say we want to change the ``account_type`` for our clients to match +their registration dates. We can do this using a conditional expression and the +:meth:`~django.db.models.query.QuerySet.update` method:: + + >>> a_month_ago = date.today() - timedelta(days=30) + >>> a_year_ago = date.today() - timedelta(days=365) + >>> # Update the account_type for each Client from the registration date + >>> Client.objects.update( + ... account_type=Case( + ... When(registered_on__lte=a_year_ago, + ... then=Value(Client.PLATINUM)), + ... When(registered_on__lte=a_month_ago, + ... then=Value(Client.GOLD)), + ... default=Value(Client.REGULAR) + ... ), + ... ) + >>> Client.objects.values_list('name', 'account_type') + [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')] + +Conditional aggregation +----------------------- + +What if we want to find out how many clients there are for each +``account_type``? We can nest conditional expression within +:ref:`aggregate functions ` to achieve this:: + + >>> # Create some more Clients first so we can have something to count + >>> Client.objects.create( + ... name='Jean Grey', + ... account_type=Client.REGULAR, + ... registered_on=date.today()) + >>> Client.objects.create( + ... name='James Bond', + ... account_type=Client.PLATINUM, + ... registered_on=date.today()) + >>> Client.objects.create( + ... name='Jane Porter', + ... account_type=Client.PLATINUM, + ... registered_on=date.today()) + >>> # Get counts for each value of account_type + >>> from django.db.models import IntegerField, Sum + >>> Client.objects.aggregate( + ... regular=Sum( + ... Case(When(account_type=Client.REGULAR, then=Value(1)), + ... output_field=IntegerField()) + ... ), + ... gold=Sum( + ... Case(When(account_type=Client.GOLD, then=Value(1)), + ... output_field=IntegerField()) + ... ), + ... platinum=Sum( + ... Case(When(account_type=Client.PLATINUM, then=Value(1)), + ... output_field=IntegerField()) + ... ) + ... ) + {'regular': 2, 'gold': 1, 'platinum': 3} diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index b36afe5633..b6b4000278 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -332,6 +332,15 @@ instantiating the model field as any arguments relating to data validation (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's output value. +Conditional expressions +----------------------- + +.. versionadded:: 1.8 + +Conditional expressions allow you to use :keyword:`if` ... :keyword:`elif` ... +:keyword:`else` logic in queries. Django natively supports SQL ``CASE`` +expressions. For more details see :doc:`conditional-expressions`. + Technical Information ===================== diff --git a/docs/ref/models/index.txt b/docs/ref/models/index.txt index c860ee2e04..103775e269 100644 --- a/docs/ref/models/index.txt +++ b/docs/ref/models/index.txt @@ -16,4 +16,5 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`. querysets lookups expressions + conditional-expressions database-functions diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 3e22ab06de..95ce56f241 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -93,16 +93,20 @@ New data types backends. There is a corresponding :class:`form field `. -Query Expressions and Database Functions -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Query Expressions, Conditional Expressions, and Database Functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -:doc:`Query Expressions ` allow users to create, +:doc:`Query Expressions ` allow you to create, customize, and compose complex SQL expressions. This has enabled annotate to accept expressions other than aggregates. Aggregates are now able to reference multiple fields, as well as perform arithmetic, similar to ``F()`` objects. :meth:`~django.db.models.query.QuerySet.order_by` has also gained the ability to accept expressions. +:doc:`Conditional Expressions ` allow +you to use :keyword:`if` ... :keyword:`elif` ... :keyword:`else` logic within +queries. + A collection of :doc:`database functions ` is also included with functionality such as :class:`~django.db.models.functions.Coalesce`, diff --git a/tests/expressions/models.py b/tests/expressions/models.py index 53eb54ec48..69de52c308 100644 --- a/tests/expressions/models.py +++ b/tests/expressions/models.py @@ -56,3 +56,19 @@ class Experiment(models.Model): def duration(self): return self.end - self.start + + +@python_2_unicode_compatible +class Time(models.Model): + time = models.TimeField(null=True) + + def __str__(self): + return "%s" % self.time + + +@python_2_unicode_compatible +class UUID(models.Model): + uuid = models.UUIDField(null=True) + + def __str__(self): + return "%s" % self.uuid diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 8165a496dc..3c508b8520 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -2,15 +2,16 @@ from __future__ import unicode_literals from copy import deepcopy import datetime +import uuid from django.core.exceptions import FieldError from django.db import connection, transaction, DatabaseError -from django.db.models import F, Value +from django.db.models import F, Value, TimeField, UUIDField from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import Approximate from django.utils import six -from .models import Company, Employee, Number, Experiment +from .models import Company, Employee, Number, Experiment, Time, UUID class BasicExpressionsTests(TestCase): @@ -799,3 +800,15 @@ class FTimeDeltaTests(TestCase): over_estimate = [e.name for e in Experiment.objects.filter(estimated_time__lt=F('end') - F('start'))] self.assertEqual(over_estimate, ['e4']) + + +class ValueTests(TestCase): + def test_update_TimeField_using_Value(self): + Time.objects.create() + Time.objects.update(time=Value(datetime.time(1), output_field=TimeField())) + self.assertEqual(Time.objects.get().time, datetime.time(1)) + + def test_update_UUIDField_using_Value(self): + UUID.objects.create() + UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) + self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) diff --git a/tests/expressions_case/__init__.py b/tests/expressions_case/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/expressions_case/models.py b/tests/expressions_case/models.py new file mode 100644 index 0000000000..fcee0cdbd3 --- /dev/null +++ b/tests/expressions_case/models.py @@ -0,0 +1,80 @@ +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class CaseTestModel(models.Model): + integer = models.IntegerField() + integer2 = models.IntegerField(null=True) + string = models.CharField(max_length=100, default='') + + big_integer = models.BigIntegerField(null=True) + binary = models.BinaryField(default=b'') + boolean = models.BooleanField(default=False) + comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='') + date = models.DateField(null=True, db_column='date_field') + date_time = models.DateTimeField(null=True) + decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True, db_column='decimal_field') + duration = models.DurationField(null=True) + email = models.EmailField(default='') + file = models.FileField(null=True, db_column='file_field') + file_path = models.FilePathField(null=True) + float = models.FloatField(null=True, db_column='float_field') + image = models.ImageField(null=True) + ip_address = models.IPAddressField(null=True) + generic_ip_address = models.GenericIPAddressField(null=True) + null_boolean = models.NullBooleanField() + positive_integer = models.PositiveIntegerField(null=True) + positive_small_integer = models.PositiveSmallIntegerField(null=True) + slug = models.SlugField(default='') + small_integer = models.SmallIntegerField(null=True) + text = models.TextField(default='') + time = models.TimeField(null=True, db_column='time_field') + url = models.URLField(default='') + uuid = models.UUIDField(null=True) + fk = models.ForeignKey('self', null=True) + + def __str__(self): + return "%i, %s" % (self.integer, self.string) + + +@python_2_unicode_compatible +class O2OCaseTestModel(models.Model): + o2o = models.OneToOneField(CaseTestModel, related_name='o2o_rel') + integer = models.IntegerField() + + def __str__(self): + return "%i, %s" % (self.id, self.o2o) + + +@python_2_unicode_compatible +class FKCaseTestModel(models.Model): + fk = models.ForeignKey(CaseTestModel, related_name='fk_rel') + integer = models.IntegerField() + + def __str__(self): + return "%i, %s" % (self.id, self.fk) + + +@python_2_unicode_compatible +class Client(models.Model): + REGULAR = 'R' + GOLD = 'G' + PLATINUM = 'P' + ACCOUNT_TYPE_CHOICES = ( + (REGULAR, 'Regular'), + (GOLD, 'Gold'), + (PLATINUM, 'Platinum'), + ) + name = models.CharField(max_length=50) + registered_on = models.DateField() + account_type = models.CharField( + max_length=1, + choices=ACCOUNT_TYPE_CHOICES, + default=REGULAR, + ) + + def __str__(self): + return self.name diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py new file mode 100644 index 0000000000..20bdb840fa --- /dev/null +++ b/tests/expressions_case/tests.py @@ -0,0 +1,1083 @@ +from __future__ import unicode_literals + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from operator import attrgetter, itemgetter +from uuid import UUID + +from django.core.exceptions import FieldError +from django.db import models +from django.db.models import F, Q, Value, Min, Max +from django.db.models.expressions import Case, When +from django.test import TestCase +from django.utils import six + +from .models import CaseTestModel, O2OCaseTestModel, FKCaseTestModel, Client + + +class CaseExpressionTests(TestCase): + @classmethod + def setUpTestData(cls): + o = CaseTestModel.objects.create(integer=1, integer2=1, string='1') + O2OCaseTestModel.objects.create(o2o=o, integer=1) + FKCaseTestModel.objects.create(fk=o, integer=1) + + o = CaseTestModel.objects.create(integer=2, integer2=3, string='2') + O2OCaseTestModel.objects.create(o2o=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=3) + + o = CaseTestModel.objects.create(integer=3, integer2=4, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=2, integer2=2, string='2') + O2OCaseTestModel.objects.create(o2o=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=3) + + o = CaseTestModel.objects.create(integer=3, integer2=4, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=3, integer2=3, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=4, integer2=5, string='4') + O2OCaseTestModel.objects.create(o2o=o, integer=1) + FKCaseTestModel.objects.create(fk=o, integer=5) + + # GROUP BY on Oracle fails with TextField/BinaryField; see #24096. + cls.non_lob_fields = [ + f.name for f in CaseTestModel._meta.get_fields() + if not (f.is_relation and f.auto_created) and not isinstance(f, (models.BinaryField, models.TextField)) + ] + + def test_annotate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(test=Case( + When(integer=1, then=Value('one')), + When(integer=2, then=Value('two')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_without_default(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(test=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_with_expression_as_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(f_test=Case( + When(integer=1, then=F('integer') + 1), + When(integer=2, then=F('integer') + 3), + default='integer', + )).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 4)], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_expression_as_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(f_test=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('integer') + 1, then=Value('+1')), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_join_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(integer=1, then=F('o2o_rel__integer') + 1), + When(integer=2, then=F('o2o_rel__integer') + 3), + default='o2o_rel__integer', + )).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 1)], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_join_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(integer2=F('o2o_rel__integer'), then=Value('equal')), + When(integer2=F('o2o_rel__integer') + 1, then=Value('+1')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, 'other')], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_join_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(o2o_rel__integer=1, then=Value('one')), + When(o2o_rel__integer=2, then=Value('two')), + When(o2o_rel__integer=3, then=Value('three')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'three'), (2, 'two'), (3, 'three'), (3, 'three'), (4, 'one')], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_annotation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + f_plus_3=F('integer') + 3, + ).annotate( + f_test=Case( + When(integer=1, then='f_plus_1'), + When(integer=2, then='f_plus_3'), + default='integer', + ), + ).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 4)], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_annotation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).annotate( + f_test=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('f_plus_1'), then=Value('+1')), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_annotation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_minus_2=F('integer') - 2, + ).annotate( + test=Case( + When(f_minus_2=-1, then=Value('negative one')), + When(f_minus_2=0, then=Value('zero')), + When(f_minus_2=1, then=Value('one')), + default=Value('other'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 'negative one'), (2, 'zero'), (3, 'one'), (2, 'zero'), (3, 'one'), (3, 'one'), (4, 'other')], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_with_aggregation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(integer=2, then='min'), + When(integer=3, then='max'), + ), + ).order_by('pk'), + [(1, None, 1, 1), (2, 2, 2, 3), (3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4), (3, 4, 3, 4), (4, None, 5, 5)], + transform=itemgetter('integer', 'test', 'min', 'max') + ) + + def test_annotate_with_aggregation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(integer2=F('min'), then=Value('min')), + When(integer2=F('max'), then=Value('max')), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 1, 'min'), (2, 3, 'max'), (3, 4, 'max'), (2, 2, 'min'), (3, 4, 'max'), (3, 3, 'min'), (4, 5, 'min')], + transform=itemgetter('integer', 'integer2', 'test') + ) + + def test_annotate_with_aggregation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(max=3, then=Value('max = 3')), + When(max=4, then=Value('max = 4')), + default=Value(''), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 1, ''), (2, 3, 'max = 3'), (3, 4, 'max = 4'), (2, 3, 'max = 3'), + (3, 4, 'max = 4'), (3, 4, 'max = 4'), (4, 5, '')], + transform=itemgetter('integer', 'max', 'test') + ) + + def test_combined_expression(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer=1, then=Value(2)), + When(integer=2, then=Value(1)), + default=Value(3), + output_field=models.IntegerField(), + ) + 1, + ).order_by('pk'), + [(1, 3), (2, 2), (3, 4), (2, 2), (3, 4), (3, 4), (4, 4)], + transform=attrgetter('integer', 'test') + ) + + def test_in_subquery(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter( + pk__in=CaseTestModel.objects.annotate( + test=Case( + When(integer=F('integer2'), then='pk'), + When(integer=4, then='pk'), + output_field=models.IntegerField(), + ), + ).values('test')).order_by('pk'), + [(1, 1), (2, 2), (3, 3), (4, 5)], + transform=attrgetter('integer', 'integer2') + ) + + def test_aggregate(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + one=models.Sum(Case( + When(integer=1, then=Value(1)), + output_field=models.IntegerField(), + )), + two=models.Sum(Case( + When(integer=2, then=Value(1)), + output_field=models.IntegerField(), + )), + three=models.Sum(Case( + When(integer=3, then=Value(1)), + output_field=models.IntegerField(), + )), + four=models.Sum(Case( + When(integer=4, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'one': 1, 'two': 2, 'three': 3, 'four': 1} + ) + + def test_aggregate_with_expression_as_value(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + one=models.Sum(Case(When(integer=1, then='integer'))), + two=models.Sum(Case(When(integer=2, then=F('integer') - 1))), + three=models.Sum(Case(When(integer=3, then=F('integer') + 1))), + ), + {'one': 1, 'two': 2, 'three': 12} + ) + + def test_aggregate_with_expression_as_condition(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + equal=models.Sum(Case( + When(integer2=F('integer'), then=Value(1)), + output_field=models.IntegerField(), + )), + plus_one=models.Sum(Case( + When(integer2=F('integer') + 1, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'equal': 3, 'plus_one': 4} + ) + + def test_filter(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=Value(3)), + When(integer=3, then=Value(4)), + default=Value(1), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_without_default(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=Value(3)), + When(integer=3, then=Value(4)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_expression_as_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=F('integer') + 1), + When(integer=3, then=F('integer')), + default='integer', + )).order_by('pk'), + [(1, 1), (2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_expression_as_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(string=Case( + When(integer2=F('integer'), then=Value('2')), + When(integer2=F('integer') + 1, then=Value('3')), + output_field=models.CharField(), + )).order_by('pk'), + [(3, 4, '3'), (2, 2, '2'), (3, 4, '3')], + transform=attrgetter('integer', 'integer2', 'string') + ) + + def test_filter_with_join_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=F('o2o_rel__integer') + 1), + When(integer=3, then=F('o2o_rel__integer')), + default='o2o_rel__integer', + )).order_by('pk'), + [(1, 1), (2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_join_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer=Case( + When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), + When(integer2=F('o2o_rel__integer'), then=Value(3)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_join_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(o2o_rel__integer=1, then=Value(1)), + When(o2o_rel__integer=2, then=Value(3)), + When(o2o_rel__integer=3, then=Value(4)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f=F('integer'), + f_plus_1=F('integer') + 1, + ).filter( + integer2=Case( + When(integer=2, then='f_plus_1'), + When(integer=3, then='f'), + ), + ).order_by('pk'), + [(2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).filter( + integer=Case( + When(integer2=F('integer'), then=Value(2)), + When(integer2=F('f_plus_1'), then=Value(3)), + output_field=models.IntegerField(), + ), + ).order_by('pk'), + [(3, 4), (2, 2), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).filter( + integer2=Case( + When(f_plus_1=3, then=Value(3)), + When(f_plus_1=4, then=Value(4)), + default=Value(1), + output_field=models.IntegerField(), + ), + ).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_aggregation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).filter( + integer2=Case( + When(integer=2, then='min'), + When(integer=3, then='max'), + ), + ).order_by('pk'), + [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], + transform=itemgetter('integer', 'integer2', 'min', 'max') + ) + + def test_filter_with_aggregation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).filter( + integer=Case( + When(integer2=F('min'), then=Value(2)), + When(integer2=F('max'), then=Value(3)), + ), + ).order_by('pk'), + [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], + transform=itemgetter('integer', 'integer2', 'min', 'max') + ) + + def test_filter_with_aggregation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + max=Max('fk_rel__integer'), + ).filter( + integer=Case( + When(max=3, then=Value(2)), + When(max=4, then=Value(3)), + ), + ).order_by('pk'), + [(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)], + transform=itemgetter('integer', 'integer2', 'max') + ) + + def test_update(self): + CaseTestModel.objects.update( + string=Case( + When(integer=1, then=Value('one')), + When(integer=2, then=Value('two')), + default=Value('other'), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')], + transform=attrgetter('integer', 'string') + ) + + def test_update_without_default(self): + CaseTestModel.objects.update( + integer2=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'integer2') + ) + + def test_update_with_expression_as_value(self): + CaseTestModel.objects.update( + integer=Case( + When(integer=1, then=F('integer') + 1), + When(integer=2, then=F('integer') + 3), + default='integer', + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [('1', 2), ('2', 5), ('3', 3), ('2', 5), ('3', 3), ('3', 3), ('4', 4)], + transform=attrgetter('string', 'integer') + ) + + def test_update_with_expression_as_condition(self): + CaseTestModel.objects.update( + string=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('integer') + 1, then=Value('+1')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'string') + ) + + def test_update_with_join_in_condition_raise_field_error(self): + with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): + CaseTestModel.objects.update( + integer=Case( + When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), + When(integer2=F('o2o_rel__integer'), then=Value(3)), + output_field=models.IntegerField(), + ), + ) + + def test_update_with_join_in_predicate_raise_field_error(self): + with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): + CaseTestModel.objects.update( + string=Case( + When(o2o_rel__integer=1, then=Value('one')), + When(o2o_rel__integer=2, then=Value('two')), + When(o2o_rel__integer=3, then=Value('three')), + default=Value('other'), + output_field=models.CharField(), + ), + ) + + def test_update_big_integer(self): + CaseTestModel.objects.update( + big_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'big_integer') + ) + + def test_update_binary(self): + CaseTestModel.objects.update( + binary=Case( + # fails on postgresql on Python 2.7 if output_field is not + # set explicitly + When(integer=1, then=Value(b'one', output_field=models.BinaryField())), + When(integer=2, then=Value(b'two', output_field=models.BinaryField())), + default=Value(b'', output_field=models.BinaryField()), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, b'one'), (2, b'two'), (3, b''), (2, b'two'), (3, b''), (3, b''), (4, b'')], + transform=lambda o: (o.integer, six.binary_type(o.binary)) + ) + + def test_update_boolean(self): + CaseTestModel.objects.update( + boolean=Case( + When(integer=1, then=Value(True)), + When(integer=2, then=Value(True)), + default=Value(False), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, True), (2, True), (3, False), (2, True), (3, False), (3, False), (4, False)], + transform=attrgetter('integer', 'boolean') + ) + + def test_update_comma_separated_integer(self): + CaseTestModel.objects.update( + comma_separated_integer=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2,2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2,2'), (3, ''), (2, '2,2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'comma_separated_integer') + ) + + def test_update_date(self): + CaseTestModel.objects.update( + date=Case( + When(integer=1, then=Value(date(2015, 1, 1))), + When(integer=2, then=Value(date(2015, 1, 2))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, date(2015, 1, 1)), (2, date(2015, 1, 2)), (3, None), (2, date(2015, 1, 2)), + (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'date') + ) + + def test_update_date_time(self): + CaseTestModel.objects.update( + date_time=Case( + When(integer=1, then=Value(datetime(2015, 1, 1))), + When(integer=2, then=Value(datetime(2015, 1, 2))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, datetime(2015, 1, 1)), (2, datetime(2015, 1, 2)), (3, None), (2, datetime(2015, 1, 2)), + (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'date_time') + ) + + def test_update_decimal(self): + CaseTestModel.objects.update( + decimal=Case( + When(integer=1, then=Value(Decimal('1.1'))), + When(integer=2, then=Value(Decimal('2.2'))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, Decimal('1.1')), (2, Decimal('2.2')), (3, None), (2, Decimal('2.2')), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'decimal') + ) + + def test_update_duration(self): + CaseTestModel.objects.update( + duration=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing timedeltas + When(integer=1, then=Value(timedelta(1), output_field=models.DurationField())), + When(integer=2, then=Value(timedelta(2), output_field=models.DurationField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, timedelta(1)), (2, timedelta(2)), (3, None), (2, timedelta(2)), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'duration') + ) + + def test_update_email(self): + CaseTestModel.objects.update( + email=Case( + When(integer=1, then=Value('1@example.com')), + When(integer=2, then=Value('2@example.com')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1@example.com'), (2, '2@example.com'), (3, ''), (2, '2@example.com'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'email') + ) + + def test_update_file(self): + CaseTestModel.objects.update( + file=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=lambda o: (o.integer, six.text_type(o.file)) + ) + + def test_update_file_path(self): + CaseTestModel.objects.update( + file_path=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'file_path') + ) + + def test_update_float(self): + CaseTestModel.objects.update( + float=Case( + When(integer=1, then=Value(1.1)), + When(integer=2, then=Value(2.2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1.1), (2, 2.2), (3, None), (2, 2.2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'float') + ) + + def test_update_image(self): + CaseTestModel.objects.update( + image=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=lambda o: (o.integer, six.text_type(o.image)) + ) + + def test_update_ip_address(self): + CaseTestModel.objects.update( + ip_address=Case( + # fails on postgresql if output_field is not set explicitly + When(integer=1, then=Value('1.1.1.1')), + When(integer=2, then=Value('2.2.2.2')), + output_field=models.IPAddressField(), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1.1.1.1'), (2, '2.2.2.2'), (3, None), (2, '2.2.2.2'), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'ip_address') + ) + + def test_update_generic_ip_address(self): + CaseTestModel.objects.update( + generic_ip_address=Case( + # fails on postgresql if output_field is not set explicitly + When(integer=1, then=Value('1.1.1.1')), + When(integer=2, then=Value('2.2.2.2')), + output_field=models.GenericIPAddressField(), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1.1.1.1'), (2, '2.2.2.2'), (3, None), (2, '2.2.2.2'), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'generic_ip_address') + ) + + def test_update_null_boolean(self): + CaseTestModel.objects.update( + null_boolean=Case( + When(integer=1, then=Value(True)), + When(integer=2, then=Value(False)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, True), (2, False), (3, None), (2, False), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'null_boolean') + ) + + def test_update_positive_integer(self): + CaseTestModel.objects.update( + positive_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'positive_integer') + ) + + def test_update_positive_small_integer(self): + CaseTestModel.objects.update( + positive_small_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'positive_small_integer') + ) + + def test_update_slug(self): + CaseTestModel.objects.update( + slug=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'slug') + ) + + def test_update_small_integer(self): + CaseTestModel.objects.update( + small_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'small_integer') + ) + + def test_update_string(self): + CaseTestModel.objects.filter(string__in=['1', '2']).update( + string=Case( + When(integer=1, then=Value('1', output_field=models.CharField())), + When(integer=2, then=Value('2', output_field=models.CharField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.filter(string__in=['1', '2']).order_by('pk'), + [(1, '1'), (2, '2'), (2, '2')], + transform=attrgetter('integer', 'string') + ) + + def test_update_text(self): + CaseTestModel.objects.update( + text=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'text') + ) + + def test_update_time(self): + CaseTestModel.objects.update( + time=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing times + When(integer=1, then=Value(time(1), output_field=models.TimeField())), + When(integer=2, then=Value(time(2), output_field=models.TimeField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, time(1)), (2, time(2)), (3, None), (2, time(2)), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'time') + ) + + def test_update_url(self): + CaseTestModel.objects.update( + url=Case( + When(integer=1, then=Value('http://1.example.com/')), + When(integer=2, then=Value('http://2.example.com/')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, ''), (2, 'http://2.example.com/'), + (3, ''), (3, ''), (4, '') + ], + transform=attrgetter('integer', 'url') + ) + + def test_update_uuid(self): + CaseTestModel.objects.update( + uuid=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing UUIDs + When(integer=1, then=Value( + UUID('11111111111111111111111111111111'), + output_field=models.UUIDField(), + )), + When(integer=2, then=Value( + UUID('22222222222222222222222222222222'), + output_field=models.UUIDField(), + )), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, UUID('11111111111111111111111111111111')), (2, UUID('22222222222222222222222222222222')), (3, None), + (2, UUID('22222222222222222222222222222222')), (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'uuid') + ) + + def test_update_fk(self): + obj1, obj2 = CaseTestModel.objects.all()[:2] + + CaseTestModel.objects.update( + fk=Case( + When(integer=1, then=Value(obj1.pk)), + When(integer=2, then=Value(obj2.pk)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, obj1.pk), (2, obj2.pk), (3, None), (2, obj2.pk), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'fk_id') + ) + + def test_lookup_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer__lt=2, then=Value('less than 2')), + When(integer__gt=2, then=Value('greater than 2')), + default=Value('equal to 2'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 'less than 2'), (2, 'equal to 2'), (3, 'greater than 2'), (2, 'equal to 2'), (3, 'greater than 2'), + (3, 'greater than 2'), (4, 'greater than 2') + ], + transform=attrgetter('integer', 'test') + ) + + def test_lookup_different_fields(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer=2, integer2=3, then=Value('when')), + default=Value('default'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 1, 'default'), (2, 3, 'when'), (3, 4, 'default'), (2, 2, 'default'), (3, 4, 'default'), + (3, 3, 'default'), (4, 5, 'default') + ], + transform=attrgetter('integer', 'integer2', 'test') + ) + + def test_combined_q_object(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(Q(integer=2) | Q(integer2=3), then=Value('when')), + default=Value('default'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 1, 'default'), (2, 3, 'when'), (3, 4, 'default'), (2, 2, 'when'), (3, 4, 'default'), + (3, 3, 'when'), (4, 5, 'default') + ], + transform=attrgetter('integer', 'integer2', 'test') + ) + + +class CaseDocumentationExamples(TestCase): + @classmethod + def setUpTestData(cls): + Client.objects.create( + name='Jane Doe', + account_type=Client.REGULAR, + registered_on=date.today() - timedelta(days=36), + ) + Client.objects.create( + name='James Smith', + account_type=Client.GOLD, + registered_on=date.today() - timedelta(days=5), + ) + Client.objects.create( + name='Jack Black', + account_type=Client.PLATINUM, + registered_on=date.today() - timedelta(days=10 * 365), + ) + + def test_simple_example(self): + self.assertQuerysetEqual( + Client.objects.annotate( + discount=Case( + When(account_type=Client.GOLD, then=Value('5%')), + When(account_type=Client.PLATINUM, then=Value('10%')), + default=Value('0%'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')], + transform=attrgetter('name', 'discount') + ) + + def test_lookup_example(self): + a_month_ago = date.today() - timedelta(days=30) + a_year_ago = date.today() - timedelta(days=365) + self.assertQuerysetEqual( + Client.objects.annotate( + discount=Case( + When(registered_on__lte=a_year_ago, then=Value('10%')), + When(registered_on__lte=a_month_ago, then=Value('5%')), + default=Value('0%'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')], + transform=attrgetter('name', 'discount') + ) + + def test_conditional_update_example(self): + a_month_ago = date.today() - timedelta(days=30) + a_year_ago = date.today() - timedelta(days=365) + Client.objects.update( + account_type=Case( + When(registered_on__lte=a_year_ago, then=Value(Client.PLATINUM)), + When(registered_on__lte=a_month_ago, then=Value(Client.GOLD)), + default=Value(Client.REGULAR), + ), + ) + self.assertQuerysetEqual( + Client.objects.all().order_by('pk'), + [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')], + transform=attrgetter('name', 'account_type') + ) + + def test_conditional_aggregation_example(self): + Client.objects.create( + name='Jean Grey', + account_type=Client.REGULAR, + registered_on=date.today(), + ) + Client.objects.create( + name='James Bond', + account_type=Client.PLATINUM, + registered_on=date.today(), + ) + Client.objects.create( + name='Jane Porter', + account_type=Client.PLATINUM, + registered_on=date.today(), + ) + self.assertEqual( + Client.objects.aggregate( + regular=models.Sum(Case( + When(account_type=Client.REGULAR, then=Value(1)), + output_field=models.IntegerField(), + )), + gold=models.Sum(Case( + When(account_type=Client.GOLD, then=Value(1)), + output_field=models.IntegerField(), + )), + platinum=models.Sum(Case( + When(account_type=Client.PLATINUM, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'regular': 2, 'gold': 1, 'platinum': 3} + )