From c7805ee214802ff1c0de53660bd2594bc1abfebb Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Sat, 23 May 2015 18:12:09 +1000 Subject: [PATCH] Fixed #24699 -- Added aggregate support for DurationField on Oracle --- django/db/backends/base/features.py | 3 - django/db/backends/oracle/features.py | 1 - django/db/backends/oracle/functions.py | 24 +++++++ django/db/models/aggregates.py | 18 ++++++ tests/aggregation/tests.py | 86 ++++++++++++++------------ 5 files changed, 89 insertions(+), 43 deletions(-) create mode 100644 django/db/backends/oracle/functions.py diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 8c7b31e471..e0a7e49512 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -157,9 +157,6 @@ class BaseDatabaseFeatures(object): # Support for the DISTINCT ON clause can_distinct_on_fields = False - # Can the backend use an Avg aggregate on DurationField? - can_avg_on_durationfield = True - # Does the backend decide to commit before SAVEPOINT statements # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965 autocommits_when_autocommit_is_off = False diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 2be00acd99..1ef0f232e5 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -39,7 +39,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): uppercases_column_names = True # select for update with limit can be achieved on Oracle, but not with the current backend. supports_select_for_update_with_limit = False - can_avg_on_durationfield = False # Pending implementation (#24699). def introspected_boolean_field_type(self, field=None, created_separately=False): """ diff --git a/django/db/backends/oracle/functions.py b/django/db/backends/oracle/functions.py new file mode 100644 index 0000000000..384f092fd4 --- /dev/null +++ b/django/db/backends/oracle/functions.py @@ -0,0 +1,24 @@ +from django.db.models import DecimalField, DurationField, Func + + +class IntervalToSeconds(Func): + function = '' + template = """ + EXTRACT(day from %(expressions)s) * 86400 + + EXTRACT(hour from %(expressions)s) * 3600 + + EXTRACT(minute from %(expressions)s) * 60 + + EXTRACT(second from %(expressions)s) + """ + + def __init__(self, expression, **extra): + output_field = extra.pop('output_field', DecimalField()) + super(IntervalToSeconds, self).__init__(expression, output_field=output_field, **extra) + + +class SecondsToInterval(Func): + function = 'NUMTODSINTERVAL' + template = "%(function)s(%(expressions)s, 'SECOND')" + + def __init__(self, expression, **extra): + output_field = extra.pop('output_field', DurationField()) + super(SecondsToInterval, self).__init__(expression, output_field=output_field, **extra) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 2d7c43c90e..08c9169c6c 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -78,6 +78,15 @@ class Avg(Aggregate): output_field = extra.pop('output_field', FloatField()) super(Avg, self).__init__(expression, output_field=output_field, **extra) + def as_oracle(self, compiler, connection): + if self.output_field.get_internal_type() == 'DurationField': + expression = self.get_source_expressions()[0] + from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval + return compiler.compile( + SecondsToInterval(Avg(IntervalToSeconds(expression))) + ) + return super(Avg, self).as_sql(compiler, connection) + class Count(Aggregate): function = 'COUNT' @@ -137,6 +146,15 @@ class Sum(Aggregate): function = 'SUM' name = 'Sum' + def as_oracle(self, compiler, connection): + if self.output_field.get_internal_type() == 'DurationField': + expression = self.get_source_expressions()[0] + from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval + return compiler.compile( + SecondsToInterval(Sum(IntervalToSeconds(expression))) + ) + return super(Sum, self).as_sql(compiler, connection) + class Variance(Aggregate): name = 'Variance' diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 81152fd2ad..8f3c9fb410 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -10,7 +10,7 @@ from django.db.models import ( F, Aggregate, Avg, Count, DecimalField, DurationField, FloatField, Func, IntegerField, Max, Min, Sum, Value, ) -from django.test import TestCase, ignore_warnings, skipUnlessDBFeature +from django.test import TestCase, ignore_warnings from django.test.utils import Approximate, CaptureQueriesContext from django.utils import six, timezone from django.utils.deprecation import RemovedInDjango20Warning @@ -441,11 +441,16 @@ class AggregateTestCase(TestCase): vals = Book.objects.annotate(num_authors=Count("authors__id")).aggregate(Avg("num_authors")) self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)}) - @skipUnlessDBFeature('can_avg_on_durationfield') def test_avg_duration_field(self): self.assertEqual( Publisher.objects.aggregate(Avg('duration', output_field=DurationField())), - {'duration__avg': datetime.timedelta(1, 43200)} # 1.5 days + {'duration__avg': datetime.timedelta(days=1, hours=12)} + ) + + def test_sum_duration_field(self): + self.assertEqual( + Publisher.objects.aggregate(Sum('duration', output_field=DurationField())), + {'duration__sum': datetime.timedelta(days=3)} ) def test_sum_distinct_aggregate(self): @@ -984,47 +989,50 @@ class AggregateTestCase(TestCase): Book.objects.annotate(Max('id')).annotate(Sum('id__max')) def test_add_implementation(self): - try: - # test completely changing how the output is rendered - def lower_case_function_override(self, compiler, connection): - sql, params = compiler.compile(self.source_expressions[0]) - substitutions = dict(function=self.function.lower(), expressions=sql) - substitutions.update(self.extra) - return self.template % substitutions, params - setattr(Sum, 'as_' + connection.vendor, lower_case_function_override) + class MySum(Sum): + pass - qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), - output_field=IntegerField())) - self.assertEqual(str(qs.query).count('sum('), 1) - b1 = qs.get(pk=self.b4.pk) - self.assertEqual(b1.sums, 383) + # test completely changing how the output is rendered + def lower_case_function_override(self, compiler, connection): + sql, params = compiler.compile(self.source_expressions[0]) + substitutions = dict(function=self.function.lower(), expressions=sql) + substitutions.update(self.extra) + return self.template % substitutions, params + setattr(MySum, 'as_' + connection.vendor, lower_case_function_override) - # test changing the dict and delegating - def lower_case_function_super(self, compiler, connection): - self.extra['function'] = self.function.lower() - return super(Sum, self).as_sql(compiler, connection) - setattr(Sum, 'as_' + connection.vendor, lower_case_function_super) + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 383) - qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), - output_field=IntegerField())) - self.assertEqual(str(qs.query).count('sum('), 1) - b1 = qs.get(pk=self.b4.pk) - self.assertEqual(b1.sums, 383) + # test changing the dict and delegating + def lower_case_function_super(self, compiler, connection): + self.extra['function'] = self.function.lower() + return super(MySum, self).as_sql(compiler, connection) + setattr(MySum, 'as_' + connection.vendor, lower_case_function_super) - # test overriding all parts of the template - def be_evil(self, compiler, connection): - substitutions = dict(function='MAX', expressions='2') - substitutions.update(self.extra) - return self.template % substitutions, () - setattr(Sum, 'as_' + connection.vendor, be_evil) + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 383) - qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), - output_field=IntegerField())) - self.assertEqual(str(qs.query).count('MAX('), 1) - b1 = qs.get(pk=self.b4.pk) - self.assertEqual(b1.sums, 2) - finally: - delattr(Sum, 'as_' + connection.vendor) + # test overriding all parts of the template + def be_evil(self, compiler, connection): + substitutions = dict(function='MAX', expressions='2') + substitutions.update(self.extra) + return self.template % substitutions, () + setattr(MySum, 'as_' + connection.vendor, be_evil) + + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('MAX('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 2) def test_complex_values_aggregation(self): max_rating = Book.objects.values('rating').aggregate(