mirror of
https://github.com/django/django.git
synced 2025-06-05 11:39:13 +00:00
Fixed #29771 -- Support database-specific syntax for bulk_update
This commit is contained in:
parent
857b1048d5
commit
29311aae15
@ -302,6 +302,9 @@ class BaseDatabaseFeatures:
|
|||||||
# Does this backend require casting the results of CASE expressions used
|
# Does this backend require casting the results of CASE expressions used
|
||||||
# in UPDATE statements to ensure the expression has the correct type?
|
# in UPDATE statements to ensure the expression has the correct type?
|
||||||
requires_casted_case_in_updates = False
|
requires_casted_case_in_updates = False
|
||||||
|
# Does this backend require casting the results of a VALUES expression used
|
||||||
|
# in UPDATE statements to ensure the expression has the correct type?
|
||||||
|
requires_casted_case_in_values_updates = False
|
||||||
|
|
||||||
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
||||||
supports_partial_indexes = True
|
supports_partial_indexes = True
|
||||||
|
@ -22,6 +22,17 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SQLBulkUpdateCompiler(compiler.SQLBulkUpdateCompiler, SQLCompiler):
|
||||||
|
def get_update_clause(self, values):
|
||||||
|
qn = self.quote_name_unless_alias
|
||||||
|
values_sql, value_params = self.compile(values)
|
||||||
|
update_sql = f"UPDATE {qn(self.query.base_table)}, ({values_sql}) as subquery"
|
||||||
|
return update_sql, tuple(value_params)
|
||||||
|
|
||||||
|
def get_subquery_from_clause(self, values):
|
||||||
|
return None, tuple()
|
||||||
|
|
||||||
|
|
||||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
||||||
def as_sql(self):
|
def as_sql(self):
|
||||||
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
||||||
|
@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
requires_casted_case_in_updates = True
|
requires_casted_case_in_updates = True
|
||||||
|
requires_casted_case_in_values_updates = True
|
||||||
supports_over_clause = True
|
supports_over_clause = True
|
||||||
supports_frame_exclusion = True
|
supports_frame_exclusion = True
|
||||||
only_supports_unbounded_with_preceding_and_following = True
|
only_supports_unbounded_with_preceding_and_following = True
|
||||||
|
@ -2140,3 +2140,72 @@ class ValueRange(WindowFrame):
|
|||||||
|
|
||||||
def window_frame_start_end(self, connection, start, end):
|
def window_frame_start_end(self, connection, start, end):
|
||||||
return connection.ops.window_frame_range_start_end(start, end)
|
return connection.ops.window_frame_range_start_end(start, end)
|
||||||
|
|
||||||
|
|
||||||
|
class RowTuple(ExpressionList):
|
||||||
|
template = "(%(expressions)s)"
|
||||||
|
|
||||||
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
|
# MySQL requires the ROW() function to be used for tuples in VALUES clauses.
|
||||||
|
if not connection.mysql_is_mariadb:
|
||||||
|
extra_context["template"] = "ROW(%(expressions)s)"
|
||||||
|
return self.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||||
|
return self.template % {"expressions": values}
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.source_expressions[item]
|
||||||
|
|
||||||
|
|
||||||
|
class RowTupleValues(ExpressionList):
|
||||||
|
template = "VALUES %(expressions)s"
|
||||||
|
|
||||||
|
def __init__(self, expression_list, pk_field, field_list, **extra):
|
||||||
|
expressions = (
|
||||||
|
(
|
||||||
|
RowTuple(*expressions)
|
||||||
|
if not isinstance(expressions, RowTuple)
|
||||||
|
else expressions
|
||||||
|
)
|
||||||
|
for expressions in expression_list
|
||||||
|
)
|
||||||
|
self.pk_field = pk_field
|
||||||
|
self.field_list = field_list
|
||||||
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
|
# MySQL doesn't support aliases in VALUES clauses. The workaround is to use a
|
||||||
|
# UNION as the first column
|
||||||
|
# SELECT 1 AS x ,2 AS y
|
||||||
|
# UNION VALUES (3,4),(5,6);
|
||||||
|
first_row = self.source_expressions.pop()
|
||||||
|
|
||||||
|
class MySQLJoiner:
|
||||||
|
def join(self, expressions):
|
||||||
|
return ", ".join(
|
||||||
|
f"{col_sql} AS column{idx}"
|
||||||
|
for idx, col_sql in enumerate(expressions, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
sql, params = first_row.as_sql(
|
||||||
|
compiler,
|
||||||
|
connection,
|
||||||
|
template="%(expressions)s",
|
||||||
|
**extra_context,
|
||||||
|
arg_joiner=MySQLJoiner(),
|
||||||
|
)
|
||||||
|
sql = f"SELECT {sql}"
|
||||||
|
if self.source_expressions:
|
||||||
|
rest_sql, rest_params = super().as_sql(
|
||||||
|
compiler, connection, **extra_context
|
||||||
|
)
|
||||||
|
sql += f" UNION {rest_sql}"
|
||||||
|
params += rest_params
|
||||||
|
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
values = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||||
|
return self.template % {"expressions": values}
|
||||||
|
@ -23,7 +23,7 @@ from django.db import (
|
|||||||
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
||||||
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
||||||
from django.db.models.deletion import Collector
|
from django.db.models.deletion import Collector
|
||||||
from django.db.models.expressions import Case, F, Value, When
|
from django.db.models.expressions import Case, F, RowTuple, RowTupleValues, Value, When
|
||||||
from django.db.models.functions import Cast, Trunc
|
from django.db.models.functions import Cast, Trunc
|
||||||
from django.db.models.query_utils import FilteredRelation, Q
|
from django.db.models.query_utils import FilteredRelation, Q
|
||||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||||
@ -890,31 +890,73 @@ class QuerySet(AltersData):
|
|||||||
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
|
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
|
||||||
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
|
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
|
||||||
requires_casting = connection.features.requires_casted_case_in_updates
|
requires_casting = connection.features.requires_casted_case_in_updates
|
||||||
|
requires_casting_in_values = (
|
||||||
|
connection.features.requires_casted_case_in_values_updates
|
||||||
|
)
|
||||||
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
|
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
|
||||||
updates = []
|
updates = []
|
||||||
|
|
||||||
|
has_related_fields = any(f.model is not self.model for f in fields)
|
||||||
|
has_field_references = False
|
||||||
for batch_objs in batches:
|
for batch_objs in batches:
|
||||||
update_kwargs = {}
|
row_tuples = []
|
||||||
for field in fields:
|
for obj in batch_objs:
|
||||||
when_statements = []
|
values = [Value(getattr(obj, "pk"))]
|
||||||
for obj in batch_objs:
|
for field in fields:
|
||||||
attr = getattr(obj, field.attname)
|
attr = getattr(obj, field.attname)
|
||||||
if not hasattr(attr, "resolve_expression"):
|
if hasattr(attr, "resolve_expression"):
|
||||||
|
has_field_references = True
|
||||||
|
else:
|
||||||
attr = Value(attr, output_field=field)
|
attr = Value(attr, output_field=field)
|
||||||
when_statements.append(When(pk=obj.pk, then=attr))
|
|
||||||
case_statement = Case(*when_statements, output_field=field)
|
if requires_casting_in_values:
|
||||||
if requires_casting:
|
attr = Cast(attr, output_field=field)
|
||||||
case_statement = Cast(case_statement, output_field=field)
|
values.append(attr)
|
||||||
update_kwargs[field.attname] = case_statement
|
|
||||||
updates.append(([obj.pk for obj in batch_objs], update_kwargs))
|
row_tuples.append(RowTuple(*values))
|
||||||
|
updates.append(
|
||||||
|
RowTupleValues(
|
||||||
|
row_tuples, pk_field=self.model._meta.pk, field_list=fields
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
rows_updated = 0
|
rows_updated = 0
|
||||||
queryset = self.using(self.db)
|
|
||||||
with transaction.atomic(using=self.db, savepoint=False):
|
with transaction.atomic(using=self.db, savepoint=False):
|
||||||
for pks, update_kwargs in updates:
|
for row_tuples in updates:
|
||||||
rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
|
if has_field_references or has_related_fields:
|
||||||
|
rows_updated += self._bulk_update_slow(
|
||||||
|
row_tuples, requires_casting=requires_casting
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = sql.BulkUpdateQuery(self.model, row_tuples=row_tuples)
|
||||||
|
rows_updated += query.get_compiler(using=self.db).execute_sql(
|
||||||
|
result_type=CURSOR
|
||||||
|
)
|
||||||
return rows_updated
|
return rows_updated
|
||||||
|
|
||||||
bulk_update.alters_data = True
|
bulk_update.alters_data = True
|
||||||
|
|
||||||
|
def _bulk_update_slow(self, row_tuples, requires_casting=False):
|
||||||
|
pks = [
|
||||||
|
row_tuple.source_expressions[0].value
|
||||||
|
for row_tuple in row_tuples.source_expressions
|
||||||
|
]
|
||||||
|
update_kwargs = {}
|
||||||
|
# Skip the ID column
|
||||||
|
for field_idx, field in enumerate(row_tuples.field_list, start=1):
|
||||||
|
when_statements = []
|
||||||
|
for row_tuple in row_tuples.source_expressions:
|
||||||
|
row_tuple: RowTuple
|
||||||
|
id_expression = row_tuple.source_expressions[0]
|
||||||
|
attr_expression = row_tuple.source_expressions[field_idx]
|
||||||
|
when_statements.append(When(pk=id_expression, then=attr_expression))
|
||||||
|
case_statement = Case(*when_statements, output_field=field)
|
||||||
|
if requires_casting:
|
||||||
|
case_statement = Cast(case_statement, output_field=field)
|
||||||
|
update_kwargs[field.attname] = case_statement
|
||||||
|
|
||||||
|
return self.filter(pk__in=pks).update(**update_kwargs)
|
||||||
|
|
||||||
async def abulk_update(self, objs, fields, batch_size=None):
|
async def abulk_update(self, objs, fields, batch_size=None):
|
||||||
return await sync_to_async(self.bulk_update)(
|
return await sync_to_async(self.bulk_update)(
|
||||||
objs=objs,
|
objs=objs,
|
||||||
|
@ -2090,6 +2090,62 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||||||
self.query.reset_refcounts(refcounts_before)
|
self.query.reset_refcounts(refcounts_before)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLBulkUpdateCompiler(SQLCompiler):
|
||||||
|
def execute_sql(self, result_type):
|
||||||
|
cursor = super().execute_sql(result_type)
|
||||||
|
try:
|
||||||
|
rows = cursor.rowcount if cursor else 0
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def get_update_clause(self, values):
|
||||||
|
qn = self.quote_name_unless_alias
|
||||||
|
return "UPDATE %s " % qn(self.query.base_table), tuple()
|
||||||
|
|
||||||
|
def get_subquery_from_clause(self, values):
|
||||||
|
values_sql, value_params = self.compile(values)
|
||||||
|
return "FROM (%s) AS subquery" % values_sql, tuple(value_params)
|
||||||
|
|
||||||
|
def get_set_sql_clause(self, values):
|
||||||
|
qn = self.quote_name_unless_alias
|
||||||
|
set_sql = []
|
||||||
|
for idx, field in enumerate(values.field_list, 2):
|
||||||
|
name = field.column
|
||||||
|
set_sql.append(
|
||||||
|
"%s = %s.%s" % (qn(name), qn("subquery"), qn(f"column{idx}"))
|
||||||
|
)
|
||||||
|
|
||||||
|
return "SET %s" % ", ".join(set_sql), tuple()
|
||||||
|
|
||||||
|
def get_where_clause(self, values):
|
||||||
|
qn = self.quote_name_unless_alias
|
||||||
|
return (
|
||||||
|
f'WHERE {qn(values.pk_field.column)} = {qn("subquery")}.{qn("column1")}',
|
||||||
|
tuple(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||||
|
self.pre_sql_setup()
|
||||||
|
|
||||||
|
values = self.query.row_values.resolve_expression(
|
||||||
|
self.query, allow_joins=True, reuse=None
|
||||||
|
)
|
||||||
|
|
||||||
|
update_clause, update_params = self.get_update_clause(values)
|
||||||
|
set_clause, set_params = self.get_set_sql_clause(values)
|
||||||
|
from_clause, from_params = self.get_subquery_from_clause(values)
|
||||||
|
where_clause, where_params = self.get_where_clause(values)
|
||||||
|
|
||||||
|
result = [update_clause, set_clause]
|
||||||
|
if from_clause is not None:
|
||||||
|
result.append(from_clause)
|
||||||
|
result.append(where_clause)
|
||||||
|
params = update_params + set_params + from_params + where_params
|
||||||
|
return " ".join(result), tuple(params)
|
||||||
|
|
||||||
|
|
||||||
class SQLAggregateCompiler(SQLCompiler):
|
class SQLAggregateCompiler(SQLCompiler):
|
||||||
def as_sql(self):
|
def as_sql(self):
|
||||||
"""
|
"""
|
||||||
|
@ -6,7 +6,13 @@ from django.core.exceptions import FieldError
|
|||||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||||
from django.db.models.sql.query import Query
|
from django.db.models.sql.query import Query
|
||||||
|
|
||||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
__all__ = [
|
||||||
|
"DeleteQuery",
|
||||||
|
"UpdateQuery",
|
||||||
|
"BulkUpdateQuery",
|
||||||
|
"InsertQuery",
|
||||||
|
"AggregateQuery",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DeleteQuery(Query):
|
class DeleteQuery(Query):
|
||||||
@ -142,6 +148,23 @@ class UpdateQuery(Query):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class BulkUpdateQuery(Query):
|
||||||
|
compiler = "SQLBulkUpdateCompiler"
|
||||||
|
|
||||||
|
def __init__(self, *args, row_tuples, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.row_values = row_tuples
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
obj = super().clone()
|
||||||
|
if self.row_values is not None:
|
||||||
|
obj.related_updates = self.row_values.copy()
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def add_row_values(self, row_values):
|
||||||
|
self.row_values = row_values
|
||||||
|
|
||||||
|
|
||||||
class InsertQuery(Query):
|
class InsertQuery(Query):
|
||||||
compiler = "SQLInsertCompiler"
|
compiler = "SQLInsertCompiler"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user