From f333e3513e8bdf5ffeb6eeb63021c230082e6f95 Mon Sep 17 00:00:00 2001 From: Jeremy Nauta Date: Thu, 6 Jul 2023 20:36:48 -0600 Subject: [PATCH] Fixed #31300 -- Added GeneratedField model field. Thanks Adam Johnson and Paolo Melchiorre for reviews. Co-Authored-By: Lily Foote Co-Authored-By: Mariusz Felisiak --- django/db/backends/base/features.py | 5 + django/db/backends/base/schema.py | 26 ++- django/db/backends/mysql/features.py | 3 + django/db/backends/oracle/features.py | 2 + django/db/backends/postgresql/features.py | 2 + django/db/backends/sqlite3/features.py | 2 + django/db/backends/sqlite3/introspection.py | 10 +- django/db/backends/sqlite3/schema.py | 2 +- django/db/models/__init__.py | 2 + django/db/models/base.py | 11 +- django/db/models/fields/__init__.py | 3 + django/db/models/fields/generated.py | 151 +++++++++++++++ django/db/models/query.py | 4 +- django/db/models/query_utils.py | 4 + django/db/models/sql/subqueries.py | 3 + docs/ref/checks.txt | 5 + docs/ref/models/fields.txt | 65 +++++++ docs/releases/5.0.txt | 7 + .../test_ordinary_fields.py | 112 +++++++++++ tests/migrations/test_operations.py | 125 +++++++++++++ tests/model_fields/models.py | 98 +++++++++- tests/model_fields/test_generatedfield.py | 176 ++++++++++++++++++ 22 files changed, 807 insertions(+), 11 deletions(-) create mode 100644 django/db/models/fields/generated.py create mode 100644 tests/model_fields/test_generatedfield.py diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 06945732a0..b1f0b9d491 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -353,6 +353,11 @@ class BaseDatabaseFeatures: # Does the backend support column comments in ADD COLUMN statements? supports_comments_inline = False + # Does the backend support stored generated columns? + supports_stored_generated_columns = False + # Does the backend support virtual generated columns? + supports_virtual_generated_columns = False + # Does the backend support the logical XOR operator? supports_logical_xor = False diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 23d7f00f57..497008dcd6 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -332,7 +332,9 @@ class BaseDatabaseSchemaEditor: and self.connection.features.interprets_empty_strings_as_nulls ): null = True - if not null: + if field.generated: + yield self._column_generated_sql(field) + elif not null: yield "NOT NULL" elif not self.connection.features.implied_column_null: yield "NULL" @@ -422,11 +424,21 @@ class BaseDatabaseSchemaEditor: params = [] return sql % default_sql, params + def _column_generated_sql(self, field): + """Return the SQL to use in a GENERATED ALWAYS clause.""" + expression_sql, params = field.generated_sql(self.connection) + persistency_sql = "STORED" if field.db_persist else "VIRTUAL" + if params: + expression_sql = expression_sql % tuple(self.quote_value(p) for p in params) + return f"GENERATED ALWAYS AS ({expression_sql}) {persistency_sql}" + @staticmethod def _effective_default(field): # This method allows testing its logic without a connection. if field.has_default(): default = field.get_default() + elif field.generated: + default = None elif not field.null and field.blank and field.empty_strings_allowed: if field.get_internal_type() == "BinaryField": default = b"" @@ -848,6 +860,18 @@ class BaseDatabaseSchemaEditor: "(you cannot alter to or from M2M fields, or add or remove " "through= on M2M fields)" % (old_field, new_field) ) + elif old_field.generated != new_field.generated or ( + new_field.generated + and ( + old_field.db_persist != new_field.db_persist + or old_field.generated_sql(self.connection) + != new_field.generated_sql(self.connection) + ) + ): + raise ValueError( + f"Modifying GeneratedFields is not supported - the field {new_field} " + "must be removed and re-added with the new definition." + ) self._alter_field( model, diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index abeef0549a..4d68b2bf3e 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -60,6 +60,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): order_by_nulls_first = True supports_logical_xor = True + supports_stored_generated_columns = True + supports_virtual_generated_columns = True + @cached_property def minimum_database_version(self): if self.connection.mysql_is_mariadb: diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 3d4dbb0bf9..9b894d0df6 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -70,6 +70,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_ignore_conflicts = False max_query_params = 2**16 - 1 supports_partial_indexes = False + supports_stored_generated_columns = False + supports_virtual_generated_columns = True can_rename_index = True supports_slicing_ordering_in_compound = True requires_compound_order_by_subquery = True diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 12dbc71743..0a0ade4ced 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -70,6 +70,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_update_conflicts = True supports_update_conflicts_with_target = True supports_covering_indexes = True + supports_stored_generated_columns = True + supports_virtual_generated_columns = False can_rename_index = True test_collations = { "non_default": "sv-x-icu", diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 3ae84f2646..44ace18681 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_json_field_contains = False supports_update_conflicts = True supports_update_conflicts_with_target = True + supports_stored_generated_columns = Database.sqlite_version_info >= (3, 31, 0) + supports_virtual_generated_columns = Database.sqlite_version_info >= (3, 31, 0) test_collations = { "ci": "nocase", "cs": "binary", diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index d2fe3d8c71..79aa1934c0 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -91,7 +91,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): interface. """ cursor.execute( - "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) + "PRAGMA table_xinfo(%s)" % self.connection.ops.quote_name(table_name) ) table_info = cursor.fetchall() if not table_info: @@ -129,7 +129,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): pk == 1, name in json_columns, ) - for cid, name, data_type, notnull, default, pk in table_info + for cid, name, data_type, notnull, default, pk, hidden in table_info + if hidden + in [ + 0, # Normal column. + 2, # Virtual generated column. + 3, # Stored generated column. + ] ] def get_sequences(self, cursor, table_name, table_fields=()): diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index ec128fd733..f311e0b745 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -135,7 +135,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Choose a default and insert it into the copy map if ( create_field.db_default is NOT_PROVIDED - and not create_field.many_to_many + and not (create_field.many_to_many or create_field.generated) and create_field.concrete ): mapping[create_field.column] = self.prepare_default( diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index ffca81de91..9426280215 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -38,6 +38,7 @@ from django.db.models.expressions import ( from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all from django.db.models.fields.files import FileField, ImageField +from django.db.models.fields.generated import GeneratedField from django.db.models.fields.json import JSONField from django.db.models.fields.proxy import OrderWrt from django.db.models.indexes import * # NOQA @@ -92,6 +93,7 @@ __all__ += [ "WindowFrame", "FileField", "ImageField", + "GeneratedField", "JSONField", "OrderWrt", "Lookup", diff --git a/django/db/models/base.py b/django/db/models/base.py index 5c0a9c430f..80503d118a 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -508,7 +508,7 @@ class Model(AltersData, metaclass=ModelBase): for field in fields_iter: is_related_object = False # Virtual field - if field.attname not in kwargs and field.column is None: + if field.attname not in kwargs and field.column is None or field.generated: continue if kwargs: if isinstance(field.remote_field, ForeignObjectRel): @@ -1050,10 +1050,11 @@ class Model(AltersData, metaclass=ModelBase): ), )["_order__max"] ) - fields = meta.local_concrete_fields - if not pk_set: - fields = [f for f in fields if f is not meta.auto_field] - + fields = [ + f + for f in meta.local_concrete_fields + if not f.generated and (pk_set or f is not meta.auto_field) + ] returning_fields = meta.db_returning_fields results = self._do_insert( cls._base_manager, using, fields, returning_fields, raw diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 15fe9d9c9c..f1e4790568 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -165,6 +165,7 @@ class Field(RegisterLookupMixin): one_to_many = None one_to_one = None related_model = None + generated = False descriptor_class = DeferredAttribute @@ -646,6 +647,8 @@ class Field(RegisterLookupMixin): path = path.replace("django.db.models.fields.related", "django.db.models") elif path.startswith("django.db.models.fields.files"): path = path.replace("django.db.models.fields.files", "django.db.models") + elif path.startswith("django.db.models.fields.generated"): + path = path.replace("django.db.models.fields.generated", "django.db.models") elif path.startswith("django.db.models.fields.json"): path = path.replace("django.db.models.fields.json", "django.db.models") elif path.startswith("django.db.models.fields.proxy"): diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py new file mode 100644 index 0000000000..0980be98af --- /dev/null +++ b/django/db/models/fields/generated.py @@ -0,0 +1,151 @@ +from django.core import checks +from django.db import connections, router +from django.db.models.sql import Query + +from . import NOT_PROVIDED, Field + +__all__ = ["GeneratedField"] + + +class GeneratedField(Field): + generated = True + db_returning = True + + _query = None + _resolved_expression = None + output_field = None + + def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs): + if kwargs.setdefault("editable", False): + raise ValueError("GeneratedField cannot be editable.") + if not kwargs.setdefault("blank", True): + raise ValueError("GeneratedField must be blank.") + if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED: + raise ValueError("GeneratedField cannot have a default.") + if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED: + raise ValueError("GeneratedField cannot have a database default.") + if db_persist not in (True, False): + raise ValueError("GeneratedField.db_persist must be True or False.") + + self.expression = expression + self._output_field = output_field + self.db_persist = db_persist + super().__init__(**kwargs) + + def contribute_to_class(self, *args, **kwargs): + super().contribute_to_class(*args, **kwargs) + + self._query = Query(model=self.model, alias_cols=False) + self._resolved_expression = self.expression.resolve_expression( + self._query, allow_joins=False + ) + self.output_field = ( + self._output_field + if self._output_field is not None + else self._resolved_expression.output_field + ) + # Register lookups from the output_field class. + for lookup_name, lookup in self.output_field.get_class_lookups().items(): + self.register_lookup(lookup, lookup_name=lookup_name) + + def generated_sql(self, connection): + return self._resolved_expression.as_sql( + compiler=connection.ops.compiler("SQLCompiler")( + self._query, connection=connection, using=None + ), + connection=connection, + ) + + def check(self, **kwargs): + databases = kwargs.get("databases") or [] + return [ + *super().check(**kwargs), + *self._check_supported(databases), + *self._check_persistence(databases), + ] + + def _check_supported(self, databases): + errors = [] + for db in databases: + if not router.allow_migrate_model(db, self.model): + continue + connection = connections[db] + if ( + self.model._meta.required_db_vendor + and self.model._meta.required_db_vendor != connection.vendor + ): + continue + if not ( + connection.features.supports_virtual_generated_columns + or "supports_stored_generated_columns" + in self.model._meta.required_db_features + ) and not ( + connection.features.supports_stored_generated_columns + or "supports_virtual_generated_columns" + in self.model._meta.required_db_features + ): + errors.append( + checks.Error( + f"{connection.display_name} does not support GeneratedFields.", + obj=self, + id="fields.E220", + ) + ) + return errors + + def _check_persistence(self, databases): + errors = [] + for db in databases: + if not router.allow_migrate_model(db, self.model): + continue + connection = connections[db] + if ( + self.model._meta.required_db_vendor + and self.model._meta.required_db_vendor != connection.vendor + ): + continue + if not self.db_persist and not ( + connection.features.supports_virtual_generated_columns + or "supports_virtual_generated_columns" + in self.model._meta.required_db_features + ): + errors.append( + checks.Error( + f"{connection.display_name} does not support non-persisted " + "GeneratedFields.", + obj=self, + id="fields.E221", + hint="Set db_persist=True on the field.", + ) + ) + if self.db_persist and not ( + connection.features.supports_stored_generated_columns + or "supports_stored_generated_columns" + in self.model._meta.required_db_features + ): + errors.append( + checks.Error( + f"{connection.display_name} does not support persisted " + "GeneratedFields.", + obj=self, + id="fields.E222", + hint="Set db_persist=False on the field.", + ) + ) + return errors + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + del kwargs["blank"] + del kwargs["editable"] + kwargs["db_persist"] = self.db_persist + kwargs["expression"] = self.expression + if self._output_field is not None: + kwargs["output_field"] = self._output_field + return name, path, args, kwargs + + def get_internal_type(self): + return self.output_field.get_internal_type() + + def db_parameters(self, connection): + return self.output_field.db_parameters(connection) diff --git a/django/db/models/query.py b/django/db/models/query.py index 6472412c14..92dd31dfe5 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -689,6 +689,8 @@ class QuerySet(AltersData): obj.pk = obj._meta.pk.get_pk_value_on_save(obj) if not connection.features.supports_default_keyword_in_bulk_insert: for field in obj._meta.fields: + if field.generated: + continue value = getattr(obj, field.attname) if isinstance(value, DatabaseDefault): setattr(obj, field.attname, field.db_default) @@ -804,7 +806,7 @@ class QuerySet(AltersData): unique_fields, ) self._for_write = True - fields = opts.concrete_fields + fields = [f for f in opts.concrete_fields if not f.generated] objs = list(objs) self._prepare_for_bulk_create(objs) with transaction.atomic(using=self.db, savepoint=False): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index fcda30b3a7..9754864eef 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -198,6 +198,10 @@ class DeferredAttribute: # might be able to reuse the already loaded value. Refs #18343. val = self._check_parent_chain(instance) if val is None: + if instance.pk is None and self.field.generated: + raise FieldError( + "Cannot read a generated field from an unsaved model." + ) instance.refresh_from_db(fields=[field_name]) else: data[field_name] = val diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index d8a246d369..f639eb8b82 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -108,6 +108,9 @@ class UpdateQuery(Query): called add_update_targets() to hint at the extra information here. """ for field, model, val in values_seq: + # Omit generated fields. + if field.generated: + continue if hasattr(val, "resolve_expression"): # Resolve expressions here so that annotations are no longer needed val = val.resolve_expression(self, allow_joins=False, for_save=True) diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index b8789ecf6f..72699ac136 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -208,6 +208,11 @@ Model fields * **fields.E180**: ```` does not support ``JSONField``\s. * **fields.E190**: ```` does not support a database collation on ````\s. +* **fields.E220**: ```` does not support ``GeneratedField``\s. +* **fields.E221**: ```` does not support non-persisted + ``GeneratedField``\s. +* **fields.E222**: ```` does not support persisted + ``GeneratedField``\s. * **fields.E900**: ``IPAddressField`` has been removed except for support in historical migrations. * **fields.W900**: ``IPAddressField`` has been deprecated. Support for it diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index fbc90e5420..a41eb7b1d2 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1215,6 +1215,71 @@ when :attr:`~django.forms.Field.localize` is ``False`` or information on the difference between the two, see Python's documentation for the :mod:`decimal` module. +``GeneratedField`` +------------------ + +.. versionadded:: 5.0 + +.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs) + +A field that is always computed based on other fields in the model. This field +is managed and updated by the database itself. Uses the ``GENERATED ALWAYS`` +SQL syntax. + +There are two kinds of generated columns: stored and virtual. A stored +generated column is computed when it is written (inserted or updated) and +occupies storage as if it were a regular column. A virtual generated column +occupies no storage and is computed when it is read. Thus, a virtual generated +column is similar to a view and a stored generated column is similar to a +materialized view. + +.. attribute:: GeneratedField.expression + + An :class:`Expression` used by the database to automatically set the field + value each time the model is changed. + + The expressions should be deterministic and only reference fields within + the model (in the same database table). Generated fields cannot reference + other generated fields. Database backends can impose further restrictions. + +.. attribute:: GeneratedField.db_persist + + Determines if the database column should occupy storage as if it were a + real column. If ``False``, the column acts as a virtual column and does + not occupy database storage space. + + PostgreSQL only supports persisted columns. Oracle only supports virtual + columns. + +.. attribute:: GeneratedField.output_field + + An optional model field instance to define the field's data type. This can + be used to customize attributes like the field's collation. By default, the + output field is derived from ``expression``. + +.. admonition:: Refresh the data + + Since the database always computed the value, the object must be reloaded + to access the new value after :meth:`~Model.save()`, for example, by using + :meth:`~Model.refresh_from_db()`. + +.. admonition:: Database limitations + + There are many database-specific restrictions on generated fields that + Django doesn't validate and the database may raise an error e.g. PostgreSQL + requires functions and operators referenced in a generated columns to be + marked as ``IMMUTABLE`` . + + You should always check that ``expression`` is supported on your database. + Check out `MariaDB`_, `MySQL`_, `Oracle`_, `PostgreSQL`_, or `SQLite`_ + docs. + +.. _MariaDB: https://mariadb.com/kb/en/generated-columns/#expression-support +.. _MySQL: https://dev.mysql.com/doc/refman/en/create-table-generated-columns.html +.. _Oracle: https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-TABLE.html#GUID-F9CE0CC3-13AE-4744-A43C-EAC7A71AAAB6__BABIIGBD +.. _PostgreSQL: https://www.postgresql.org/docs/current/ddl-generated-columns.html +.. _SQLite: https://www.sqlite.org/gencol.html#limitations + ``GenericIPAddressField`` ------------------------- diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 5b8abc8498..b9a871c8e9 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -129,6 +129,13 @@ sets a database-computed default value. For example:: created = models.DateTimeField(db_default=Now()) circumference = models.FloatField(db_default=2 * Pi()) +Database generated model field +------------------------------ + +The new :class:`~django.db.models.GeneratedField` allows creation of database +generated columns. This field can be used on all supported database backends +to create a field that is always computed from other fields. + More options for declaring field choices ---------------------------------------- diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index 6014448013..affe642ac7 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -1226,3 +1226,115 @@ class InvalidDBDefaultTests(TestCase): msg = f"{expression} cannot be used in db_default." expected_error = Error(msg=msg, obj=field, id="fields.E012") self.assertEqual(errors, [expected_error]) + + +@isolate_apps("invalid_models_tests") +class GeneratedFieldTests(TestCase): + def test_not_supported(self): + db_persist = connection.features.supports_stored_generated_columns + + class Model(models.Model): + name = models.IntegerField() + field = models.GeneratedField( + expression=models.F("name"), db_persist=db_persist + ) + + expected_errors = [] + if ( + not connection.features.supports_stored_generated_columns + and not connection.features.supports_virtual_generated_columns + ): + expected_errors.append( + Error( + f"{connection.display_name} does not support GeneratedFields.", + obj=Model._meta.get_field("field"), + id="fields.E220", + ) + ) + if ( + not db_persist + and not connection.features.supports_virtual_generated_columns + ): + expected_errors.append( + Error( + f"{connection.display_name} does not support non-persisted " + "GeneratedFields.", + obj=Model._meta.get_field("field"), + id="fields.E221", + hint="Set db_persist=True on the field.", + ), + ) + self.assertEqual( + Model._meta.get_field("field").check(databases={"default"}), + expected_errors, + ) + + def test_not_supported_stored_required_db_features(self): + class Model(models.Model): + name = models.IntegerField() + field = models.GeneratedField(expression=models.F("name"), db_persist=True) + + class Meta: + required_db_features = {"supports_stored_generated_columns"} + + self.assertEqual(Model.check(databases=self.databases), []) + + def test_not_supported_virtual_required_db_features(self): + class Model(models.Model): + name = models.IntegerField() + field = models.GeneratedField(expression=models.F("name"), db_persist=False) + + class Meta: + required_db_features = {"supports_virtual_generated_columns"} + + self.assertEqual(Model.check(databases=self.databases), []) + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_not_supported_virtual(self): + class Model(models.Model): + name = models.IntegerField() + field = models.GeneratedField(expression=models.F("name"), db_persist=False) + a = models.TextField() + + excepted_errors = ( + [] + if connection.features.supports_virtual_generated_columns + else [ + Error( + f"{connection.display_name} does not support non-persisted " + "GeneratedFields.", + obj=Model._meta.get_field("field"), + id="fields.E221", + hint="Set db_persist=True on the field.", + ), + ] + ) + self.assertEqual( + Model._meta.get_field("field").check(databases={"default"}), + excepted_errors, + ) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_not_supported_stored(self): + class Model(models.Model): + name = models.IntegerField() + field = models.GeneratedField(expression=models.F("name"), db_persist=True) + a = models.TextField() + + expected_errors = ( + [] + if connection.features.supports_stored_generated_columns + else [ + Error( + f"{connection.display_name} does not support persisted " + "GeneratedFields.", + obj=Model._meta.get_field("field"), + id="fields.E222", + hint="Set db_persist=False on the field.", + ), + ] + ) + self.assertEqual( + Model._meta.get_field("field").check(databases={"default"}), + expected_errors, + ) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 617bd3d7b0..d4fc3e855a 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -5,6 +5,7 @@ from django.db import IntegrityError, connection, migrations, models, transactio from django.db.migrations.migration import Migration from django.db.migrations.operations.fields import FieldOperation from django.db.migrations.state import ModelState, ProjectState +from django.db.models import F from django.db.models.expressions import Value from django.db.models.functions import Abs, Pi from django.db.transaction import atomic @@ -5741,6 +5742,130 @@ class OperationTests(OperationTestBase): operation.database_backwards(app_label, editor, new_state, project_state) assertModelsAndTables(after_db=False) + def _test_invalid_generated_field_changes(self, db_persist): + regular = models.IntegerField(default=1) + generated_1 = models.GeneratedField( + expression=F("pink") + F("pink"), db_persist=db_persist + ) + generated_2 = models.GeneratedField( + expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist + ) + tests = [ + ("test_igfc_1", regular, generated_1), + ("test_igfc_2", generated_1, regular), + ("test_igfc_3", generated_1, generated_2), + ] + for app_label, add_field, alter_field in tests: + project_state = self.set_up_test_model(app_label) + operations = [ + migrations.AddField("Pony", "modified_pink", add_field), + migrations.AlterField("Pony", "modified_pink", alter_field), + ] + msg = ( + "Modifying GeneratedFields is not supported - the field " + f"{app_label}.Pony.modified_pink must be removed and re-added with the " + "new definition." + ) + with self.assertRaisesMessage(ValueError, msg): + self.apply_operations(app_label, project_state, operations) + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_invalid_generated_field_changes_stored(self): + self._test_invalid_generated_field_changes(db_persist=True) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_invalid_generated_field_changes_virtual(self): + self._test_invalid_generated_field_changes(db_persist=False) + + @skipUnlessDBFeature( + "supports_stored_generated_columns", + "supports_virtual_generated_columns", + ) + def test_invalid_generated_field_persistency_change(self): + app_label = "test_igfpc" + project_state = self.set_up_test_model(app_label) + operations = [ + migrations.AddField( + "Pony", + "modified_pink", + models.GeneratedField(expression=F("pink"), db_persist=True), + ), + migrations.AlterField( + "Pony", + "modified_pink", + models.GeneratedField(expression=F("pink"), db_persist=False), + ), + ] + msg = ( + "Modifying GeneratedFields is not supported - the field " + f"{app_label}.Pony.modified_pink must be removed and re-added with the " + "new definition." + ) + with self.assertRaisesMessage(ValueError, msg): + self.apply_operations(app_label, project_state, operations) + + def _test_add_generated_field(self, db_persist): + app_label = "test_agf" + operation = migrations.AddField( + "Pony", + "modified_pink", + models.GeneratedField( + expression=F("pink") + F("pink"), db_persist=db_persist + ), + ) + project_state, new_state = self.make_test_state(app_label, operation) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + # Add generated column. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnExists(f"{app_label}_pony", "modified_pink") + Pony = new_state.apps.get_model(app_label, "Pony") + obj = Pony.objects.create(pink=5, weight=3.23) + self.assertEqual(obj.modified_pink, 10) + # Reversal. + with connection.schema_editor() as editor: + operation.database_backwards(app_label, editor, new_state, project_state) + self.assertColumnNotExists(f"{app_label}_pony", "modified_pink") + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_add_generated_field_stored(self): + self._test_add_generated_field(db_persist=True) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_add_generated_field_virtual(self): + self._test_add_generated_field(db_persist=False) + + def _test_remove_generated_field(self, db_persist): + app_label = "test_rgf" + operation = migrations.AddField( + "Pony", + "modified_pink", + models.GeneratedField( + expression=F("pink") + F("pink"), db_persist=db_persist + ), + ) + project_state, new_state = self.make_test_state(app_label, operation) + self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6) + # Add generated column. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + project_state = new_state + new_state = project_state.clone() + operation = migrations.RemoveField("Pony", "modified_pink") + operation.state_forwards(app_label, new_state) + # Remove generated column. + with connection.schema_editor() as editor: + operation.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnNotExists(f"{app_label}_pony", "modified_pink") + + @skipUnlessDBFeature("supports_stored_generated_columns") + def test_remove_generated_field_stored(self): + self._test_remove_generated_field(db_persist=True) + + @skipUnlessDBFeature("supports_virtual_generated_columns") + def test_remove_generated_field_virtual(self): + self._test_remove_generated_field(db_persist=False) + class SwappableOperationTests(OperationTestBase): """ diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index e1a5a3872f..7fb0f8b610 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -6,8 +6,11 @@ from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelatio from django.contrib.contenttypes.models import ContentType from django.core.files.storage import FileSystemStorage from django.core.serializers.json import DjangoJSONEncoder -from django.db import models +from django.db import connection, models +from django.db.models import F, Value from django.db.models.fields.files import ImageFieldFile +from django.db.models.functions import Lower +from django.utils.functional import SimpleLazyObject from django.utils.translation import gettext_lazy as _ try: @@ -16,6 +19,11 @@ except ImportError: Image = None +test_collation = SimpleLazyObject( + lambda: connection.features.test_collations.get("non_default") +) + + class Foo(models.Model): a = models.CharField(max_length=10) d = models.DecimalField(max_digits=5, decimal_places=3) @@ -468,3 +476,91 @@ class UUIDChild(PrimaryKeyUUIDModel): class UUIDGrandchild(UUIDChild): pass + + +class GeneratedModel(models.Model): + a = models.IntegerField() + b = models.IntegerField() + field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True) + + class Meta: + required_db_features = {"supports_stored_generated_columns"} + + +class GeneratedModelVirtual(models.Model): + a = models.IntegerField() + b = models.IntegerField() + field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False) + + class Meta: + required_db_features = {"supports_virtual_generated_columns"} + + +class GeneratedModelParams(models.Model): + field = models.GeneratedField( + expression=Value("Constant", output_field=models.CharField(max_length=10)), + db_persist=True, + ) + + class Meta: + required_db_features = {"supports_stored_generated_columns"} + + +class GeneratedModelParamsVirtual(models.Model): + field = models.GeneratedField( + expression=Value("Constant", output_field=models.CharField(max_length=10)), + db_persist=False, + ) + + class Meta: + required_db_features = {"supports_virtual_generated_columns"} + + +class GeneratedModelOutputField(models.Model): + name = models.CharField(max_length=10) + lower_name = models.GeneratedField( + expression=Lower("name"), + output_field=models.CharField(db_collation=test_collation, max_length=11), + db_persist=True, + ) + + class Meta: + required_db_features = { + "supports_stored_generated_columns", + "supports_collation_on_charfield", + } + + +class GeneratedModelOutputFieldVirtual(models.Model): + name = models.CharField(max_length=10) + lower_name = models.GeneratedField( + expression=Lower("name"), + db_persist=False, + output_field=models.CharField(db_collation=test_collation, max_length=11), + ) + + class Meta: + required_db_features = { + "supports_virtual_generated_columns", + "supports_collation_on_charfield", + } + + +class GeneratedModelNull(models.Model): + name = models.CharField(max_length=10, null=True) + lower_name = models.GeneratedField( + expression=Lower("name"), db_persist=True, null=True + ) + + class Meta: + required_db_features = {"supports_stored_generated_columns"} + + +class GeneratedModelNullVirtual(models.Model): + name = models.CharField(max_length=10, null=True) + lower_name = models.GeneratedField( + expression=Lower("name"), db_persist=False, null=True + ) + + class Meta: + required_db_features = {"supports_virtual_generated_columns"} diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py new file mode 100644 index 0000000000..e2746bdd0c --- /dev/null +++ b/tests/model_fields/test_generatedfield.py @@ -0,0 +1,176 @@ +from django.core.exceptions import FieldError +from django.db import IntegrityError, connection +from django.db.models import F, GeneratedField, IntegerField +from django.db.models.functions import Lower +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature + +from .models import ( + GeneratedModel, + GeneratedModelNull, + GeneratedModelNullVirtual, + GeneratedModelOutputField, + GeneratedModelOutputFieldVirtual, + GeneratedModelParams, + GeneratedModelParamsVirtual, + GeneratedModelVirtual, +) + + +class BaseGeneratedFieldTests(SimpleTestCase): + def test_editable_unsupported(self): + with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."): + GeneratedField(expression=Lower("name"), editable=True, db_persist=False) + + def test_blank_unsupported(self): + with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."): + GeneratedField(expression=Lower("name"), blank=False, db_persist=False) + + def test_default_unsupported(self): + msg = "GeneratedField cannot have a default." + with self.assertRaisesMessage(ValueError, msg): + GeneratedField(expression=Lower("name"), default="", db_persist=False) + + def test_database_default_unsupported(self): + msg = "GeneratedField cannot have a database default." + with self.assertRaisesMessage(ValueError, msg): + GeneratedField(expression=Lower("name"), db_default="", db_persist=False) + + def test_db_persist_required(self): + msg = "GeneratedField.db_persist must be True or False." + with self.assertRaisesMessage(ValueError, msg): + GeneratedField(expression=Lower("name")) + with self.assertRaisesMessage(ValueError, msg): + GeneratedField(expression=Lower("name"), db_persist=None) + + def test_deconstruct(self): + field = GeneratedField(expression=F("a") + F("b"), db_persist=True) + _, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.GeneratedField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")}) + + +class GeneratedFieldTestMixin: + def _refresh_if_needed(self, m): + if not connection.features.can_return_columns_from_insert: + m.refresh_from_db() + return m + + def test_unsaved_error(self): + m = self.base_model(a=1, b=2) + msg = "Cannot read a generated field from an unsaved model." + with self.assertRaisesMessage(FieldError, msg): + m.field + + def test_create(self): + m = self.base_model.objects.create(a=1, b=2) + m = self._refresh_if_needed(m) + self.assertEqual(m.field, 3) + + def test_non_nullable_create(self): + with self.assertRaises(IntegrityError): + self.base_model.objects.create() + + def test_save(self): + # Insert. + m = self.base_model(a=2, b=4) + m.save() + m = self._refresh_if_needed(m) + self.assertEqual(m.field, 6) + # Update. + m.a = 4 + m.save() + m.refresh_from_db() + self.assertEqual(m.field, 8) + + def test_update(self): + m = self.base_model.objects.create(a=1, b=2) + self.base_model.objects.update(b=3) + m = self.base_model.objects.get(pk=m.pk) + self.assertEqual(m.field, 4) + + def test_bulk_create(self): + m = self.base_model(a=3, b=4) + (m,) = self.base_model.objects.bulk_create([m]) + if not connection.features.can_return_rows_from_bulk_insert: + m = self.base_model.objects.get() + self.assertEqual(m.field, 7) + + def test_bulk_update(self): + m = self.base_model.objects.create(a=1, b=2) + m.a = 3 + self.base_model.objects.bulk_update([m], fields=["a"]) + m = self.base_model.objects.get(pk=m.pk) + self.assertEqual(m.field, 5) + + def test_output_field_lookups(self): + """Lookups from the output_field are available on GeneratedFields.""" + internal_type = IntegerField().get_internal_type() + min_value, max_value = connection.ops.integer_field_range(internal_type) + if min_value is None: + self.skipTest("Backend doesn't define an integer min value.") + if max_value is None: + self.skipTest("Backend doesn't define an integer max value.") + + does_not_exist = self.base_model.DoesNotExist + underflow_value = min_value - 1 + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field=underflow_value) + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field__lt=underflow_value) + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field__lte=underflow_value) + + overflow_value = max_value + 1 + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field=overflow_value) + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field__gt=overflow_value) + with self.assertNumQueries(0), self.assertRaises(does_not_exist): + self.base_model.objects.get(field__gte=overflow_value) + + @skipUnlessDBFeature("supports_collation_on_charfield") + def test_output_field(self): + collation = connection.features.test_collations.get("non_default") + if not collation: + self.skipTest("Language collations are not supported.") + + m = self.output_field_model.objects.create(name="NAME") + field = m._meta.get_field("lower_name") + db_parameters = field.db_parameters(connection) + self.assertEqual(db_parameters["collation"], collation) + self.assertEqual(db_parameters["type"], field.output_field.db_type(connection)) + self.assertNotEqual( + db_parameters["type"], + field._resolved_expression.output_field.db_type(connection), + ) + + def test_model_with_params(self): + m = self.params_model.objects.create() + m = self._refresh_if_needed(m) + self.assertEqual(m.field, "Constant") + + def test_nullable(self): + m1 = self.nullable_model.objects.create() + m1 = self._refresh_if_needed(m1) + none_val = "" if connection.features.interprets_empty_strings_as_nulls else None + self.assertEqual(m1.lower_name, none_val) + m2 = self.nullable_model.objects.create(name="NaMe") + m2 = self._refresh_if_needed(m2) + self.assertEqual(m2.lower_name, "name") + + +@skipUnlessDBFeature("supports_stored_generated_columns") +class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): + base_model = GeneratedModel + nullable_model = GeneratedModelNull + output_field_model = GeneratedModelOutputField + params_model = GeneratedModelParams + + +@skipUnlessDBFeature("supports_virtual_generated_columns") +class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): + base_model = GeneratedModelVirtual + nullable_model = GeneratedModelNullVirtual + output_field_model = GeneratedModelOutputFieldVirtual + params_model = GeneratedModelParamsVirtual