diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index ef874d74db..dba8342373 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -302,6 +302,9 @@ class BaseDatabaseFeatures: # Does this backend require casting the results of CASE expressions used # in UPDATE statements to ensure the expression has the correct type? requires_casted_case_in_updates = False + # Does this backend require casting the results of a VALUES expression used + # in UPDATE statements to ensure the expression has the correct type? + requires_casted_case_in_values_updates = False # Does the backend support partial indexes (CREATE INDEX ... WHERE ...)? supports_partial_indexes = True diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index 2ec6bea2f1..f7b5a91285 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -22,6 +22,17 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass +class SQLBulkUpdateCompiler(compiler.SQLBulkUpdateCompiler, SQLCompiler): + def get_update_clause(self, values): + qn = self.quote_name_unless_alias + values_sql, value_params = self.compile(values) + update_sql = f"UPDATE {qn(self.query.base_table)}, ({values_sql}) as subquery" + return update_sql, tuple(value_params) + + def get_subquery_from_clause(self, values): + return None, tuple() + + class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): def as_sql(self): # Prefer the non-standard DELETE FROM syntax over the SQL generated by diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 16653a0519..f7c19ebc63 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): ) """ requires_casted_case_in_updates = True + requires_casted_case_in_values_updates = True supports_over_clause = True supports_frame_exclusion = True only_supports_unbounded_with_preceding_and_following = True diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 667e9f93c6..94e8647d21 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -2140,3 +2140,72 @@ class ValueRange(WindowFrame): def window_frame_start_end(self, connection, start, end): return connection.ops.window_frame_range_start_end(start, end) + + +class RowTuple(ExpressionList): + template = "(%(expressions)s)" + + def as_mysql(self, compiler, connection, **extra_context): + # MySQL requires the ROW() function to be used for tuples in VALUES clauses. + if not connection.mysql_is_mariadb: + extra_context["template"] = "ROW(%(expressions)s)" + return self.as_sql(compiler, connection, **extra_context) + + def __str__(self): + values = self.arg_joiner.join(str(arg) for arg in self.source_expressions) + return self.template % {"expressions": values} + + def __getitem__(self, item): + return self.source_expressions[item] + + +class RowTupleValues(ExpressionList): + template = "VALUES %(expressions)s" + + def __init__(self, expression_list, pk_field, field_list, **extra): + expressions = ( + ( + RowTuple(*expressions) + if not isinstance(expressions, RowTuple) + else expressions + ) + for expressions in expression_list + ) + self.pk_field = pk_field + self.field_list = field_list + super().__init__(*expressions, **extra) + + def as_mysql(self, compiler, connection, **extra_context): + # MySQL doesn't support aliases in VALUES clauses. The workaround is to use a + # UNION as the first column + # SELECT 1 AS x ,2 AS y + # UNION VALUES (3,4),(5,6); + first_row = self.source_expressions.pop() + + class MySQLJoiner: + def join(self, expressions): + return ", ".join( + f"{col_sql} AS column{idx}" + for idx, col_sql in enumerate(expressions, 1) + ) + + sql, params = first_row.as_sql( + compiler, + connection, + template="%(expressions)s", + **extra_context, + arg_joiner=MySQLJoiner(), + ) + sql = f"SELECT {sql}" + if self.source_expressions: + rest_sql, rest_params = super().as_sql( + compiler, connection, **extra_context + ) + sql += f" UNION {rest_sql}" + params += rest_params + + return sql, params + + def __str__(self): + values = self.arg_joiner.join(str(arg) for arg in self.source_expressions) + return self.template % {"expressions": values} diff --git a/django/db/models/query.py b/django/db/models/query.py index 21d5534cc9..83557772c6 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -23,7 +23,7 @@ from django.db import ( from django.db.models import AutoField, DateField, DateTimeField, Field, sql from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.deletion import Collector -from django.db.models.expressions import Case, F, Value, When +from django.db.models.expressions import Case, F, RowTuple, RowTupleValues, Value, When from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE @@ -890,31 +890,73 @@ class QuerySet(AltersData): max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size requires_casting = connection.features.requires_casted_case_in_updates + requires_casting_in_values = ( + connection.features.requires_casted_case_in_values_updates + ) batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) updates = [] + + has_related_fields = any(f.model is not self.model for f in fields) + has_field_references = False for batch_objs in batches: - update_kwargs = {} - for field in fields: - when_statements = [] - for obj in batch_objs: + row_tuples = [] + for obj in batch_objs: + values = [Value(getattr(obj, "pk"))] + for field in fields: attr = getattr(obj, field.attname) - if not hasattr(attr, "resolve_expression"): + if hasattr(attr, "resolve_expression"): + has_field_references = True + else: attr = Value(attr, output_field=field) - when_statements.append(When(pk=obj.pk, then=attr)) - case_statement = Case(*when_statements, output_field=field) - if requires_casting: - case_statement = Cast(case_statement, output_field=field) - update_kwargs[field.attname] = case_statement - updates.append(([obj.pk for obj in batch_objs], update_kwargs)) + + if requires_casting_in_values: + attr = Cast(attr, output_field=field) + values.append(attr) + + row_tuples.append(RowTuple(*values)) + updates.append( + RowTupleValues( + row_tuples, pk_field=self.model._meta.pk, field_list=fields + ) + ) + rows_updated = 0 - queryset = self.using(self.db) with transaction.atomic(using=self.db, savepoint=False): - for pks, update_kwargs in updates: - rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs) + for row_tuples in updates: + if has_field_references or has_related_fields: + rows_updated += self._bulk_update_slow( + row_tuples, requires_casting=requires_casting + ) + else: + query = sql.BulkUpdateQuery(self.model, row_tuples=row_tuples) + rows_updated += query.get_compiler(using=self.db).execute_sql( + result_type=CURSOR + ) return rows_updated bulk_update.alters_data = True + def _bulk_update_slow(self, row_tuples, requires_casting=False): + pks = [ + row_tuple.source_expressions[0].value + for row_tuple in row_tuples.source_expressions + ] + update_kwargs = {} + # Skip the ID column + for field_idx, field in enumerate(row_tuples.field_list, start=1): + when_statements = [] + for row_tuple in row_tuples.source_expressions: + row_tuple: RowTuple + id_expression = row_tuple.source_expressions[0] + attr_expression = row_tuple.source_expressions[field_idx] + when_statements.append(When(pk=id_expression, then=attr_expression)) + case_statement = Case(*when_statements, output_field=field) + if requires_casting: + case_statement = Cast(case_statement, output_field=field) + update_kwargs[field.attname] = case_statement + + return self.filter(pk__in=pks).update(**update_kwargs) + async def abulk_update(self, objs, fields, batch_size=None): return await sync_to_async(self.bulk_update)( objs=objs, diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 49263d5944..29a8a75557 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2090,6 +2090,62 @@ class SQLUpdateCompiler(SQLCompiler): self.query.reset_refcounts(refcounts_before) +class SQLBulkUpdateCompiler(SQLCompiler): + def execute_sql(self, result_type): + cursor = super().execute_sql(result_type) + try: + rows = cursor.rowcount if cursor else 0 + finally: + if cursor: + cursor.close() + return rows + + def get_update_clause(self, values): + qn = self.quote_name_unless_alias + return "UPDATE %s " % qn(self.query.base_table), tuple() + + def get_subquery_from_clause(self, values): + values_sql, value_params = self.compile(values) + return "FROM (%s) AS subquery" % values_sql, tuple(value_params) + + def get_set_sql_clause(self, values): + qn = self.quote_name_unless_alias + set_sql = [] + for idx, field in enumerate(values.field_list, 2): + name = field.column + set_sql.append( + "%s = %s.%s" % (qn(name), qn("subquery"), qn(f"column{idx}")) + ) + + return "SET %s" % ", ".join(set_sql), tuple() + + def get_where_clause(self, values): + qn = self.quote_name_unless_alias + return ( + f'WHERE {qn(values.pk_field.column)} = {qn("subquery")}.{qn("column1")}', + tuple(), + ) + + def as_sql(self, with_limits=True, with_col_aliases=False): + self.pre_sql_setup() + + values = self.query.row_values.resolve_expression( + self.query, allow_joins=True, reuse=None + ) + + update_clause, update_params = self.get_update_clause(values) + set_clause, set_params = self.get_set_sql_clause(values) + from_clause, from_params = self.get_subquery_from_clause(values) + where_clause, where_params = self.get_where_clause(values) + + result = [update_clause, set_clause] + if from_clause is not None: + result.append(from_clause) + result.append(where_clause) + params = update_params + set_params + from_params + where_params + return " ".join(result), tuple(params) + + class SQLAggregateCompiler(SQLCompiler): def as_sql(self): """ diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f639eb8b82..a1fe860687 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -6,7 +6,13 @@ from django.core.exceptions import FieldError from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS from django.db.models.sql.query import Query -__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] +__all__ = [ + "DeleteQuery", + "UpdateQuery", + "BulkUpdateQuery", + "InsertQuery", + "AggregateQuery", +] class DeleteQuery(Query): @@ -142,6 +148,23 @@ class UpdateQuery(Query): return result +class BulkUpdateQuery(Query): + compiler = "SQLBulkUpdateCompiler" + + def __init__(self, *args, row_tuples, **kwargs): + super().__init__(*args, **kwargs) + self.row_values = row_tuples + + def clone(self): + obj = super().clone() + if self.row_values is not None: + obj.related_updates = self.row_values.copy() + return obj + + def add_row_values(self, row_values): + self.row_values = row_values + + class InsertQuery(Query): compiler = "SQLInsertCompiler"