From 978aae4334fa71ba78a3e94408f0f3aebde8d07c Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Sun, 7 Apr 2024 10:32:16 +0800 Subject: [PATCH] Fixed #373 -- Added CompositePrimaryKey. Thanks Lily Foote and Simon Charette for reviews and mentoring this Google Summer of Code 2024 project. Co-authored-by: Simon Charette Co-authored-by: Lily Foote --- django/contrib/admin/sites.py | 5 + django/core/serializers/python.py | 3 + django/db/backends/base/schema.py | 12 + django/db/backends/oracle/schema.py | 2 + django/db/backends/sqlite3/schema.py | 15 +- django/db/models/__init__.py | 2 + django/db/models/aggregates.py | 19 +- django/db/models/base.py | 73 +++- django/db/models/fields/__init__.py | 2 + django/db/models/fields/composite.py | 150 ++++++++ django/db/models/fields/related.py | 20 +- django/db/models/fields/related_lookups.py | 6 +- django/db/models/fields/tuple_lookups.py | 10 +- django/db/models/options.py | 17 +- django/db/models/query.py | 9 +- django/db/models/sql/compiler.py | 75 +++- django/db/models/sql/query.py | 8 +- docs/ref/checks.txt | 3 + docs/ref/models/fields.txt | 19 + docs/releases/5.2.txt | 19 + docs/topics/composite-primary-key.txt | 183 +++++++++ docs/topics/index.txt | 1 + tests/admin_registration/models.py | 6 + tests/admin_registration/tests.py | 10 +- tests/composite_pk/__init__.py | 0 tests/composite_pk/fixtures/tenant.json | 75 ++++ tests/composite_pk/models/__init__.py | 9 + tests/composite_pk/models/tenant.py | 50 +++ tests/composite_pk/test_aggregate.py | 139 +++++++ tests/composite_pk/test_checks.py | 242 ++++++++++++ tests/composite_pk/test_create.py | 138 +++++++ tests/composite_pk/test_delete.py | 83 +++++ tests/composite_pk/test_filter.py | 412 +++++++++++++++++++++ tests/composite_pk/test_get.py | 126 +++++++ tests/composite_pk/test_models.py | 153 ++++++++ tests/composite_pk/test_names_to_path.py | 134 +++++++ tests/composite_pk/test_update.py | 135 +++++++ tests/composite_pk/test_values.py | 212 +++++++++++ tests/composite_pk/tests.py | 345 +++++++++++++++++ tests/migrations/test_autodetector.py | 89 +++++ tests/migrations/test_operations.py | 55 +++ tests/migrations/test_state.py | 22 ++ tests/migrations/test_writer.py | 19 + 43 files changed, 3078 insertions(+), 29 deletions(-) create mode 100644 django/db/models/fields/composite.py create mode 100644 docs/topics/composite-primary-key.txt create mode 100644 tests/composite_pk/__init__.py create mode 100644 tests/composite_pk/fixtures/tenant.json create mode 100644 tests/composite_pk/models/__init__.py create mode 100644 tests/composite_pk/models/tenant.py create mode 100644 tests/composite_pk/test_aggregate.py create mode 100644 tests/composite_pk/test_checks.py create mode 100644 tests/composite_pk/test_create.py create mode 100644 tests/composite_pk/test_delete.py create mode 100644 tests/composite_pk/test_filter.py create mode 100644 tests/composite_pk/test_get.py create mode 100644 tests/composite_pk/test_models.py create mode 100644 tests/composite_pk/test_names_to_path.py create mode 100644 tests/composite_pk/test_update.py create mode 100644 tests/composite_pk/test_values.py create mode 100644 tests/composite_pk/tests.py diff --git a/django/contrib/admin/sites.py b/django/contrib/admin/sites.py index 3399bd87b8..201f28ef37 100644 --- a/django/contrib/admin/sites.py +++ b/django/contrib/admin/sites.py @@ -113,6 +113,11 @@ class AdminSite: "The model %s is abstract, so it cannot be registered with admin." % model.__name__ ) + if model._meta.is_composite_pk: + raise ImproperlyConfigured( + "The model %s has a composite primary key, so it cannot be " + "registered with admin." % model.__name__ + ) if self.is_registered(model): registered_admin = str(self.get_model_admin(model)) diff --git a/django/core/serializers/python.py b/django/core/serializers/python.py index 46ef9f0771..57edebbb70 100644 --- a/django/core/serializers/python.py +++ b/django/core/serializers/python.py @@ -7,6 +7,7 @@ other serializers. from django.apps import apps from django.core.serializers import base from django.db import DEFAULT_DB_ALIAS, models +from django.db.models import CompositePrimaryKey from django.utils.encoding import is_protected_type @@ -39,6 +40,8 @@ class Serializer(base.Serializer): return data def _value_from_field(self, obj, field): + if isinstance(field, CompositePrimaryKey): + return [self._value_from_field(obj, f) for f in field] value = field.value_from_object(obj) # Protected types (i.e., primitives like None, numbers, dates, # and Decimals) are passed through as is. All other values are diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 3e38c56d50..de4886837e 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -14,6 +14,7 @@ from django.db.backends.ddl_references import ( ) from django.db.backends.utils import names_digest, split_identifier, truncate_name from django.db.models import NOT_PROVIDED, Deferrable, Index +from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.sql import Query from django.db.transaction import TransactionManagementError, atomic from django.utils import timezone @@ -106,6 +107,7 @@ class BaseDatabaseSchemaEditor: sql_check_constraint = "CHECK (%(check)s)" sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_constraint = "CONSTRAINT %(name)s %(constraint)s" + sql_pk_constraint = "PRIMARY KEY (%(columns)s)" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_delete_check = sql_delete_constraint @@ -282,6 +284,11 @@ class BaseDatabaseSchemaEditor: constraint.constraint_sql(model, self) for constraint in model._meta.constraints ) + + pk = model._meta.pk + if isinstance(pk, CompositePrimaryKey): + constraint_sqls.append(self._pk_constraint_sql(pk.columns)) + sql = self.sql_create_table % { "table": self.quote_name(model._meta.db_table), "definition": ", ".join( @@ -1999,6 +2006,11 @@ class BaseDatabaseSchemaEditor: result.append(name) return result + def _pk_constraint_sql(self, columns): + return self.sql_pk_constraint % { + "columns": ", ".join(self.quote_name(column) for column in columns) + } + def _delete_primary_key(self, model, strict=False): constraint_names = self._constraint_names(model, primary_key=True) if strict and len(constraint_names) != 1: diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py index 0d70522a2a..ba3c4778d3 100644 --- a/django/db/backends/oracle/schema.py +++ b/django/db/backends/oracle/schema.py @@ -211,6 +211,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): return create_index def _is_identity_column(self, table_name, column_name): + if not column_name: + return False with self.connection.cursor() as cursor: cursor.execute( """ diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index c5b428fc67..6da9852282 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -6,7 +6,7 @@ from django.db import NotSupportedError from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.ddl_references import Statement from django.db.backends.utils import strip_quotes -from django.db.models import NOT_PROVIDED, UniqueConstraint +from django.db.models import NOT_PROVIDED, CompositePrimaryKey, UniqueConstraint class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @@ -104,6 +104,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): f.name: f.clone() if is_self_referential(f) else f for f in model._meta.local_concrete_fields } + + # Since CompositePrimaryKey is not a concrete field (column is None), + # it's not copied by default. + pk = model._meta.pk + if isinstance(pk, CompositePrimaryKey): + body[pk.name] = pk.clone() + # Since mapping might mix column names and default values, # its values must be already quoted. mapping = { @@ -296,6 +303,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Special-case implicit M2M tables. if field.many_to_many and field.remote_field.through._meta.auto_created: self.create_model(field.remote_field.through) + elif isinstance(field, CompositePrimaryKey): + # If a CompositePrimaryKey field was added, the existing primary key field + # had to be altered too, resulting in an AddField, AlterField migration. + # The table cannot be re-created on AddField, it would result in a + # duplicate primary key error. + return elif ( # Primary keys and unique fields are not supported in ALTER TABLE # ADD COLUMN. diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index fe81d92d36..ec54b65240 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -38,6 +38,7 @@ from django.db.models.expressions import ( ) from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all +from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.generated import GeneratedField from django.db.models.fields.json import JSONField @@ -82,6 +83,7 @@ __all__ += [ "ProtectedError", "RestrictedError", "Case", + "CompositePrimaryKey", "Exists", "Expression", "ExpressionList", diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index bf94decab7..73f03a4916 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions. """ from django.core.exceptions import FieldError, FullResultSet -from django.db.models.expressions import Case, Func, Star, Value, When +from django.db import NotSupportedError +from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When from django.db.models.fields import IntegerField from django.db.models.functions.comparison import Coalesce from django.db.models.functions.mixins import ( @@ -174,6 +175,22 @@ class Count(Aggregate): raise ValueError("Star cannot be used with filter. Please specify a field.") super().__init__(expression, filter=filter, **extra) + def resolve_expression(self, *args, **kwargs): + result = super().resolve_expression(*args, **kwargs) + expr = result.source_expressions[0] + + # In case of composite primary keys, count the first column. + if isinstance(expr, ColPairs): + if self.distinct: + raise NotSupportedError( + "COUNT(DISTINCT) doesn't support composite primary keys" + ) + + cols = expr.get_cols() + return Count(cols[0], filter=result.filter) + + return result + class Max(Aggregate): function = "MAX" diff --git a/django/db/models/base.py b/django/db/models/base.py index 5b819b1406..a20e88749f 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1,6 +1,7 @@ import copy import inspect import warnings +from collections import defaultdict from functools import partialmethod from itertools import chain @@ -30,6 +31,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import CASCADE, Collector from django.db.models.expressions import DatabaseDefault +from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.fields.related import ( ForeignObjectRel, OneToOneField, @@ -508,7 +510,7 @@ class Model(AltersData, metaclass=ModelBase): for field in fields_iter: is_related_object = False # Virtual field - if field.attname not in kwargs and field.column is None or field.generated: + if field.column is None or field.generated: continue if kwargs: if isinstance(field.remote_field, ForeignObjectRel): @@ -663,7 +665,11 @@ class Model(AltersData, metaclass=ModelBase): pk = property(_get_pk_val, _set_pk_val) def _is_pk_set(self, meta=None): - return self._get_pk_val(meta) is not None + pk_val = self._get_pk_val(meta) + return not ( + pk_val is None + or (isinstance(pk_val, tuple) and any(f is None for f in pk_val)) + ) def get_deferred_fields(self): """ @@ -1454,6 +1460,11 @@ class Model(AltersData, metaclass=ModelBase): name = f.name if name in exclude: continue + if isinstance(f, CompositePrimaryKey): + names = tuple(field.name for field in f.fields) + if exclude.isdisjoint(names): + unique_checks.append((model_class, names)) + continue if f.unique: unique_checks.append((model_class, (name,))) if f.unique_for_date and f.unique_for_date not in exclude: @@ -1728,6 +1739,7 @@ class Model(AltersData, metaclass=ModelBase): *cls._check_constraints(databases), *cls._check_default_pk(), *cls._check_db_table_comment(databases), + *cls._check_composite_pk(), ] return errors @@ -1764,6 +1776,63 @@ class Model(AltersData, metaclass=ModelBase): ] return [] + @classmethod + def _check_composite_pk(cls): + errors = [] + meta = cls._meta + pk = meta.pk + + if not isinstance(pk, CompositePrimaryKey): + return errors + + seen_columns = defaultdict(list) + + for field_name in pk.field_names: + hint = None + + try: + field = meta.get_field(field_name) + except FieldDoesNotExist: + field = None + + if not field: + hint = f"{field_name!r} is not a valid field." + elif not field.column: + hint = f"{field_name!r} field has no column." + elif field.null: + hint = f"{field_name!r} field may not set 'null=True'." + elif field.generated: + hint = f"{field_name!r} field is a generated field." + else: + seen_columns[field.column].append(field_name) + + if hint: + errors.append( + checks.Error( + f"{field_name!r} cannot be included in the composite primary " + "key.", + hint=hint, + obj=cls, + id="models.E042", + ) + ) + + for column, field_names in seen_columns.items(): + if len(field_names) > 1: + field_name, *rest = field_names + duplicates = ", ".join(repr(field) for field in rest) + errors.append( + checks.Error( + f"{duplicates} cannot be included in the composite primary " + "key.", + hint=f"{duplicates} and {field_name!r} are the same fields.", + obj=cls, + id="models.E042", + ) + ) + + return errors + @classmethod def _check_db_table_comment(cls, databases): if not cls._meta.db_table_comment: diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index f9cafdb4bb..855e8cc28d 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -656,6 +656,8 @@ class Field(RegisterLookupMixin): path = path.replace("django.db.models.fields.json", "django.db.models") elif path.startswith("django.db.models.fields.proxy"): path = path.replace("django.db.models.fields.proxy", "django.db.models") + elif path.startswith("django.db.models.fields.composite"): + path = path.replace("django.db.models.fields.composite", "django.db.models") elif path.startswith("django.db.models.fields"): path = path.replace("django.db.models.fields", "django.db.models") # Return basic info - other fields should override this. diff --git a/django/db/models/fields/composite.py b/django/db/models/fields/composite.py new file mode 100644 index 0000000000..550a440dcf --- /dev/null +++ b/django/db/models/fields/composite.py @@ -0,0 +1,150 @@ +from django.core import checks +from django.db.models import NOT_PROVIDED, Field +from django.db.models.expressions import ColPairs +from django.db.models.fields.tuple_lookups import ( + TupleExact, + TupleGreaterThan, + TupleGreaterThanOrEqual, + TupleIn, + TupleIsNull, + TupleLessThan, + TupleLessThanOrEqual, +) +from django.utils.functional import cached_property + + +class CompositeAttribute: + def __init__(self, field): + self.field = field + + @property + def attnames(self): + return [field.attname for field in self.field.fields] + + def __get__(self, instance, cls=None): + return tuple(getattr(instance, attname) for attname in self.attnames) + + def __set__(self, instance, values): + attnames = self.attnames + length = len(attnames) + + if values is None: + values = (None,) * length + + if not isinstance(values, (list, tuple)): + raise ValueError(f"{self.field.name!r} must be a list or a tuple.") + if length != len(values): + raise ValueError(f"{self.field.name!r} must have {length} elements.") + + for attname, value in zip(attnames, values): + setattr(instance, attname, value) + + +class CompositePrimaryKey(Field): + descriptor_class = CompositeAttribute + + def __init__(self, *args, **kwargs): + if ( + not args + or not all(isinstance(field, str) for field in args) + or len(set(args)) != len(args) + ): + raise ValueError("CompositePrimaryKey args must be unique strings.") + if len(args) == 1: + raise ValueError("CompositePrimaryKey must include at least two fields.") + if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED: + raise ValueError("CompositePrimaryKey cannot have a default.") + if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED: + raise ValueError("CompositePrimaryKey cannot have a database default.") + if kwargs.setdefault("editable", False): + raise ValueError("CompositePrimaryKey cannot be editable.") + if not kwargs.setdefault("primary_key", True): + raise ValueError("CompositePrimaryKey must be a primary key.") + if not kwargs.setdefault("blank", True): + raise ValueError("CompositePrimaryKey must be blank.") + + self.field_names = args + super().__init__(**kwargs) + + def deconstruct(self): + # args is always [] so it can be ignored. + name, path, _, kwargs = super().deconstruct() + return name, path, self.field_names, kwargs + + @cached_property + def fields(self): + meta = self.model._meta + return tuple(meta.get_field(field_name) for field_name in self.field_names) + + @cached_property + def columns(self): + return tuple(field.column for field in self.fields) + + def contribute_to_class(self, cls, name, private_only=False): + super().contribute_to_class(cls, name, private_only=private_only) + cls._meta.pk = self + setattr(cls, self.attname, self.descriptor_class(self)) + + def get_attname_column(self): + return self.get_attname(), None + + def __iter__(self): + return iter(self.fields) + + def __len__(self): + return len(self.field_names) + + @cached_property + def cached_col(self): + return ColPairs(self.model._meta.db_table, self.fields, self.fields, self) + + def get_col(self, alias, output_field=None): + if alias == self.model._meta.db_table and ( + output_field is None or output_field == self + ): + return self.cached_col + + return ColPairs(alias, self.fields, self.fields, output_field) + + def get_pk_value_on_save(self, instance): + values = [] + + for field in self.fields: + value = field.value_from_object(instance) + if value is None: + value = field.get_pk_value_on_save(instance) + values.append(value) + + return tuple(values) + + def _check_field_name(self): + if self.name == "pk": + return [] + return [ + checks.Error( + "'CompositePrimaryKey' must be named 'pk'.", + obj=self, + id="fields.E013", + ) + ] + + +CompositePrimaryKey.register_lookup(TupleExact) +CompositePrimaryKey.register_lookup(TupleGreaterThan) +CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual) +CompositePrimaryKey.register_lookup(TupleLessThan) +CompositePrimaryKey.register_lookup(TupleLessThanOrEqual) +CompositePrimaryKey.register_lookup(TupleIn) +CompositePrimaryKey.register_lookup(TupleIsNull) + + +def unnest(fields): + result = [] + + for field in fields: + if isinstance(field, CompositePrimaryKey): + result.extend(field.fields) + else: + result.append(field) + + return result diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index b672a4b488..9ef2d29024 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -624,11 +624,21 @@ class ForeignObject(RelatedField): if not has_unique_constraint: foreign_fields = {f.name for f in self.foreign_related_fields} remote_opts = self.remote_field.model._meta - has_unique_constraint = any( - frozenset(ut) <= foreign_fields for ut in remote_opts.unique_together - ) or any( - frozenset(uc.fields) <= foreign_fields - for uc in remote_opts.total_unique_constraints + has_unique_constraint = ( + any( + frozenset(ut) <= foreign_fields + for ut in remote_opts.unique_together + ) + or any( + frozenset(uc.fields) <= foreign_fields + for uc in remote_opts.total_unique_constraints + ) + # If the model defines a composite primary key and the foreign key + # refers to it, the target is unique. + or ( + frozenset(field.name for field in remote_opts.pk_fields) + == foreign_fields + ) ) if not has_unique_constraint: diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 8b5663dfea..6992d75833 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -1,5 +1,6 @@ from django.db import NotSupportedError from django.db.models.expressions import ColPairs +from django.db.models.fields import composite from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups from django.db.models.lookups import ( Exact, @@ -19,7 +20,7 @@ def get_normalized_value(value, lhs): if not value._is_pk_set(): raise ValueError("Model instances passed to related filters must be saved.") value_list = [] - sources = lhs.output_field.path_infos[-1].target_fields + sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields) for source in sources: while not isinstance(value, source.model) and source.remote_field: source = source.remote_field.model._meta.get_field( @@ -30,7 +31,8 @@ def get_normalized_value(value, lhs): except AttributeError: # A case like Restaurant.objects.filter(place=restaurant_instance), # where place is a OneToOneField and the primary key of Restaurant. - return (value.pk,) + pk = value.pk + return pk if isinstance(pk, tuple) else (pk,) return tuple(value_list) if not isinstance(value, tuple): return (value,) diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index 6342937cd6..e515e971b4 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -250,6 +250,8 @@ class TupleIn(TupleLookupMixin, In): def check_rhs_select_length_equals_lhs_length(self): len_rhs = len(self.rhs.select) + if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs): + len_rhs = len(self.rhs.select[0]) len_lhs = len(self.lhs) if len_rhs != len_lhs: lhs_str = self.get_lhs_str() @@ -304,7 +306,13 @@ class TupleIn(TupleLookupMixin, In): return root.as_sql(compiler, connection) def as_subquery(self, compiler, connection): - return compiler.compile(In(self.lhs, self.rhs)) + lhs = self.lhs + rhs = self.rhs + if isinstance(lhs, ColPairs): + rhs = rhs.clone() + rhs.set_values([source.name for source in lhs.sources]) + lhs = Tuple(lhs) + return compiler.compile(In(lhs, rhs)) tuple_lookups = { diff --git a/django/db/models/options.py b/django/db/models/options.py index 68a7228cbe..7c4cf2229a 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -7,7 +7,14 @@ from django.conf import settings from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.signals import setting_changed from django.db import connections -from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint +from django.db.models import ( + AutoField, + CompositePrimaryKey, + Manager, + OrderWrt, + UniqueConstraint, +) +from django.db.models.fields import composite from django.db.models.query_utils import PathInfo from django.utils.datastructures import ImmutableList, OrderedSet from django.utils.functional import cached_property @@ -973,6 +980,14 @@ class Options: ) ] + @cached_property + def pk_fields(self): + return composite.unnest([self.pk]) + + @property + def is_composite_pk(self): + return isinstance(self.pk, CompositePrimaryKey) + @cached_property def _property_names(self): """Return a set of the names of the properties defined on the model.""" diff --git a/django/db/models/query.py b/django/db/models/query.py index 21d5534cc9..ea8cc179f3 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -171,11 +171,14 @@ class RawModelIterable(BaseIterable): "Raw query must include the primary key" ) fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] - converters = compiler.get_converters( - [f.get_col(f.model._meta.db_table) if f else None for f in fields] - ) + cols = [f.get_col(f.model._meta.db_table) if f else None for f in fields] + converters = compiler.get_converters(cols) if converters: query_iterator = compiler.apply_converters(query_iterator, converters) + if compiler.has_composite_fields(cols): + query_iterator = compiler.composite_fields_to_tuples( + query_iterator, cols + ) for values in query_iterator: # Associate fields to values model_init_values = [values[pos] for pos in model_init_pos] diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 49263d5944..053bdc09d5 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -7,7 +7,9 @@ from itertools import chain from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value +from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value +from django.db.models.fields import composite +from django.db.models.fields.composite import CompositePrimaryKey from django.db.models.functions import Cast, Random from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend @@ -283,6 +285,9 @@ class SQLCompiler: # Reference to a column. elif isinstance(expression, int): expression = cols[expression] + # ColPairs cannot be aliased. + if isinstance(expression, ColPairs): + alias = None selected.append((alias, expression)) for select_idx, (alias, expression) in enumerate(selected): @@ -997,6 +1002,7 @@ class SQLCompiler: # alias for a given field. This also includes None -> start_alias to # be used by local fields. seen_models = {None: start_alias} + select_mask_fields = set(composite.unnest(select_mask)) for field in opts.concrete_fields: model = field.model._meta.concrete_model @@ -1017,7 +1023,7 @@ class SQLCompiler: # parent model data is already present in the SELECT clause, # and we want to avoid reloading the same data again. continue - if select_mask and field not in select_mask: + if select_mask and field not in select_mask_fields: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) @@ -1110,9 +1116,10 @@ class SQLCompiler: ) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) + target_fields = composite.unnest(targets) return [ (OrderBy(transform_function(t, alias), descending=descending), False) - for t in targets + for t in target_fields ] def _setup_joins(self, pieces, opts, alias): @@ -1504,13 +1511,25 @@ class SQLCompiler: return result def get_converters(self, expressions): + i = 0 converters = {} - for i, expression in enumerate(expressions): - if expression: + + for expression in expressions: + if isinstance(expression, ColPairs): + cols = expression.get_source_expressions() + cols_converters = self.get_converters(cols) + for j, (convs, col) in cols_converters.items(): + converters[i + j] = (convs, col) + i += len(expression) + elif expression: backend_converters = self.connection.ops.get_db_converters(expression) field_converters = expression.get_db_converters(self.connection) if backend_converters or field_converters: converters[i] = (backend_converters + field_converters, expression) + i += 1 + else: + i += 1 + return converters def apply_converters(self, rows, converters): @@ -1524,6 +1543,24 @@ class SQLCompiler: row[pos] = value yield row + def has_composite_fields(self, expressions): + # Check for composite fields before calling the relatively costly + # composite_fields_to_tuples. + return any(isinstance(expression, ColPairs) for expression in expressions) + + def composite_fields_to_tuples(self, rows, expressions): + col_pair_slices = [ + slice(i, i + len(expression)) + for i, expression in enumerate(expressions) + if isinstance(expression, ColPairs) + ] + + for row in map(list, rows): + for pos in col_pair_slices: + row[pos] = (tuple(row[pos]),) + + yield row + def results_iter( self, results=None, @@ -1541,8 +1578,10 @@ class SQLCompiler: rows = chain.from_iterable(results) if converters: rows = self.apply_converters(rows, converters) - if tuple_expected: - rows = map(tuple, rows) + if self.has_composite_fields(fields): + rows = self.composite_fields_to_tuples(rows, fields) + if tuple_expected: + rows = map(tuple, rows) return rows def has_results(self): @@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler): ) ] cols = [field.get_col(opts.db_table) for field in self.returning_fields] + elif isinstance(opts.pk, CompositePrimaryKey): + returning_field = returning_fields[0] + cols = [returning_field.get_col(opts.db_table)] + rows = [ + ( + self.connection.ops.last_insert_id( + cursor, + opts.db_table, + returning_field.column, + ), + ) + ] else: cols = [opts.pk.get_col(opts.db_table)] rows = [ @@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler): ] converters = self.get_converters(cols) if converters: - rows = list(self.apply_converters(rows, converters)) - return rows + rows = self.apply_converters(rows, converters) + if self.has_composite_fields(cols): + rows = self.composite_fields_to_tuples(rows, cols) + return list(rows) class SQLDeleteCompiler(SQLCompiler): @@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler): query.add_fields(fields) super().pre_sql_setup() + is_composite_pk = meta.is_composite_pk must_pre_select = ( count > 1 and not self.connection.features.update_can_self_select ) @@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler): idents = [] related_ids = collections.defaultdict(list) for rows in query.get_compiler(self.using).execute_sql(MULTI): - idents.extend(r[0] for r in rows) + pks = [row if is_composite_pk else row[0] for row in rows] + idents.extend(pks) for parent, index in related_ids_index: related_ids[parent].extend(r[index] for r in rows) self.query.add_filter("pk__in", idents) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b7b93c235a..cca11bfcc2 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -627,8 +627,12 @@ class Query(BaseExpression): if result is None: result = empty_set_result else: - converters = compiler.get_converters(outer_query.annotation_select.values()) - result = next(compiler.apply_converters((result,), converters)) + cols = outer_query.annotation_select.values() + converters = compiler.get_converters(cols) + rows = compiler.apply_converters((result,), converters) + if compiler.has_composite_fields(cols): + rows = compiler.composite_fields_to_tuples(rows, cols) + result = next(rows) return dict(zip(outer_query.annotation_select, result)) diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index 2308a854c7..b0a98bde28 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -181,6 +181,7 @@ Model fields * **fields.E011**: ```` does not support default database values with expressions (``db_default``). * **fields.E012**: ```` cannot be used in ``db_default``. +* **fields.E013**: ``CompositePrimaryKey`` must be named ``pk``. * **fields.E100**: ``AutoField``\s must set primary_key=True. * **fields.E110**: ``BooleanField``\s do not accept null values. *This check appeared before support for null values was added in Django 2.1.* @@ -417,6 +418,8 @@ Models * **models.W040**: ```` does not support indexes with non-key columns. * **models.E041**: ``constraints`` refers to the joined field ````. +* **models.E042**: ```` cannot be included in the composite + primary key. * **models.W042**: Auto-created primary key used when not defining a primary key type, by default ``django.db.models.AutoField``. * **models.W043**: ```` does not support indexes on expressions. diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 07e86785d9..5b0f127c6f 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -707,6 +707,23 @@ or :class:`~django.forms.NullBooleanSelect` if :attr:`null=True `. The default value of ``BooleanField`` is ``None`` when :attr:`Field.default` isn't defined. +``CompositePrimaryKey`` +----------------------- + +.. versionadded:: 5.2 + +.. class:: CompositePrimaryKey(*field_names, **options) + +A virtual field used for defining a composite primary key. + +This field must be defined as the model's ``pk`` field. If present, Django will +create the underlying model table with a composite primary key. + +The ``*field_names`` argument is a list of positional field names that compose +the primary key. + +See :doc:`/topics/composite-primary-key` for more details. + ``CharField`` ------------- @@ -1615,6 +1632,8 @@ not an instance of ``UUID``. hyphens, because PostgreSQL and MariaDB 10.7+ store them in a hyphenated uuid datatype type. +.. _relationship-fields: + Relationship fields =================== diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index 7af0b955f6..4b05fd3279 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -31,6 +31,25 @@ and only officially support the latest release of each series. What's new in Django 5.2 ======================== +Composite Primary Keys +---------------------- + +The new :class:`django.db.models.CompositePrimaryKey` allows tables to be +created with a primary key consisting of multiple fields. + +To use a composite primary key, when creating a model set the ``pk`` field to +be a ``CompositePrimaryKey``:: + + from django.db import models + + + class Release(models.Model): + pk = models.CompositePrimaryKey("version", "name") + version = models.IntegerField() + name = models.CharField(max_length=20) + +See :doc:`/topics/composite-primary-key` for more details. + Minor features -------------- diff --git a/docs/topics/composite-primary-key.txt b/docs/topics/composite-primary-key.txt new file mode 100644 index 0000000000..9e5234ca9f --- /dev/null +++ b/docs/topics/composite-primary-key.txt @@ -0,0 +1,183 @@ +====================== +Composite primary keys +====================== + +.. versionadded:: 5.2 + +In Django, each model has a primary key. By default, this primary key consists +of a single field. + +In most cases, a single primary key should suffice. In database design, +however, defining a primary key consisting of multiple fields is sometimes +necessary. + +To use a composite primary key, when creating a model set the ``pk`` field to +be a :class:`.CompositePrimaryKey`:: + + class Product(models.Model): + name = models.CharField(max_length=100) + + + class Order(models.Model): + reference = models.CharField(max_length=20, primary_key=True) + + + class OrderLineItem(models.Model): + pk = models.CompositePrimaryKey("product_id", "order_id") + product = models.ForeignKey(Product, on_delete=models.CASCADE) + order = models.ForeignKey(Order, on_delete=models.CASCADE) + quantity = models.IntegerField() + +This will instruct Django to create a composite primary key +(``PRIMARY KEY (product_id, order_id)``) when creating the table. + +A composite primary key is represented by a ``tuple``: + +.. code-block:: pycon + + >>> product = Product.objects.create(name="apple") + >>> order = Order.objects.create(reference="A755H") + >>> item = OrderLineItem.objects.create(product=product, order=order, quantity=1) + >>> item.pk + (1, "A755H") + +You can assign a ``tuple`` to a composite primary key. This sets the associated +field values. + +.. code-block:: pycon + + >>> item = OrderLineItem(pk=(2, "B142C")) + >>> item.pk + (2, "B142C") + >>> item.product_id + 2 + >>> item.order_id + "B142C" + +A composite primary key can also be filtered by a ``tuple``: + +.. code-block:: pycon + + >>> OrderLineItem.objects.filter(pk=(1, "A755H")).count() + 1 + +We're still working on composite primary key support for +:ref:`relational fields `, including +:class:`.GenericForeignKey` fields, and the Django admin. Models with composite +primary keys cannot be registered in the Django admin at this time. You can +expect to see this in future releases. + +Migrating to a composite primary key +==================================== + +Django doesn't support migrating to, or from, a composite primary key after the +table is created. It also doesn't support adding or removing fields from the +composite primary key. + +If you would like to migrate an existing table from a single primary key to a +composite primary key, follow your database backend's instructions to do so. + +Once the composite primary key is in place, add the ``CompositePrimaryKey`` +field to your model. This allows Django to recognize and handle the composite +primary key appropriately. + +While migration operations (e.g. ``AddField``, ``AlterField``) on primary key +fields are not supported, ``makemigrations`` will still detect changes. + +In order to avoid errors, it's recommended to apply such migrations with +``--fake``. + +Alternatively, :class:`.SeparateDatabaseAndState` may be used to execute the +backend-specific migrations and Django-generated migrations in a single +operation. + +.. _cpk-and-relations: + +Composite primary keys and relations +==================================== + +:ref:`Relationship fields `, including +:ref:`generic relations ` do not support composite primary +keys. + +For example, given the ``OrderLineItem`` model, the following is not +supported:: + + class Foo(models.Model): + item = models.ForeignKey(OrderLineItem, on_delete=models.CASCADE) + +Because ``ForeignKey`` currently cannot reference models with composite primary +keys. + +To work around this limitation, ``ForeignObject`` can be used as an +alternative:: + + class Foo(models.Model): + item_order_id = models.IntegerField() + item_product_id = models.CharField(max_length=20) + item = models.ForeignObject( + OrderLineItem, + on_delete=models.CASCADE, + from_fields=("item_order_id", "item_product_id"), + to_fields=("order_id", "product_id"), + ) + +``ForeignObject`` is much like ``ForeignKey``, except that it doesn't create +any columns (e.g. ``item_id``), foreign key constraints or indexes in the +database. + +.. warning:: + + ``ForeignObject`` is an internal API. This means it is not covered by our + :ref:`deprecation policy `. + +Composite primary keys and database functions +============================================= + +Many database functions only accept a single expression. + +.. code-block:: sql + + MAX("order_id") -- OK + MAX("product_id", "order_id") -- ERROR + +As a consequence, they cannot be used with composite primary key references as +they are composed of multiple column expressions. + +.. code-block:: python + + Max("order_id") # OK + Max("pk") # ERROR + +Composite primary keys in forms +=============================== + +As a composite primary key is a virtual field, a field which doesn't represent +a single database column, this field is excluded from ModelForms. + +For example, take the following form:: + + class OrderLineItemForm(forms.ModelForm): + class Meta: + model = OrderLineItem + fields = "__all__" + +This form does not have a form field ``pk`` for the composite primary key: + +.. code-block:: pycon + + >>> OrderLineItemForm() + + +Setting the primary composite field ``pk`` as a form field raises an unknown +field :exc:`.FieldError`. + +.. admonition:: Primary key fields are read only + + If you change the value of a primary key on an existing object and then + save it, a new object will be created alongside the old one (see + :attr:`.Field.primary_key`). + + This is also true of composite primary keys. Hence, you may want to set + :attr:`.Field.editable` to ``False`` on all primary key fields to exclude + them from ModelForms. diff --git a/docs/topics/index.txt b/docs/topics/index.txt index ffb9fa9d92..4f837c81e2 100644 --- a/docs/topics/index.txt +++ b/docs/topics/index.txt @@ -19,6 +19,7 @@ Introductions to all the key parts of Django you'll need to know: auth/index cache conditional-view-processing + composite-primary-key signing email i18n/index diff --git a/tests/admin_registration/models.py b/tests/admin_registration/models.py index 0ae9251133..2231c236de 100644 --- a/tests/admin_registration/models.py +++ b/tests/admin_registration/models.py @@ -20,3 +20,9 @@ class Location(models.Model): class Place(Location): name = models.CharField(max_length=200) + + +class Guest(models.Model): + pk = models.CompositePrimaryKey("traveler", "place") + traveler = models.ForeignKey(Traveler, on_delete=models.CASCADE) + place = models.ForeignKey(Place, on_delete=models.CASCADE) diff --git a/tests/admin_registration/tests.py b/tests/admin_registration/tests.py index 3b0e656f5f..0a881caf65 100644 --- a/tests/admin_registration/tests.py +++ b/tests/admin_registration/tests.py @@ -5,7 +5,7 @@ from django.contrib.admin.sites import site from django.core.exceptions import ImproperlyConfigured from django.test import SimpleTestCase -from .models import Location, Person, Place, Traveler +from .models import Guest, Location, Person, Place, Traveler class NameAdmin(admin.ModelAdmin): @@ -92,6 +92,14 @@ class TestRegistration(SimpleTestCase): with self.assertRaisesMessage(ImproperlyConfigured, msg): self.site.register(Location) + def test_composite_pk_model(self): + msg = ( + "The model Guest has a composite primary key, so it cannot be registered " + "with admin." + ) + with self.assertRaisesMessage(ImproperlyConfigured, msg): + self.site.register(Guest) + def test_is_registered_model(self): "Checks for registered models should return true." self.site.register(Person) diff --git a/tests/composite_pk/__init__.py b/tests/composite_pk/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/composite_pk/fixtures/tenant.json b/tests/composite_pk/fixtures/tenant.json new file mode 100644 index 0000000000..3eeff42fef --- /dev/null +++ b/tests/composite_pk/fixtures/tenant.json @@ -0,0 +1,75 @@ +[ + { + "pk": 1, + "model": "composite_pk.tenant", + "fields": { + "id": 1, + "name": "Tenant 1" + } + }, + { + "pk": 2, + "model": "composite_pk.tenant", + "fields": { + "id": 2, + "name": "Tenant 2" + } + }, + { + "pk": 3, + "model": "composite_pk.tenant", + "fields": { + "id": 3, + "name": "Tenant 3" + } + }, + { + "pk": [1, 1], + "model": "composite_pk.user", + "fields": { + "tenant_id": 1, + "id": 1, + "email": "user0001@example.com" + } + }, + { + "pk": [1, 2], + "model": "composite_pk.user", + "fields": { + "tenant_id": 1, + "id": 2, + "email": "user0002@example.com" + } + }, + { + "pk": [2, 3], + "model": "composite_pk.user", + "fields": { + "email": "user0003@example.com" + } + }, + { + "model": "composite_pk.user", + "fields": { + "tenant_id": 2, + "id": 4, + "email": "user0004@example.com" + } + }, + { + "pk": [2, "11111111-1111-1111-1111-111111111111"], + "model": "composite_pk.post", + "fields": { + "tenant_id": 2, + "id": "11111111-1111-1111-1111-111111111111" + } + }, + { + "pk": [2, "ffffffff-ffff-ffff-ffff-ffffffffffff"], + "model": "composite_pk.post", + "fields": { + "tenant_id": 2, + "id": "ffffffff-ffff-ffff-ffff-ffffffffffff" + } + } +] diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py new file mode 100644 index 0000000000..35c3943716 --- /dev/null +++ b/tests/composite_pk/models/__init__.py @@ -0,0 +1,9 @@ +from .tenant import Comment, Post, Tenant, Token, User + +__all__ = [ + "Comment", + "Post", + "Tenant", + "Token", + "User", +] diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py new file mode 100644 index 0000000000..ac0b3d9715 --- /dev/null +++ b/tests/composite_pk/models/tenant.py @@ -0,0 +1,50 @@ +from django.db import models + + +class Tenant(models.Model): + name = models.CharField(max_length=10, default="", blank=True) + + +class Token(models.Model): + pk = models.CompositePrimaryKey("tenant_id", "id") + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, related_name="tokens") + id = models.SmallIntegerField() + secret = models.CharField(max_length=10, default="", blank=True) + + +class BaseModel(models.Model): + pk = models.CompositePrimaryKey("tenant_id", "id") + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + id = models.SmallIntegerField(unique=True) + + class Meta: + abstract = True + + +class User(BaseModel): + email = models.EmailField(unique=True) + + +class Comment(models.Model): + pk = models.CompositePrimaryKey("tenant", "id") + tenant = models.ForeignKey( + Tenant, + on_delete=models.CASCADE, + related_name="comments", + ) + id = models.SmallIntegerField(unique=True, db_column="comment_id") + user_id = models.SmallIntegerField() + user = models.ForeignObject( + User, + on_delete=models.CASCADE, + from_fields=("tenant_id", "user_id"), + to_fields=("tenant_id", "id"), + related_name="comments", + ) + text = models.TextField(default="", blank=True) + + +class Post(models.Model): + pk = models.CompositePrimaryKey("tenant_id", "id") + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + id = models.UUIDField() diff --git a/tests/composite_pk/test_aggregate.py b/tests/composite_pk/test_aggregate.py new file mode 100644 index 0000000000..b5474c5218 --- /dev/null +++ b/tests/composite_pk/test_aggregate.py @@ -0,0 +1,139 @@ +from django.db import NotSupportedError +from django.db.models import Count, Q +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class CompositePKAggregateTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, + id=3, + email="user0003@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_2, text="foo") + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1, text="bar") + cls.comment_3 = Comment.objects.create(id=3, user=cls.user_1, text="foobar") + cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3, text="foobarbaz") + cls.comment_5 = Comment.objects.create(id=5, user=cls.user_3, text="barbaz") + cls.comment_6 = Comment.objects.create(id=6, user=cls.user_3, text="baz") + + def test_users_annotated_with_comments_id_count(self): + user_1, user_2, user_3 = User.objects.annotate(Count("comments__id")).order_by( + "pk" + ) + + self.assertEqual(user_1, self.user_1) + self.assertEqual(user_1.comments__id__count, 2) + self.assertEqual(user_2, self.user_2) + self.assertEqual(user_2.comments__id__count, 1) + self.assertEqual(user_3, self.user_3) + self.assertEqual(user_3.comments__id__count, 3) + + def test_users_annotated_with_aliased_comments_id_count(self): + user_1, user_2, user_3 = User.objects.annotate( + comments_count=Count("comments__id") + ).order_by("pk") + + self.assertEqual(user_1, self.user_1) + self.assertEqual(user_1.comments_count, 2) + self.assertEqual(user_2, self.user_2) + self.assertEqual(user_2.comments_count, 1) + self.assertEqual(user_3, self.user_3) + self.assertEqual(user_3.comments_count, 3) + + def test_users_annotated_with_comments_count(self): + user_1, user_2, user_3 = User.objects.annotate(Count("comments")).order_by("pk") + + self.assertEqual(user_1, self.user_1) + self.assertEqual(user_1.comments__count, 2) + self.assertEqual(user_2, self.user_2) + self.assertEqual(user_2.comments__count, 1) + self.assertEqual(user_3, self.user_3) + self.assertEqual(user_3.comments__count, 3) + + def test_users_annotated_with_comments_count_filter(self): + user_1, user_2, user_3 = User.objects.annotate( + comments__count=Count( + "comments", filter=Q(pk__in=[self.user_1.pk, self.user_2.pk]) + ) + ).order_by("pk") + + self.assertEqual(user_1, self.user_1) + self.assertEqual(user_1.comments__count, 2) + self.assertEqual(user_2, self.user_2) + self.assertEqual(user_2.comments__count, 1) + self.assertEqual(user_3, self.user_3) + self.assertEqual(user_3.comments__count, 0) + + def test_count_distinct_not_supported(self): + with self.assertRaisesMessage( + NotSupportedError, "COUNT(DISTINCT) doesn't support composite primary keys" + ): + self.assertIsNone( + User.objects.annotate(comments__count=Count("comments", distinct=True)) + ) + + def test_user_values_annotated_with_comments_id_count(self): + self.assertSequenceEqual( + User.objects.values("pk").annotate(Count("comments__id")).order_by("pk"), + ( + {"pk": self.user_1.pk, "comments__id__count": 2}, + {"pk": self.user_2.pk, "comments__id__count": 1}, + {"pk": self.user_3.pk, "comments__id__count": 3}, + ), + ) + + def test_user_values_annotated_with_filtered_comments_id_count(self): + self.assertSequenceEqual( + User.objects.values("pk") + .annotate( + comments_count=Count( + "comments__id", + filter=Q(comments__text__icontains="foo"), + ) + ) + .order_by("pk"), + ( + {"pk": self.user_1.pk, "comments_count": 1}, + {"pk": self.user_2.pk, "comments_count": 1}, + {"pk": self.user_3.pk, "comments_count": 1}, + ), + ) + + def test_filter_and_count_users_by_comments_fields(self): + users = User.objects.filter(comments__id__gt=2).order_by("pk") + self.assertEqual(users.count(), 4) + self.assertSequenceEqual( + users, (self.user_1, self.user_3, self.user_3, self.user_3) + ) + + users = User.objects.filter(comments__text__icontains="foo").order_by("pk") + self.assertEqual(users.count(), 3) + self.assertSequenceEqual(users, (self.user_1, self.user_2, self.user_3)) + + users = User.objects.filter(comments__text__icontains="baz").order_by("pk") + self.assertEqual(users.count(), 3) + self.assertSequenceEqual(users, (self.user_3, self.user_3, self.user_3)) + + def test_order_by_comments_id_count(self): + self.assertSequenceEqual( + User.objects.annotate(comments_count=Count("comments__id")).order_by( + "-comments_count" + ), + (self.user_3, self.user_1, self.user_2), + ) diff --git a/tests/composite_pk/test_checks.py b/tests/composite_pk/test_checks.py new file mode 100644 index 0000000000..02a162c31d --- /dev/null +++ b/tests/composite_pk/test_checks.py @@ -0,0 +1,242 @@ +from django.core import checks +from django.db import connection, models +from django.db.models import F +from django.test import TestCase +from django.test.utils import isolate_apps + + +@isolate_apps("composite_pk") +class CompositePKChecksTests(TestCase): + maxDiff = None + + def test_composite_pk_must_be_unique_strings(self): + test_cases = ( + (), + (0,), + (1,), + ("id", False), + ("id", "id"), + (("id",),), + ) + + for i, args in enumerate(test_cases): + with ( + self.subTest(args=args), + self.assertRaisesMessage( + ValueError, "CompositePrimaryKey args must be unique strings." + ), + ): + models.CompositePrimaryKey(*args) + + def test_composite_pk_must_include_at_least_2_fields(self): + expected_message = "CompositePrimaryKey must include at least two fields." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("id") + + def test_composite_pk_cannot_have_a_default(self): + expected_message = "CompositePrimaryKey cannot have a default." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("tenant_id", "id", default=(1, 1)) + + def test_composite_pk_cannot_have_a_database_default(self): + expected_message = "CompositePrimaryKey cannot have a database default." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("tenant_id", "id", db_default=models.F("id")) + + def test_composite_pk_cannot_be_editable(self): + expected_message = "CompositePrimaryKey cannot be editable." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("tenant_id", "id", editable=True) + + def test_composite_pk_must_be_a_primary_key(self): + expected_message = "CompositePrimaryKey must be a primary key." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("tenant_id", "id", primary_key=False) + + def test_composite_pk_must_be_blank(self): + expected_message = "CompositePrimaryKey must be blank." + with self.assertRaisesMessage(ValueError, expected_message): + models.CompositePrimaryKey("tenant_id", "id", blank=False) + + def test_composite_pk_must_not_have_other_pk_field(self): + class Foo(models.Model): + pk = models.CompositePrimaryKey("foo_id", "id") + foo_id = models.IntegerField() + id = models.IntegerField(primary_key=True) + + self.assertEqual( + Foo.check(databases=self.databases), + [ + checks.Error( + "The model cannot have more than one field with " + "'primary_key=True'.", + obj=Foo, + id="models.E026", + ), + ], + ) + + def test_composite_pk_cannot_include_nullable_field(self): + class Foo(models.Model): + pk = models.CompositePrimaryKey("foo_id", "id") + foo_id = models.IntegerField() + id = models.IntegerField(null=True) + + self.assertEqual( + Foo.check(databases=self.databases), + [ + checks.Error( + "'id' cannot be included in the composite primary key.", + hint="'id' field may not set 'null=True'.", + obj=Foo, + id="models.E042", + ), + ], + ) + + def test_composite_pk_can_include_fk_name(self): + class Foo(models.Model): + pass + + class Bar(models.Model): + pk = models.CompositePrimaryKey("foo", "id") + foo = models.ForeignKey(Foo, on_delete=models.CASCADE) + id = models.SmallIntegerField() + + self.assertEqual(Foo.check(databases=self.databases), []) + self.assertEqual(Bar.check(databases=self.databases), []) + + def test_composite_pk_cannot_include_same_field(self): + class Foo(models.Model): + pass + + class Bar(models.Model): + pk = models.CompositePrimaryKey("foo", "foo_id") + foo = models.ForeignKey(Foo, on_delete=models.CASCADE) + id = models.SmallIntegerField() + + self.assertEqual(Foo.check(databases=self.databases), []) + self.assertEqual( + Bar.check(databases=self.databases), + [ + checks.Error( + "'foo_id' cannot be included in the composite primary key.", + hint="'foo_id' and 'foo' are the same fields.", + obj=Bar, + id="models.E042", + ), + ], + ) + + def test_composite_pk_cannot_include_composite_pk_field(self): + class Foo(models.Model): + pk = models.CompositePrimaryKey("id", "pk") + id = models.SmallIntegerField() + + self.assertEqual( + Foo.check(databases=self.databases), + [ + checks.Error( + "'pk' cannot be included in the composite primary key.", + hint="'pk' field has no column.", + obj=Foo, + id="models.E042", + ), + ], + ) + + def test_composite_pk_cannot_include_db_column(self): + class Foo(models.Model): + pk = models.CompositePrimaryKey("foo", "bar") + foo = models.SmallIntegerField(db_column="foo_id") + bar = models.SmallIntegerField(db_column="bar_id") + + class Bar(models.Model): + pk = models.CompositePrimaryKey("foo_id", "bar_id") + foo = models.SmallIntegerField(db_column="foo_id") + bar = models.SmallIntegerField(db_column="bar_id") + + self.assertEqual(Foo.check(databases=self.databases), []) + self.assertEqual( + Bar.check(databases=self.databases), + [ + checks.Error( + "'foo_id' cannot be included in the composite primary key.", + hint="'foo_id' is not a valid field.", + obj=Bar, + id="models.E042", + ), + checks.Error( + "'bar_id' cannot be included in the composite primary key.", + hint="'bar_id' is not a valid field.", + obj=Bar, + id="models.E042", + ), + ], + ) + + def test_foreign_object_can_refer_composite_pk(self): + class Foo(models.Model): + pass + + class Bar(models.Model): + pk = models.CompositePrimaryKey("foo_id", "id") + foo = models.ForeignKey(Foo, on_delete=models.CASCADE) + id = models.IntegerField() + + class Baz(models.Model): + pk = models.CompositePrimaryKey("foo_id", "id") + foo = models.ForeignKey(Foo, on_delete=models.CASCADE) + id = models.IntegerField() + bar_id = models.IntegerField() + bar = models.ForeignObject( + Bar, + on_delete=models.CASCADE, + from_fields=("foo_id", "bar_id"), + to_fields=("foo_id", "id"), + ) + + self.assertEqual(Foo.check(databases=self.databases), []) + self.assertEqual(Bar.check(databases=self.databases), []) + self.assertEqual(Baz.check(databases=self.databases), []) + + def test_composite_pk_must_be_named_pk(self): + class Foo(models.Model): + primary_key = models.CompositePrimaryKey("foo_id", "id") + foo_id = models.IntegerField() + id = models.IntegerField() + + self.assertEqual( + Foo.check(databases=self.databases), + [ + checks.Error( + "'CompositePrimaryKey' must be named 'pk'.", + obj=Foo._meta.get_field("primary_key"), + id="fields.E013", + ), + ], + ) + + def test_composite_pk_cannot_include_generated_field(self): + is_oracle = connection.vendor == "oracle" + + class Foo(models.Model): + pk = models.CompositePrimaryKey("id", "foo") + id = models.IntegerField() + foo = models.GeneratedField( + expression=F("id"), + output_field=models.IntegerField(), + db_persist=not is_oracle, + ) + + self.assertEqual( + Foo.check(databases=self.databases), + [ + checks.Error( + "'foo' cannot be included in the composite primary key.", + hint="'foo' field is a generated field.", + obj=Foo, + id="models.E042", + ), + ], + ) diff --git a/tests/composite_pk/test_create.py b/tests/composite_pk/test_create.py new file mode 100644 index 0000000000..7c9925b946 --- /dev/null +++ b/tests/composite_pk/test_create.py @@ -0,0 +1,138 @@ +from django.test import TestCase + +from .models import Tenant, User + + +class CompositePKCreateTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create( + tenant=cls.tenant, + id=1, + email="user0001@example.com", + ) + + def test_create_user(self): + test_cases = ( + {"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"}, + {"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"}, + {"pk": (self.tenant.id, 7424), "email": "user7424@example.com"}, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user = User(**fields) + obj = User.objects.create(**fields) + self.assertEqual(obj.tenant_id, self.tenant.id) + self.assertEqual(obj.id, user.id) + self.assertEqual(obj.pk, (self.tenant.id, user.id)) + self.assertEqual(obj.email, user.email) + self.assertEqual(count + 1, User.objects.count()) + + def test_save_user(self): + test_cases = ( + {"tenant": self.tenant, "id": 9241, "email": "user9241@example.com"}, + {"tenant_id": self.tenant.id, "id": 5132, "email": "user5132@example.com"}, + {"pk": (self.tenant.id, 3014), "email": "user3014@example.com"}, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user = User(**fields) + self.assertIsNotNone(user.id) + self.assertIsNotNone(user.email) + user.save() + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.tenant, self.tenant) + self.assertIsNotNone(user.id) + self.assertEqual(user.pk, (self.tenant.id, user.id)) + self.assertEqual(user.email, fields["email"]) + self.assertEqual(user.email, f"user{user.id}@example.com") + self.assertEqual(count + 1, User.objects.count()) + + def test_bulk_create_users(self): + objs = [ + User(tenant=self.tenant, id=8291, email="user8291@example.com"), + User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"), + User(pk=(self.tenant.id, 8214), email="user8214@example.com"), + ] + + obj_1, obj_2, obj_3 = User.objects.bulk_create(objs) + + self.assertEqual(obj_1.tenant_id, self.tenant.id) + self.assertEqual(obj_1.id, 8291) + self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id)) + self.assertEqual(obj_1.email, "user8291@example.com") + self.assertEqual(obj_2.tenant_id, self.tenant.id) + self.assertEqual(obj_2.id, 4021) + self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id)) + self.assertEqual(obj_2.email, "user4021@example.com") + self.assertEqual(obj_3.tenant_id, self.tenant.id) + self.assertEqual(obj_3.id, 8214) + self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id)) + self.assertEqual(obj_3.email, "user8214@example.com") + + def test_get_or_create_user(self): + test_cases = ( + { + "pk": (self.tenant.id, 8314), + "defaults": {"email": "user8314@example.com"}, + }, + { + "tenant": self.tenant, + "id": 3142, + "defaults": {"email": "user3142@example.com"}, + }, + { + "tenant_id": self.tenant.id, + "id": 4218, + "defaults": {"email": "user4218@example.com"}, + }, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user, created = User.objects.get_or_create(**fields) + self.assertIs(created, True) + self.assertIsNotNone(user.id) + self.assertEqual(user.pk, (self.tenant.id, user.id)) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.email, fields["defaults"]["email"]) + self.assertEqual(user.email, f"user{user.id}@example.com") + self.assertEqual(count + 1, User.objects.count()) + + def test_update_or_create_user(self): + test_cases = ( + { + "pk": (self.tenant.id, 2931), + "defaults": {"email": "user2931@example.com"}, + }, + { + "tenant": self.tenant, + "id": 6428, + "defaults": {"email": "user6428@example.com"}, + }, + { + "tenant_id": self.tenant.id, + "id": 5278, + "defaults": {"email": "user5278@example.com"}, + }, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user, created = User.objects.update_or_create(**fields) + self.assertIs(created, True) + self.assertIsNotNone(user.id) + self.assertEqual(user.pk, (self.tenant.id, user.id)) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.email, fields["defaults"]["email"]) + self.assertEqual(user.email, f"user{user.id}@example.com") + self.assertEqual(count + 1, User.objects.count()) diff --git a/tests/composite_pk/test_delete.py b/tests/composite_pk/test_delete.py new file mode 100644 index 0000000000..9a14deb813 --- /dev/null +++ b/tests/composite_pk/test_delete.py @@ -0,0 +1,83 @@ +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class CompositePKDeleteTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_2, + id=2, + email="user0002@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_2) + cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2) + + def test_delete_tenant_by_pk(self): + result = Tenant.objects.filter(pk=self.tenant_1.pk).delete() + + self.assertEqual( + result, + ( + 3, + { + "composite_pk.Comment": 1, + "composite_pk.User": 1, + "composite_pk.Tenant": 1, + }, + ), + ) + + self.assertIs(Tenant.objects.filter(pk=self.tenant_1.pk).exists(), False) + self.assertIs(Tenant.objects.filter(pk=self.tenant_2.pk).exists(), True) + self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False) + self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True) + self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False) + self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True) + self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True) + + def test_delete_user_by_pk(self): + result = User.objects.filter(pk=self.user_1.pk).delete() + + self.assertEqual( + result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1}) + ) + + self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False) + self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True) + self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False) + self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True) + self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True) + + def test_delete_comments_by_user(self): + result = Comment.objects.filter(user=self.user_2).delete() + + self.assertEqual(result, (2, {"composite_pk.Comment": 2})) + + self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), True) + self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), False) + self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), False) + + def test_delete_without_pk(self): + msg = ( + "Comment object can't be deleted because its pk attribute is set " + "to None." + ) + + with self.assertRaisesMessage(ValueError, msg): + Comment().delete() + with self.assertRaisesMessage(ValueError, msg): + Comment(tenant_id=1).delete() + with self.assertRaisesMessage(ValueError, msg): + Comment(id=1).delete() diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py new file mode 100644 index 0000000000..7e361c5925 --- /dev/null +++ b/tests/composite_pk/test_filter.py @@ -0,0 +1,412 @@ +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class CompositePKFilterTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.tenant_3 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, + id=3, + email="user0003@example.com", + ) + cls.user_4 = User.objects.create( + tenant=cls.tenant_3, + id=4, + email="user0004@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2) + cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3) + cls.comment_5 = Comment.objects.create(id=5, user=cls.user_1) + + def test_filter_and_count_user_by_pk(self): + test_cases = ( + ({"pk": self.user_1.pk}, 1), + ({"pk": self.user_2.pk}, 1), + ({"pk": self.user_3.pk}, 1), + ({"pk": (self.tenant_1.id, self.user_1.id)}, 1), + ({"pk": (self.tenant_1.id, self.user_2.id)}, 1), + ({"pk": (self.tenant_2.id, self.user_3.id)}, 1), + ({"pk": (self.tenant_1.id, self.user_3.id)}, 0), + ({"pk": (self.tenant_2.id, self.user_1.id)}, 0), + ({"pk": (self.tenant_2.id, self.user_2.id)}, 0), + ) + + for lookup, count in test_cases: + with self.subTest(lookup=lookup, count=count): + self.assertEqual(User.objects.filter(**lookup).count(), count) + + def test_order_comments_by_pk_asc(self): + self.assertSequenceEqual( + Comment.objects.order_by("pk"), + ( + self.comment_1, # (1, 1) + self.comment_2, # (1, 2) + self.comment_3, # (1, 3) + self.comment_5, # (1, 5) + self.comment_4, # (2, 4) + ), + ) + + def test_order_comments_by_pk_desc(self): + self.assertSequenceEqual( + Comment.objects.order_by("-pk"), + ( + self.comment_4, # (2, 4) + self.comment_5, # (1, 5) + self.comment_3, # (1, 3) + self.comment_2, # (1, 2) + self.comment_1, # (1, 1) + ), + ) + + def test_filter_comments_by_pk_gt(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + test_cases = ( + (c11, (c12, c13, c15, c24)), + (c12, (c13, c15, c24)), + (c13, (c15, c24)), + (c15, (c24,)), + (c24, ()), + ) + + for obj, objs in test_cases: + with self.subTest(obj=obj, objs=objs): + self.assertSequenceEqual( + Comment.objects.filter(pk__gt=obj.pk).order_by("pk"), objs + ) + + def test_filter_comments_by_pk_gte(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + test_cases = ( + (c11, (c11, c12, c13, c15, c24)), + (c12, (c12, c13, c15, c24)), + (c13, (c13, c15, c24)), + (c15, (c15, c24)), + (c24, (c24,)), + ) + + for obj, objs in test_cases: + with self.subTest(obj=obj, objs=objs): + self.assertSequenceEqual( + Comment.objects.filter(pk__gte=obj.pk).order_by("pk"), objs + ) + + def test_filter_comments_by_pk_lt(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + test_cases = ( + (c24, (c11, c12, c13, c15)), + (c15, (c11, c12, c13)), + (c13, (c11, c12)), + (c12, (c11,)), + (c11, ()), + ) + + for obj, objs in test_cases: + with self.subTest(obj=obj, objs=objs): + self.assertSequenceEqual( + Comment.objects.filter(pk__lt=obj.pk).order_by("pk"), objs + ) + + def test_filter_comments_by_pk_lte(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + test_cases = ( + (c24, (c11, c12, c13, c15, c24)), + (c15, (c11, c12, c13, c15)), + (c13, (c11, c12, c13)), + (c12, (c11, c12)), + (c11, (c11,)), + ) + + for obj, objs in test_cases: + with self.subTest(obj=obj, objs=objs): + self.assertSequenceEqual( + Comment.objects.filter(pk__lte=obj.pk).order_by("pk"), objs + ) + + def test_filter_comments_by_pk_in(self): + test_cases = ( + (), + (self.comment_1,), + (self.comment_1, self.comment_4), + ) + + for objs in test_cases: + with self.subTest(objs=objs): + pks = [obj.pk for obj in objs] + self.assertSequenceEqual( + Comment.objects.filter(pk__in=pks).order_by("pk"), objs + ) + + def test_filter_comments_by_user_and_order_by_pk_asc(self): + self.assertSequenceEqual( + Comment.objects.filter(user=self.user_1).order_by("pk"), + (self.comment_1, self.comment_2, self.comment_5), + ) + + def test_filter_comments_by_user_and_order_by_pk_desc(self): + self.assertSequenceEqual( + Comment.objects.filter(user=self.user_1).order_by("-pk"), + (self.comment_5, self.comment_2, self.comment_1), + ) + + def test_filter_comments_by_user_and_exclude_by_pk(self): + self.assertSequenceEqual( + Comment.objects.filter(user=self.user_1) + .exclude(pk=self.comment_1.pk) + .order_by("pk"), + (self.comment_2, self.comment_5), + ) + + def test_filter_comments_by_user_and_contains(self): + self.assertIs( + Comment.objects.filter(user=self.user_1).contains(self.comment_1), True + ) + + def test_filter_users_by_comments_in(self): + c1, c2, c3, c4, c5 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2, u3 = ( + self.user_1, + self.user_2, + self.user_3, + ) + test_cases = ( + ((), ()), + ((c1,), (u1,)), + ((c1, c2), (u1, u1)), + ((c1, c2, c3), (u1, u1, u2)), + ((c1, c2, c3, c4), (u1, u1, u2, u3)), + ((c1, c2, c3, c4, c5), (u1, u1, u1, u2, u3)), + ) + + for comments, users in test_cases: + with self.subTest(comments=comments, users=users): + self.assertSequenceEqual( + User.objects.filter(comments__in=comments).order_by("pk"), users + ) + + def test_filter_users_by_comments_lt(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2 = ( + self.user_1, + self.user_2, + ) + test_cases = ( + (c11, ()), + (c12, (u1,)), + (c13, (u1, u1)), + (c15, (u1, u1, u2)), + (c24, (u1, u1, u1, u2)), + ) + + for comment, users in test_cases: + with self.subTest(comment=comment, users=users): + self.assertSequenceEqual( + User.objects.filter(comments__lt=comment).order_by("pk"), users + ) + + def test_filter_users_by_comments_lte(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2, u3 = ( + self.user_1, + self.user_2, + self.user_3, + ) + test_cases = ( + (c11, (u1,)), + (c12, (u1, u1)), + (c13, (u1, u1, u2)), + (c15, (u1, u1, u1, u2)), + (c24, (u1, u1, u1, u2, u3)), + ) + + for comment, users in test_cases: + with self.subTest(comment=comment, users=users): + self.assertSequenceEqual( + User.objects.filter(comments__lte=comment).order_by("pk"), users + ) + + def test_filter_users_by_comments_gt(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2, u3 = ( + self.user_1, + self.user_2, + self.user_3, + ) + test_cases = ( + (c11, (u1, u1, u2, u3)), + (c12, (u1, u2, u3)), + (c13, (u1, u3)), + (c15, (u3,)), + (c24, ()), + ) + + for comment, users in test_cases: + with self.subTest(comment=comment, users=users): + self.assertSequenceEqual( + User.objects.filter(comments__gt=comment).order_by("pk"), users + ) + + def test_filter_users_by_comments_gte(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2, u3 = ( + self.user_1, + self.user_2, + self.user_3, + ) + test_cases = ( + (c11, (u1, u1, u1, u2, u3)), + (c12, (u1, u1, u2, u3)), + (c13, (u1, u2, u3)), + (c15, (u1, u3)), + (c24, (u3,)), + ) + + for comment, users in test_cases: + with self.subTest(comment=comment, users=users): + self.assertSequenceEqual( + User.objects.filter(comments__gte=comment).order_by("pk"), users + ) + + def test_filter_users_by_comments_exact(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + u1, u2, u3 = ( + self.user_1, + self.user_2, + self.user_3, + ) + test_cases = ( + (c11, (u1,)), + (c12, (u1,)), + (c13, (u2,)), + (c15, (u1,)), + (c24, (u3,)), + ) + + for comment, users in test_cases: + with self.subTest(comment=comment, users=users): + self.assertSequenceEqual( + User.objects.filter(comments=comment).order_by("pk"), users + ) + + def test_filter_users_by_comments_isnull(self): + u1, u2, u3, u4 = ( + self.user_1, + self.user_2, + self.user_3, + self.user_4, + ) + + with self.subTest("comments__isnull=True"): + self.assertSequenceEqual( + User.objects.filter(comments__isnull=True).order_by("pk"), + (u4,), + ) + with self.subTest("comments__isnull=False"): + self.assertSequenceEqual( + User.objects.filter(comments__isnull=False).order_by("pk"), + (u1, u1, u1, u2, u3), + ) + + def test_filter_comments_by_pk_isnull(self): + c11, c12, c13, c24, c15 = ( + self.comment_1, + self.comment_2, + self.comment_3, + self.comment_4, + self.comment_5, + ) + + with self.subTest("pk__isnull=True"): + self.assertSequenceEqual( + Comment.objects.filter(pk__isnull=True).order_by("pk"), + (), + ) + with self.subTest("pk__isnull=False"): + self.assertSequenceEqual( + Comment.objects.filter(pk__isnull=False).order_by("pk"), + (c11, c12, c13, c15, c24), + ) + + def test_filter_users_by_comments_subquery(self): + subquery = Comment.objects.filter(id=3).only("pk") + queryset = User.objects.filter(comments__in=subquery) + self.assertSequenceEqual(queryset, (self.user_2,)) diff --git a/tests/composite_pk/test_get.py b/tests/composite_pk/test_get.py new file mode 100644 index 0000000000..c896ec26ed --- /dev/null +++ b/tests/composite_pk/test_get.py @@ -0,0 +1,126 @@ +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class CompositePKGetTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, + id=3, + email="user0003@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + + def test_get_user(self): + test_cases = ( + {"pk": self.user_1.pk}, + {"pk": (self.tenant_1.id, self.user_1.id)}, + {"id": self.user_1.id}, + ) + + for lookup in test_cases: + with self.subTest(lookup=lookup): + self.assertEqual(User.objects.get(**lookup), self.user_1) + + def test_get_comment(self): + test_cases = ( + {"pk": self.comment_1.pk}, + {"pk": (self.tenant_1.id, self.comment_1.id)}, + {"id": self.comment_1.id}, + {"user": self.user_1}, + {"user_id": self.user_1.id}, + {"user__id": self.user_1.id}, + {"user__pk": self.user_1.pk}, + {"tenant": self.tenant_1}, + {"tenant_id": self.tenant_1.id}, + {"tenant__id": self.tenant_1.id}, + {"tenant__pk": self.tenant_1.pk}, + ) + + for lookup in test_cases: + with self.subTest(lookup=lookup): + self.assertEqual(Comment.objects.get(**lookup), self.comment_1) + + def test_get_or_create_user(self): + test_cases = ( + { + "pk": self.user_1.pk, + "defaults": {"email": "user9201@example.com"}, + }, + { + "pk": (self.tenant_1.id, self.user_1.id), + "defaults": {"email": "user9201@example.com"}, + }, + { + "tenant": self.tenant_1, + "id": self.user_1.id, + "defaults": {"email": "user3512@example.com"}, + }, + { + "tenant_id": self.tenant_1.id, + "id": self.user_1.id, + "defaults": {"email": "user8239@example.com"}, + }, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user, created = User.objects.get_or_create(**fields) + self.assertIs(created, False) + self.assertEqual(user.id, self.user_1.id) + self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id)) + self.assertEqual(user.tenant_id, self.tenant_1.id) + self.assertEqual(user.email, self.user_1.email) + self.assertEqual(count, User.objects.count()) + + def test_lookup_errors(self): + m_tuple = "'%s' lookup of 'pk' must be a tuple or a list" + m_2_elements = "'%s' lookup of 'pk' must have 2 elements" + m_tuple_collection = ( + "'in' lookup of 'pk' must be a collection of tuples or lists" + ) + m_2_elements_each = "'in' lookup of 'pk' must have 2 elements each" + test_cases = ( + ({"pk": 1}, m_tuple % "exact"), + ({"pk": (1, 2, 3)}, m_2_elements % "exact"), + ({"pk__exact": 1}, m_tuple % "exact"), + ({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"), + ({"pk__in": 1}, m_tuple % "in"), + ({"pk__in": (1, 2, 3)}, m_tuple_collection), + ({"pk__in": ((1, 2, 3),)}, m_2_elements_each), + ({"pk__gt": 1}, m_tuple % "gt"), + ({"pk__gt": (1, 2, 3)}, m_2_elements % "gt"), + ({"pk__gte": 1}, m_tuple % "gte"), + ({"pk__gte": (1, 2, 3)}, m_2_elements % "gte"), + ({"pk__lt": 1}, m_tuple % "lt"), + ({"pk__lt": (1, 2, 3)}, m_2_elements % "lt"), + ({"pk__lte": 1}, m_tuple % "lte"), + ({"pk__lte": (1, 2, 3)}, m_2_elements % "lte"), + ) + + for kwargs, message in test_cases: + with ( + self.subTest(kwargs=kwargs), + self.assertRaisesMessage(ValueError, message), + ): + Comment.objects.get(**kwargs) + + def test_get_user_by_comments(self): + self.assertEqual(User.objects.get(comments=self.comment_1), self.user_1) diff --git a/tests/composite_pk/test_models.py b/tests/composite_pk/test_models.py new file mode 100644 index 0000000000..ca6ad8b5dc --- /dev/null +++ b/tests/composite_pk/test_models.py @@ -0,0 +1,153 @@ +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ValidationError +from django.test import TestCase + +from .models import Comment, Tenant, Token, User + + +class CompositePKModelsTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, + id=3, + email="user0003@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2) + cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3) + + def test_fields(self): + # tenant_1 + self.assertSequenceEqual( + self.tenant_1.user_set.order_by("pk"), + [self.user_1, self.user_2], + ) + self.assertSequenceEqual( + self.tenant_1.comments.order_by("pk"), + [self.comment_1, self.comment_2, self.comment_3], + ) + + # tenant_2 + self.assertSequenceEqual(self.tenant_2.user_set.order_by("pk"), [self.user_3]) + self.assertSequenceEqual( + self.tenant_2.comments.order_by("pk"), [self.comment_4] + ) + + # user_1 + self.assertEqual(self.user_1.id, 1) + self.assertEqual(self.user_1.tenant_id, self.tenant_1.id) + self.assertEqual(self.user_1.tenant, self.tenant_1) + self.assertEqual(self.user_1.pk, (self.tenant_1.id, self.user_1.id)) + self.assertSequenceEqual( + self.user_1.comments.order_by("pk"), [self.comment_1, self.comment_2] + ) + + # user_2 + self.assertEqual(self.user_2.id, 2) + self.assertEqual(self.user_2.tenant_id, self.tenant_1.id) + self.assertEqual(self.user_2.tenant, self.tenant_1) + self.assertEqual(self.user_2.pk, (self.tenant_1.id, self.user_2.id)) + self.assertSequenceEqual(self.user_2.comments.order_by("pk"), [self.comment_3]) + + # comment_1 + self.assertEqual(self.comment_1.id, 1) + self.assertEqual(self.comment_1.user_id, self.user_1.id) + self.assertEqual(self.comment_1.user, self.user_1) + self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id) + self.assertEqual(self.comment_1.tenant, self.tenant_1) + self.assertEqual(self.comment_1.pk, (self.tenant_1.id, self.user_1.id)) + + def test_full_clean_success(self): + test_cases = ( + # 1, 1234, {} + ({"tenant": self.tenant_1, "id": 1234}, {}), + ({"tenant_id": self.tenant_1.id, "id": 1234}, {}), + ({"pk": (self.tenant_1.id, 1234)}, {}), + # 1, 1, {"id"} + ({"tenant": self.tenant_1, "id": 1}, {"id"}), + ({"tenant_id": self.tenant_1.id, "id": 1}, {"id"}), + ({"pk": (self.tenant_1.id, 1)}, {"id"}), + # 1, 1, {"tenant", "id"} + ({"tenant": self.tenant_1, "id": 1}, {"tenant", "id"}), + ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant", "id"}), + ({"pk": (self.tenant_1.id, 1)}, {"tenant", "id"}), + ) + + for kwargs, exclude in test_cases: + with self.subTest(kwargs): + kwargs["email"] = "user0004@example.com" + User(**kwargs).full_clean(exclude=exclude) + + def test_full_clean_failure(self): + e_tenant_and_id = "User with this Tenant and Id already exists." + e_id = "User with this Id already exists." + test_cases = ( + # 1, 1, {} + ({"tenant": self.tenant_1, "id": 1}, {}, (e_tenant_and_id, e_id)), + ({"tenant_id": self.tenant_1.id, "id": 1}, {}, (e_tenant_and_id, e_id)), + ({"pk": (self.tenant_1.id, 1)}, {}, (e_tenant_and_id, e_id)), + # 2, 1, {} + ({"tenant": self.tenant_2, "id": 1}, {}, (e_id,)), + ({"tenant_id": self.tenant_2.id, "id": 1}, {}, (e_id,)), + ({"pk": (self.tenant_2.id, 1)}, {}, (e_id,)), + # 1, 1, {"tenant"} + ({"tenant": self.tenant_1, "id": 1}, {"tenant"}, (e_id,)), + ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant"}, (e_id,)), + ({"pk": (self.tenant_1.id, 1)}, {"tenant"}, (e_id,)), + ) + + for kwargs, exclude, messages in test_cases: + with self.subTest(kwargs): + with self.assertRaises(ValidationError) as ctx: + kwargs["email"] = "user0004@example.com" + User(**kwargs).full_clean(exclude=exclude) + + self.assertSequenceEqual(ctx.exception.messages, messages) + + def test_field_conflicts(self): + test_cases = ( + ({"pk": (1, 1), "id": 2}, (1, 1)), + ({"id": 2, "pk": (1, 1)}, (1, 1)), + ({"pk": (1, 1), "tenant_id": 2}, (1, 1)), + ({"tenant_id": 2, "pk": (1, 1)}, (1, 1)), + ({"pk": (2, 2), "tenant_id": 3, "id": 4}, (2, 2)), + ({"tenant_id": 3, "id": 4, "pk": (2, 2)}, (2, 2)), + ) + + for kwargs, pk in test_cases: + with self.subTest(kwargs=kwargs): + user = User(**kwargs) + self.assertEqual(user.pk, pk) + + def test_validate_unique(self): + user = User.objects.get(pk=self.user_1.pk) + user.id = None + + with self.assertRaises(ValidationError) as ctx: + user.validate_unique() + + self.assertSequenceEqual( + ctx.exception.messages, ("User with this Email already exists.",) + ) + + def test_permissions(self): + token = ContentType.objects.get_for_model(Token) + user = ContentType.objects.get_for_model(User) + comment = ContentType.objects.get_for_model(Comment) + self.assertEqual(4, token.permission_set.count()) + self.assertEqual(4, user.permission_set.count()) + self.assertEqual(4, comment.permission_set.count()) diff --git a/tests/composite_pk/test_names_to_path.py b/tests/composite_pk/test_names_to_path.py new file mode 100644 index 0000000000..de4a04f4cb --- /dev/null +++ b/tests/composite_pk/test_names_to_path.py @@ -0,0 +1,134 @@ +from django.db.models.query_utils import PathInfo +from django.db.models.sql import Query +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class NamesToPathTests(TestCase): + def test_id(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path(["id"], User._meta) + + self.assertEqual(path, []) + self.assertEqual(final_field, User._meta.get_field("id")) + self.assertEqual(targets, (User._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_pk(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path(["pk"], User._meta) + + self.assertEqual(path, []) + self.assertEqual(final_field, User._meta.get_field("pk")) + self.assertEqual(targets, (User._meta.get_field("pk"),)) + self.assertEqual(rest, []) + + def test_tenant_id(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path( + ["tenant", "id"], User._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=User._meta, + to_opts=Tenant._meta, + target_fields=(Tenant._meta.get_field("id"),), + join_field=User._meta.get_field("tenant"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, Tenant._meta.get_field("id")) + self.assertEqual(targets, (Tenant._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_user_id(self): + query = Query(Comment) + path, final_field, targets, rest = query.names_to_path( + ["user", "id"], Comment._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=Comment._meta, + to_opts=User._meta, + target_fields=( + User._meta.get_field("tenant"), + User._meta.get_field("id"), + ), + join_field=Comment._meta.get_field("user"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, User._meta.get_field("id")) + self.assertEqual(targets, (User._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_user_tenant_id(self): + query = Query(Comment) + path, final_field, targets, rest = query.names_to_path( + ["user", "tenant", "id"], Comment._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=Comment._meta, + to_opts=User._meta, + target_fields=( + User._meta.get_field("tenant"), + User._meta.get_field("id"), + ), + join_field=Comment._meta.get_field("user"), + m2m=False, + direct=True, + filtered_relation=None, + ), + PathInfo( + from_opts=User._meta, + to_opts=Tenant._meta, + target_fields=(Tenant._meta.get_field("id"),), + join_field=User._meta.get_field("tenant"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, Tenant._meta.get_field("id")) + self.assertEqual(targets, (Tenant._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_comments(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path(["comments"], User._meta) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=User._meta, + to_opts=Comment._meta, + target_fields=(Comment._meta.get_field("pk"),), + join_field=User._meta.get_field("comments"), + m2m=True, + direct=False, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, User._meta.get_field("comments")) + self.assertEqual(targets, (Comment._meta.get_field("pk"),)) + self.assertEqual(rest, []) diff --git a/tests/composite_pk/test_update.py b/tests/composite_pk/test_update.py new file mode 100644 index 0000000000..e711745447 --- /dev/null +++ b/tests/composite_pk/test_update.py @@ -0,0 +1,135 @@ +from django.test import TestCase + +from .models import Comment, Tenant, Token, User + + +class CompositePKUpdateTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create(name="A") + cls.tenant_2 = Tenant.objects.create(name="B") + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, + id=3, + email="user0003@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2) + cls.token_1 = Token.objects.create(id=1, tenant=cls.tenant_1) + cls.token_2 = Token.objects.create(id=2, tenant=cls.tenant_2) + cls.token_3 = Token.objects.create(id=3, tenant=cls.tenant_1) + cls.token_4 = Token.objects.create(id=4, tenant=cls.tenant_2) + + def test_update_user(self): + email = "user9315@example.com" + result = User.objects.filter(pk=self.user_1.pk).update(email=email) + self.assertEqual(result, 1) + user = User.objects.get(pk=self.user_1.pk) + self.assertEqual(user.email, email) + + def test_save_user(self): + count = User.objects.count() + email = "user9314@example.com" + user = User.objects.get(pk=self.user_1.pk) + user.email = email + user.save() + user.refresh_from_db() + self.assertEqual(user.email, email) + user = User.objects.get(pk=self.user_1.pk) + self.assertEqual(user.email, email) + self.assertEqual(count, User.objects.count()) + + def test_bulk_update_comments(self): + comment_1 = Comment.objects.get(pk=self.comment_1.pk) + comment_2 = Comment.objects.get(pk=self.comment_2.pk) + comment_3 = Comment.objects.get(pk=self.comment_3.pk) + comment_1.text = "foo" + comment_2.text = "bar" + comment_3.text = "baz" + + result = Comment.objects.bulk_update( + [comment_1, comment_2, comment_3], ["text"] + ) + + self.assertEqual(result, 3) + comment_1 = Comment.objects.get(pk=self.comment_1.pk) + comment_2 = Comment.objects.get(pk=self.comment_2.pk) + comment_3 = Comment.objects.get(pk=self.comment_3.pk) + self.assertEqual(comment_1.text, "foo") + self.assertEqual(comment_2.text, "bar") + self.assertEqual(comment_3.text, "baz") + + def test_update_or_create_user(self): + test_cases = ( + { + "pk": self.user_1.pk, + "defaults": {"email": "user3914@example.com"}, + }, + { + "pk": (self.tenant_1.id, self.user_1.id), + "defaults": {"email": "user9375@example.com"}, + }, + { + "tenant": self.tenant_1, + "id": self.user_1.id, + "defaults": {"email": "user3517@example.com"}, + }, + { + "tenant_id": self.tenant_1.id, + "id": self.user_1.id, + "defaults": {"email": "user8391@example.com"}, + }, + ) + + for fields in test_cases: + with self.subTest(fields=fields): + count = User.objects.count() + user, created = User.objects.update_or_create(**fields) + self.assertIs(created, False) + self.assertEqual(user.id, self.user_1.id) + self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id)) + self.assertEqual(user.tenant_id, self.tenant_1.id) + self.assertEqual(user.email, fields["defaults"]["email"]) + self.assertEqual(count, User.objects.count()) + + def test_update_comment_by_user_email(self): + result = Comment.objects.filter(user__email=self.user_1.email).update( + text="foo" + ) + + self.assertEqual(result, 2) + comment_1 = Comment.objects.get(pk=self.comment_1.pk) + comment_2 = Comment.objects.get(pk=self.comment_2.pk) + self.assertEqual(comment_1.text, "foo") + self.assertEqual(comment_2.text, "foo") + + def test_update_token_by_tenant_name(self): + result = Token.objects.filter(tenant__name="A").update(secret="bar") + + self.assertEqual(result, 2) + token_1 = Token.objects.get(pk=self.token_1.pk) + self.assertEqual(token_1.secret, "bar") + token_3 = Token.objects.get(pk=self.token_3.pk) + self.assertEqual(token_3.secret, "bar") + + def test_cant_update_to_unsaved_object(self): + msg = ( + "Unsaved model instance cannot be used " + "in an ORM query." + ) + + with self.assertRaisesMessage(ValueError, msg): + Comment.objects.update(user=User()) diff --git a/tests/composite_pk/test_values.py b/tests/composite_pk/test_values.py new file mode 100644 index 0000000000..a3c7a589cc --- /dev/null +++ b/tests/composite_pk/test_values.py @@ -0,0 +1,212 @@ +from collections import namedtuple +from uuid import UUID + +from django.test import TestCase + +from .models import Post, Tenant, User + + +class CompositePKValuesTests(TestCase): + USER_1_EMAIL = "user0001@example.com" + USER_2_EMAIL = "user0002@example.com" + USER_3_EMAIL = "user0003@example.com" + POST_1_ID = "77777777-7777-7777-7777-777777777777" + POST_2_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + POST_3_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, id=1, email=cls.USER_1_EMAIL + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, id=2, email=cls.USER_2_EMAIL + ) + cls.user_3 = User.objects.create( + tenant=cls.tenant_2, id=3, email=cls.USER_3_EMAIL + ) + cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_1_ID) + cls.post_2 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_2_ID) + cls.post_3 = Post.objects.create(tenant=cls.tenant_2, id=cls.POST_3_ID) + + def test_values_list(self): + with self.subTest('User.objects.values_list("pk")'): + self.assertSequenceEqual( + User.objects.values_list("pk").order_by("pk"), + ( + (self.user_1.pk,), + (self.user_2.pk,), + (self.user_3.pk,), + ), + ) + with self.subTest('User.objects.values_list("pk", "email")'): + self.assertSequenceEqual( + User.objects.values_list("pk", "email").order_by("pk"), + ( + (self.user_1.pk, self.USER_1_EMAIL), + (self.user_2.pk, self.USER_2_EMAIL), + (self.user_3.pk, self.USER_3_EMAIL), + ), + ) + with self.subTest('User.objects.values_list("pk", "id")'): + self.assertSequenceEqual( + User.objects.values_list("pk", "id").order_by("pk"), + ( + (self.user_1.pk, self.user_1.id), + (self.user_2.pk, self.user_2.id), + (self.user_3.pk, self.user_3.id), + ), + ) + with self.subTest('User.objects.values_list("pk", "tenant_id", "id")'): + self.assertSequenceEqual( + User.objects.values_list("pk", "tenant_id", "id").order_by("pk"), + ( + (self.user_1.pk, self.user_1.tenant_id, self.user_1.id), + (self.user_2.pk, self.user_2.tenant_id, self.user_2.id), + (self.user_3.pk, self.user_3.tenant_id, self.user_3.id), + ), + ) + with self.subTest('User.objects.values_list("pk", flat=True)'): + self.assertSequenceEqual( + User.objects.values_list("pk", flat=True).order_by("pk"), + ( + self.user_1.pk, + self.user_2.pk, + self.user_3.pk, + ), + ) + with self.subTest('Post.objects.values_list("pk", flat=True)'): + self.assertSequenceEqual( + Post.objects.values_list("pk", flat=True).order_by("pk"), + ( + (self.tenant_1.id, UUID(self.POST_1_ID)), + (self.tenant_1.id, UUID(self.POST_2_ID)), + (self.tenant_2.id, UUID(self.POST_3_ID)), + ), + ) + with self.subTest('Post.objects.values_list("pk")'): + self.assertSequenceEqual( + Post.objects.values_list("pk").order_by("pk"), + ( + ((self.tenant_1.id, UUID(self.POST_1_ID)),), + ((self.tenant_1.id, UUID(self.POST_2_ID)),), + ((self.tenant_2.id, UUID(self.POST_3_ID)),), + ), + ) + with self.subTest('Post.objects.values_list("pk", "id")'): + self.assertSequenceEqual( + Post.objects.values_list("pk", "id").order_by("pk"), + ( + ((self.tenant_1.id, UUID(self.POST_1_ID)), UUID(self.POST_1_ID)), + ((self.tenant_1.id, UUID(self.POST_2_ID)), UUID(self.POST_2_ID)), + ((self.tenant_2.id, UUID(self.POST_3_ID)), UUID(self.POST_3_ID)), + ), + ) + with self.subTest('Post.objects.values_list("id", "pk")'): + self.assertSequenceEqual( + Post.objects.values_list("id", "pk").order_by("pk"), + ( + (UUID(self.POST_1_ID), (self.tenant_1.id, UUID(self.POST_1_ID))), + (UUID(self.POST_2_ID), (self.tenant_1.id, UUID(self.POST_2_ID))), + (UUID(self.POST_3_ID), (self.tenant_2.id, UUID(self.POST_3_ID))), + ), + ) + with self.subTest('User.objects.values_list("pk", named=True)'): + Row = namedtuple("Row", ["pk"]) + self.assertSequenceEqual( + User.objects.values_list("pk", named=True).order_by("pk"), + ( + Row(pk=self.user_1.pk), + Row(pk=self.user_2.pk), + Row(pk=self.user_3.pk), + ), + ) + with self.subTest('User.objects.values_list("pk", "pk")'): + self.assertSequenceEqual( + User.objects.values_list("pk", "pk").order_by("pk"), + ( + (self.user_1.pk,), + (self.user_2.pk,), + (self.user_3.pk,), + ), + ) + with self.subTest('User.objects.values_list("pk", "id", "pk", "id")'): + self.assertSequenceEqual( + User.objects.values_list("pk", "id", "pk", "id").order_by("pk"), + ( + (self.user_1.pk, self.user_1.id), + (self.user_2.pk, self.user_2.id), + (self.user_3.pk, self.user_3.id), + ), + ) + + def test_values(self): + with self.subTest('User.objects.values("pk")'): + self.assertSequenceEqual( + User.objects.values("pk").order_by("pk"), + ( + {"pk": self.user_1.pk}, + {"pk": self.user_2.pk}, + {"pk": self.user_3.pk}, + ), + ) + with self.subTest('User.objects.values("pk", "email")'): + self.assertSequenceEqual( + User.objects.values("pk", "email").order_by("pk"), + ( + {"pk": self.user_1.pk, "email": self.USER_1_EMAIL}, + {"pk": self.user_2.pk, "email": self.USER_2_EMAIL}, + {"pk": self.user_3.pk, "email": self.USER_3_EMAIL}, + ), + ) + with self.subTest('User.objects.values("pk", "id")'): + self.assertSequenceEqual( + User.objects.values("pk", "id").order_by("pk"), + ( + {"pk": self.user_1.pk, "id": self.user_1.id}, + {"pk": self.user_2.pk, "id": self.user_2.id}, + {"pk": self.user_3.pk, "id": self.user_3.id}, + ), + ) + with self.subTest('User.objects.values("pk", "tenant_id", "id")'): + self.assertSequenceEqual( + User.objects.values("pk", "tenant_id", "id").order_by("pk"), + ( + { + "pk": self.user_1.pk, + "tenant_id": self.user_1.tenant_id, + "id": self.user_1.id, + }, + { + "pk": self.user_2.pk, + "tenant_id": self.user_2.tenant_id, + "id": self.user_2.id, + }, + { + "pk": self.user_3.pk, + "tenant_id": self.user_3.tenant_id, + "id": self.user_3.id, + }, + ), + ) + with self.subTest('User.objects.values("pk", "pk")'): + self.assertSequenceEqual( + User.objects.values("pk", "pk").order_by("pk"), + ( + {"pk": self.user_1.pk}, + {"pk": self.user_2.pk}, + {"pk": self.user_3.pk}, + ), + ) + with self.subTest('User.objects.values("pk", "id", "pk", "id")'): + self.assertSequenceEqual( + User.objects.values("pk", "id", "pk", "id").order_by("pk"), + ( + {"pk": self.user_1.pk, "id": self.user_1.id}, + {"pk": self.user_2.pk, "id": self.user_2.id}, + {"pk": self.user_3.pk, "id": self.user_3.id}, + ), + ) diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py new file mode 100644 index 0000000000..71522cb836 --- /dev/null +++ b/tests/composite_pk/tests.py @@ -0,0 +1,345 @@ +import json +import unittest +from uuid import UUID + +import yaml + +from django import forms +from django.core import serializers +from django.core.exceptions import FieldError +from django.db import IntegrityError, connection +from django.db.models import CompositePrimaryKey +from django.forms import modelform_factory +from django.test import TestCase + +from .models import Comment, Post, Tenant, User + + +class CommentForm(forms.ModelForm): + class Meta: + model = Comment + fields = "__all__" + + +class CompositePKTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create( + tenant=cls.tenant, + id=1, + email="user0001@example.com", + ) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + @staticmethod + def get_constraints(table): + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def test_pk_updated_if_field_updated(self): + user = User.objects.get(pk=self.user.pk) + self.assertEqual(user.pk, (self.tenant.id, self.user.id)) + self.assertIs(user._is_pk_set(), True) + user.tenant_id = 9831 + self.assertEqual(user.pk, (9831, self.user.id)) + self.assertIs(user._is_pk_set(), True) + user.id = 4321 + self.assertEqual(user.pk, (9831, 4321)) + self.assertIs(user._is_pk_set(), True) + user.pk = (9132, 3521) + self.assertEqual(user.tenant_id, 9132) + self.assertEqual(user.id, 3521) + self.assertIs(user._is_pk_set(), True) + user.id = None + self.assertEqual(user.pk, (9132, None)) + self.assertEqual(user.tenant_id, 9132) + self.assertIsNone(user.id) + self.assertIs(user._is_pk_set(), False) + + def test_hash(self): + self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2))) + self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3))) + msg = "Model instances without primary key value are unhashable" + + with self.assertRaisesMessage(TypeError, msg): + hash(User()) + with self.assertRaisesMessage(TypeError, msg): + hash(User(tenant_id=1)) + with self.assertRaisesMessage(TypeError, msg): + hash(User(id=1)) + + def test_pk_must_be_list_or_tuple(self): + user = User.objects.get(pk=self.user.pk) + test_cases = [ + "foo", + 1000, + 3.14, + True, + False, + ] + + for pk in test_cases: + with self.assertRaisesMessage( + ValueError, "'pk' must be a list or a tuple." + ): + user.pk = pk + + def test_pk_must_have_2_elements(self): + user = User.objects.get(pk=self.user.pk) + test_cases = [ + (), + [], + (1000,), + [1000], + (1, 2, 3), + [1, 2, 3], + ] + + for pk in test_cases: + with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."): + user.pk = pk + + def test_composite_pk_in_fields(self): + user_fields = {f.name for f in User._meta.get_fields()} + self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"}) + + comment_fields = {f.name for f in Comment._meta.get_fields()} + self.assertEqual( + comment_fields, + {"pk", "tenant", "id", "user_id", "user", "text"}, + ) + + def test_pk_field(self): + pk = User._meta.get_field("pk") + self.assertIsInstance(pk, CompositePrimaryKey) + self.assertIs(User._meta.pk, pk) + + def test_error_on_user_pk_conflict(self): + with self.assertRaises(IntegrityError): + User.objects.create(tenant=self.tenant, id=self.user.id) + + def test_error_on_comment_pk_conflict(self): + with self.assertRaises(IntegrityError): + Comment.objects.create(tenant=self.tenant, id=self.comment.id) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_get_constraints_postgresql(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["composite_pk_user_pkey"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["composite_pk_comment_pkey"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_get_constraints_sqlite(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["__primary__"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["__primary__"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test") + def test_get_constraints_mysql(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = user_constraints["PRIMARY"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertIs(user_pk["primary_key"], True) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["PRIMARY"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertIs(comment_pk["primary_key"], True) + + @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test") + def test_get_constraints_oracle(self): + user_constraints = self.get_constraints(User._meta.db_table) + user_pk = next(c for c in user_constraints.values() if c["primary_key"]) + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertEqual(user_pk["primary_key"], 1) + + comment_constraints = self.get_constraints(Comment._meta.db_table) + comment_pk = next(c for c in comment_constraints.values() if c["primary_key"]) + self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) + self.assertEqual(comment_pk["primary_key"], 1) + + def test_in_bulk(self): + """ + Test the .in_bulk() method of composite_pk models. + """ + result = Comment.objects.in_bulk() + self.assertEqual(result, {self.comment.pk: self.comment}) + + result = Comment.objects.in_bulk([self.comment.pk]) + self.assertEqual(result, {self.comment.pk: self.comment}) + + def test_iterator(self): + """ + Test the .iterator() method of composite_pk models. + """ + result = list(Comment.objects.iterator()) + self.assertEqual(result, [self.comment]) + + def test_query(self): + users = User.objects.values_list("pk").order_by("pk") + self.assertNotIn('AS "pk"', str(users.query)) + + def test_only(self): + users = User.objects.only("pk") + self.assertSequenceEqual(users, (self.user,)) + user = users[0] + + with self.assertNumQueries(0): + self.assertEqual(user.pk, (self.user.tenant_id, self.user.id)) + self.assertEqual(user.tenant_id, self.user.tenant_id) + self.assertEqual(user.id, self.user.id) + with self.assertNumQueries(1): + self.assertEqual(user.email, self.user.email) + + def test_model_forms(self): + fields = ["tenant", "id", "user_id", "text"] + self.assertEqual(list(CommentForm.base_fields), fields) + + form = modelform_factory(Comment, fields="__all__") + self.assertEqual(list(form().fields), fields) + + with self.assertRaisesMessage( + FieldError, "Unknown field(s) (pk) specified for Comment" + ): + self.assertIsNone(modelform_factory(Comment, fields=["pk"])) + + +class CompositePKFixturesTests(TestCase): + fixtures = ["tenant"] + + def test_objects(self): + tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk") + self.assertEqual(tenant_1.id, 1) + self.assertEqual(tenant_1.name, "Tenant 1") + self.assertEqual(tenant_2.id, 2) + self.assertEqual(tenant_2.name, "Tenant 2") + self.assertEqual(tenant_3.id, 3) + self.assertEqual(tenant_3.name, "Tenant 3") + + user_1, user_2, user_3, user_4 = User.objects.order_by("pk") + self.assertEqual(user_1.id, 1) + self.assertEqual(user_1.tenant_id, 1) + self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id)) + self.assertEqual(user_1.email, "user0001@example.com") + self.assertEqual(user_2.id, 2) + self.assertEqual(user_2.tenant_id, 1) + self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id)) + self.assertEqual(user_2.email, "user0002@example.com") + self.assertEqual(user_3.id, 3) + self.assertEqual(user_3.tenant_id, 2) + self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id)) + self.assertEqual(user_3.email, "user0003@example.com") + self.assertEqual(user_4.id, 4) + self.assertEqual(user_4.tenant_id, 2) + self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id)) + self.assertEqual(user_4.email, "user0004@example.com") + + post_1, post_2 = Post.objects.order_by("pk") + self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111")) + self.assertEqual(post_1.tenant_id, 2) + self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id)) + self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")) + self.assertEqual(post_2.tenant_id, 2) + self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id)) + + def test_serialize_user_json(self): + users = User.objects.filter(pk=(1, 1)) + result = serializers.serialize("json", users) + self.assertEqual( + json.loads(result), + [ + { + "model": "composite_pk.user", + "pk": [1, 1], + "fields": { + "email": "user0001@example.com", + "id": 1, + "tenant": 1, + }, + } + ], + ) + + def test_serialize_user_jsonl(self): + users = User.objects.filter(pk=(1, 2)) + result = serializers.serialize("jsonl", users) + self.assertEqual( + json.loads(result), + { + "model": "composite_pk.user", + "pk": [1, 2], + "fields": { + "email": "user0002@example.com", + "id": 2, + "tenant": 1, + }, + }, + ) + + def test_serialize_user_yaml(self): + users = User.objects.filter(pk=(2, 3)) + result = serializers.serialize("yaml", users) + self.assertEqual( + yaml.safe_load(result), + [ + { + "model": "composite_pk.user", + "pk": [2, 3], + "fields": { + "email": "user0003@example.com", + "id": 3, + "tenant": 2, + }, + }, + ], + ) + + def test_serialize_user_python(self): + users = User.objects.filter(pk=(2, 4)) + result = serializers.serialize("python", users) + self.assertEqual( + result, + [ + { + "model": "composite_pk.user", + "pk": [2, 4], + "fields": { + "email": "user0004@example.com", + "id": 4, + "tenant": 2, + }, + }, + ], + ) + + def test_serialize_post_uuid(self): + posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111")) + result = serializers.serialize("json", posts) + self.assertEqual( + json.loads(result), + [ + { + "model": "composite_pk.post", + "pk": [2, "11111111-1111-1111-1111-111111111111"], + "fields": { + "id": "11111111-1111-1111-1111-111111111111", + "tenant": 2, + }, + }, + ], + ) diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index de62170eb3..33196ea6f4 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -5059,6 +5059,95 @@ class AutodetectorTests(BaseAutodetectorTests): self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"]) self.assertOperationAttributes(changes, "testapp", 0, 0, name="Book") + @mock.patch( + "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition" + ) + def test_add_composite_pk(self, mocked_ask_method): + before = [ + ModelState( + "app", + "foo", + [ + ("id", models.AutoField(primary_key=True)), + ], + ), + ] + after = [ + ModelState( + "app", + "foo", + [ + ("pk", models.CompositePrimaryKey("foo_id", "bar_id")), + ("id", models.IntegerField()), + ], + ), + ] + + changes = self.get_changes(before, after) + self.assertEqual(mocked_ask_method.call_count, 0) + self.assertNumberMigrations(changes, "app", 1) + self.assertOperationTypes(changes, "app", 0, ["AddField", "AlterField"]) + self.assertOperationAttributes( + changes, + "app", + 0, + 0, + name="pk", + model_name="foo", + preserve_default=True, + ) + self.assertOperationAttributes( + changes, + "app", + 0, + 1, + name="id", + model_name="foo", + preserve_default=True, + ) + + def test_remove_composite_pk(self): + before = [ + ModelState( + "app", + "foo", + [ + ("pk", models.CompositePrimaryKey("foo_id", "bar_id")), + ("id", models.IntegerField()), + ], + ), + ] + after = [ + ModelState( + "app", + "foo", + [ + ("id", models.AutoField(primary_key=True)), + ], + ), + ] + + changes = self.get_changes(before, after) + self.assertNumberMigrations(changes, "app", 1) + self.assertOperationTypes(changes, "app", 0, ["RemoveField", "AlterField"]) + self.assertOperationAttributes( + changes, + "app", + 0, + 0, + name="pk", + model_name="foo", + ) + self.assertOperationAttributes( + changes, + "app", + 0, + 1, + name="id", + model_name="foo", + preserve_default=True, + ) + class MigrationSuggestNameTests(SimpleTestCase): def test_no_operations(self): diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index da0ec93dcd..6312a7d4a2 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -6287,6 +6287,61 @@ class OperationTests(OperationTestBase): self.assertEqual(pony_new.generated, 1) self.assertEqual(pony_new.static, 2) + def test_composite_pk_operations(self): + app_label = "test_d8d90af6" + project_state = self.set_up_test_model(app_label) + operation_1 = migrations.AddField( + "Pony", "pk", models.CompositePrimaryKey("id", "pink") + ) + operation_2 = migrations.AlterField("Pony", "id", models.IntegerField()) + operation_3 = migrations.RemoveField("Pony", "pk") + table_name = f"{app_label}_pony" + + # 1. Add field (pk). + new_state = project_state.clone() + operation_1.state_forwards(app_label, new_state) + with connection.schema_editor() as editor: + operation_1.database_forwards(app_label, editor, project_state, new_state) + self.assertColumnNotExists(table_name, "pk") + Pony = new_state.apps.get_model(app_label, "pony") + obj_1 = Pony.objects.create(weight=1) + msg = ( + f"obj_1={obj_1}, " + f"obj_1.id={obj_1.id}, " + f"obj_1.pink={obj_1.pink}, " + f"obj_1.pk={obj_1.pk}, " + f"Pony._meta.pk={repr(Pony._meta.pk)}, " + f"Pony._meta.get_field('id')={repr(Pony._meta.get_field('id'))}" + ) + self.assertEqual(obj_1.pink, 3, msg) + self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink), msg) + + # 2. Alter field (id -> IntegerField()). + project_state, new_state = new_state, new_state.clone() + operation_2.state_forwards(app_label, new_state) + with connection.schema_editor() as editor: + operation_2.database_forwards(app_label, editor, project_state, new_state) + Pony = new_state.apps.get_model(app_label, "pony") + obj_1 = Pony.objects.get(id=obj_1.id) + self.assertEqual(obj_1.pink, 3) + self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink)) + obj_2 = Pony.objects.create(id=2, weight=2) + self.assertEqual(obj_2.id, 2) + self.assertEqual(obj_2.pink, 3) + self.assertEqual(obj_2.pk, (obj_2.id, obj_2.pink)) + + # 3. Remove field (pk). + project_state, new_state = new_state, new_state.clone() + operation_3.state_forwards(app_label, new_state) + with connection.schema_editor() as editor: + operation_3.database_forwards(app_label, editor, project_state, new_state) + Pony = new_state.apps.get_model(app_label, "pony") + obj_1 = Pony.objects.get(id=obj_1.id) + self.assertEqual(obj_1.pk, obj_1.id) + obj_2 = Pony.objects.get(id=obj_2.id) + self.assertEqual(obj_2.id, 2) + self.assertEqual(obj_2.pk, obj_2.id) + class SwappableOperationTests(OperationTestBase): """ diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index dbbdf77734..d6ecaa1c5d 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -1206,6 +1206,28 @@ class StateTests(SimpleTestCase): choices_field = Author._meta.get_field("choice") self.assertEqual(list(choices_field.choices), choices) + def test_composite_pk_state(self): + new_apps = Apps(["migrations"]) + + class Foo(models.Model): + pk = models.CompositePrimaryKey("account_id", "id") + account_id = models.SmallIntegerField() + id = models.SmallIntegerField() + + class Meta: + app_label = "migrations" + apps = new_apps + + project_state = ProjectState.from_apps(new_apps) + model_state = project_state.models["migrations", "foo"] + self.assertEqual(len(model_state.options), 2) + self.assertEqual(model_state.options["constraints"], []) + self.assertEqual(model_state.options["indexes"], []) + self.assertEqual(len(model_state.fields), 3) + self.assertIn("pk", model_state.fields) + self.assertIn("account_id", model_state.fields) + self.assertIn("id", model_state.fields) + class StateRelationsTests(SimpleTestCase): def get_base_project_state(self): diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 51783b7346..953a3cdb6c 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -1138,3 +1138,22 @@ class WriterTests(SimpleTestCase): ValueError, "'TestModel1' must inherit from 'BaseSerializer'." ): MigrationWriter.register_serializer(complex, TestModel1) + + def test_composite_pk_import(self): + migration = type( + "Migration", + (migrations.Migration,), + { + "operations": [ + migrations.AddField( + "foo", + "bar", + models.CompositePrimaryKey("foo_id", "bar_id"), + ), + ], + }, + ) + writer = MigrationWriter(migration) + output = writer.as_string() + self.assertEqual(output.count("import"), 1) + self.assertIn("from django.db import migrations, models", output)