From 7414704e88d73dafbcfbb85f9bc54cb6111439d3 Mon Sep 17 00:00:00 2001 From: Ian Foote Date: Sun, 22 Nov 2020 22:27:57 +0000 Subject: [PATCH] Fixed #470 -- Added support for database defaults on fields. Special thanks to Hannes Ljungberg for finding multiple implementation gaps. Thanks also to Simon Charette, Adam Johnson, and Mariusz Felisiak for reviews. --- AUTHORS | 1 + django/db/backends/base/features.py | 12 + django/db/backends/base/schema.py | 88 ++++- django/db/backends/mysql/features.py | 7 + django/db/backends/mysql/schema.py | 29 +- django/db/backends/oracle/features.py | 4 + django/db/backends/oracle/introspection.py | 2 +- django/db/backends/oracle/schema.py | 4 +- django/db/backends/postgresql/features.py | 1 + django/db/backends/sqlite3/features.py | 2 + django/db/backends/sqlite3/schema.py | 24 +- django/db/migrations/autodetector.py | 2 + django/db/models/base.py | 6 +- django/db/models/expressions.py | 37 ++ django/db/models/fields/__init__.py | 61 +++- django/db/models/functions/comparison.py | 1 + django/db/models/lookups.py | 4 + django/db/models/query.py | 9 + docs/ref/checks.txt | 3 + docs/ref/models/expressions.txt | 7 + docs/ref/models/fields.txt | 35 ++ docs/ref/models/instances.txt | 14 +- docs/releases/5.0.txt | 26 +- tests/basic/models.py | 4 + tests/basic/tests.py | 6 + tests/field_defaults/models.py | 44 +++ tests/field_defaults/tests.py | 192 ++++++++++- .../test_ordinary_fields.py | 107 ++++++ tests/migrations/test_autodetector.py | 40 +++ tests/migrations/test_base.py | 7 + tests/migrations/test_operations.py | 317 +++++++++++++++++- tests/schema/tests.py | 27 ++ 32 files changed, 1089 insertions(+), 34 deletions(-) diff --git a/AUTHORS b/AUTHORS index ea321038ac..291b5da657 100644 --- a/AUTHORS +++ b/AUTHORS @@ -587,6 +587,7 @@ answer newbie questions, and generally made Django that much better: lerouxb@gmail.com Lex Berezhny Liang Feng + Lily Foote limodou Lincoln Smith Liu Yijie <007gzs@gmail.com> diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 11fa807c1b..11dd079110 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -201,6 +201,15 @@ class BaseDatabaseFeatures: # Does the backend require literal defaults, rather than parameterized ones? requires_literal_defaults = False + # Does the backend support functions in defaults? + supports_expression_defaults = True + + # Does the backend support the DEFAULT keyword in insert queries? + supports_default_keyword_in_insert = True + + # Does the backend support the DEFAULT keyword in bulk insert queries? + supports_default_keyword_in_bulk_insert = True + # Does the backend require a connection reset after each material schema change? connection_persists_old_columns = False @@ -361,6 +370,9 @@ class BaseDatabaseFeatures: # SQL template override for tests.aggregation.tests.NowUTC test_now_utc_template = None + # SQL to create a model instance using the database defaults. + insert_test_table_with_defaults = None + # A set of dotted paths to tests in Django's test suite that are expected # to fail on this database. django_test_expected_failures = set() diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 6b03450e2f..01b56151be 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -12,7 +12,7 @@ from django.db.backends.ddl_references import ( Table, ) from django.db.backends.utils import names_digest, split_identifier, truncate_name -from django.db.models import Deferrable, Index +from django.db.models import NOT_PROVIDED, Deferrable, Index from django.db.models.sql import Query from django.db.transaction import TransactionManagementError, atomic from django.utils import timezone @@ -296,6 +296,12 @@ class BaseDatabaseSchemaEditor: yield self._comment_sql(field.db_comment) # Work out nullability. null = field.null + # Add database default. + if field.db_default is not NOT_PROVIDED: + default_sql, default_params = self.db_default_sql(field) + yield f"DEFAULT {default_sql}" + params.extend(default_params) + include_default = False # Include a default value, if requested. include_default = ( include_default @@ -400,6 +406,22 @@ class BaseDatabaseSchemaEditor: """ return "%s" + def db_default_sql(self, field): + """Return the sql and params for the field's database default.""" + from django.db.models.expressions import Value + + sql = "%s" if isinstance(field.db_default, Value) else "(%s)" + query = Query(model=field.model) + compiler = query.get_compiler(connection=self.connection) + default_sql, params = compiler.compile(field.db_default) + if self.connection.features.requires_literal_defaults: + # Some databases doesn't support parameterized defaults (Oracle, + # SQLite). If this is the case, the individual schema backend + # should implement prepare_default(). + default_sql %= tuple(self.prepare_default(p) for p in params) + params = [] + return sql % default_sql, params + @staticmethod def _effective_default(field): # This method allows testing its logic without a connection. @@ -1025,6 +1047,21 @@ class BaseDatabaseSchemaEditor: ) actions.append(fragment) post_actions.extend(other_actions) + + if new_field.db_default is not NOT_PROVIDED: + if ( + old_field.db_default is NOT_PROVIDED + or new_field.db_default != old_field.db_default + ): + actions.append( + self._alter_column_database_default_sql(model, old_field, new_field) + ) + elif old_field.db_default is not NOT_PROVIDED: + actions.append( + self._alter_column_database_default_sql( + model, old_field, new_field, drop=True + ) + ) # When changing a column NULL constraint to NOT NULL with a given # default value, we need to perform 4 steps: # 1. Add a default for new incoming writes @@ -1033,7 +1070,11 @@ class BaseDatabaseSchemaEditor: # 4. Drop the default again. # Default change? needs_database_default = False - if old_field.null and not new_field.null: + if ( + old_field.null + and not new_field.null + and new_field.db_default is NOT_PROVIDED + ): old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) if ( @@ -1051,9 +1092,9 @@ class BaseDatabaseSchemaEditor: if fragment: null_actions.append(fragment) # Only if we have a default and there is a change from NULL to NOT NULL - four_way_default_alteration = new_field.has_default() and ( - old_field.null and not new_field.null - ) + four_way_default_alteration = ( + new_field.has_default() or new_field.db_default is not NOT_PROVIDED + ) and (old_field.null and not new_field.null) if actions or null_actions: if not four_way_default_alteration: # If we don't have to do a 4-way default alteration we can @@ -1074,15 +1115,20 @@ class BaseDatabaseSchemaEditor: params, ) if four_way_default_alteration: + if new_field.db_default is NOT_PROVIDED: + default_sql = "%s" + params = [new_default] + else: + default_sql, params = self.db_default_sql(new_field) # Update existing rows with default value self.execute( self.sql_update_with_default % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(new_field.column), - "default": "%s", + "default": default_sql, }, - [new_default], + params, ) # Since we didn't run a NOT NULL change before we need to do it # now @@ -1264,6 +1310,34 @@ class BaseDatabaseSchemaEditor: params, ) + def _alter_column_database_default_sql( + self, model, old_field, new_field, drop=False + ): + """ + Hook to specialize column database default alteration. + + Return a (sql, params) fragment to add or drop (depending on the drop + argument) a default to new_field's column. + """ + if drop: + sql = self.sql_alter_column_no_default + default_sql = "" + params = [] + else: + sql = self.sql_alter_column_default + default_sql, params = self.db_default_sql(new_field) + + new_db_params = new_field.db_parameters(connection=self.connection) + return ( + sql + % { + "column": self.quote_name(new_field.column), + "type": new_db_params["type"], + "default": default_sql, + }, + params, + ) + def _alter_column_type_sql( self, model, old_field, new_field, new_type, old_collation, new_collation ): diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 9e17d33e93..0bb0f91f55 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -51,6 +51,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): # COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an # indexed expression. collate_as_index_expression = True + insert_test_table_with_defaults = "INSERT INTO {} () VALUES ()" supports_order_by_nulls_modifier = False order_by_nulls_first = True @@ -342,3 +343,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): if self.connection.mysql_is_mariadb: return self.connection.mysql_version >= (10, 5, 2) return True + + @cached_property + def supports_expression_defaults(self): + if self.connection.mysql_is_mariadb: + return True + return self.connection.mysql_version >= (8, 0, 13) diff --git a/django/db/backends/mysql/schema.py b/django/db/backends/mysql/schema.py index 31829506c1..bfe5a2e805 100644 --- a/django/db/backends/mysql/schema.py +++ b/django/db/backends/mysql/schema.py @@ -209,11 +209,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): self._create_missing_fk_index(model, fields=fields) return super()._delete_composed_index(model, fields, *args) - def _set_field_new_type_null_status(self, field, new_type): + def _set_field_new_type(self, field, new_type): """ - Keep the null property of the old field. If it has changed, it will be - handled separately. + Keep the NULL and DEFAULT properties of the old field. If it has + changed, it will be handled separately. """ + if field.db_default is not NOT_PROVIDED: + default_sql, params = self.db_default_sql(field) + default_sql %= tuple(self.quote_value(p) for p in params) + new_type += f" DEFAULT {default_sql}" if field.null: new_type += " NULL" else: @@ -223,7 +227,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _alter_column_type_sql( self, model, old_field, new_field, new_type, old_collation, new_collation ): - new_type = self._set_field_new_type_null_status(old_field, new_type) + new_type = self._set_field_new_type(old_field, new_type) return super()._alter_column_type_sql( model, old_field, new_field, new_type, old_collation, new_collation ) @@ -242,7 +246,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): return field_db_params["check"] def _rename_field_sql(self, table, old_field, new_field, new_type): - new_type = self._set_field_new_type_null_status(old_field, new_type) + new_type = self._set_field_new_type(old_field, new_type) return super()._rename_field_sql(table, old_field, new_field, new_type) def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment): @@ -252,3 +256,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _comment_sql(self, comment): comment_sql = super()._comment_sql(comment) return f" COMMENT {comment_sql}" + + def _alter_column_null_sql(self, model, old_field, new_field): + if new_field.db_default is NOT_PROVIDED: + return super()._alter_column_null_sql(model, old_field, new_field) + + new_db_params = new_field.db_parameters(connection=self.connection) + type_sql = self._set_field_new_type(new_field, new_db_params["type"]) + return ( + "MODIFY %(column)s %(type)s" + % { + "column": self.quote_name(new_field.column), + "type": type_sql, + }, + [], + ) diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 05dc552a98..2ef9e4300c 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): atomic_transactions = False nulls_order_largest = True requires_literal_defaults = True + supports_default_keyword_in_bulk_insert = False closed_cursor_error_class = InterfaceError bare_select_suffix = " FROM DUAL" # Select for update with limit can be achieved on Oracle, but not with the @@ -130,6 +131,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): "annotations.tests.NonAggregateAnnotationTestCase." "test_custom_functions_can_ref_other_functions", } + insert_test_table_with_defaults = ( + "INSERT INTO {} VALUES (DEFAULT, DEFAULT, DEFAULT)" + ) @cached_property def introspected_field_types(self): diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 5d1e3e6761..c4a734f7ec 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -156,7 +156,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): field_map = { column: ( display_size, - default if default != "NULL" else None, + default.rstrip() if default and default != "NULL" else None, collation, is_autofield, is_json, diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py index 0d70522a2a..c8dd64650f 100644 --- a/django/db/backends/oracle/schema.py +++ b/django/db/backends/oracle/schema.py @@ -198,7 +198,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): return self.normalize_name(for_name + "_" + suffix) def prepare_default(self, value): - return self.quote_value(value) + # Replace % with %% as %-formatting is applied in + # FormatStylePlaceholderCursor._fix_for_params(). + return self.quote_value(value).replace("%", "%%") def _field_should_be_indexed(self, model, field): create_index = super()._field_should_be_indexed(model, field) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 732b30b0a4..29b6a4f6c5 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -76,6 +76,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "swedish_ci": "sv-x-icu", } test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'" + insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES" django_test_skips = { "opclasses are PostgreSQL only.": { diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 7dd1c39702..f471b72cb2 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -59,6 +59,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): PRIMARY KEY(column_1, column_2) ) """ + insert_test_table_with_defaults = 'INSERT INTO {} ("null") VALUES (1)' + supports_default_keyword_in_insert = False @cached_property def django_test_skips(self): diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 2ca9a01855..46ba07092d 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -6,7 +6,7 @@ from django.db import NotSupportedError from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.ddl_references import Statement from django.db.backends.utils import strip_quotes -from django.db.models import UniqueConstraint +from django.db.models import NOT_PROVIDED, UniqueConstraint from django.db.transaction import atomic @@ -233,9 +233,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if create_field: body[create_field.name] = create_field # Choose a default and insert it into the copy map - if not create_field.many_to_many and create_field.concrete: + if ( + create_field.db_default is NOT_PROVIDED + and not create_field.many_to_many + and create_field.concrete + ): mapping[create_field.column] = self.prepare_default( - self.effective_default(create_field), + self.effective_default(create_field) ) # Add in any altered fields for alter_field in alter_fields: @@ -244,9 +248,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): mapping.pop(old_field.column, None) body[new_field.name] = new_field if old_field.null and not new_field.null: + if new_field.db_default is NOT_PROVIDED: + default = self.prepare_default(self.effective_default(new_field)) + else: + default, _ = self.db_default_sql(new_field) case_sql = "coalesce(%(col)s, %(default)s)" % { "col": self.quote_name(old_field.column), - "default": self.prepare_default(self.effective_default(new_field)), + "default": default, } mapping[new_field.column] = case_sql else: @@ -381,6 +389,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def add_field(self, model, field): """Create a field on a model.""" + from django.db.models.expressions import Value + # Special-case implicit M2M tables. if field.many_to_many and field.remote_field.through._meta.auto_created: self.create_model(field.remote_field.through) @@ -394,6 +404,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # COLUMN statement because DROP DEFAULT is not supported in # ALTER TABLE. or self.effective_default(field) is not None + # Fields with non-constant defaults cannot by handled by ALTER + # TABLE ADD COLUMN statement. + or ( + field.db_default is not NOT_PROVIDED + and not isinstance(field.db_default, Value) + ) ): self._remake_table(model, create_field=field) else: diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 23c97e5474..154ac44419 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -1040,6 +1040,7 @@ class MigrationAutodetector: preserve_default = ( field.null or field.has_default() + or field.db_default is not models.NOT_PROVIDED or field.many_to_many or (field.blank and field.empty_strings_allowed) or (isinstance(field, time_fields) and field.auto_now) @@ -1187,6 +1188,7 @@ class MigrationAutodetector: old_field.null and not new_field.null and not new_field.has_default() + and new_field.db_default is models.NOT_PROVIDED and not new_field.many_to_many ): field = new_field.clone() diff --git a/django/db/models/base.py b/django/db/models/base.py index 344508e0e2..7aabe0b667 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -971,8 +971,10 @@ class Model(AltersData, metaclass=ModelBase): not raw and not force_insert and self._state.adding - and meta.pk.default - and meta.pk.default is not NOT_PROVIDED + and ( + (meta.pk.default and meta.pk.default is not NOT_PROVIDED) + or (meta.pk.db_default and meta.pk.db_default is not NOT_PROVIDED) + ) ): force_insert = True # If possible, try an UPDATE. If that doesn't update anything, do an INSERT. diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index d412e7657e..e1861759c4 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -176,6 +176,8 @@ class BaseExpression: filterable = True # Can the expression can be used as a source expression in Window? window_compatible = False + # Can the expression be used as a database default value? + allowed_default = False def __init__(self, output_field=None): if output_field is not None: @@ -733,6 +735,10 @@ class CombinedExpression(SQLiteNumericMixin, Expression): c.rhs = rhs return c + @cached_property + def allowed_default(self): + return self.lhs.allowed_default and self.rhs.allowed_default + class DurationExpression(CombinedExpression): def compile(self, side, compiler, connection): @@ -804,6 +810,8 @@ class TemporalSubtraction(CombinedExpression): class F(Combinable): """An object capable of resolving references to existing query objects.""" + allowed_default = False + def __init__(self, name): """ Arguments: @@ -987,6 +995,10 @@ class Func(SQLiteNumericMixin, Expression): copy.extra = self.extra.copy() return copy + @cached_property + def allowed_default(self): + return all(expression.allowed_default for expression in self.source_expressions) + @deconstructible(path="django.db.models.Value") class Value(SQLiteNumericMixin, Expression): @@ -995,6 +1007,7 @@ class Value(SQLiteNumericMixin, Expression): # Provide a default value for `for_save` in order to allow unresolved # instances to be compiled until a decision is taken in #25425. for_save = False + allowed_default = True def __init__(self, value, output_field=None): """ @@ -1069,6 +1082,8 @@ class Value(SQLiteNumericMixin, Expression): class RawSQL(Expression): + allowed_default = True + def __init__(self, sql, params, output_field=None): if output_field is None: output_field = fields.Field() @@ -1110,6 +1125,13 @@ class Star(Expression): return "*", [] +class DatabaseDefault(Expression): + """Placeholder expression for the database default in an insert query.""" + + def as_sql(self, compiler, connection): + return "DEFAULT", [] + + class Col(Expression): contains_column_references = True possibly_multivalued = False @@ -1213,6 +1235,7 @@ class ExpressionList(Func): class OrderByList(Func): + allowed_default = False template = "ORDER BY %(expressions)s" def __init__(self, *expressions, **extra): @@ -1270,6 +1293,10 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression): def __repr__(self): return "{}({})".format(self.__class__.__name__, self.expression) + @property + def allowed_default(self): + return self.expression.allowed_default + class NegatedExpression(ExpressionWrapper): """The logical negation of a conditional expression.""" @@ -1397,6 +1424,10 @@ class When(Expression): cols.extend(source.get_group_by_cols()) return cols + @cached_property + def allowed_default(self): + return self.condition.allowed_default and self.result.allowed_default + @deconstructible(path="django.db.models.Case") class Case(SQLiteNumericMixin, Expression): @@ -1494,6 +1525,12 @@ class Case(SQLiteNumericMixin, Expression): return self.default.get_group_by_cols() return super().get_group_by_cols() + @cached_property + def allowed_default(self): + return self.default.allowed_default and all( + case_.allowed_default for case_ in self.cases + ) + class Subquery(BaseExpression, Combinable): """ diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 4416898d80..18b48c0e72 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -202,6 +202,7 @@ class Field(RegisterLookupMixin): validators=(), error_messages=None, db_comment=None, + db_default=NOT_PROVIDED, ): self.name = name self.verbose_name = verbose_name # May be set by set_attributes_from_name @@ -212,6 +213,13 @@ class Field(RegisterLookupMixin): self.remote_field = rel self.is_relation = self.remote_field is not None self.default = default + if db_default is not NOT_PROVIDED and not hasattr( + db_default, "resolve_expression" + ): + from django.db.models.expressions import Value + + db_default = Value(db_default) + self.db_default = db_default self.editable = editable self.serialize = serialize self.unique_for_date = unique_for_date @@ -263,6 +271,7 @@ class Field(RegisterLookupMixin): return [ *self._check_field_name(), *self._check_choices(), + *self._check_db_default(**kwargs), *self._check_db_index(), *self._check_db_comment(**kwargs), *self._check_null_allowed_for_primary_keys(), @@ -379,6 +388,39 @@ class Field(RegisterLookupMixin): ) ] + def _check_db_default(self, databases=None, **kwargs): + from django.db.models.expressions import Value + + if ( + self.db_default is NOT_PROVIDED + or isinstance(self.db_default, Value) + or databases is None + ): + return [] + errors = [] + for db in databases: + if not router.allow_migrate_model(db, self.model): + continue + connection = connections[db] + + if not getattr(self.db_default, "allowed_default", False) and ( + connection.features.supports_expression_defaults + ): + msg = f"{self.db_default} cannot be used in db_default." + errors.append(checks.Error(msg, obj=self, id="fields.E012")) + + if not ( + connection.features.supports_expression_defaults + or "supports_expression_defaults" + in self.model._meta.required_db_features + ): + msg = ( + f"{connection.display_name} does not support default database " + "values with expressions (db_default)." + ) + errors.append(checks.Error(msg, obj=self, id="fields.E011")) + return errors + def _check_db_index(self): if self.db_index not in (None, True, False): return [ @@ -558,6 +600,7 @@ class Field(RegisterLookupMixin): "null": False, "db_index": False, "default": NOT_PROVIDED, + "db_default": NOT_PROVIDED, "editable": True, "serialize": True, "unique_for_date": None, @@ -876,7 +919,10 @@ class Field(RegisterLookupMixin): @property def db_returning(self): """Private API intended only to be used by Django itself.""" - return False + return ( + self.db_default is not NOT_PROVIDED + and connection.features.can_return_columns_from_insert + ) def set_attributes_from_name(self, name): self.name = self.name or name @@ -929,7 +975,13 @@ class Field(RegisterLookupMixin): def pre_save(self, model_instance, add): """Return field's value just before saving.""" - return getattr(model_instance, self.attname) + value = getattr(model_instance, self.attname) + if not connection.features.supports_default_keyword_in_insert: + from django.db.models.expressions import DatabaseDefault + + if isinstance(value, DatabaseDefault): + return self.db_default + return value def get_prep_value(self, value): """Perform preliminary non-db specific value checks and conversions.""" @@ -968,6 +1020,11 @@ class Field(RegisterLookupMixin): return self.default return lambda: self.default + if self.db_default is not NOT_PROVIDED: + from django.db.models.expressions import DatabaseDefault + + return DatabaseDefault + if ( not self.empty_strings_allowed or self.null diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index de7eef4cdc..108d904712 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -105,6 +105,7 @@ class Coalesce(Func): class Collate(Func): function = "COLLATE" template = "%(expressions)s %(function)s %(collation)s" + allowed_default = False # Inspired from # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS collation_re = _lazy_re_compile(r"^[\w\-]+$") diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 46ebe3f3a2..91342a864a 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -185,6 +185,10 @@ class Lookup(Expression): sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" return sql, params + @cached_property + def allowed_default(self): + return self.lhs.allowed_default and self.rhs.allowed_default + class Transform(RegisterLookupMixin, Func): """ diff --git a/django/db/models/query.py b/django/db/models/query.py index 56ad4d5c20..a5b0f464a9 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -654,10 +654,19 @@ class QuerySet(AltersData): return await sync_to_async(self.create)(**kwargs) def _prepare_for_bulk_create(self, objs): + from django.db.models.expressions import DatabaseDefault + + connection = connections[self.db] for obj in objs: if obj.pk is None: # Populate new PK values. obj.pk = obj._meta.pk.get_pk_value_on_save(obj) + if not connection.features.supports_default_keyword_in_bulk_insert: + for field in obj._meta.fields: + value = getattr(obj, field.attname) + if isinstance(value, DatabaseDefault): + setattr(obj, field.attname, field.db_default) + obj._prepare_related_fields_for_save(operation_name="bulk_create") def _check_bulk_create_options( diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index 9e350a3ff3..df0adbef63 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -175,6 +175,9 @@ Model fields ``choices`` (```` characters). * **fields.E010**: ```` default should be a callable instead of an instance so that it's not shared between all field instances. +* **fields.E011**: ```` does not support default database values with + expressions (``db_default``). +* **fields.E012**: ```` cannot be used in ``db_default``. * **fields.E100**: ``AutoField``\s must set primary_key=True. * **fields.E110**: ``BooleanField``\s do not accept null values. *This check appeared before support for null values was added in Django 2.1.* diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 52a1022771..50560dfa9b 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -996,6 +996,13 @@ calling the appropriate methods on the wrapped expression. .. class:: Expression + .. attribute:: allowed_default + + .. versionadded:: 5.0 + + Tells Django that this expression can be used in + :attr:`Field.db_default`. Defaults to ``False``. + .. attribute:: contains_aggregate Tells Django that this expression contains an aggregate and that a diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 27b87c1f53..344cc45280 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -351,6 +351,38 @@ looking at your Django code. For example:: db_comment="Date and time when the article was published", ) +``db_default`` +-------------- + +.. versionadded:: 5.0 + +.. attribute:: Field.db_default + +The database-computed default value for this field. This can be a literal value +or a database function, such as :class:`~django.db.models.functions.Now`:: + + created = models.DateTimeField(db_default=Now()) + +More complex expressions can be used, as long as they are made from literals +and database functions:: + + month_due = models.DateField( + db_default=TruncMonth( + Now() + timedelta(days=90), + output_field=models.DateField(), + ) + ) + +Database defaults cannot reference other fields or models. For example, this is +invalid:: + + end = models.IntegerField(db_default=F("start") + 50) + +If both ``db_default`` and :attr:`Field.default` are set, ``default`` will take +precedence when creating instances in Python code. ``db_default`` will still be +set at the database level and will be used when inserting rows outside of the +ORM or when adding a new field in a migration. + ``db_index`` ------------ @@ -408,6 +440,9 @@ The default value is used when new model instances are created and a value isn't provided for the field. When the field is a primary key, the default is also used when the field is set to ``None``. +The default value can also be set at the database level with +:attr:`Field.db_default`. + ``editable`` ------------ diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index d03b05577c..346ae55130 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -541,7 +541,8 @@ You may have noticed Django database objects use the same ``save()`` method for creating and changing objects. Django abstracts the need to use ``INSERT`` or ``UPDATE`` SQL statements. Specifically, when you call ``save()`` and the object's primary key attribute does **not** define a -:attr:`~django.db.models.Field.default`, Django follows this algorithm: +:attr:`~django.db.models.Field.default` or +:attr:`~django.db.models.Field.db_default`, Django follows this algorithm: * If the object's primary key attribute is set to a value that evaluates to ``True`` (i.e., a value other than ``None`` or the empty string), Django @@ -551,9 +552,10 @@ object's primary key attribute does **not** define a exist in the database), Django executes an ``INSERT``. If the object's primary key attribute defines a -:attr:`~django.db.models.Field.default` then Django executes an ``UPDATE`` if -it is an existing model instance and primary key is set to a value that exists -in the database. Otherwise, Django executes an ``INSERT``. +:attr:`~django.db.models.Field.default` or +:attr:`~django.db.models.Field.db_default` then Django executes an ``UPDATE`` +if it is an existing model instance and primary key is set to a value that +exists in the database. Otherwise, Django executes an ``INSERT``. The one gotcha here is that you should be careful not to specify a primary-key value explicitly when saving new objects, if you cannot guarantee the @@ -570,6 +572,10 @@ which returns ``NULL``. In such cases it is possible to revert to the old algorithm by setting the :attr:`~django.db.models.Options.select_on_save` option to ``True``. +.. versionchanged:: 5.0 + + The ``Field.db_default`` parameter was added. + .. _ref-models-force-insert: Forcing an INSERT or UPDATE diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index f23f39b014..d40cd6a4f0 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -108,6 +108,21 @@ Can now be simplified to: per-project, per-field, or per-request basis. See :ref:`reusable-field-group-templates`. +Database-computed default values +-------------------------------- + +The new :attr:`Field.db_default ` parameter +sets a database-computed default value. For example:: + + from django.db import models + from django.db.models.functions import Now, Pi + + + class MyModel(models.Model): + age = models.IntegerField(db_default=18) + created = models.DateTimeField(db_default=Now()) + circumference = models.FloatField(db_default=2 * Pi()) + Minor features -------------- @@ -355,7 +370,16 @@ Database backend API This section describes changes that may be needed in third-party database backends. -* ... +* ``DatabaseFeatures.supports_expression_defaults`` should be set to ``False`` + if the database doesn't support using database functions as defaults. + +* ``DatabaseFeatures.supports_default_keyword_in_insert`` should be set to + ``False`` if the database doesn't support the ``DEFAULT`` keyword in + ``INSERT`` queries. + +* ``DatabaseFeatures.supports_default_keyword_in_bulk insert`` should be set to + ``False`` if the database doesn't support the ``DEFAULT`` keyword in bulk + ``INSERT`` queries. Using ``create_defaults__exact`` may now be required with ``QuerySet.update_or_create()`` ----------------------------------------------------------------------------------------- diff --git a/tests/basic/models.py b/tests/basic/models.py index 59a6a8d67f..b71b60a213 100644 --- a/tests/basic/models.py +++ b/tests/basic/models.py @@ -49,5 +49,9 @@ class PrimaryKeyWithDefault(models.Model): uuid = models.UUIDField(primary_key=True, default=uuid.uuid4) +class PrimaryKeyWithDbDefault(models.Model): + uuid = models.IntegerField(primary_key=True, db_default=1) + + class ChildPrimaryKeyWithDefault(PrimaryKeyWithDefault): pass diff --git a/tests/basic/tests.py b/tests/basic/tests.py index ea9228376c..3c2d1dead9 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -20,6 +20,7 @@ from .models import ( ArticleSelectOnSave, ChildPrimaryKeyWithDefault, FeaturedArticle, + PrimaryKeyWithDbDefault, PrimaryKeyWithDefault, SelfRef, ) @@ -175,6 +176,11 @@ class ModelInstanceCreationTests(TestCase): with self.assertNumQueries(1): PrimaryKeyWithDefault().save() + def test_save_primary_with_db_default(self): + # An UPDATE attempt is skipped when a primary key has db_default. + with self.assertNumQueries(1): + PrimaryKeyWithDbDefault().save() + def test_save_parent_primary_with_default(self): # An UPDATE attempt is skipped when an inherited primary key has # default. diff --git a/tests/field_defaults/models.py b/tests/field_defaults/models.py index b95005192a..5f9c38a5a4 100644 --- a/tests/field_defaults/models.py +++ b/tests/field_defaults/models.py @@ -12,6 +12,8 @@ field. from datetime import datetime from django.db import models +from django.db.models.functions import Coalesce, ExtractYear, Now, Pi +from django.db.models.lookups import GreaterThan class Article(models.Model): @@ -20,3 +22,45 @@ class Article(models.Model): def __str__(self): return self.headline + + +class DBArticle(models.Model): + """ + Values or expressions can be passed as the db_default parameter to a field. + When the object is created without an explicit value passed in, the + database will insert the default value automatically. + """ + + headline = models.CharField(max_length=100, db_default="Default headline") + pub_date = models.DateTimeField(db_default=Now()) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + +class DBDefaults(models.Model): + both = models.IntegerField(default=1, db_default=2) + null = models.FloatField(null=True, db_default=1.1) + + +class DBDefaultsFunction(models.Model): + number = models.FloatField(db_default=Pi()) + year = models.IntegerField(db_default=ExtractYear(Now())) + added = models.FloatField(db_default=Pi() + 4.5) + multiple_subfunctions = models.FloatField(db_default=Coalesce(4.5, Pi())) + case_when = models.IntegerField( + db_default=models.Case(models.When(GreaterThan(2, 1), then=3), default=4) + ) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + +class DBDefaultsPK(models.Model): + language_code = models.CharField(primary_key=True, max_length=2, db_default="en") + + +class DBDefaultsFK(models.Model): + language_code = models.ForeignKey( + DBDefaultsPK, db_default="fr", on_delete=models.CASCADE + ) diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py index 19b05aa537..76d01f7a5a 100644 --- a/tests/field_defaults/tests.py +++ b/tests/field_defaults/tests.py @@ -1,8 +1,28 @@ from datetime import datetime +from math import pi -from django.test import TestCase +from django.db import connection +from django.db.models import Case, F, FloatField, Value, When +from django.db.models.expressions import ( + Expression, + ExpressionList, + ExpressionWrapper, + Func, + OrderByList, + RawSQL, +) +from django.db.models.functions import Collate +from django.db.models.lookups import GreaterThan +from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature -from .models import Article +from .models import ( + Article, + DBArticle, + DBDefaults, + DBDefaultsFK, + DBDefaultsFunction, + DBDefaultsPK, +) class DefaultTests(TestCase): @@ -14,3 +34,171 @@ class DefaultTests(TestCase): self.assertIsInstance(a.id, int) self.assertEqual(a.headline, "Default headline") self.assertLess((now - a.pub_date).seconds, 5) + + @skipUnlessDBFeature( + "can_return_columns_from_insert", "supports_expression_defaults" + ) + def test_field_db_defaults_returning(self): + a = DBArticle() + a.save() + self.assertIsInstance(a.id, int) + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + + @skipIfDBFeature("can_return_columns_from_insert") + @skipUnlessDBFeature("supports_expression_defaults") + def test_field_db_defaults_refresh(self): + a = DBArticle() + a.save() + a.refresh_from_db() + self.assertIsInstance(a.id, int) + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + + def test_null_db_default(self): + obj1 = DBDefaults.objects.create() + if not connection.features.can_return_columns_from_insert: + obj1.refresh_from_db() + self.assertEqual(obj1.null, 1.1) + + obj2 = DBDefaults.objects.create(null=None) + self.assertIsNone(obj2.null) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_function(self): + m = DBDefaultsFunction.objects.create() + if not connection.features.can_return_columns_from_insert: + m.refresh_from_db() + self.assertAlmostEqual(m.number, pi) + self.assertEqual(m.year, datetime.now().year) + self.assertAlmostEqual(m.added, pi + 4.5) + self.assertEqual(m.multiple_subfunctions, 4.5) + + @skipUnlessDBFeature("insert_test_table_with_defaults") + def test_both_default(self): + create_sql = connection.features.insert_test_table_with_defaults + with connection.cursor() as cursor: + cursor.execute(create_sql.format(DBDefaults._meta.db_table)) + obj1 = DBDefaults.objects.get() + self.assertEqual(obj1.both, 2) + + obj2 = DBDefaults.objects.create() + self.assertEqual(obj2.both, 1) + + def test_pk_db_default(self): + obj1 = DBDefaultsPK.objects.create() + if not connection.features.can_return_columns_from_insert: + # refresh_from_db() cannot be used because that needs the pk to + # already be known to Django. + obj1 = DBDefaultsPK.objects.get(pk="en") + self.assertEqual(obj1.pk, "en") + self.assertEqual(obj1.language_code, "en") + + obj2 = DBDefaultsPK.objects.create(language_code="de") + self.assertEqual(obj2.pk, "de") + self.assertEqual(obj2.language_code, "de") + + def test_foreign_key_db_default(self): + parent1 = DBDefaultsPK.objects.create(language_code="fr") + child1 = DBDefaultsFK.objects.create() + if not connection.features.can_return_columns_from_insert: + child1.refresh_from_db() + self.assertEqual(child1.language_code, parent1) + + parent2 = DBDefaultsPK.objects.create() + if not connection.features.can_return_columns_from_insert: + # refresh_from_db() cannot be used because that needs the pk to + # already be known to Django. + parent2 = DBDefaultsPK.objects.get(pk="en") + child2 = DBDefaultsFK.objects.create(language_code=parent2) + self.assertEqual(child2.language_code, parent2) + + @skipUnlessDBFeature( + "can_return_columns_from_insert", "supports_expression_defaults" + ) + def test_case_when_db_default_returning(self): + m = DBDefaultsFunction.objects.create() + self.assertEqual(m.case_when, 3) + + @skipIfDBFeature("can_return_columns_from_insert") + @skipUnlessDBFeature("supports_expression_defaults") + def test_case_when_db_default_no_returning(self): + m = DBDefaultsFunction.objects.create() + m.refresh_from_db() + self.assertEqual(m.case_when, 3) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_all_db_defaults(self): + articles = [DBArticle(), DBArticle()] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", flat=True) + self.assertSequenceEqual(headlines, ["Default headline", "Default headline"]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_all_db_defaults_one_field(self): + pub_date = datetime.now() + articles = [DBArticle(pub_date=pub_date), DBArticle(pub_date=pub_date)] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", "pub_date") + self.assertSequenceEqual( + headlines, + [ + ("Default headline", pub_date), + ("Default headline", pub_date), + ], + ) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_mixed_db_defaults(self): + articles = [DBArticle(), DBArticle(headline="Something else")] + DBArticle.objects.bulk_create(articles) + + headlines = DBArticle.objects.values_list("headline", flat=True) + self.assertCountEqual(headlines, ["Default headline", "Something else"]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_bulk_create_mixed_db_defaults_function(self): + instances = [DBDefaultsFunction(), DBDefaultsFunction(year=2000)] + DBDefaultsFunction.objects.bulk_create(instances) + + years = DBDefaultsFunction.objects.values_list("year", flat=True) + self.assertCountEqual(years, [2000, datetime.now().year]) + + +class AllowedDefaultTests(SimpleTestCase): + def test_allowed(self): + class Max(Func): + function = "MAX" + + tests = [ + Value(10), + Max(1, 2), + RawSQL("Now()", ()), + Value(10) + Value(7), # Combined expression. + ExpressionList(Value(1), Value(2)), + ExpressionWrapper(Value(1), output_field=FloatField()), + Case(When(GreaterThan(2, 1), then=3), default=4), + ] + for expression in tests: + with self.subTest(expression=expression): + self.assertIs(expression.allowed_default, True) + + def test_disallowed(self): + class Max(Func): + function = "MAX" + + tests = [ + Expression(), + F("field"), + Max(F("count"), 1), + Value(10) + F("count"), # Combined expression. + ExpressionList(F("count"), Value(2)), + ExpressionWrapper(F("count"), output_field=FloatField()), + Collate(Value("John"), "nocase"), + OrderByList("field"), + ] + for expression in tests: + with self.subTest(expression=expression): + self.assertIs(expression.allowed_default, False) diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 4e37c48286..e9e8a702e0 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -4,6 +4,7 @@ import uuid from django.core.checks import Error from django.core.checks import Warning as DjangoWarning from django.db import connection, models +from django.db.models.functions import Coalesce, Pi from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import isolate_apps, override_settings from django.utils.functional import lazy @@ -1057,3 +1058,109 @@ class DbCommentTests(TestCase): errors = Model._meta.get_field("field").check(databases=self.databases) self.assertEqual(errors, []) + + +@isolate_apps("invalid_models_tests") +class InvalidDBDefaultTests(TestCase): + def test_db_default(self): + class Model(models.Model): + field = models.FloatField(db_default=Pi()) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + expected_errors = [] + else: + msg = ( + f"{connection.display_name} does not support default database values " + "with expressions (db_default)." + ) + expected_errors = [Error(msg=msg, obj=field, id="fields.E011")] + self.assertEqual(errors, expected_errors) + + def test_db_default_literal(self): + class Model(models.Model): + field = models.IntegerField(db_default=1) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + self.assertEqual(errors, []) + + def test_db_default_required_db_features(self): + class Model(models.Model): + field = models.FloatField(db_default=Pi()) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + self.assertEqual(errors, []) + + def test_db_default_expression_invalid(self): + expression = models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + msg = f"{expression} cannot be used in db_default." + expected_errors = [Error(msg=msg, obj=field, id="fields.E012")] + else: + msg = ( + f"{connection.display_name} does not support default database values " + "with expressions (db_default)." + ) + expected_errors = [Error(msg=msg, obj=field, id="fields.E011")] + self.assertEqual(errors, expected_errors) + + def test_db_default_expression_required_db_features(self): + expression = models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + class Meta: + required_db_features = {"supports_expression_defaults"} + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + if connection.features.supports_expression_defaults: + msg = f"{expression} cannot be used in db_default." + expected_errors = [Error(msg=msg, obj=field, id="fields.E012")] + else: + expected_errors = [] + self.assertEqual(errors, expected_errors) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_combined_invalid(self): + expression = models.Value(4.5) + models.F("field_name") + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + msg = f"{expression} cannot be used in db_default." + expected_error = Error(msg=msg, obj=field, id="fields.E012") + self.assertEqual(errors, [expected_error]) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_db_default_function_arguments_invalid(self): + expression = Coalesce(models.Value(4.5), models.F("field_name")) + + class Model(models.Model): + field = models.FloatField(db_default=expression) + + field = Model._meta.get_field("field") + errors = field.check(databases=self.databases) + + msg = f"{expression} cannot be used in db_default." + expected_error = Error(msg=msg, obj=field, id="fields.E012") + self.assertEqual(errors, [expected_error]) diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index ee199fea68..74892bbf3d 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -269,6 +269,14 @@ class AutodetectorTests(BaseAutodetectorTests): ("name", models.CharField(max_length=200, default="Ada Lovelace")), ], ) + author_name_db_default = ModelState( + "testapp", + "Author", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=200, db_default="Ada Lovelace")), + ], + ) author_name_check_constraint = ModelState( "testapp", "Author", @@ -1289,6 +1297,21 @@ class AutodetectorTests(BaseAutodetectorTests): self.assertOperationTypes(changes, "testapp", 0, ["AddField"]) self.assertOperationAttributes(changes, "testapp", 0, 0, name="name") + @mock.patch( + "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition", + side_effect=AssertionError("Should not have prompted for not null addition"), + ) + def test_add_not_null_field_with_db_default(self, mocked_ask_method): + changes = self.get_changes([self.author_empty], [self.author_name_db_default]) + self.assertNumberMigrations(changes, "testapp", 1) + self.assertOperationTypes(changes, "testapp", 0, ["AddField"]) + self.assertOperationAttributes( + changes, "testapp", 0, 0, name="name", preserve_default=True + ) + self.assertOperationFieldAttributes( + changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace") + ) + @mock.patch( "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition", side_effect=AssertionError("Should not have prompted for not null addition"), @@ -1478,6 +1501,23 @@ class AutodetectorTests(BaseAutodetectorTests): changes, "testapp", 0, 0, default="Ada Lovelace" ) + @mock.patch( + "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration", + side_effect=AssertionError("Should not have prompted for not null alteration"), + ) + def test_alter_field_to_not_null_with_db_default(self, mocked_ask_method): + changes = self.get_changes( + [self.author_name_null], [self.author_name_db_default] + ) + self.assertNumberMigrations(changes, "testapp", 1) + self.assertOperationTypes(changes, "testapp", 0, ["AlterField"]) + self.assertOperationAttributes( + changes, "testapp", 0, 0, name="name", preserve_default=True + ) + self.assertOperationFieldAttributes( + changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace") + ) + @mock.patch( "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration", return_value=models.NOT_PROVIDED, diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index f038cd7605..b5228ad445 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -292,6 +292,13 @@ class OperationTestBase(MigrationTestBase): ("id", models.AutoField(primary_key=True)), ("pink", models.IntegerField(default=3)), ("weight", models.FloatField()), + ("green", models.IntegerField(null=True)), + ( + "yellow", + models.CharField( + blank=True, null=True, db_default="Yellow", max_length=20 + ), + ), ], options=model_options, ) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index b67a871bc8..e377e4ca64 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,14 +1,18 @@ +import math + from django.core.exceptions import FieldDoesNotExist from django.db import IntegrityError, connection, migrations, models, transaction from django.db.migrations.migration import Migration from django.db.migrations.operations.fields import FieldOperation from django.db.migrations.state import ModelState, ProjectState -from django.db.models.functions import Abs +from django.db.models.expressions import Value +from django.db.models.functions import Abs, Pi from django.db.transaction import atomic from django.test import ( SimpleTestCase, ignore_warnings, override_settings, + skipIfDBFeature, skipUnlessDBFeature, ) from django.test.utils import CaptureQueriesContext @@ -1340,7 +1344,7 @@ class OperationTests(OperationTestBase): self.assertEqual(operation.describe(), "Add field height to Pony") self.assertEqual(operation.migration_name_fragment, "pony_height") project_state, new_state = self.make_test_state("test_adfl", operation) - self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 6) field = new_state.models["test_adfl", "pony"].fields["height"] self.assertEqual(field.default, 5) # Test the database alteration @@ -1528,7 +1532,7 @@ class OperationTests(OperationTestBase): ) new_state = project_state.clone() operation.state_forwards("test_adflpd", new_state) - self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 6) field = new_state.models["test_adflpd", "pony"].fields["height"] self.assertEqual(field.default, models.NOT_PROVIDED) # Test the database alteration @@ -1547,6 +1551,169 @@ class OperationTests(OperationTestBase): sorted(definition[2]), ["field", "model_name", "name", "preserve_default"] ) + def test_add_field_database_default(self): + """The AddField operation can set and unset a database default.""" + app_label = "test_adfldd" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(null=True, db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Value(4)) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertEqual(old_pony.height, 4) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.height, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNotExists(table_name, "height") + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AddField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": field, + "model_name": "Pony", + "name": "height", + }, + ) + + def test_add_field_database_default_special_char_escaping(self): + app_label = "test_adflddsce" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + old_pony_pk = ( + project_state.apps.get_model(app_label, "pony").objects.create(weight=4).pk + ) + tests = ["%", "'", '"'] + for db_default in tests: + with self.subTest(db_default=db_default): + operation = migrations.AddField( + "Pony", + "special_char", + models.CharField(max_length=1, db_default=db_default), + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["special_char"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Value(db_default)) + self.assertColumnNotExists(table_name, "special_char") + with connection.schema_editor() as editor: + operation.database_forwards( + app_label, editor, project_state, new_state + ) + self.assertColumnExists(table_name, "special_char") + new_model = new_state.apps.get_model(app_label, "pony") + try: + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.special_char, db_default) + + old_pony = new_model.objects.get(pk=old_pony_pk) + if connection.vendor != "oracle" or db_default != "'": + # The single quotation mark ' is properly quoted and is + # set for new rows on Oracle, however it is not set on + # existing rows. Skip the assertion as it's probably a + # bug in Oracle. + self.assertEqual(old_pony.special_char, db_default) + finally: + with connection.schema_editor() as editor: + operation.database_backwards( + app_label, editor, new_state, project_state + ) + + @skipUnlessDBFeature("supports_expression_defaults") + def test_add_field_database_default_function(self): + app_label = "test_adflddf" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(db_default=Pi()) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, models.NOT_PROVIDED) + self.assertEqual(field.db_default, Pi()) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertAlmostEqual(old_pony.height, math.pi) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertAlmostEqual(old_pony.height, math.pi) + + def test_add_field_both_defaults(self): + """The AddField operation with both default and db_default.""" + app_label = "test_adflbddd" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + operation = migrations.AddField( + "Pony", "height", models.FloatField(default=3, db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + field = new_state.models[app_label, "pony"].fields["height"] + self.assertEqual(field.default, 3) + self.assertEqual(field.db_default, Value(4)) + project_state.apps.get_model(app_label, "pony").objects.create(weight=4) + self.assertColumnNotExists(table_name, "height") + # Add field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(table_name, "height") + new_model = new_state.apps.get_model(app_label, "pony") + old_pony = new_model.objects.get() + self.assertEqual(old_pony.height, 4) + new_pony = new_model.objects.create(weight=5) + if not connection.features.can_return_columns_from_insert: + new_pony.refresh_from_db() + self.assertEqual(new_pony.height, 3) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNotExists(table_name, "height") + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AddField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": field, + "model_name": "Pony", + "name": "height", + }, + ) + def test_add_field_m2m(self): """ Tests the AddField operation with a ManyToManyField. @@ -1558,7 +1725,7 @@ class OperationTests(OperationTestBase): ) new_state = project_state.clone() operation.state_forwards("test_adflmm", new_state) - self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 4) + self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 6) # Test the database alteration self.assertTableNotExists("test_adflmm_pony_stables") with connection.schema_editor() as editor: @@ -1727,7 +1894,7 @@ class OperationTests(OperationTestBase): self.assertEqual(operation.migration_name_fragment, "remove_pony_pink") new_state = project_state.clone() operation.state_forwards("test_rmfl", new_state) - self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 2) + self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 4) # Test the database alteration self.assertColumnExists("test_rmfl_pony", "pink") with connection.schema_editor() as editor: @@ -1934,6 +2101,146 @@ class OperationTests(OperationTestBase): self.assertEqual(definition[1], []) self.assertEqual(sorted(definition[2]), ["field", "model_name", "name"]) + def test_alter_field_add_database_default(self): + app_label = "test_alfladd" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "weight", models.FloatField(db_default=4.5) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_weight = project_state.models[app_label, "pony"].fields["weight"] + self.assertIs(old_weight.db_default, models.NOT_PROVIDED) + new_weight = new_state.models[app_label, "pony"].fields["weight"] + self.assertEqual(new_weight.db_default, Value(4.5)) + with self.assertRaises(IntegrityError), transaction.atomic(): + project_state.apps.get_model(app_label, "pony").objects.create() + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + pony = new_state.apps.get_model(app_label, "pony").objects.create() + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.weight, 4.5) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + with self.assertRaises(IntegrityError), transaction.atomic(): + project_state.apps.get_model(app_label, "pony").objects.create() + # Deconstruction. + definition = operation.deconstruct() + self.assertEqual(definition[0], "AlterField") + self.assertEqual(definition[1], []) + self.assertEqual( + definition[2], + { + "field": new_weight, + "model_name": "Pony", + "name": "weight", + }, + ) + + def test_alter_field_change_default_to_database_default(self): + """The AlterField operation changing default to db_default.""" + app_label = "test_alflcdtdd" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "pink", models.IntegerField(db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_pink = project_state.models[app_label, "pony"].fields["pink"] + self.assertEqual(old_pink.default, 3) + self.assertIs(old_pink.db_default, models.NOT_PROVIDED) + new_pink = new_state.models[app_label, "pony"].fields["pink"] + self.assertIs(new_pink.default, models.NOT_PROVIDED) + self.assertEqual(new_pink.db_default, Value(4)) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.pink, 3) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.pink, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertEqual(pony.pink, 3) + + def test_alter_field_change_nullable_to_database_default_not_null(self): + """ + The AlterField operation changing a null field to db_default. + """ + app_label = "test_alflcntddnn" + project_state = self.set_up_test_model(app_label) + operation = migrations.AlterField( + "Pony", "green", models.IntegerField(db_default=4) + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + old_green = project_state.models[app_label, "pony"].fields["green"] + self.assertIs(old_green.db_default, models.NOT_PROVIDED) + new_green = new_state.models[app_label, "pony"].fields["green"] + self.assertEqual(new_green.db_default, Value(4)) + old_pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1 + ) + self.assertIsNone(old_pony.green) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + old_pony.refresh_from_db() + self.assertEqual(old_pony.green, 4) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.green, 4) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1) + self.assertIsNone(pony.green) + + @skipIfDBFeature("interprets_empty_strings_as_nulls") + def test_alter_field_change_blank_nullable_database_default_to_not_null(self): + app_label = "test_alflcbnddnn" + table_name = f"{app_label}_pony" + project_state = self.set_up_test_model(app_label) + default = "Yellow" + operation = migrations.AlterField( + "Pony", + "yellow", + models.CharField(blank=True, db_default=default, max_length=20), + ) + new_state = project_state.clone() + operation.state_forwards(app_label, new_state) + self.assertColumnNull(table_name, "yellow") + pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1, yellow=None + ) + self.assertIsNone(pony.yellow) + # Alter field. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnNotNull(table_name, "yellow") + pony.refresh_from_db() + self.assertEqual(pony.yellow, default) + pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1) + if not connection.features.can_return_columns_from_insert: + pony.refresh_from_db() + self.assertEqual(pony.yellow, default) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNull(table_name, "yellow") + pony = project_state.apps.get_model(app_label, "pony").objects.create( + weight=1, yellow=None + ) + self.assertIsNone(pony.yellow) + def test_alter_field_add_db_column_noop(self): """ AlterField operation is a noop when adding only a db_column and the diff --git a/tests/schema/tests.py b/tests/schema/tests.py index d81a01b41d..688a9f1fcf 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -2102,6 +2102,33 @@ class SchemaTests(TransactionTestCase): with self.assertRaises(IntegrityError): NoteRename.objects.create(detail_info=None) + @isolate_apps("schema") + def test_rename_keep_db_default(self): + """Renaming a field shouldn't affect a database default.""" + + class AuthorDbDefault(Model): + birth_year = IntegerField(db_default=1985) + + class Meta: + app_label = "schema" + + self.isolated_local_models = [AuthorDbDefault] + with connection.schema_editor() as editor: + editor.create_model(AuthorDbDefault) + columns = self.column_classes(AuthorDbDefault) + self.assertEqual(columns["birth_year"][1].default, "1985") + + old_field = AuthorDbDefault._meta.get_field("birth_year") + new_field = IntegerField(db_default=1985) + new_field.set_attributes_from_name("renamed_year") + new_field.model = AuthorDbDefault + with connection.schema_editor( + atomic=connection.features.supports_atomic_references_rename + ) as editor: + editor.alter_field(AuthorDbDefault, old_field, new_field, strict=True) + columns = self.column_classes(AuthorDbDefault) + self.assertEqual(columns["renamed_year"][1].default, "1985") + @skipUnlessDBFeature( "supports_column_check_constraints", "can_introspect_check_constraints" )