1
0
mirror of https://github.com/django/django.git synced 2025-03-31 11:37:06 +00:00

Fixed #29771 -- Support database-specific syntax for bulk_update

This commit is contained in:
Thomas Forbes 2024-11-21 21:09:52 +00:00 committed by Tom Forbes
parent 857b1048d5
commit 29311aae15
7 changed files with 221 additions and 16 deletions

View File

@ -302,6 +302,9 @@ class BaseDatabaseFeatures:
# Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_updates = False
# Does this backend require casting the results of a VALUES expression used
# in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_values_updates = False
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
supports_partial_indexes = True

View File

@ -22,6 +22,17 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLBulkUpdateCompiler(compiler.SQLBulkUpdateCompiler, SQLCompiler):
def get_update_clause(self, values):
qn = self.quote_name_unless_alias
values_sql, value_params = self.compile(values)
update_sql = f"UPDATE {qn(self.query.base_table)}, ({values_sql}) as subquery"
return update_sql, tuple(value_params)
def get_subquery_from_clause(self, values):
return None, tuple()
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self):
# Prefer the non-standard DELETE FROM syntax over the SQL generated by

View File

@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
)
"""
requires_casted_case_in_updates = True
requires_casted_case_in_values_updates = True
supports_over_clause = True
supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True

View File

@ -2140,3 +2140,72 @@ class ValueRange(WindowFrame):
def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_range_start_end(start, end)
class RowTuple(ExpressionList):
template = "(%(expressions)s)"
def as_mysql(self, compiler, connection, **extra_context):
# MySQL requires the ROW() function to be used for tuples in VALUES clauses.
if not connection.mysql_is_mariadb:
extra_context["template"] = "ROW(%(expressions)s)"
return self.as_sql(compiler, connection, **extra_context)
def __str__(self):
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
return self.template % {"expressions": values}
def __getitem__(self, item):
return self.source_expressions[item]
class RowTupleValues(ExpressionList):
template = "VALUES %(expressions)s"
def __init__(self, expression_list, pk_field, field_list, **extra):
expressions = (
(
RowTuple(*expressions)
if not isinstance(expressions, RowTuple)
else expressions
)
for expressions in expression_list
)
self.pk_field = pk_field
self.field_list = field_list
super().__init__(*expressions, **extra)
def as_mysql(self, compiler, connection, **extra_context):
# MySQL doesn't support aliases in VALUES clauses. The workaround is to use a
# UNION as the first column
# SELECT 1 AS x ,2 AS y
# UNION VALUES (3,4),(5,6);
first_row = self.source_expressions.pop()
class MySQLJoiner:
def join(self, expressions):
return ", ".join(
f"{col_sql} AS column{idx}"
for idx, col_sql in enumerate(expressions, 1)
)
sql, params = first_row.as_sql(
compiler,
connection,
template="%(expressions)s",
**extra_context,
arg_joiner=MySQLJoiner(),
)
sql = f"SELECT {sql}"
if self.source_expressions:
rest_sql, rest_params = super().as_sql(
compiler, connection, **extra_context
)
sql += f" UNION {rest_sql}"
params += rest_params
return sql, params
def __str__(self):
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
return self.template % {"expressions": values}

View File

@ -23,7 +23,7 @@ from django.db import (
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Value, When
from django.db.models.expressions import Case, F, RowTuple, RowTupleValues, Value, When
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
@ -890,31 +890,73 @@ class QuerySet(AltersData):
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
requires_casting = connection.features.requires_casted_case_in_updates
requires_casting_in_values = (
connection.features.requires_casted_case_in_values_updates
)
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
updates = []
has_related_fields = any(f.model is not self.model for f in fields)
has_field_references = False
for batch_objs in batches:
update_kwargs = {}
for field in fields:
when_statements = []
for obj in batch_objs:
row_tuples = []
for obj in batch_objs:
values = [Value(getattr(obj, "pk"))]
for field in fields:
attr = getattr(obj, field.attname)
if not hasattr(attr, "resolve_expression"):
if hasattr(attr, "resolve_expression"):
has_field_references = True
else:
attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr))
case_statement = Case(*when_statements, output_field=field)
if requires_casting:
case_statement = Cast(case_statement, output_field=field)
update_kwargs[field.attname] = case_statement
updates.append(([obj.pk for obj in batch_objs], update_kwargs))
if requires_casting_in_values:
attr = Cast(attr, output_field=field)
values.append(attr)
row_tuples.append(RowTuple(*values))
updates.append(
RowTupleValues(
row_tuples, pk_field=self.model._meta.pk, field_list=fields
)
)
rows_updated = 0
queryset = self.using(self.db)
with transaction.atomic(using=self.db, savepoint=False):
for pks, update_kwargs in updates:
rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
for row_tuples in updates:
if has_field_references or has_related_fields:
rows_updated += self._bulk_update_slow(
row_tuples, requires_casting=requires_casting
)
else:
query = sql.BulkUpdateQuery(self.model, row_tuples=row_tuples)
rows_updated += query.get_compiler(using=self.db).execute_sql(
result_type=CURSOR
)
return rows_updated
bulk_update.alters_data = True
def _bulk_update_slow(self, row_tuples, requires_casting=False):
pks = [
row_tuple.source_expressions[0].value
for row_tuple in row_tuples.source_expressions
]
update_kwargs = {}
# Skip the ID column
for field_idx, field in enumerate(row_tuples.field_list, start=1):
when_statements = []
for row_tuple in row_tuples.source_expressions:
row_tuple: RowTuple
id_expression = row_tuple.source_expressions[0]
attr_expression = row_tuple.source_expressions[field_idx]
when_statements.append(When(pk=id_expression, then=attr_expression))
case_statement = Case(*when_statements, output_field=field)
if requires_casting:
case_statement = Cast(case_statement, output_field=field)
update_kwargs[field.attname] = case_statement
return self.filter(pk__in=pks).update(**update_kwargs)
async def abulk_update(self, objs, fields, batch_size=None):
return await sync_to_async(self.bulk_update)(
objs=objs,

View File

@ -2090,6 +2090,62 @@ class SQLUpdateCompiler(SQLCompiler):
self.query.reset_refcounts(refcounts_before)
class SQLBulkUpdateCompiler(SQLCompiler):
def execute_sql(self, result_type):
cursor = super().execute_sql(result_type)
try:
rows = cursor.rowcount if cursor else 0
finally:
if cursor:
cursor.close()
return rows
def get_update_clause(self, values):
qn = self.quote_name_unless_alias
return "UPDATE %s " % qn(self.query.base_table), tuple()
def get_subquery_from_clause(self, values):
values_sql, value_params = self.compile(values)
return "FROM (%s) AS subquery" % values_sql, tuple(value_params)
def get_set_sql_clause(self, values):
qn = self.quote_name_unless_alias
set_sql = []
for idx, field in enumerate(values.field_list, 2):
name = field.column
set_sql.append(
"%s = %s.%s" % (qn(name), qn("subquery"), qn(f"column{idx}"))
)
return "SET %s" % ", ".join(set_sql), tuple()
def get_where_clause(self, values):
qn = self.quote_name_unless_alias
return (
f'WHERE {qn(values.pk_field.column)} = {qn("subquery")}.{qn("column1")}',
tuple(),
)
def as_sql(self, with_limits=True, with_col_aliases=False):
self.pre_sql_setup()
values = self.query.row_values.resolve_expression(
self.query, allow_joins=True, reuse=None
)
update_clause, update_params = self.get_update_clause(values)
set_clause, set_params = self.get_set_sql_clause(values)
from_clause, from_params = self.get_subquery_from_clause(values)
where_clause, where_params = self.get_where_clause(values)
result = [update_clause, set_clause]
if from_clause is not None:
result.append(from_clause)
result.append(where_clause)
params = update_params + set_params + from_params + where_params
return " ".join(result), tuple(params)
class SQLAggregateCompiler(SQLCompiler):
def as_sql(self):
"""

View File

@ -6,7 +6,13 @@ from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
__all__ = [
"DeleteQuery",
"UpdateQuery",
"BulkUpdateQuery",
"InsertQuery",
"AggregateQuery",
]
class DeleteQuery(Query):
@ -142,6 +148,23 @@ class UpdateQuery(Query):
return result
class BulkUpdateQuery(Query):
compiler = "SQLBulkUpdateCompiler"
def __init__(self, *args, row_tuples, **kwargs):
super().__init__(*args, **kwargs)
self.row_values = row_tuples
def clone(self):
obj = super().clone()
if self.row_values is not None:
obj.related_updates = self.row_values.copy()
return obj
def add_row_values(self, row_values):
self.row_values = row_values
class InsertQuery(Query):
compiler = "SQLInsertCompiler"