Fixed #24509 -- Added Expression support to SQLInsertCompiler

This commit is contained in:
Alex Hill 2015-08-04 00:34:19 +10:00 committed by Josh Smeaton
parent 6e51d5d0e5
commit 134ca4d438
15 changed files with 247 additions and 67 deletions

View File

@ -263,11 +263,10 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
from django.contrib.gis.db.backends.oracle.models import OracleSpatialRefSys
return OracleSpatialRefSys
def modify_insert_params(self, placeholders, params):
def modify_insert_params(self, placeholder, params):
"""Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
backend due to #10888
backend due to #10888.
"""
# This code doesn't work for bulk insert cases.
assert len(placeholders) == 1
return [[param for pholder, param
in six.moves.zip(placeholders[0], params[0]) if pholder != 'NULL'], ]
if placeholder == 'NULL':
return []
return super(OracleOperations, self).modify_insert_params(placeholder, params)

View File

@ -576,7 +576,7 @@ class BaseDatabaseOperations(object):
def combine_duration_expression(self, connector, sub_expressions):
return self.combine_expression(connector, sub_expressions)
def modify_insert_params(self, placeholders, params):
def modify_insert_params(self, placeholder, params):
"""Allow modification of insert parameters. Needed for Oracle Spatial
backend due to #10888.
"""

View File

@ -166,9 +166,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def max_name_length(self):
return 64
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
"""

View File

@ -439,6 +439,8 @@ WHEN (new.%(col_name)s IS NULL)
name_length = self.max_name_length() - 3
return '%s_TR' % truncate_name(table, name_length).upper()
def bulk_insert_sql(self, fields, num_values):
items_sql = "SELECT %s FROM DUAL" % ", ".join(["%s"] * len(fields))
return " UNION ALL ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s FROM DUAL" % ", ".join(row)
for row in placeholder_rows
)

View File

@ -221,9 +221,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_id(self):
return "RETURNING %s", ()
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def adapt_datefield_value(self, value):
return value

View File

@ -226,13 +226,11 @@ class DatabaseOperations(BaseDatabaseOperations):
value = uuid.UUID(value)
return value
def bulk_insert_sql(self, fields, num_values):
res = []
res.append("SELECT %s" % ", ".join(
"%%s AS %s" % self.quote_name(f.column) for f in fields
))
res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
return " ".join(res)
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a power function, so we fake it with a

View File

@ -180,6 +180,13 @@ class BaseExpression(object):
return True
return False
@cached_property
def contains_column_references(self):
for expr in self.get_source_expressions():
if expr and expr.contains_column_references:
return True
return False
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
"""
Provides the chance to do any preprocessing or validation before being
@ -339,6 +346,17 @@ class BaseExpression(object):
def reverse_ordering(self):
return self
def flatten(self):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
"""
yield self
for expr in self.get_source_expressions():
if expr:
for inner_expr in expr.flatten():
yield inner_expr
class Expression(BaseExpression, Combinable):
"""
@ -613,6 +631,9 @@ class Random(Expression):
class Col(Expression):
contains_column_references = True
def __init__(self, alias, target, output_field=None):
if output_field is None:
output_field = target

View File

@ -458,6 +458,8 @@ class QuerySet(object):
specifying whether an object was created.
"""
lookup, params = self._extract_model_params(defaults, **kwargs)
# The get() needs to be targeted at the write database in order
# to avoid potential transaction consistency problems.
self._for_write = True
try:
return self.get(**lookup), False

View File

@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler):
self.return_id = False
super(SQLInsertCompiler, self).__init__(*args, **kwargs)
def placeholder(self, field, val):
def field_as_sql(self, field, val):
"""
Take a field and a value intended to be saved on that field, and
return placeholder SQL and accompanying params. Checks for raw values,
expressions and fields with get_placeholder() defined in that order.
When field is None, the value is considered raw and is used as the
placeholder, with no corresponding parameters returned.
"""
if field is None:
# A field value of None means the value is raw.
return val
sql, params = val, []
elif hasattr(val, 'as_sql'):
# This is an expression, let's compile it.
sql, params = self.compile(val)
elif hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
return field.get_placeholder(val, self, self.connection)
sql, params = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
return '%s'
sql, params = '%s', [val]
# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
# of '%s' and [None]. The 'NULL' placeholder is produced earlier by
# OracleOperations.get_geom_placeholder(). The following line removes
# the corresponding None parameter. See ticket #10888.
params = self.connection.ops.modify_insert_params(sql, params)
return sql, params
def prepare_value(self, field, value):
"""
Prepare a value to be used in a query by resolving it if it is an
expression and otherwise calling the field's get_db_prep_save().
"""
if hasattr(value, 'resolve_expression'):
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
# Don't allow values containing Col expressions. They refer to
# existing columns on a row, but in the case of insert the row
# doesn't exist yet.
if value.contains_column_references:
raise ValueError(
'Failed to insert expression "%s" on %s. F() expressions '
'can only be used to update, not to insert.' % (value, field)
)
if value.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
else:
value = field.get_db_prep_save(value, connection=self.connection)
return value
def pre_save_val(self, field, obj):
"""
Get the given field's value off the given obj. pre_save() is used for
things like auto_now on DateTimeField. Skip it if this is a raw query.
"""
if self.query.raw:
return getattr(obj, field.attname)
return field.pre_save(obj, add=True)
def assemble_as_sql(self, fields, value_rows):
"""
Take a sequence of N fields and a sequence of M rows of values,
generate placeholder SQL and parameters for each field and value, and
return a pair containing:
* a sequence of M rows of N SQL placeholder strings, and
* a sequence of M rows of corresponding parameter values.
Each placeholder string may contain any number of '%s' interpolation
strings, and each parameter row will contain exactly as many params
as the total number of '%s's in the corresponding placeholder row.
"""
if not value_rows:
return [], []
# list of (sql, [params]) tuples for each object to be saved
# Shape: [n_objs][n_fields][2]
rows_of_fields_as_sql = (
(self.field_as_sql(field, v) for field, v in zip(fields, row))
for row in value_rows
)
# tuple like ([sqls], [[params]s]) for each object to be saved
# Shape: [n_objs][2][n_fields]
sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
# Extract separate lists for placeholders and params.
# Each of these has shape [n_objs][n_fields]
placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
# Params for each field are still lists, and need to be flattened.
param_rows = [[p for ps in row for p in ps] for row in param_rows]
return placeholder_rows, param_rows
def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler):
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
if has_fields:
params = values = [
[
f.get_db_prep_save(
getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True),
connection=self.connection
) for f in fields
]
value_rows = [
[self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
for obj in self.query.objs
]
else:
values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
params = [[]]
# An empty object.
value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
fields = [None]
can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
not self.return_id and self.connection.features.has_bulk_insert)
if can_bulk:
placeholders = [["%s"] * len(fields)]
else:
placeholders = [
[self.placeholder(field, v) for field, v in zip(fields, val)]
for val in values
]
# Oracle Spatial needs to remove some values due to #10888
params = self.connection.ops.modify_insert_params(placeholders, params)
# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
if self.return_id and self.connection.features.can_return_id_from_insert:
params = params[0]
params = param_rows[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholders[0]))
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
r_fmt, r_params = self.connection.ops.return_insert_id()
# Skip empty r_fmt to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler):
result.append(r_fmt % col)
params += r_params
return [(" ".join(result), tuple(params))]
if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
return [(" ".join(result), tuple(v for val in values for v in val))]
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
else:
return [
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
for p, vals in zip(placeholders, params)
for p, vals in zip(placeholder_rows, param_rows)
]
def execute_sql(self, return_id=False):
@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler):
connection=self.connection,
)
else:
raise TypeError("Database is trying to update a relational field "
"of type %s with a value of type %s. Make sure "
"you are setting the correct relations" %
(field.__class__.__name__, val.__class__.__name__))
raise TypeError(
"Tried to update field %s with a model instance, %r. "
"Use a value compatible with %s."
% (field, val, field.__class__.__name__)
)
else:
val = field.get_db_prep_save(val, connection=self.connection)

View File

@ -139,9 +139,9 @@ class UpdateQuery(Query):
def add_update_fields(self, values_seq):
"""
Turn a sequence of (field, model, value) triples into an update query.
Used by add_update_values() as well as the "fast" update path when
saving models.
Append a sequence of (field, model, value) triples to the internal list
that will be used to generate the UPDATE query. Might be more usefully
called add_update_targets() to hint at the extra information here.
"""
self.values.extend(values_seq)

View File

@ -5,10 +5,14 @@ Query Expressions
.. currentmodule:: django.db.models
Query expressions describe a value or a computation that can be used as part of
a filter, order by, annotation, or aggregate. There are a number of built-in
expressions (documented below) that can be used to help you write queries.
Expressions can be combined, or in some cases nested, to form more complex
computations.
an update, create, filter, order by, annotation, or aggregate. There are a
number of built-in expressions (documented below) that can be used to help you
write queries. Expressions can be combined, or in some cases nested, to form
more complex computations.
.. versionchanged:: 1.9
Support for using expressions when creating new model instances was added.
Supported arithmetic
====================
@ -27,7 +31,7 @@ Some examples
.. code-block:: python
from django.db.models import F, Count
from django.db.models.functions import Length
from django.db.models.functions import Length, Upper, Value
# Find companies that have more employees than chairs.
Company.objects.filter(num_employees__gt=F('num_chairs'))
@ -49,6 +53,13 @@ Some examples
>>> company.chairs_needed
70
# Create a new company using expressions.
>>> company = Company.objects.create(name='Google', ticker=Upper(Value('goog')))
# Be sure to refresh it if you need to access the field.
>>> company.refresh_from_db()
>>> company.ticker
'GOOG'
# Annotate models with an aggregated value. Both forms
# below are equivalent.
Company.objects.annotate(num_products=Count('products'))
@ -122,6 +133,8 @@ and describe the operation.
will need to be reloaded::
reporter = Reporters.objects.get(pk=reporter.pk)
# Or, more succinctly:
reporter.refresh_from_db()
As well as being used in operations on single instances as above, ``F()`` can
be used on ``QuerySets`` of object instances, with ``update()``. This reduces
@ -356,7 +369,10 @@ boolean, or string within an expression, you can wrap that value within a
You will rarely need to use ``Value()`` directly. When you write the expression
``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``,
allowing simple values to be used in more complex expressions.
allowing simple values to be used in more complex expressions. You will need to
use ``Value()`` when you want to pass a string to an expression. Most
expressions interpret a string argument as the name of a field, like
``Lower('name')``.
The ``value`` argument describes the value to be included in the expression,
such as ``1``, ``True``, or ``None``. Django knows how to convert these Python

View File

@ -542,6 +542,10 @@ Models
* Added a new model field check that makes sure
:attr:`~django.db.models.Field.default` is a valid value.
* :doc:`Query expressions </ref/models/expressions>` can now be used when
creating new model instances using ``save()``, ``create()``, and
``bulk_create()``.
Requests and Responses
^^^^^^^^^^^^^^^^^^^^^^

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
from operator import attrgetter
from django.db import connection
from django.db.models import Value
from django.db.models.functions import Lower
from django.test import (
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
)
@ -183,3 +185,12 @@ class BulkCreateTests(TestCase):
TwoFields.objects.all().delete()
with self.assertNumQueries(1):
TwoFields.objects.bulk_create(objs, len(objs))
@skipUnlessDBFeature('has_bulk_insert')
def test_bulk_insert_expressions(self):
Restaurant.objects.bulk_create([
Restaurant(name="Sam's Shake Shack"),
Restaurant(name=Lower(Value("Betty's Beetroot Bar")))
])
bbb = Restaurant.objects.filter(name="betty's beetroot bar")
self.assertEqual(bbb.count(), 1)

View File

@ -249,6 +249,32 @@ class BasicExpressionsTests(TestCase):
test_gmbh = Company.objects.get(pk=test_gmbh.pk)
self.assertEqual(test_gmbh.num_employees, 36)
def test_new_object_save(self):
# We should be able to use Funcs when inserting new data
test_co = Company(
name=Lower(Value("UPPER")), num_employees=32, num_chairs=1,
ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
)
test_co.save()
test_co.refresh_from_db()
self.assertEqual(test_co.name, "upper")
def test_new_object_create(self):
test_co = Company.objects.create(
name=Lower(Value("UPPER")), num_employees=32, num_chairs=1,
ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
)
test_co.refresh_from_db()
self.assertEqual(test_co.name, "upper")
def test_object_create_with_aggregate(self):
# Aggregates are not allowed when inserting new data
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
Company.objects.create(
name='Company', num_employees=Max(Value(1)), num_chairs=1,
ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
)
def test_object_update_fk(self):
# F expressions cannot be used to update attributes which are foreign
# keys, or attributes which involve joins.
@ -272,7 +298,22 @@ class BasicExpressionsTests(TestCase):
ceo=test_gmbh.ceo
)
acme.num_employees = F("num_employees") + 16
self.assertRaises(TypeError, acme.save)
msg = (
'Failed to insert expression "Col(expressions_company, '
'expressions.Company.num_employees) + Value(16)" on '
'expressions.Company.num_employees. F() expressions can only be '
'used to update, not to insert.'
)
self.assertRaisesMessage(ValueError, msg, acme.save)
acme.num_employees = 12
acme.name = Lower(F('name'))
msg = (
'Failed to insert expression "Lower(Col(expressions_company, '
'expressions.Company.name))" on expressions.Company.name. F() '
'expressions can only be used to update, not to insert.'
)
self.assertRaisesMessage(ValueError, msg, acme.save)
def test_ticket_11722_iexact_lookup(self):
Employee.objects.create(firstname="John", lastname="Doe")

View File

@ -98,8 +98,13 @@ class BasicFieldTests(test.TestCase):
self.assertTrue(instance.id)
# Set field to object on saved instance
instance.size = instance
msg = (
"Tried to update field model_fields.FloatModel.size with a model "
"instance, <FloatModel: FloatModel object>. Use a value "
"compatible with FloatField."
)
with transaction.atomic():
with self.assertRaises(TypeError):
with self.assertRaisesMessage(TypeError, msg):
instance.save()
# Try setting field to object on retrieved object
obj = FloatModel.objects.get(pk=instance.id)