1
0
mirror of https://github.com/django/django.git synced 2025-06-05 11:39:13 +00:00

Fixed #29771 -- Support database-specific syntax for bulk_update

This commit is contained in:
Thomas Forbes 2024-11-21 21:09:52 +00:00 committed by Tom Forbes
parent 857b1048d5
commit 29311aae15
7 changed files with 221 additions and 16 deletions

View File

@ -302,6 +302,9 @@ class BaseDatabaseFeatures:
# Does this backend require casting the results of CASE expressions used # Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type? # in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_updates = False requires_casted_case_in_updates = False
# Does this backend require casting the results of a VALUES expression used
# in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_values_updates = False
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)? # Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
supports_partial_indexes = True supports_partial_indexes = True

View File

@ -22,6 +22,17 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass pass
class SQLBulkUpdateCompiler(compiler.SQLBulkUpdateCompiler, SQLCompiler):
def get_update_clause(self, values):
qn = self.quote_name_unless_alias
values_sql, value_params = self.compile(values)
update_sql = f"UPDATE {qn(self.query.base_table)}, ({values_sql}) as subquery"
return update_sql, tuple(value_params)
def get_subquery_from_clause(self, values):
return None, tuple()
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self): def as_sql(self):
# Prefer the non-standard DELETE FROM syntax over the SQL generated by # Prefer the non-standard DELETE FROM syntax over the SQL generated by

View File

@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
) )
""" """
requires_casted_case_in_updates = True requires_casted_case_in_updates = True
requires_casted_case_in_values_updates = True
supports_over_clause = True supports_over_clause = True
supports_frame_exclusion = True supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True only_supports_unbounded_with_preceding_and_following = True

View File

@ -2140,3 +2140,72 @@ class ValueRange(WindowFrame):
def window_frame_start_end(self, connection, start, end): def window_frame_start_end(self, connection, start, end):
return connection.ops.window_frame_range_start_end(start, end) return connection.ops.window_frame_range_start_end(start, end)
class RowTuple(ExpressionList):
template = "(%(expressions)s)"
def as_mysql(self, compiler, connection, **extra_context):
# MySQL requires the ROW() function to be used for tuples in VALUES clauses.
if not connection.mysql_is_mariadb:
extra_context["template"] = "ROW(%(expressions)s)"
return self.as_sql(compiler, connection, **extra_context)
def __str__(self):
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
return self.template % {"expressions": values}
def __getitem__(self, item):
return self.source_expressions[item]
class RowTupleValues(ExpressionList):
template = "VALUES %(expressions)s"
def __init__(self, expression_list, pk_field, field_list, **extra):
expressions = (
(
RowTuple(*expressions)
if not isinstance(expressions, RowTuple)
else expressions
)
for expressions in expression_list
)
self.pk_field = pk_field
self.field_list = field_list
super().__init__(*expressions, **extra)
def as_mysql(self, compiler, connection, **extra_context):
# MySQL doesn't support aliases in VALUES clauses. The workaround is to use a
# UNION as the first column
# SELECT 1 AS x ,2 AS y
# UNION VALUES (3,4),(5,6);
first_row = self.source_expressions.pop()
class MySQLJoiner:
def join(self, expressions):
return ", ".join(
f"{col_sql} AS column{idx}"
for idx, col_sql in enumerate(expressions, 1)
)
sql, params = first_row.as_sql(
compiler,
connection,
template="%(expressions)s",
**extra_context,
arg_joiner=MySQLJoiner(),
)
sql = f"SELECT {sql}"
if self.source_expressions:
rest_sql, rest_params = super().as_sql(
compiler, connection, **extra_context
)
sql += f" UNION {rest_sql}"
params += rest_params
return sql, params
def __str__(self):
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
return self.template % {"expressions": values}

View File

@ -23,7 +23,7 @@ from django.db import (
from django.db.models import AutoField, DateField, DateTimeField, Field, sql from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Value, When from django.db.models.expressions import Case, F, RowTuple, RowTupleValues, Value, When
from django.db.models.functions import Cast, Trunc from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
@ -890,31 +890,73 @@ class QuerySet(AltersData):
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
requires_casting = connection.features.requires_casted_case_in_updates requires_casting = connection.features.requires_casted_case_in_updates
requires_casting_in_values = (
connection.features.requires_casted_case_in_values_updates
)
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
updates = [] updates = []
has_related_fields = any(f.model is not self.model for f in fields)
has_field_references = False
for batch_objs in batches: for batch_objs in batches:
update_kwargs = {} row_tuples = []
for field in fields: for obj in batch_objs:
when_statements = [] values = [Value(getattr(obj, "pk"))]
for obj in batch_objs: for field in fields:
attr = getattr(obj, field.attname) attr = getattr(obj, field.attname)
if not hasattr(attr, "resolve_expression"): if hasattr(attr, "resolve_expression"):
has_field_references = True
else:
attr = Value(attr, output_field=field) attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr))
case_statement = Case(*when_statements, output_field=field) if requires_casting_in_values:
if requires_casting: attr = Cast(attr, output_field=field)
case_statement = Cast(case_statement, output_field=field) values.append(attr)
update_kwargs[field.attname] = case_statement
updates.append(([obj.pk for obj in batch_objs], update_kwargs)) row_tuples.append(RowTuple(*values))
updates.append(
RowTupleValues(
row_tuples, pk_field=self.model._meta.pk, field_list=fields
)
)
rows_updated = 0 rows_updated = 0
queryset = self.using(self.db)
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
for pks, update_kwargs in updates: for row_tuples in updates:
rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs) if has_field_references or has_related_fields:
rows_updated += self._bulk_update_slow(
row_tuples, requires_casting=requires_casting
)
else:
query = sql.BulkUpdateQuery(self.model, row_tuples=row_tuples)
rows_updated += query.get_compiler(using=self.db).execute_sql(
result_type=CURSOR
)
return rows_updated return rows_updated
bulk_update.alters_data = True bulk_update.alters_data = True
def _bulk_update_slow(self, row_tuples, requires_casting=False):
pks = [
row_tuple.source_expressions[0].value
for row_tuple in row_tuples.source_expressions
]
update_kwargs = {}
# Skip the ID column
for field_idx, field in enumerate(row_tuples.field_list, start=1):
when_statements = []
for row_tuple in row_tuples.source_expressions:
row_tuple: RowTuple
id_expression = row_tuple.source_expressions[0]
attr_expression = row_tuple.source_expressions[field_idx]
when_statements.append(When(pk=id_expression, then=attr_expression))
case_statement = Case(*when_statements, output_field=field)
if requires_casting:
case_statement = Cast(case_statement, output_field=field)
update_kwargs[field.attname] = case_statement
return self.filter(pk__in=pks).update(**update_kwargs)
async def abulk_update(self, objs, fields, batch_size=None): async def abulk_update(self, objs, fields, batch_size=None):
return await sync_to_async(self.bulk_update)( return await sync_to_async(self.bulk_update)(
objs=objs, objs=objs,

View File

@ -2090,6 +2090,62 @@ class SQLUpdateCompiler(SQLCompiler):
self.query.reset_refcounts(refcounts_before) self.query.reset_refcounts(refcounts_before)
class SQLBulkUpdateCompiler(SQLCompiler):
def execute_sql(self, result_type):
cursor = super().execute_sql(result_type)
try:
rows = cursor.rowcount if cursor else 0
finally:
if cursor:
cursor.close()
return rows
def get_update_clause(self, values):
qn = self.quote_name_unless_alias
return "UPDATE %s " % qn(self.query.base_table), tuple()
def get_subquery_from_clause(self, values):
values_sql, value_params = self.compile(values)
return "FROM (%s) AS subquery" % values_sql, tuple(value_params)
def get_set_sql_clause(self, values):
qn = self.quote_name_unless_alias
set_sql = []
for idx, field in enumerate(values.field_list, 2):
name = field.column
set_sql.append(
"%s = %s.%s" % (qn(name), qn("subquery"), qn(f"column{idx}"))
)
return "SET %s" % ", ".join(set_sql), tuple()
def get_where_clause(self, values):
qn = self.quote_name_unless_alias
return (
f'WHERE {qn(values.pk_field.column)} = {qn("subquery")}.{qn("column1")}',
tuple(),
)
def as_sql(self, with_limits=True, with_col_aliases=False):
self.pre_sql_setup()
values = self.query.row_values.resolve_expression(
self.query, allow_joins=True, reuse=None
)
update_clause, update_params = self.get_update_clause(values)
set_clause, set_params = self.get_set_sql_clause(values)
from_clause, from_params = self.get_subquery_from_clause(values)
where_clause, where_params = self.get_where_clause(values)
result = [update_clause, set_clause]
if from_clause is not None:
result.append(from_clause)
result.append(where_clause)
params = update_params + set_params + from_params + where_params
return " ".join(result), tuple(params)
class SQLAggregateCompiler(SQLCompiler): class SQLAggregateCompiler(SQLCompiler):
def as_sql(self): def as_sql(self):
""" """

View File

@ -6,7 +6,13 @@ from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] __all__ = [
"DeleteQuery",
"UpdateQuery",
"BulkUpdateQuery",
"InsertQuery",
"AggregateQuery",
]
class DeleteQuery(Query): class DeleteQuery(Query):
@ -142,6 +148,23 @@ class UpdateQuery(Query):
return result return result
class BulkUpdateQuery(Query):
compiler = "SQLBulkUpdateCompiler"
def __init__(self, *args, row_tuples, **kwargs):
super().__init__(*args, **kwargs)
self.row_values = row_tuples
def clone(self):
obj = super().clone()
if self.row_values is not None:
obj.related_updates = self.row_values.copy()
return obj
def add_row_values(self, row_values):
self.row_values = row_values
class InsertQuery(Query): class InsertQuery(Query):
compiler = "SQLInsertCompiler" compiler = "SQLInsertCompiler"