mirror of
https://github.com/django/django.git
synced 2025-03-31 11:37:06 +00:00
Fixed #29771 -- Support database-specific syntax for bulk_update
This commit is contained in:
parent
857b1048d5
commit
29311aae15
@ -302,6 +302,9 @@ class BaseDatabaseFeatures:
|
||||
# Does this backend require casting the results of CASE expressions used
|
||||
# in UPDATE statements to ensure the expression has the correct type?
|
||||
requires_casted_case_in_updates = False
|
||||
# Does this backend require casting the results of a VALUES expression used
|
||||
# in UPDATE statements to ensure the expression has the correct type?
|
||||
requires_casted_case_in_values_updates = False
|
||||
|
||||
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
||||
supports_partial_indexes = True
|
||||
|
@ -22,6 +22,17 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLBulkUpdateCompiler(compiler.SQLBulkUpdateCompiler, SQLCompiler):
|
||||
def get_update_clause(self, values):
|
||||
qn = self.quote_name_unless_alias
|
||||
values_sql, value_params = self.compile(values)
|
||||
update_sql = f"UPDATE {qn(self.query.base_table)}, ({values_sql}) as subquery"
|
||||
return update_sql, tuple(value_params)
|
||||
|
||||
def get_subquery_from_clause(self, values):
|
||||
return None, tuple()
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
||||
def as_sql(self):
|
||||
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
||||
|
@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
)
|
||||
"""
|
||||
requires_casted_case_in_updates = True
|
||||
requires_casted_case_in_values_updates = True
|
||||
supports_over_clause = True
|
||||
supports_frame_exclusion = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
|
@ -2140,3 +2140,72 @@ class ValueRange(WindowFrame):
|
||||
|
||||
def window_frame_start_end(self, connection, start, end):
|
||||
return connection.ops.window_frame_range_start_end(start, end)
|
||||
|
||||
|
||||
class RowTuple(ExpressionList):
|
||||
template = "(%(expressions)s)"
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# MySQL requires the ROW() function to be used for tuples in VALUES clauses.
|
||||
if not connection.mysql_is_mariadb:
|
||||
extra_context["template"] = "ROW(%(expressions)s)"
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def __str__(self):
|
||||
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||
return self.template % {"expressions": values}
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.source_expressions[item]
|
||||
|
||||
|
||||
class RowTupleValues(ExpressionList):
|
||||
template = "VALUES %(expressions)s"
|
||||
|
||||
def __init__(self, expression_list, pk_field, field_list, **extra):
|
||||
expressions = (
|
||||
(
|
||||
RowTuple(*expressions)
|
||||
if not isinstance(expressions, RowTuple)
|
||||
else expressions
|
||||
)
|
||||
for expressions in expression_list
|
||||
)
|
||||
self.pk_field = pk_field
|
||||
self.field_list = field_list
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# MySQL doesn't support aliases in VALUES clauses. The workaround is to use a
|
||||
# UNION as the first column
|
||||
# SELECT 1 AS x ,2 AS y
|
||||
# UNION VALUES (3,4),(5,6);
|
||||
first_row = self.source_expressions.pop()
|
||||
|
||||
class MySQLJoiner:
|
||||
def join(self, expressions):
|
||||
return ", ".join(
|
||||
f"{col_sql} AS column{idx}"
|
||||
for idx, col_sql in enumerate(expressions, 1)
|
||||
)
|
||||
|
||||
sql, params = first_row.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(expressions)s",
|
||||
**extra_context,
|
||||
arg_joiner=MySQLJoiner(),
|
||||
)
|
||||
sql = f"SELECT {sql}"
|
||||
if self.source_expressions:
|
||||
rest_sql, rest_params = super().as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
sql += f" UNION {rest_sql}"
|
||||
params += rest_params
|
||||
|
||||
return sql, params
|
||||
|
||||
def __str__(self):
|
||||
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||
return self.template % {"expressions": values}
|
||||
|
@ -23,7 +23,7 @@ from django.db import (
|
||||
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
||||
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
||||
from django.db.models.deletion import Collector
|
||||
from django.db.models.expressions import Case, F, Value, When
|
||||
from django.db.models.expressions import Case, F, RowTuple, RowTupleValues, Value, When
|
||||
from django.db.models.functions import Cast, Trunc
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||
@ -890,31 +890,73 @@ class QuerySet(AltersData):
|
||||
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
|
||||
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
|
||||
requires_casting = connection.features.requires_casted_case_in_updates
|
||||
requires_casting_in_values = (
|
||||
connection.features.requires_casted_case_in_values_updates
|
||||
)
|
||||
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
|
||||
updates = []
|
||||
|
||||
has_related_fields = any(f.model is not self.model for f in fields)
|
||||
has_field_references = False
|
||||
for batch_objs in batches:
|
||||
update_kwargs = {}
|
||||
for field in fields:
|
||||
when_statements = []
|
||||
for obj in batch_objs:
|
||||
row_tuples = []
|
||||
for obj in batch_objs:
|
||||
values = [Value(getattr(obj, "pk"))]
|
||||
for field in fields:
|
||||
attr = getattr(obj, field.attname)
|
||||
if not hasattr(attr, "resolve_expression"):
|
||||
if hasattr(attr, "resolve_expression"):
|
||||
has_field_references = True
|
||||
else:
|
||||
attr = Value(attr, output_field=field)
|
||||
when_statements.append(When(pk=obj.pk, then=attr))
|
||||
case_statement = Case(*when_statements, output_field=field)
|
||||
if requires_casting:
|
||||
case_statement = Cast(case_statement, output_field=field)
|
||||
update_kwargs[field.attname] = case_statement
|
||||
updates.append(([obj.pk for obj in batch_objs], update_kwargs))
|
||||
|
||||
if requires_casting_in_values:
|
||||
attr = Cast(attr, output_field=field)
|
||||
values.append(attr)
|
||||
|
||||
row_tuples.append(RowTuple(*values))
|
||||
updates.append(
|
||||
RowTupleValues(
|
||||
row_tuples, pk_field=self.model._meta.pk, field_list=fields
|
||||
)
|
||||
)
|
||||
|
||||
rows_updated = 0
|
||||
queryset = self.using(self.db)
|
||||
with transaction.atomic(using=self.db, savepoint=False):
|
||||
for pks, update_kwargs in updates:
|
||||
rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
|
||||
for row_tuples in updates:
|
||||
if has_field_references or has_related_fields:
|
||||
rows_updated += self._bulk_update_slow(
|
||||
row_tuples, requires_casting=requires_casting
|
||||
)
|
||||
else:
|
||||
query = sql.BulkUpdateQuery(self.model, row_tuples=row_tuples)
|
||||
rows_updated += query.get_compiler(using=self.db).execute_sql(
|
||||
result_type=CURSOR
|
||||
)
|
||||
return rows_updated
|
||||
|
||||
bulk_update.alters_data = True
|
||||
|
||||
def _bulk_update_slow(self, row_tuples, requires_casting=False):
|
||||
pks = [
|
||||
row_tuple.source_expressions[0].value
|
||||
for row_tuple in row_tuples.source_expressions
|
||||
]
|
||||
update_kwargs = {}
|
||||
# Skip the ID column
|
||||
for field_idx, field in enumerate(row_tuples.field_list, start=1):
|
||||
when_statements = []
|
||||
for row_tuple in row_tuples.source_expressions:
|
||||
row_tuple: RowTuple
|
||||
id_expression = row_tuple.source_expressions[0]
|
||||
attr_expression = row_tuple.source_expressions[field_idx]
|
||||
when_statements.append(When(pk=id_expression, then=attr_expression))
|
||||
case_statement = Case(*when_statements, output_field=field)
|
||||
if requires_casting:
|
||||
case_statement = Cast(case_statement, output_field=field)
|
||||
update_kwargs[field.attname] = case_statement
|
||||
|
||||
return self.filter(pk__in=pks).update(**update_kwargs)
|
||||
|
||||
async def abulk_update(self, objs, fields, batch_size=None):
|
||||
return await sync_to_async(self.bulk_update)(
|
||||
objs=objs,
|
||||
|
@ -2090,6 +2090,62 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
self.query.reset_refcounts(refcounts_before)
|
||||
|
||||
|
||||
class SQLBulkUpdateCompiler(SQLCompiler):
|
||||
def execute_sql(self, result_type):
|
||||
cursor = super().execute_sql(result_type)
|
||||
try:
|
||||
rows = cursor.rowcount if cursor else 0
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
return rows
|
||||
|
||||
def get_update_clause(self, values):
|
||||
qn = self.quote_name_unless_alias
|
||||
return "UPDATE %s " % qn(self.query.base_table), tuple()
|
||||
|
||||
def get_subquery_from_clause(self, values):
|
||||
values_sql, value_params = self.compile(values)
|
||||
return "FROM (%s) AS subquery" % values_sql, tuple(value_params)
|
||||
|
||||
def get_set_sql_clause(self, values):
|
||||
qn = self.quote_name_unless_alias
|
||||
set_sql = []
|
||||
for idx, field in enumerate(values.field_list, 2):
|
||||
name = field.column
|
||||
set_sql.append(
|
||||
"%s = %s.%s" % (qn(name), qn("subquery"), qn(f"column{idx}"))
|
||||
)
|
||||
|
||||
return "SET %s" % ", ".join(set_sql), tuple()
|
||||
|
||||
def get_where_clause(self, values):
|
||||
qn = self.quote_name_unless_alias
|
||||
return (
|
||||
f'WHERE {qn(values.pk_field.column)} = {qn("subquery")}.{qn("column1")}',
|
||||
tuple(),
|
||||
)
|
||||
|
||||
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||
self.pre_sql_setup()
|
||||
|
||||
values = self.query.row_values.resolve_expression(
|
||||
self.query, allow_joins=True, reuse=None
|
||||
)
|
||||
|
||||
update_clause, update_params = self.get_update_clause(values)
|
||||
set_clause, set_params = self.get_set_sql_clause(values)
|
||||
from_clause, from_params = self.get_subquery_from_clause(values)
|
||||
where_clause, where_params = self.get_where_clause(values)
|
||||
|
||||
result = [update_clause, set_clause]
|
||||
if from_clause is not None:
|
||||
result.append(from_clause)
|
||||
result.append(where_clause)
|
||||
params = update_params + set_params + from_params + where_params
|
||||
return " ".join(result), tuple(params)
|
||||
|
||||
|
||||
class SQLAggregateCompiler(SQLCompiler):
|
||||
def as_sql(self):
|
||||
"""
|
||||
|
@ -6,7 +6,13 @@ from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
||||
__all__ = [
|
||||
"DeleteQuery",
|
||||
"UpdateQuery",
|
||||
"BulkUpdateQuery",
|
||||
"InsertQuery",
|
||||
"AggregateQuery",
|
||||
]
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
@ -142,6 +148,23 @@ class UpdateQuery(Query):
|
||||
return result
|
||||
|
||||
|
||||
class BulkUpdateQuery(Query):
|
||||
compiler = "SQLBulkUpdateCompiler"
|
||||
|
||||
def __init__(self, *args, row_tuples, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.row_values = row_tuples
|
||||
|
||||
def clone(self):
|
||||
obj = super().clone()
|
||||
if self.row_values is not None:
|
||||
obj.related_updates = self.row_values.copy()
|
||||
return obj
|
||||
|
||||
def add_row_values(self, row_values):
|
||||
self.row_values = row_values
|
||||
|
||||
|
||||
class InsertQuery(Query):
|
||||
compiler = "SQLInsertCompiler"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user