diff --git a/AUTHORS b/AUTHORS index 4a35b84176..f39c8f4366 100644 --- a/AUTHORS +++ b/AUTHORS @@ -498,6 +498,7 @@ answer newbie questions, and generally made Django that much better: Matthew Schinckel Matthew Somerville Matthew Tretter + Matthew Wilkes Matthias Kestenholz Matthias Pronk Matt Hoskins diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index e550c4b260..d123ce769d 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -239,7 +239,13 @@ class ArrayInLookup(In): values = super(ArrayInLookup, self).get_prep_lookup() # In.process_rhs() expects values to be hashable, so convert lists # to tuples. - return [tuple(value) for value in values] + prepared_values = [] + for value in values: + if hasattr(value, 'resolve_expression'): + prepared_values.append(value) + else: + prepared_values.append(tuple(value)) + return prepared_values class IndexTransform(Transform): diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 659ee2e2ce..306a0378b2 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -155,6 +155,10 @@ class DatabaseOperations(BaseDatabaseOperations): if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + # MySQL doesn't support tz-aware datetimes if timezone.is_aware(value): if settings.USE_TZ: @@ -171,6 +175,10 @@ class DatabaseOperations(BaseDatabaseOperations): if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + # MySQL doesn't support tz-aware times if timezone.is_aware(value): raise ValueError("MySQL backend does not support timezone-aware times.") diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 3ea45cb3bf..d30b093119 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -408,6 +408,10 @@ WHEN (new.%(col_name)s IS NULL) if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + # cx_Oracle doesn't support tz-aware datetimes if timezone.is_aware(value): if settings.USE_TZ: @@ -421,6 +425,10 @@ WHEN (new.%(col_name)s IS NULL) if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + if isinstance(value, six.string_types): return datetime.datetime.strptime(value, '%H:%M:%S') diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 65847bd5ef..7f6de932ad 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -182,6 +182,10 @@ class DatabaseOperations(BaseDatabaseOperations): if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + # SQLite doesn't support tz-aware datetimes if timezone.is_aware(value): if settings.USE_TZ: @@ -195,6 +199,10 @@ class DatabaseOperations(BaseDatabaseOperations): if value is None: return None + # Expression values are adapted by the database. + if hasattr(value, 'resolve_expression'): + return value + # SQLite doesn't support tz-aware datetimes if timezone.is_aware(value): raise ValueError("SQLite backend does not support timezone-aware times.") diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 4f10458c4c..bd8beea53f 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,3 +1,4 @@ +import itertools import math import warnings from copy import copy @@ -170,6 +171,12 @@ class FieldGetDbPrepValueMixin(object): """ get_db_prep_lookup_value_is_iterable = False + @classmethod + def get_prep_lookup_value(cls, value, output_field): + if hasattr(value, '_prepare'): + return value._prepare(output_field) + return output_field.get_prep_value(value) + def get_db_prep_lookup(self, value, connection): # For relational fields, use the output_field of the 'field' attribute. field = getattr(self.lhs.output_field, 'field', None) @@ -191,6 +198,51 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): """ get_db_prep_lookup_value_is_iterable = True + def get_prep_lookup(self): + prepared_values = [] + if hasattr(self.rhs, '_prepare'): + # A subquery is like an iterable but its items shouldn't be + # prepared independently. + return self.rhs._prepare(self.lhs.output_field) + for rhs_value in self.rhs: + if hasattr(rhs_value, 'resolve_expression'): + # An expression will be handled by the database but can coexist + # alongside real values. + pass + elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): + rhs_value = self.lhs.output_field.get_prep_value(rhs_value) + prepared_values.append(rhs_value) + return prepared_values + + def process_rhs(self, compiler, connection): + if self.rhs_is_direct_value(): + # rhs should be an iterable of values. Use batch_process_rhs() + # to prepare/transform those values. + return self.batch_process_rhs(compiler, connection) + else: + return super(FieldGetDbPrepValueIterableMixin, self).process_rhs(compiler, connection) + + def resolve_expression_parameter(self, compiler, connection, sql, param): + params = [param] + if hasattr(param, 'resolve_expression'): + param = param.resolve_expression(compiler.query) + if hasattr(param, 'as_sql'): + sql, params = param.as_sql(compiler, connection) + return sql, params + + def batch_process_rhs(self, compiler, connection, rhs=None): + pre_processed = super(FieldGetDbPrepValueIterableMixin, self).batch_process_rhs(compiler, connection, rhs) + # The params list may contain expressions which compile to a + # sql/param pair. Zip them to get sql and param pairs that refer to the + # same argument and attempt to replace them with the result of + # compiling the param step. + sql, params = zip(*( + self.resolve_expression_parameter(compiler, connection, sql, param) + for sql, param in zip(*pre_processed) + )) + params = itertools.chain.from_iterable(params) + return sql, tuple(params) + class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'exact' @@ -255,13 +307,6 @@ IntegerField.register_lookup(IntegerLessThan) class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'in' - def get_prep_lookup(self): - if hasattr(self.rhs, '_prepare'): - return self.rhs._prepare(self.lhs.output_field) - if hasattr(self.lhs.output_field, 'get_prep_value'): - return [self.lhs.output_field.get_prep_value(v) for v in self.rhs] - return self.rhs - def process_rhs(self, compiler, connection): db_rhs = getattr(self.rhs, '_db', None) if db_rhs is not None and db_rhs != connection.alias: @@ -409,21 +454,9 @@ Field.register_lookup(IEndsWith) class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'range' - def get_prep_lookup(self): - if hasattr(self.rhs, '_prepare'): - return self.rhs._prepare(self.lhs.output_field) - return [self.lhs.output_field.get_prep_value(v) for v in self.rhs] - def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) - def process_rhs(self, compiler, connection): - if self.rhs_is_direct_value(): - # rhs should be an iterable of 2 values, we use batch_process_rhs - # to prepare/transform those values - return self.batch_process_rhs(compiler, connection) - else: - return super(Range, self).process_rhs(compiler, connection) Field.register_lookup(Range) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index d94a3abda5..9c46ee4ca1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -990,6 +990,20 @@ class Query(object): pre_joins = self.alias_refcount.copy() value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] + elif isinstance(value, (list, tuple)): + # The items of the iterable may be expressions and therefore need + # to be resolved independently. + processed_values = [] + used_joins = set() + for sub_value in value: + if hasattr(sub_value, 'resolve_expression'): + pre_joins = self.alias_refcount.copy() + processed_values.append( + sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) + ) + # The used_joins for a tuple of expressions is the union of + # the used_joins for the individual expressions. + used_joins |= set(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)) # Subqueries need to use a different set of aliases than the # outer query. Call bump_prefix to change aliases of the inner # query (the value). diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index 41baeb40dd..706ee1dbd4 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -234,6 +234,9 @@ Models * Added support for expressions in :meth:`.QuerySet.values` and :meth:`~.QuerySet.values_list`. +* Added support for query expressions on lookups that take multiple arguments, + such as ``range``. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions/models.py b/tests/expressions/models.py index 4633008868..264120d251 100644 --- a/tests/expressions/models.py +++ b/tests/expressions/models.py @@ -61,6 +61,15 @@ class Experiment(models.Model): return self.end - self.start +@python_2_unicode_compatible +class Result(models.Model): + experiment = models.ForeignKey(Experiment, models.CASCADE) + result_time = models.DateTimeField() + + def __str__(self): + return "Result at %s" % self.result_time + + @python_2_unicode_compatible class Time(models.Model): time = models.TimeField(null=True) @@ -69,6 +78,16 @@ class Time(models.Model): return "%s" % self.time +@python_2_unicode_compatible +class SimulationRun(models.Model): + start = models.ForeignKey(Time, models.CASCADE, null=True) + end = models.ForeignKey(Time, models.CASCADE, null=True) + midpoint = models.TimeField() + + def __str__(self): + return "%s (%s to %s)" % (self.midpoint, self.start, self.end) + + @python_2_unicode_compatible class UUID(models.Model): uuid = models.UUIDField(null=True) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 27929c9146..6d9ee8dc75 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import datetime +import unittest import uuid from copy import deepcopy @@ -17,11 +18,15 @@ from django.db.models.expressions import ( from django.db.models.functions import ( Coalesce, Concat, Length, Lower, Substr, Upper, ) +from django.db.models.sql import constants +from django.db.models.sql.datastructures import Join from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import Approximate from django.utils import six -from .models import UUID, Company, Employee, Experiment, Number, Time +from .models import ( + UUID, Company, Employee, Experiment, Number, Result, SimulationRun, Time, +) class BasicExpressionsTests(TestCase): @@ -391,6 +396,144 @@ class BasicExpressionsTests(TestCase): self.assertEqual(str(qs.query).count('JOIN'), 2) +class IterableLookupInnerExpressionsTests(TestCase): + @classmethod + def setUpTestData(cls): + ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30) + # MySQL requires that the values calculated for expressions don't pass + # outside of the field's range, so it's inconvenient to use the values + # in the more general tests. + Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo) + Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo) + Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo) + Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo) + Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo) + + def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self): + # __in lookups can use F() expressions for integers. + queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10])) + self.assertQuerysetEqual(queryset, [''], ordered=False) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])), + ['', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter( + num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10]) + ), + ['', '', ''], + ordered=False + ) + + def test_expressions_in_lookups_join_choice(self): + midpoint = datetime.time(13, 0) + t1 = Time.objects.create(time=datetime.time(12, 0)) + t2 = Time.objects.create(time=datetime.time(14, 0)) + SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint) + SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint) + SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint) + SimulationRun.objects.create(start=None, end=None, midpoint=midpoint) + + queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')]) + self.assertQuerysetEqual( + queryset, + [''], + ordered=False + ) + for alias in queryset.query.alias_map.values(): + if isinstance(alias, Join): + self.assertEqual(alias.join_type, constants.INNER) + + queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')]) + self.assertQuerysetEqual(queryset, [], ordered=False) + for alias in queryset.query.alias_map.values(): + if isinstance(alias, Join): + self.assertEqual(alias.join_type, constants.LOUTER) + + def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self): + # Range lookups can use F() expressions for integers. + Company.objects.filter(num_employees__exact=F("num_chairs")) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs'), 100)), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)), + ['', '', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(1, 100)), + [ + '', '', '', + '', '', + ], + ordered=False + ) + + @unittest.skipUnless(connection.vendor == 'sqlite', + "This defensive test only works on databases that don't validate parameter types") + def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self): + """ + This tests that SQL injection isn't possible using compilation of + expressions in iterable filters, as their compilation happens before + the main query compilation. It's limited to SQLite, as PostgreSQL, + Oracle and other vendors have defense in depth against this by type + checking. Testing against SQLite (the most permissive of the built-in + databases) demonstrates that the problem doesn't exist while keeping + the test simple. + """ + queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1']) + self.assertQuerysetEqual(queryset, [], ordered=False) + + def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self): + start = datetime.datetime(2016, 2, 3, 15, 0, 0) + end = datetime.datetime(2016, 2, 5, 15, 0, 0) + experiment_1 = Experiment.objects.create( + name='Integrity testing', + assigned=start.date(), + start=start, + end=end, + completed=end.date(), + estimated_time=end - start, + ) + experiment_2 = Experiment.objects.create( + name='Taste testing', + assigned=start.date(), + start=start, + end=end, + completed=end.date(), + estimated_time=end - start, + ) + Result.objects.create( + experiment=experiment_1, + result_time=datetime.datetime(2016, 2, 4, 15, 0, 0), + ) + Result.objects.create( + experiment=experiment_1, + result_time=datetime.datetime(2016, 3, 10, 2, 0, 0), + ) + Result.objects.create( + experiment=experiment_2, + result_time=datetime.datetime(2016, 1, 8, 5, 0, 0), + ) + + within_experiment_time = [F('experiment__start'), F('experiment__end')] + queryset = Result.objects.filter(result_time__range=within_experiment_time) + self.assertQuerysetEqual(queryset, [""]) + + within_experiment_time = [F('experiment__start'), F('experiment__end')] + queryset = Result.objects.filter(result_time__range=within_experiment_time) + self.assertQuerysetEqual(queryset, [""]) + + class ExpressionsTests(TestCase): def test_F_object_deepcopy(self): diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 6dd1fb85e0..2b5796dc6f 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -173,12 +173,40 @@ class TestQuerying(PostgreSQLTestCase): self.objs[:2] ) + @unittest.expectedFailure + def test_in_including_F_object(self): + # This test asserts that Array objects passed to filters can be + # constructed to contain F objects. This currently doesn't work as the + # psycopg2 mogrify method that generates the ARRAY() syntax is + # expecting literals, not column references (#27095). + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]), + self.objs[:2] + ) + + def test_in_as_F_object(self): + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]), + self.objs[:4] + ) + def test_contained_by(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), self.objs[:2] ) + @unittest.expectedFailure + def test_contained_by_including_F_object(self): + # This test asserts that Array objects passed to filters can be + # constructed to contain F objects. This currently doesn't work as the + # psycopg2 mogrify method that generates the ARRAY() syntax is + # expecting literals, not column references (#27095). + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]), + self.objs[:2] + ) + def test_contains(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contains=[2]),