diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 30e87e63a1..6e901129d9 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -263,11 +263,10 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): from django.contrib.gis.db.backends.oracle.models import OracleSpatialRefSys return OracleSpatialRefSys - def modify_insert_params(self, placeholders, params): + def modify_insert_params(self, placeholder, params): """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial - backend due to #10888 + backend due to #10888. """ - # This code doesn't work for bulk insert cases. - assert len(placeholders) == 1 - return [[param for pholder, param - in six.moves.zip(placeholders[0], params[0]) if pholder != 'NULL'], ] + if placeholder == 'NULL': + return [] + return super(OracleOperations, self).modify_insert_params(placeholder, params) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index fe7c27a7e7..79572651c4 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -576,7 +576,7 @@ class BaseDatabaseOperations(object): def combine_duration_expression(self, connector, sub_expressions): return self.combine_expression(connector, sub_expressions) - def modify_insert_params(self, placeholders, params): + def modify_insert_params(self, placeholder, params): """Allow modification of insert parameters. Needed for Oracle Spatial backend due to #10888. """ diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index a496079932..85854dc1b6 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -166,9 +166,10 @@ class DatabaseOperations(BaseDatabaseOperations): def max_name_length(self): return 64 - def bulk_insert_sql(self, fields, num_values): - items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) - return "VALUES " + ", ".join([items_sql] * num_values) + def bulk_insert_sql(self, fields, placeholder_rows): + placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) + values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) + return "VALUES " + values_sql def combine_expression(self, connector, sub_expressions): """ diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 1cd17a2af1..ce6946011b 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -439,6 +439,8 @@ WHEN (new.%(col_name)s IS NULL) name_length = self.max_name_length() - 3 return '%s_TR' % truncate_name(table, name_length).upper() - def bulk_insert_sql(self, fields, num_values): - items_sql = "SELECT %s FROM DUAL" % ", ".join(["%s"] * len(fields)) - return " UNION ALL ".join([items_sql] * num_values) + def bulk_insert_sql(self, fields, placeholder_rows): + return " UNION ALL ".join( + "SELECT %s FROM DUAL" % ", ".join(row) + for row in placeholder_rows + ) diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 866e2ca38b..3624c9cf56 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -221,9 +221,10 @@ class DatabaseOperations(BaseDatabaseOperations): def return_insert_id(self): return "RETURNING %s", () - def bulk_insert_sql(self, fields, num_values): - items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) - return "VALUES " + ", ".join([items_sql] * num_values) + def bulk_insert_sql(self, fields, placeholder_rows): + placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) + values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) + return "VALUES " + values_sql def adapt_datefield_value(self, value): return value diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 7f99eaa271..91d1a27f8a 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -226,13 +226,11 @@ class DatabaseOperations(BaseDatabaseOperations): value = uuid.UUID(value) return value - def bulk_insert_sql(self, fields, num_values): - res = [] - res.append("SELECT %s" % ", ".join( - "%%s AS %s" % self.quote_name(f.column) for f in fields - )) - res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1)) - return " ".join(res) + def bulk_insert_sql(self, fields, placeholder_rows): + return " UNION ALL ".join( + "SELECT %s" % ", ".join(row) + for row in placeholder_rows + ) def combine_expression(self, connector, sub_expressions): # SQLite doesn't have a power function, so we fake it with a diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 0271c7d3d3..d9540055b9 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -180,6 +180,13 @@ class BaseExpression(object): return True return False + @cached_property + def contains_column_references(self): + for expr in self.get_source_expressions(): + if expr and expr.contains_column_references: + return True + return False + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): """ Provides the chance to do any preprocessing or validation before being @@ -339,6 +346,17 @@ class BaseExpression(object): def reverse_ordering(self): return self + def flatten(self): + """ + Recursively yield this expression and all subexpressions, in + depth-first order. + """ + yield self + for expr in self.get_source_expressions(): + if expr: + for inner_expr in expr.flatten(): + yield inner_expr + class Expression(BaseExpression, Combinable): """ @@ -613,6 +631,9 @@ class Random(Expression): class Col(Expression): + + contains_column_references = True + def __init__(self, alias, target, output_field=None): if output_field is None: output_field = target diff --git a/django/db/models/query.py b/django/db/models/query.py index 9361a8f597..b7d55b6a4b 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -458,6 +458,8 @@ class QuerySet(object): specifying whether an object was created. """ lookup, params = self._extract_model_params(defaults, **kwargs) + # The get() needs to be targeted at the write database in order + # to avoid potential transaction consistency problems. self._for_write = True try: return self.get(**lookup), False diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 3f50e951a3..90121bdb37 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler): self.return_id = False super(SQLInsertCompiler, self).__init__(*args, **kwargs) - def placeholder(self, field, val): + def field_as_sql(self, field, val): + """ + Take a field and a value intended to be saved on that field, and + return placeholder SQL and accompanying params. Checks for raw values, + expressions and fields with get_placeholder() defined in that order. + + When field is None, the value is considered raw and is used as the + placeholder, with no corresponding parameters returned. + """ if field is None: # A field value of None means the value is raw. - return val + sql, params = val, [] + elif hasattr(val, 'as_sql'): + # This is an expression, let's compile it. + sql, params = self.compile(val) elif hasattr(field, 'get_placeholder'): # Some fields (e.g. geo fields) need special munging before # they can be inserted. - return field.get_placeholder(val, self, self.connection) + sql, params = field.get_placeholder(val, self, self.connection), [val] else: # Return the common case for the placeholder - return '%s' + sql, params = '%s', [val] + + # The following hook is only used by Oracle Spatial, which sometimes + # needs to yield 'NULL' and [] as its placeholder and params instead + # of '%s' and [None]. The 'NULL' placeholder is produced earlier by + # OracleOperations.get_geom_placeholder(). The following line removes + # the corresponding None parameter. See ticket #10888. + params = self.connection.ops.modify_insert_params(sql, params) + + return sql, params + + def prepare_value(self, field, value): + """ + Prepare a value to be used in a query by resolving it if it is an + expression and otherwise calling the field's get_db_prep_save(). + """ + if hasattr(value, 'resolve_expression'): + value = value.resolve_expression(self.query, allow_joins=False, for_save=True) + # Don't allow values containing Col expressions. They refer to + # existing columns on a row, but in the case of insert the row + # doesn't exist yet. + if value.contains_column_references: + raise ValueError( + 'Failed to insert expression "%s" on %s. F() expressions ' + 'can only be used to update, not to insert.' % (value, field) + ) + if value.contains_aggregate: + raise FieldError("Aggregate functions are not allowed in this query") + else: + value = field.get_db_prep_save(value, connection=self.connection) + return value + + def pre_save_val(self, field, obj): + """ + Get the given field's value off the given obj. pre_save() is used for + things like auto_now on DateTimeField. Skip it if this is a raw query. + """ + if self.query.raw: + return getattr(obj, field.attname) + return field.pre_save(obj, add=True) + + def assemble_as_sql(self, fields, value_rows): + """ + Take a sequence of N fields and a sequence of M rows of values, + generate placeholder SQL and parameters for each field and value, and + return a pair containing: + * a sequence of M rows of N SQL placeholder strings, and + * a sequence of M rows of corresponding parameter values. + + Each placeholder string may contain any number of '%s' interpolation + strings, and each parameter row will contain exactly as many params + as the total number of '%s's in the corresponding placeholder row. + """ + if not value_rows: + return [], [] + + # list of (sql, [params]) tuples for each object to be saved + # Shape: [n_objs][n_fields][2] + rows_of_fields_as_sql = ( + (self.field_as_sql(field, v) for field, v in zip(fields, row)) + for row in value_rows + ) + + # tuple like ([sqls], [[params]s]) for each object to be saved + # Shape: [n_objs][2][n_fields] + sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql) + + # Extract separate lists for placeholders and params. + # Each of these has shape [n_objs][n_fields] + placeholder_rows, param_rows = zip(*sql_and_param_pair_rows) + + # Params for each field are still lists, and need to be flattened. + param_rows = [[p for ps in row for p in ps] for row in param_rows] + + return placeholder_rows, param_rows def as_sql(self): # We don't need quote_name_unless_alias() here, since these are all @@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler): result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) if has_fields: - params = values = [ - [ - f.get_db_prep_save( - getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), - connection=self.connection - ) for f in fields - ] + value_rows = [ + [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields] for obj in self.query.objs ] else: - values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] - params = [[]] + # An empty object. + value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs] fields = [None] - can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and - not self.return_id and self.connection.features.has_bulk_insert) - if can_bulk: - placeholders = [["%s"] * len(fields)] - else: - placeholders = [ - [self.placeholder(field, v) for field, v in zip(fields, val)] - for val in values - ] - # Oracle Spatial needs to remove some values due to #10888 - params = self.connection.ops.modify_insert_params(placeholders, params) + # Currently the backends just accept values when generating bulk + # queries and generate their own placeholders. Doing that isn't + # necessary and it should be possible to use placeholders and + # expressions in bulk inserts too. + can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) + + placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) + if self.return_id and self.connection.features.can_return_id_from_insert: - params = params[0] + params = param_rows[0] col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - result.append("VALUES (%s)" % ", ".join(placeholders[0])) + result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) r_fmt, r_params = self.connection.ops.return_insert_id() # Skip empty r_fmt to allow subclasses to customize behavior for # 3rd party backends. Refs #19096. @@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler): result.append(r_fmt % col) params += r_params return [(" ".join(result), tuple(params))] + if can_bulk: - result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) - return [(" ".join(result), tuple(v for val in values for v in val))] + result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) + return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] else: return [ (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) - for p, vals in zip(placeholders, params) + for p, vals in zip(placeholder_rows, param_rows) ] def execute_sql(self, return_id=False): @@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler): connection=self.connection, ) else: - raise TypeError("Database is trying to update a relational field " - "of type %s with a value of type %s. Make sure " - "you are setting the correct relations" % - (field.__class__.__name__, val.__class__.__name__)) + raise TypeError( + "Tried to update field %s with a model instance, %r. " + "Use a value compatible with %s." + % (field, val, field.__class__.__name__) + ) else: val = field.get_db_prep_save(val, connection=self.connection) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 2dbdf2edd7..35c814f903 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -139,9 +139,9 @@ class UpdateQuery(Query): def add_update_fields(self, values_seq): """ - Turn a sequence of (field, model, value) triples into an update query. - Used by add_update_values() as well as the "fast" update path when - saving models. + Append a sequence of (field, model, value) triples to the internal list + that will be used to generate the UPDATE query. Might be more usefully + called add_update_targets() to hint at the extra information here. """ self.values.extend(values_seq) diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index e114fb56ed..e3078b5449 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -5,10 +5,14 @@ Query Expressions .. currentmodule:: django.db.models Query expressions describe a value or a computation that can be used as part of -a filter, order by, annotation, or aggregate. There are a number of built-in -expressions (documented below) that can be used to help you write queries. -Expressions can be combined, or in some cases nested, to form more complex -computations. +an update, create, filter, order by, annotation, or aggregate. There are a +number of built-in expressions (documented below) that can be used to help you +write queries. Expressions can be combined, or in some cases nested, to form +more complex computations. + +.. versionchanged:: 1.9 + + Support for using expressions when creating new model instances was added. Supported arithmetic ==================== @@ -27,7 +31,7 @@ Some examples .. code-block:: python from django.db.models import F, Count - from django.db.models.functions import Length + from django.db.models.functions import Length, Upper, Value # Find companies that have more employees than chairs. Company.objects.filter(num_employees__gt=F('num_chairs')) @@ -49,6 +53,13 @@ Some examples >>> company.chairs_needed 70 + # Create a new company using expressions. + >>> company = Company.objects.create(name='Google', ticker=Upper(Value('goog'))) + # Be sure to refresh it if you need to access the field. + >>> company.refresh_from_db() + >>> company.ticker + 'GOOG' + # Annotate models with an aggregated value. Both forms # below are equivalent. Company.objects.annotate(num_products=Count('products')) @@ -122,6 +133,8 @@ and describe the operation. will need to be reloaded:: reporter = Reporters.objects.get(pk=reporter.pk) + # Or, more succinctly: + reporter.refresh_from_db() As well as being used in operations on single instances as above, ``F()`` can be used on ``QuerySets`` of object instances, with ``update()``. This reduces @@ -356,7 +369,10 @@ boolean, or string within an expression, you can wrap that value within a You will rarely need to use ``Value()`` directly. When you write the expression ``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``, -allowing simple values to be used in more complex expressions. +allowing simple values to be used in more complex expressions. You will need to +use ``Value()`` when you want to pass a string to an expression. Most +expressions interpret a string argument as the name of a field, like +``Lower('name')``. The ``value`` argument describes the value to be included in the expression, such as ``1``, ``True``, or ``None``. Django knows how to convert these Python diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt index bc30395b13..0b1463ef13 100644 --- a/docs/releases/1.9.txt +++ b/docs/releases/1.9.txt @@ -542,6 +542,10 @@ Models * Added a new model field check that makes sure :attr:`~django.db.models.Field.default` is a valid value. +* :doc:`Query expressions ` can now be used when + creating new model instances using ``save()``, ``create()``, and + ``bulk_create()``. + Requests and Responses ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 3a0c654112..ce069504d0 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals from operator import attrgetter from django.db import connection +from django.db.models import Value +from django.db.models.functions import Lower from django.test import ( TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, ) @@ -183,3 +185,12 @@ class BulkCreateTests(TestCase): TwoFields.objects.all().delete() with self.assertNumQueries(1): TwoFields.objects.bulk_create(objs, len(objs)) + + @skipUnlessDBFeature('has_bulk_insert') + def test_bulk_insert_expressions(self): + Restaurant.objects.bulk_create([ + Restaurant(name="Sam's Shake Shack"), + Restaurant(name=Lower(Value("Betty's Beetroot Bar"))) + ]) + bbb = Restaurant.objects.filter(name="betty's beetroot bar") + self.assertEqual(bbb.count(), 1) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index a18084e06f..1af0b6e7a2 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -249,6 +249,32 @@ class BasicExpressionsTests(TestCase): test_gmbh = Company.objects.get(pk=test_gmbh.pk) self.assertEqual(test_gmbh.num_employees, 36) + def test_new_object_save(self): + # We should be able to use Funcs when inserting new data + test_co = Company( + name=Lower(Value("UPPER")), num_employees=32, num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + test_co.save() + test_co.refresh_from_db() + self.assertEqual(test_co.name, "upper") + + def test_new_object_create(self): + test_co = Company.objects.create( + name=Lower(Value("UPPER")), num_employees=32, num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + test_co.refresh_from_db() + self.assertEqual(test_co.name, "upper") + + def test_object_create_with_aggregate(self): + # Aggregates are not allowed when inserting new data + with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): + Company.objects.create( + name='Company', num_employees=Max(Value(1)), num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + def test_object_update_fk(self): # F expressions cannot be used to update attributes which are foreign # keys, or attributes which involve joins. @@ -272,7 +298,22 @@ class BasicExpressionsTests(TestCase): ceo=test_gmbh.ceo ) acme.num_employees = F("num_employees") + 16 - self.assertRaises(TypeError, acme.save) + msg = ( + 'Failed to insert expression "Col(expressions_company, ' + 'expressions.Company.num_employees) + Value(16)" on ' + 'expressions.Company.num_employees. F() expressions can only be ' + 'used to update, not to insert.' + ) + self.assertRaisesMessage(ValueError, msg, acme.save) + + acme.num_employees = 12 + acme.name = Lower(F('name')) + msg = ( + 'Failed to insert expression "Lower(Col(expressions_company, ' + 'expressions.Company.name))" on expressions.Company.name. F() ' + 'expressions can only be used to update, not to insert.' + ) + self.assertRaisesMessage(ValueError, msg, acme.save) def test_ticket_11722_iexact_lookup(self): Employee.objects.create(firstname="John", lastname="Doe") diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 3d791980ba..23e90cccb8 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -98,8 +98,13 @@ class BasicFieldTests(test.TestCase): self.assertTrue(instance.id) # Set field to object on saved instance instance.size = instance + msg = ( + "Tried to update field model_fields.FloatModel.size with a model " + "instance, . Use a value " + "compatible with FloatField." + ) with transaction.atomic(): - with self.assertRaises(TypeError): + with self.assertRaisesMessage(TypeError, msg): instance.save() # Try setting field to object on retrieved object obj = FloatModel.objects.get(pk=instance.id)