1
0
mirror of https://github.com/django/django.git synced 2024-12-22 09:05:43 +00:00

Fixed #373 -- Added CompositePrimaryKey.

Thanks Lily Foote and Simon Charette for reviews and mentoring
this Google Summer of Code 2024 project.

Co-authored-by: Simon Charette <charette.s@gmail.com>
Co-authored-by: Lily Foote <code@lilyf.org>
This commit is contained in:
Bendeguz Csirmaz 2024-04-07 10:32:16 +08:00 committed by Sarah Boyce
parent 86661f2449
commit 978aae4334
43 changed files with 3078 additions and 29 deletions

View File

@ -113,6 +113,11 @@ class AdminSite:
"The model %s is abstract, so it cannot be registered with admin." "The model %s is abstract, so it cannot be registered with admin."
% model.__name__ % model.__name__
) )
if model._meta.is_composite_pk:
raise ImproperlyConfigured(
"The model %s has a composite primary key, so it cannot be "
"registered with admin." % model.__name__
)
if self.is_registered(model): if self.is_registered(model):
registered_admin = str(self.get_model_admin(model)) registered_admin = str(self.get_model_admin(model))

View File

@ -7,6 +7,7 @@ other serializers.
from django.apps import apps from django.apps import apps
from django.core.serializers import base from django.core.serializers import base
from django.db import DEFAULT_DB_ALIAS, models from django.db import DEFAULT_DB_ALIAS, models
from django.db.models import CompositePrimaryKey
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
@ -39,6 +40,8 @@ class Serializer(base.Serializer):
return data return data
def _value_from_field(self, obj, field): def _value_from_field(self, obj, field):
if isinstance(field, CompositePrimaryKey):
return [self._value_from_field(obj, f) for f in field]
value = field.value_from_object(obj) value = field.value_from_object(obj)
# Protected types (i.e., primitives like None, numbers, dates, # Protected types (i.e., primitives like None, numbers, dates,
# and Decimals) are passed through as is. All other values are # and Decimals) are passed through as is. All other values are

View File

@ -14,6 +14,7 @@ from django.db.backends.ddl_references import (
) )
from django.db.backends.utils import names_digest, split_identifier, truncate_name from django.db.backends.utils import names_digest, split_identifier, truncate_name
from django.db.models import NOT_PROVIDED, Deferrable, Index from django.db.models import NOT_PROVIDED, Deferrable, Index
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.sql import Query from django.db.models.sql import Query
from django.db.transaction import TransactionManagementError, atomic from django.db.transaction import TransactionManagementError, atomic
from django.utils import timezone from django.utils import timezone
@ -106,6 +107,7 @@ class BaseDatabaseSchemaEditor:
sql_check_constraint = "CHECK (%(check)s)" sql_check_constraint = "CHECK (%(check)s)"
sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_constraint = "CONSTRAINT %(name)s %(constraint)s" sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
sql_pk_constraint = "PRIMARY KEY (%(columns)s)"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = sql_delete_constraint sql_delete_check = sql_delete_constraint
@ -282,6 +284,11 @@ class BaseDatabaseSchemaEditor:
constraint.constraint_sql(model, self) constraint.constraint_sql(model, self)
for constraint in model._meta.constraints for constraint in model._meta.constraints
) )
pk = model._meta.pk
if isinstance(pk, CompositePrimaryKey):
constraint_sqls.append(self._pk_constraint_sql(pk.columns))
sql = self.sql_create_table % { sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
"definition": ", ".join( "definition": ", ".join(
@ -1999,6 +2006,11 @@ class BaseDatabaseSchemaEditor:
result.append(name) result.append(name)
return result return result
def _pk_constraint_sql(self, columns):
return self.sql_pk_constraint % {
"columns": ", ".join(self.quote_name(column) for column in columns)
}
def _delete_primary_key(self, model, strict=False): def _delete_primary_key(self, model, strict=False):
constraint_names = self._constraint_names(model, primary_key=True) constraint_names = self._constraint_names(model, primary_key=True)
if strict and len(constraint_names) != 1: if strict and len(constraint_names) != 1:

View File

@ -211,6 +211,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
return create_index return create_index
def _is_identity_column(self, table_name, column_name): def _is_identity_column(self, table_name, column_name):
if not column_name:
return False
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
cursor.execute( cursor.execute(
""" """

View File

@ -6,7 +6,7 @@ from django.db import NotSupportedError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import Statement from django.db.backends.ddl_references import Statement
from django.db.backends.utils import strip_quotes from django.db.backends.utils import strip_quotes
from django.db.models import NOT_PROVIDED, UniqueConstraint from django.db.models import NOT_PROVIDED, CompositePrimaryKey, UniqueConstraint
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
@ -104,6 +104,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
f.name: f.clone() if is_self_referential(f) else f f.name: f.clone() if is_self_referential(f) else f
for f in model._meta.local_concrete_fields for f in model._meta.local_concrete_fields
} }
# Since CompositePrimaryKey is not a concrete field (column is None),
# it's not copied by default.
pk = model._meta.pk
if isinstance(pk, CompositePrimaryKey):
body[pk.name] = pk.clone()
# Since mapping might mix column names and default values, # Since mapping might mix column names and default values,
# its values must be already quoted. # its values must be already quoted.
mapping = { mapping = {
@ -296,6 +303,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Special-case implicit M2M tables. # Special-case implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created: if field.many_to_many and field.remote_field.through._meta.auto_created:
self.create_model(field.remote_field.through) self.create_model(field.remote_field.through)
elif isinstance(field, CompositePrimaryKey):
# If a CompositePrimaryKey field was added, the existing primary key field
# had to be altered too, resulting in an AddField, AlterField migration.
# The table cannot be re-created on AddField, it would result in a
# duplicate primary key error.
return
elif ( elif (
# Primary keys and unique fields are not supported in ALTER TABLE # Primary keys and unique fields are not supported in ALTER TABLE
# ADD COLUMN. # ADD COLUMN.

View File

@ -38,6 +38,7 @@ from django.db.models.expressions import (
) )
from django.db.models.fields import * # NOQA from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.generated import GeneratedField from django.db.models.fields.generated import GeneratedField
from django.db.models.fields.json import JSONField from django.db.models.fields.json import JSONField
@ -82,6 +83,7 @@ __all__ += [
"ProtectedError", "ProtectedError",
"RestrictedError", "RestrictedError",
"Case", "Case",
"CompositePrimaryKey",
"Exists", "Exists",
"Expression", "Expression",
"ExpressionList", "ExpressionList",

View File

@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError, FullResultSet from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, Value, When from django.db import NotSupportedError
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
from django.db.models.fields import IntegerField from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import ( from django.db.models.functions.mixins import (
@ -174,6 +175,22 @@ class Count(Aggregate):
raise ValueError("Star cannot be used with filter. Please specify a field.") raise ValueError("Star cannot be used with filter. Please specify a field.")
super().__init__(expression, filter=filter, **extra) super().__init__(expression, filter=filter, **extra)
def resolve_expression(self, *args, **kwargs):
result = super().resolve_expression(*args, **kwargs)
expr = result.source_expressions[0]
# In case of composite primary keys, count the first column.
if isinstance(expr, ColPairs):
if self.distinct:
raise NotSupportedError(
"COUNT(DISTINCT) doesn't support composite primary keys"
)
cols = expr.get_cols()
return Count(cols[0], filter=result.filter)
return result
class Max(Aggregate): class Max(Aggregate):
function = "MAX" function = "MAX"

View File

@ -1,6 +1,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from collections import defaultdict
from functools import partialmethod from functools import partialmethod
from itertools import chain from itertools import chain
@ -30,6 +31,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max,
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, Collector from django.db.models.deletion import CASCADE, Collector
from django.db.models.expressions import DatabaseDefault from django.db.models.expressions import DatabaseDefault
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.fields.related import ( from django.db.models.fields.related import (
ForeignObjectRel, ForeignObjectRel,
OneToOneField, OneToOneField,
@ -508,7 +510,7 @@ class Model(AltersData, metaclass=ModelBase):
for field in fields_iter: for field in fields_iter:
is_related_object = False is_related_object = False
# Virtual field # Virtual field
if field.attname not in kwargs and field.column is None or field.generated: if field.column is None or field.generated:
continue continue
if kwargs: if kwargs:
if isinstance(field.remote_field, ForeignObjectRel): if isinstance(field.remote_field, ForeignObjectRel):
@ -663,7 +665,11 @@ class Model(AltersData, metaclass=ModelBase):
pk = property(_get_pk_val, _set_pk_val) pk = property(_get_pk_val, _set_pk_val)
def _is_pk_set(self, meta=None): def _is_pk_set(self, meta=None):
return self._get_pk_val(meta) is not None pk_val = self._get_pk_val(meta)
return not (
pk_val is None
or (isinstance(pk_val, tuple) and any(f is None for f in pk_val))
)
def get_deferred_fields(self): def get_deferred_fields(self):
""" """
@ -1454,6 +1460,11 @@ class Model(AltersData, metaclass=ModelBase):
name = f.name name = f.name
if name in exclude: if name in exclude:
continue continue
if isinstance(f, CompositePrimaryKey):
names = tuple(field.name for field in f.fields)
if exclude.isdisjoint(names):
unique_checks.append((model_class, names))
continue
if f.unique: if f.unique:
unique_checks.append((model_class, (name,))) unique_checks.append((model_class, (name,)))
if f.unique_for_date and f.unique_for_date not in exclude: if f.unique_for_date and f.unique_for_date not in exclude:
@ -1728,6 +1739,7 @@ class Model(AltersData, metaclass=ModelBase):
*cls._check_constraints(databases), *cls._check_constraints(databases),
*cls._check_default_pk(), *cls._check_default_pk(),
*cls._check_db_table_comment(databases), *cls._check_db_table_comment(databases),
*cls._check_composite_pk(),
] ]
return errors return errors
@ -1764,6 +1776,63 @@ class Model(AltersData, metaclass=ModelBase):
] ]
return [] return []
@classmethod
def _check_composite_pk(cls):
errors = []
meta = cls._meta
pk = meta.pk
if not isinstance(pk, CompositePrimaryKey):
return errors
seen_columns = defaultdict(list)
for field_name in pk.field_names:
hint = None
try:
field = meta.get_field(field_name)
except FieldDoesNotExist:
field = None
if not field:
hint = f"{field_name!r} is not a valid field."
elif not field.column:
hint = f"{field_name!r} field has no column."
elif field.null:
hint = f"{field_name!r} field may not set 'null=True'."
elif field.generated:
hint = f"{field_name!r} field is a generated field."
else:
seen_columns[field.column].append(field_name)
if hint:
errors.append(
checks.Error(
f"{field_name!r} cannot be included in the composite primary "
"key.",
hint=hint,
obj=cls,
id="models.E042",
)
)
for column, field_names in seen_columns.items():
if len(field_names) > 1:
field_name, *rest = field_names
duplicates = ", ".join(repr(field) for field in rest)
errors.append(
checks.Error(
f"{duplicates} cannot be included in the composite primary "
"key.",
hint=f"{duplicates} and {field_name!r} are the same fields.",
obj=cls,
id="models.E042",
)
)
return errors
@classmethod @classmethod
def _check_db_table_comment(cls, databases): def _check_db_table_comment(cls, databases):
if not cls._meta.db_table_comment: if not cls._meta.db_table_comment:

View File

@ -656,6 +656,8 @@ class Field(RegisterLookupMixin):
path = path.replace("django.db.models.fields.json", "django.db.models") path = path.replace("django.db.models.fields.json", "django.db.models")
elif path.startswith("django.db.models.fields.proxy"): elif path.startswith("django.db.models.fields.proxy"):
path = path.replace("django.db.models.fields.proxy", "django.db.models") path = path.replace("django.db.models.fields.proxy", "django.db.models")
elif path.startswith("django.db.models.fields.composite"):
path = path.replace("django.db.models.fields.composite", "django.db.models")
elif path.startswith("django.db.models.fields"): elif path.startswith("django.db.models.fields"):
path = path.replace("django.db.models.fields", "django.db.models") path = path.replace("django.db.models.fields", "django.db.models")
# Return basic info - other fields should override this. # Return basic info - other fields should override this.

View File

@ -0,0 +1,150 @@
from django.core import checks
from django.db.models import NOT_PROVIDED, Field
from django.db.models.expressions import ColPairs
from django.db.models.fields.tuple_lookups import (
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleIn,
TupleIsNull,
TupleLessThan,
TupleLessThanOrEqual,
)
from django.utils.functional import cached_property
class CompositeAttribute:
def __init__(self, field):
self.field = field
@property
def attnames(self):
return [field.attname for field in self.field.fields]
def __get__(self, instance, cls=None):
return tuple(getattr(instance, attname) for attname in self.attnames)
def __set__(self, instance, values):
attnames = self.attnames
length = len(attnames)
if values is None:
values = (None,) * length
if not isinstance(values, (list, tuple)):
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
if length != len(values):
raise ValueError(f"{self.field.name!r} must have {length} elements.")
for attname, value in zip(attnames, values):
setattr(instance, attname, value)
class CompositePrimaryKey(Field):
descriptor_class = CompositeAttribute
def __init__(self, *args, **kwargs):
if (
not args
or not all(isinstance(field, str) for field in args)
or len(set(args)) != len(args)
):
raise ValueError("CompositePrimaryKey args must be unique strings.")
if len(args) == 1:
raise ValueError("CompositePrimaryKey must include at least two fields.")
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a default.")
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a database default.")
if kwargs.setdefault("editable", False):
raise ValueError("CompositePrimaryKey cannot be editable.")
if not kwargs.setdefault("primary_key", True):
raise ValueError("CompositePrimaryKey must be a primary key.")
if not kwargs.setdefault("blank", True):
raise ValueError("CompositePrimaryKey must be blank.")
self.field_names = args
super().__init__(**kwargs)
def deconstruct(self):
# args is always [] so it can be ignored.
name, path, _, kwargs = super().deconstruct()
return name, path, self.field_names, kwargs
@cached_property
def fields(self):
meta = self.model._meta
return tuple(meta.get_field(field_name) for field_name in self.field_names)
@cached_property
def columns(self):
return tuple(field.column for field in self.fields)
def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only=private_only)
cls._meta.pk = self
setattr(cls, self.attname, self.descriptor_class(self))
def get_attname_column(self):
return self.get_attname(), None
def __iter__(self):
return iter(self.fields)
def __len__(self):
return len(self.field_names)
@cached_property
def cached_col(self):
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
def get_col(self, alias, output_field=None):
if alias == self.model._meta.db_table and (
output_field is None or output_field == self
):
return self.cached_col
return ColPairs(alias, self.fields, self.fields, output_field)
def get_pk_value_on_save(self, instance):
values = []
for field in self.fields:
value = field.value_from_object(instance)
if value is None:
value = field.get_pk_value_on_save(instance)
values.append(value)
return tuple(values)
def _check_field_name(self):
if self.name == "pk":
return []
return [
checks.Error(
"'CompositePrimaryKey' must be named 'pk'.",
obj=self,
id="fields.E013",
)
]
CompositePrimaryKey.register_lookup(TupleExact)
CompositePrimaryKey.register_lookup(TupleGreaterThan)
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
CompositePrimaryKey.register_lookup(TupleLessThan)
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
CompositePrimaryKey.register_lookup(TupleIn)
CompositePrimaryKey.register_lookup(TupleIsNull)
def unnest(fields):
result = []
for field in fields:
if isinstance(field, CompositePrimaryKey):
result.extend(field.fields)
else:
result.append(field)
return result

View File

@ -624,11 +624,21 @@ class ForeignObject(RelatedField):
if not has_unique_constraint: if not has_unique_constraint:
foreign_fields = {f.name for f in self.foreign_related_fields} foreign_fields = {f.name for f in self.foreign_related_fields}
remote_opts = self.remote_field.model._meta remote_opts = self.remote_field.model._meta
has_unique_constraint = any( has_unique_constraint = (
frozenset(ut) <= foreign_fields for ut in remote_opts.unique_together any(
) or any( frozenset(ut) <= foreign_fields
frozenset(uc.fields) <= foreign_fields for ut in remote_opts.unique_together
for uc in remote_opts.total_unique_constraints )
or any(
frozenset(uc.fields) <= foreign_fields
for uc in remote_opts.total_unique_constraints
)
# If the model defines a composite primary key and the foreign key
# refers to it, the target is unique.
or (
frozenset(field.name for field in remote_opts.pk_fields)
== foreign_fields
)
) )
if not has_unique_constraint: if not has_unique_constraint:

View File

@ -1,5 +1,6 @@
from django.db import NotSupportedError from django.db import NotSupportedError
from django.db.models.expressions import ColPairs from django.db.models.expressions import ColPairs
from django.db.models.fields import composite
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
from django.db.models.lookups import ( from django.db.models.lookups import (
Exact, Exact,
@ -19,7 +20,7 @@ def get_normalized_value(value, lhs):
if not value._is_pk_set(): if not value._is_pk_set():
raise ValueError("Model instances passed to related filters must be saved.") raise ValueError("Model instances passed to related filters must be saved.")
value_list = [] value_list = []
sources = lhs.output_field.path_infos[-1].target_fields sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
for source in sources: for source in sources:
while not isinstance(value, source.model) and source.remote_field: while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field( source = source.remote_field.model._meta.get_field(
@ -30,7 +31,8 @@ def get_normalized_value(value, lhs):
except AttributeError: except AttributeError:
# A case like Restaurant.objects.filter(place=restaurant_instance), # A case like Restaurant.objects.filter(place=restaurant_instance),
# where place is a OneToOneField and the primary key of Restaurant. # where place is a OneToOneField and the primary key of Restaurant.
return (value.pk,) pk = value.pk
return pk if isinstance(pk, tuple) else (pk,)
return tuple(value_list) return tuple(value_list)
if not isinstance(value, tuple): if not isinstance(value, tuple):
return (value,) return (value,)

View File

@ -250,6 +250,8 @@ class TupleIn(TupleLookupMixin, In):
def check_rhs_select_length_equals_lhs_length(self): def check_rhs_select_length_equals_lhs_length(self):
len_rhs = len(self.rhs.select) len_rhs = len(self.rhs.select)
if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
len_rhs = len(self.rhs.select[0])
len_lhs = len(self.lhs) len_lhs = len(self.lhs)
if len_rhs != len_lhs: if len_rhs != len_lhs:
lhs_str = self.get_lhs_str() lhs_str = self.get_lhs_str()
@ -304,7 +306,13 @@ class TupleIn(TupleLookupMixin, In):
return root.as_sql(compiler, connection) return root.as_sql(compiler, connection)
def as_subquery(self, compiler, connection): def as_subquery(self, compiler, connection):
return compiler.compile(In(self.lhs, self.rhs)) lhs = self.lhs
rhs = self.rhs
if isinstance(lhs, ColPairs):
rhs = rhs.clone()
rhs.set_values([source.name for source in lhs.sources])
lhs = Tuple(lhs)
return compiler.compile(In(lhs, rhs))
tuple_lookups = { tuple_lookups = {

View File

@ -7,7 +7,14 @@ from django.conf import settings
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.core.signals import setting_changed from django.core.signals import setting_changed
from django.db import connections from django.db import connections
from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint from django.db.models import (
AutoField,
CompositePrimaryKey,
Manager,
OrderWrt,
UniqueConstraint,
)
from django.db.models.fields import composite
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils.datastructures import ImmutableList, OrderedSet from django.utils.datastructures import ImmutableList, OrderedSet
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -973,6 +980,14 @@ class Options:
) )
] ]
@cached_property
def pk_fields(self):
return composite.unnest([self.pk])
@property
def is_composite_pk(self):
return isinstance(self.pk, CompositePrimaryKey)
@cached_property @cached_property
def _property_names(self): def _property_names(self):
"""Return a set of the names of the properties defined on the model.""" """Return a set of the names of the properties defined on the model."""

View File

@ -171,11 +171,14 @@ class RawModelIterable(BaseIterable):
"Raw query must include the primary key" "Raw query must include the primary key"
) )
fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
converters = compiler.get_converters( cols = [f.get_col(f.model._meta.db_table) if f else None for f in fields]
[f.get_col(f.model._meta.db_table) if f else None for f in fields] converters = compiler.get_converters(cols)
)
if converters: if converters:
query_iterator = compiler.apply_converters(query_iterator, converters) query_iterator = compiler.apply_converters(query_iterator, converters)
if compiler.has_composite_fields(cols):
query_iterator = compiler.composite_fields_to_tuples(
query_iterator, cols
)
for values in query_iterator: for values in query_iterator:
# Associate fields to values # Associate fields to values
model_init_values = [values[pos] for pos in model_init_pos] model_init_values = [values[pos] for pos in model_init_pos]

View File

@ -7,7 +7,9 @@ from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value
from django.db.models.fields import composite
from django.db.models.fields.composite import CompositePrimaryKey
from django.db.models.functions import Cast, Random from django.db.models.functions import Cast, Random
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
from django.db.models.query_utils import select_related_descend from django.db.models.query_utils import select_related_descend
@ -283,6 +285,9 @@ class SQLCompiler:
# Reference to a column. # Reference to a column.
elif isinstance(expression, int): elif isinstance(expression, int):
expression = cols[expression] expression = cols[expression]
# ColPairs cannot be aliased.
if isinstance(expression, ColPairs):
alias = None
selected.append((alias, expression)) selected.append((alias, expression))
for select_idx, (alias, expression) in enumerate(selected): for select_idx, (alias, expression) in enumerate(selected):
@ -997,6 +1002,7 @@ class SQLCompiler:
# alias for a given field. This also includes None -> start_alias to # alias for a given field. This also includes None -> start_alias to
# be used by local fields. # be used by local fields.
seen_models = {None: start_alias} seen_models = {None: start_alias}
select_mask_fields = set(composite.unnest(select_mask))
for field in opts.concrete_fields: for field in opts.concrete_fields:
model = field.model._meta.concrete_model model = field.model._meta.concrete_model
@ -1017,7 +1023,7 @@ class SQLCompiler:
# parent model data is already present in the SELECT clause, # parent model data is already present in the SELECT clause,
# and we want to avoid reloading the same data again. # and we want to avoid reloading the same data again.
continue continue
if select_mask and field not in select_mask: if select_mask and field not in select_mask_fields:
continue continue
alias = self.query.join_parent_model(opts, model, start_alias, seen_models) alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
column = field.get_col(alias) column = field.get_col(alias)
@ -1110,9 +1116,10 @@ class SQLCompiler:
) )
return results return results
targets, alias, _ = self.query.trim_joins(targets, joins, path) targets, alias, _ = self.query.trim_joins(targets, joins, path)
target_fields = composite.unnest(targets)
return [ return [
(OrderBy(transform_function(t, alias), descending=descending), False) (OrderBy(transform_function(t, alias), descending=descending), False)
for t in targets for t in target_fields
] ]
def _setup_joins(self, pieces, opts, alias): def _setup_joins(self, pieces, opts, alias):
@ -1504,13 +1511,25 @@ class SQLCompiler:
return result return result
def get_converters(self, expressions): def get_converters(self, expressions):
i = 0
converters = {} converters = {}
for i, expression in enumerate(expressions):
if expression: for expression in expressions:
if isinstance(expression, ColPairs):
cols = expression.get_source_expressions()
cols_converters = self.get_converters(cols)
for j, (convs, col) in cols_converters.items():
converters[i + j] = (convs, col)
i += len(expression)
elif expression:
backend_converters = self.connection.ops.get_db_converters(expression) backend_converters = self.connection.ops.get_db_converters(expression)
field_converters = expression.get_db_converters(self.connection) field_converters = expression.get_db_converters(self.connection)
if backend_converters or field_converters: if backend_converters or field_converters:
converters[i] = (backend_converters + field_converters, expression) converters[i] = (backend_converters + field_converters, expression)
i += 1
else:
i += 1
return converters return converters
def apply_converters(self, rows, converters): def apply_converters(self, rows, converters):
@ -1524,6 +1543,24 @@ class SQLCompiler:
row[pos] = value row[pos] = value
yield row yield row
def has_composite_fields(self, expressions):
# Check for composite fields before calling the relatively costly
# composite_fields_to_tuples.
return any(isinstance(expression, ColPairs) for expression in expressions)
def composite_fields_to_tuples(self, rows, expressions):
col_pair_slices = [
slice(i, i + len(expression))
for i, expression in enumerate(expressions)
if isinstance(expression, ColPairs)
]
for row in map(list, rows):
for pos in col_pair_slices:
row[pos] = (tuple(row[pos]),)
yield row
def results_iter( def results_iter(
self, self,
results=None, results=None,
@ -1541,8 +1578,10 @@ class SQLCompiler:
rows = chain.from_iterable(results) rows = chain.from_iterable(results)
if converters: if converters:
rows = self.apply_converters(rows, converters) rows = self.apply_converters(rows, converters)
if tuple_expected: if self.has_composite_fields(fields):
rows = map(tuple, rows) rows = self.composite_fields_to_tuples(rows, fields)
if tuple_expected:
rows = map(tuple, rows)
return rows return rows
def has_results(self): def has_results(self):
@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler):
) )
] ]
cols = [field.get_col(opts.db_table) for field in self.returning_fields] cols = [field.get_col(opts.db_table) for field in self.returning_fields]
elif isinstance(opts.pk, CompositePrimaryKey):
returning_field = returning_fields[0]
cols = [returning_field.get_col(opts.db_table)]
rows = [
(
self.connection.ops.last_insert_id(
cursor,
opts.db_table,
returning_field.column,
),
)
]
else: else:
cols = [opts.pk.get_col(opts.db_table)] cols = [opts.pk.get_col(opts.db_table)]
rows = [ rows = [
@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler):
] ]
converters = self.get_converters(cols) converters = self.get_converters(cols)
if converters: if converters:
rows = list(self.apply_converters(rows, converters)) rows = self.apply_converters(rows, converters)
return rows if self.has_composite_fields(cols):
rows = self.composite_fields_to_tuples(rows, cols)
return list(rows)
class SQLDeleteCompiler(SQLCompiler): class SQLDeleteCompiler(SQLCompiler):
@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler):
query.add_fields(fields) query.add_fields(fields)
super().pre_sql_setup() super().pre_sql_setup()
is_composite_pk = meta.is_composite_pk
must_pre_select = ( must_pre_select = (
count > 1 and not self.connection.features.update_can_self_select count > 1 and not self.connection.features.update_can_self_select
) )
@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler):
idents = [] idents = []
related_ids = collections.defaultdict(list) related_ids = collections.defaultdict(list)
for rows in query.get_compiler(self.using).execute_sql(MULTI): for rows in query.get_compiler(self.using).execute_sql(MULTI):
idents.extend(r[0] for r in rows) pks = [row if is_composite_pk else row[0] for row in rows]
idents.extend(pks)
for parent, index in related_ids_index: for parent, index in related_ids_index:
related_ids[parent].extend(r[index] for r in rows) related_ids[parent].extend(r[index] for r in rows)
self.query.add_filter("pk__in", idents) self.query.add_filter("pk__in", idents)

View File

@ -627,8 +627,12 @@ class Query(BaseExpression):
if result is None: if result is None:
result = empty_set_result result = empty_set_result
else: else:
converters = compiler.get_converters(outer_query.annotation_select.values()) cols = outer_query.annotation_select.values()
result = next(compiler.apply_converters((result,), converters)) converters = compiler.get_converters(cols)
rows = compiler.apply_converters((result,), converters)
if compiler.has_composite_fields(cols):
rows = compiler.composite_fields_to_tuples(rows, cols)
result = next(rows)
return dict(zip(outer_query.annotation_select, result)) return dict(zip(outer_query.annotation_select, result))

View File

@ -181,6 +181,7 @@ Model fields
* **fields.E011**: ``<database>`` does not support default database values with * **fields.E011**: ``<database>`` does not support default database values with
expressions (``db_default``). expressions (``db_default``).
* **fields.E012**: ``<expression>`` cannot be used in ``db_default``. * **fields.E012**: ``<expression>`` cannot be used in ``db_default``.
* **fields.E013**: ``CompositePrimaryKey`` must be named ``pk``.
* **fields.E100**: ``AutoField``\s must set primary_key=True. * **fields.E100**: ``AutoField``\s must set primary_key=True.
* **fields.E110**: ``BooleanField``\s do not accept null values. *This check * **fields.E110**: ``BooleanField``\s do not accept null values. *This check
appeared before support for null values was added in Django 2.1.* appeared before support for null values was added in Django 2.1.*
@ -417,6 +418,8 @@ Models
* **models.W040**: ``<database>`` does not support indexes with non-key * **models.W040**: ``<database>`` does not support indexes with non-key
columns. columns.
* **models.E041**: ``constraints`` refers to the joined field ``<field name>``. * **models.E041**: ``constraints`` refers to the joined field ``<field name>``.
* **models.E042**: ``<field name>`` cannot be included in the composite
primary key.
* **models.W042**: Auto-created primary key used when not defining a primary * **models.W042**: Auto-created primary key used when not defining a primary
key type, by default ``django.db.models.AutoField``. key type, by default ``django.db.models.AutoField``.
* **models.W043**: ``<database>`` does not support indexes on expressions. * **models.W043**: ``<database>`` does not support indexes on expressions.

View File

@ -707,6 +707,23 @@ or :class:`~django.forms.NullBooleanSelect` if :attr:`null=True <Field.null>`.
The default value of ``BooleanField`` is ``None`` when :attr:`Field.default` The default value of ``BooleanField`` is ``None`` when :attr:`Field.default`
isn't defined. isn't defined.
``CompositePrimaryKey``
-----------------------
.. versionadded:: 5.2
.. class:: CompositePrimaryKey(*field_names, **options)
A virtual field used for defining a composite primary key.
This field must be defined as the model's ``pk`` field. If present, Django will
create the underlying model table with a composite primary key.
The ``*field_names`` argument is a list of positional field names that compose
the primary key.
See :doc:`/topics/composite-primary-key` for more details.
``CharField`` ``CharField``
------------- -------------
@ -1615,6 +1632,8 @@ not an instance of ``UUID``.
hyphens, because PostgreSQL and MariaDB 10.7+ store them in a hyphenated hyphens, because PostgreSQL and MariaDB 10.7+ store them in a hyphenated
uuid datatype type. uuid datatype type.
.. _relationship-fields:
Relationship fields Relationship fields
=================== ===================

View File

@ -31,6 +31,25 @@ and only officially support the latest release of each series.
What's new in Django 5.2 What's new in Django 5.2
======================== ========================
Composite Primary Keys
----------------------
The new :class:`django.db.models.CompositePrimaryKey` allows tables to be
created with a primary key consisting of multiple fields.
To use a composite primary key, when creating a model set the ``pk`` field to
be a ``CompositePrimaryKey``::
from django.db import models
class Release(models.Model):
pk = models.CompositePrimaryKey("version", "name")
version = models.IntegerField()
name = models.CharField(max_length=20)
See :doc:`/topics/composite-primary-key` for more details.
Minor features Minor features
-------------- --------------

View File

@ -0,0 +1,183 @@
======================
Composite primary keys
======================
.. versionadded:: 5.2
In Django, each model has a primary key. By default, this primary key consists
of a single field.
In most cases, a single primary key should suffice. In database design,
however, defining a primary key consisting of multiple fields is sometimes
necessary.
To use a composite primary key, when creating a model set the ``pk`` field to
be a :class:`.CompositePrimaryKey`::
class Product(models.Model):
name = models.CharField(max_length=100)
class Order(models.Model):
reference = models.CharField(max_length=20, primary_key=True)
class OrderLineItem(models.Model):
pk = models.CompositePrimaryKey("product_id", "order_id")
product = models.ForeignKey(Product, on_delete=models.CASCADE)
order = models.ForeignKey(Order, on_delete=models.CASCADE)
quantity = models.IntegerField()
This will instruct Django to create a composite primary key
(``PRIMARY KEY (product_id, order_id)``) when creating the table.
A composite primary key is represented by a ``tuple``:
.. code-block:: pycon
>>> product = Product.objects.create(name="apple")
>>> order = Order.objects.create(reference="A755H")
>>> item = OrderLineItem.objects.create(product=product, order=order, quantity=1)
>>> item.pk
(1, "A755H")
You can assign a ``tuple`` to a composite primary key. This sets the associated
field values.
.. code-block:: pycon
>>> item = OrderLineItem(pk=(2, "B142C"))
>>> item.pk
(2, "B142C")
>>> item.product_id
2
>>> item.order_id
"B142C"
A composite primary key can also be filtered by a ``tuple``:
.. code-block:: pycon
>>> OrderLineItem.objects.filter(pk=(1, "A755H")).count()
1
We're still working on composite primary key support for
:ref:`relational fields <cpk-and-relations>`, including
:class:`.GenericForeignKey` fields, and the Django admin. Models with composite
primary keys cannot be registered in the Django admin at this time. You can
expect to see this in future releases.
Migrating to a composite primary key
====================================
Django doesn't support migrating to, or from, a composite primary key after the
table is created. It also doesn't support adding or removing fields from the
composite primary key.
If you would like to migrate an existing table from a single primary key to a
composite primary key, follow your database backend's instructions to do so.
Once the composite primary key is in place, add the ``CompositePrimaryKey``
field to your model. This allows Django to recognize and handle the composite
primary key appropriately.
While migration operations (e.g. ``AddField``, ``AlterField``) on primary key
fields are not supported, ``makemigrations`` will still detect changes.
In order to avoid errors, it's recommended to apply such migrations with
``--fake``.
Alternatively, :class:`.SeparateDatabaseAndState` may be used to execute the
backend-specific migrations and Django-generated migrations in a single
operation.
.. _cpk-and-relations:
Composite primary keys and relations
====================================
:ref:`Relationship fields <relationship-fields>`, including
:ref:`generic relations <generic-relations>` do not support composite primary
keys.
For example, given the ``OrderLineItem`` model, the following is not
supported::
class Foo(models.Model):
item = models.ForeignKey(OrderLineItem, on_delete=models.CASCADE)
Because ``ForeignKey`` currently cannot reference models with composite primary
keys.
To work around this limitation, ``ForeignObject`` can be used as an
alternative::
class Foo(models.Model):
item_order_id = models.IntegerField()
item_product_id = models.CharField(max_length=20)
item = models.ForeignObject(
OrderLineItem,
on_delete=models.CASCADE,
from_fields=("item_order_id", "item_product_id"),
to_fields=("order_id", "product_id"),
)
``ForeignObject`` is much like ``ForeignKey``, except that it doesn't create
any columns (e.g. ``item_id``), foreign key constraints or indexes in the
database.
.. warning::
``ForeignObject`` is an internal API. This means it is not covered by our
:ref:`deprecation policy <internal-release-deprecation-policy>`.
Composite primary keys and database functions
=============================================
Many database functions only accept a single expression.
.. code-block:: sql
MAX("order_id") -- OK
MAX("product_id", "order_id") -- ERROR
As a consequence, they cannot be used with composite primary key references as
they are composed of multiple column expressions.
.. code-block:: python
Max("order_id") # OK
Max("pk") # ERROR
Composite primary keys in forms
===============================
As a composite primary key is a virtual field, a field which doesn't represent
a single database column, this field is excluded from ModelForms.
For example, take the following form::
class OrderLineItemForm(forms.ModelForm):
class Meta:
model = OrderLineItem
fields = "__all__"
This form does not have a form field ``pk`` for the composite primary key:
.. code-block:: pycon
>>> OrderLineItemForm()
<OrderLineItemForm bound=False, valid=Unknown, fields=(product;order;quantity)>
Setting the primary composite field ``pk`` as a form field raises an unknown
field :exc:`.FieldError`.
.. admonition:: Primary key fields are read only
If you change the value of a primary key on an existing object and then
save it, a new object will be created alongside the old one (see
:attr:`.Field.primary_key`).
This is also true of composite primary keys. Hence, you may want to set
:attr:`.Field.editable` to ``False`` on all primary key fields to exclude
them from ModelForms.

View File

@ -19,6 +19,7 @@ Introductions to all the key parts of Django you'll need to know:
auth/index auth/index
cache cache
conditional-view-processing conditional-view-processing
composite-primary-key
signing signing
email email
i18n/index i18n/index

View File

@ -20,3 +20,9 @@ class Location(models.Model):
class Place(Location): class Place(Location):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
class Guest(models.Model):
pk = models.CompositePrimaryKey("traveler", "place")
traveler = models.ForeignKey(Traveler, on_delete=models.CASCADE)
place = models.ForeignKey(Place, on_delete=models.CASCADE)

View File

@ -5,7 +5,7 @@ from django.contrib.admin.sites import site
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.test import SimpleTestCase from django.test import SimpleTestCase
from .models import Location, Person, Place, Traveler from .models import Guest, Location, Person, Place, Traveler
class NameAdmin(admin.ModelAdmin): class NameAdmin(admin.ModelAdmin):
@ -92,6 +92,14 @@ class TestRegistration(SimpleTestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg): with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.site.register(Location) self.site.register(Location)
def test_composite_pk_model(self):
msg = (
"The model Guest has a composite primary key, so it cannot be registered "
"with admin."
)
with self.assertRaisesMessage(ImproperlyConfigured, msg):
self.site.register(Guest)
def test_is_registered_model(self): def test_is_registered_model(self):
"Checks for registered models should return true." "Checks for registered models should return true."
self.site.register(Person) self.site.register(Person)

View File

View File

@ -0,0 +1,75 @@
[
{
"pk": 1,
"model": "composite_pk.tenant",
"fields": {
"id": 1,
"name": "Tenant 1"
}
},
{
"pk": 2,
"model": "composite_pk.tenant",
"fields": {
"id": 2,
"name": "Tenant 2"
}
},
{
"pk": 3,
"model": "composite_pk.tenant",
"fields": {
"id": 3,
"name": "Tenant 3"
}
},
{
"pk": [1, 1],
"model": "composite_pk.user",
"fields": {
"tenant_id": 1,
"id": 1,
"email": "user0001@example.com"
}
},
{
"pk": [1, 2],
"model": "composite_pk.user",
"fields": {
"tenant_id": 1,
"id": 2,
"email": "user0002@example.com"
}
},
{
"pk": [2, 3],
"model": "composite_pk.user",
"fields": {
"email": "user0003@example.com"
}
},
{
"model": "composite_pk.user",
"fields": {
"tenant_id": 2,
"id": 4,
"email": "user0004@example.com"
}
},
{
"pk": [2, "11111111-1111-1111-1111-111111111111"],
"model": "composite_pk.post",
"fields": {
"tenant_id": 2,
"id": "11111111-1111-1111-1111-111111111111"
}
},
{
"pk": [2, "ffffffff-ffff-ffff-ffff-ffffffffffff"],
"model": "composite_pk.post",
"fields": {
"tenant_id": 2,
"id": "ffffffff-ffff-ffff-ffff-ffffffffffff"
}
}
]

View File

@ -0,0 +1,9 @@
from .tenant import Comment, Post, Tenant, Token, User
__all__ = [
"Comment",
"Post",
"Tenant",
"Token",
"User",
]

View File

@ -0,0 +1,50 @@
from django.db import models
class Tenant(models.Model):
name = models.CharField(max_length=10, default="", blank=True)
class Token(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, related_name="tokens")
id = models.SmallIntegerField()
secret = models.CharField(max_length=10, default="", blank=True)
class BaseModel(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.SmallIntegerField(unique=True)
class Meta:
abstract = True
class User(BaseModel):
email = models.EmailField(unique=True)
class Comment(models.Model):
pk = models.CompositePrimaryKey("tenant", "id")
tenant = models.ForeignKey(
Tenant,
on_delete=models.CASCADE,
related_name="comments",
)
id = models.SmallIntegerField(unique=True, db_column="comment_id")
user_id = models.SmallIntegerField()
user = models.ForeignObject(
User,
on_delete=models.CASCADE,
from_fields=("tenant_id", "user_id"),
to_fields=("tenant_id", "id"),
related_name="comments",
)
text = models.TextField(default="", blank=True)
class Post(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.UUIDField()

View File

@ -0,0 +1,139 @@
from django.db import NotSupportedError
from django.db.models import Count, Q
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKAggregateTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_2, text="foo")
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1, text="bar")
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_1, text="foobar")
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3, text="foobarbaz")
cls.comment_5 = Comment.objects.create(id=5, user=cls.user_3, text="barbaz")
cls.comment_6 = Comment.objects.create(id=6, user=cls.user_3, text="baz")
def test_users_annotated_with_comments_id_count(self):
user_1, user_2, user_3 = User.objects.annotate(Count("comments__id")).order_by(
"pk"
)
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__id__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__id__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__id__count, 3)
def test_users_annotated_with_aliased_comments_id_count(self):
user_1, user_2, user_3 = User.objects.annotate(
comments_count=Count("comments__id")
).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments_count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments_count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments_count, 3)
def test_users_annotated_with_comments_count(self):
user_1, user_2, user_3 = User.objects.annotate(Count("comments")).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__count, 3)
def test_users_annotated_with_comments_count_filter(self):
user_1, user_2, user_3 = User.objects.annotate(
comments__count=Count(
"comments", filter=Q(pk__in=[self.user_1.pk, self.user_2.pk])
)
).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__count, 0)
def test_count_distinct_not_supported(self):
with self.assertRaisesMessage(
NotSupportedError, "COUNT(DISTINCT) doesn't support composite primary keys"
):
self.assertIsNone(
User.objects.annotate(comments__count=Count("comments", distinct=True))
)
def test_user_values_annotated_with_comments_id_count(self):
self.assertSequenceEqual(
User.objects.values("pk").annotate(Count("comments__id")).order_by("pk"),
(
{"pk": self.user_1.pk, "comments__id__count": 2},
{"pk": self.user_2.pk, "comments__id__count": 1},
{"pk": self.user_3.pk, "comments__id__count": 3},
),
)
def test_user_values_annotated_with_filtered_comments_id_count(self):
self.assertSequenceEqual(
User.objects.values("pk")
.annotate(
comments_count=Count(
"comments__id",
filter=Q(comments__text__icontains="foo"),
)
)
.order_by("pk"),
(
{"pk": self.user_1.pk, "comments_count": 1},
{"pk": self.user_2.pk, "comments_count": 1},
{"pk": self.user_3.pk, "comments_count": 1},
),
)
def test_filter_and_count_users_by_comments_fields(self):
users = User.objects.filter(comments__id__gt=2).order_by("pk")
self.assertEqual(users.count(), 4)
self.assertSequenceEqual(
users, (self.user_1, self.user_3, self.user_3, self.user_3)
)
users = User.objects.filter(comments__text__icontains="foo").order_by("pk")
self.assertEqual(users.count(), 3)
self.assertSequenceEqual(users, (self.user_1, self.user_2, self.user_3))
users = User.objects.filter(comments__text__icontains="baz").order_by("pk")
self.assertEqual(users.count(), 3)
self.assertSequenceEqual(users, (self.user_3, self.user_3, self.user_3))
def test_order_by_comments_id_count(self):
self.assertSequenceEqual(
User.objects.annotate(comments_count=Count("comments__id")).order_by(
"-comments_count"
),
(self.user_3, self.user_1, self.user_2),
)

View File

@ -0,0 +1,242 @@
from django.core import checks
from django.db import connection, models
from django.db.models import F
from django.test import TestCase
from django.test.utils import isolate_apps
@isolate_apps("composite_pk")
class CompositePKChecksTests(TestCase):
maxDiff = None
def test_composite_pk_must_be_unique_strings(self):
test_cases = (
(),
(0,),
(1,),
("id", False),
("id", "id"),
(("id",),),
)
for i, args in enumerate(test_cases):
with (
self.subTest(args=args),
self.assertRaisesMessage(
ValueError, "CompositePrimaryKey args must be unique strings."
),
):
models.CompositePrimaryKey(*args)
def test_composite_pk_must_include_at_least_2_fields(self):
expected_message = "CompositePrimaryKey must include at least two fields."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("id")
def test_composite_pk_cannot_have_a_default(self):
expected_message = "CompositePrimaryKey cannot have a default."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", default=(1, 1))
def test_composite_pk_cannot_have_a_database_default(self):
expected_message = "CompositePrimaryKey cannot have a database default."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", db_default=models.F("id"))
def test_composite_pk_cannot_be_editable(self):
expected_message = "CompositePrimaryKey cannot be editable."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", editable=True)
def test_composite_pk_must_be_a_primary_key(self):
expected_message = "CompositePrimaryKey must be a primary key."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", primary_key=False)
def test_composite_pk_must_be_blank(self):
expected_message = "CompositePrimaryKey must be blank."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", blank=False)
def test_composite_pk_must_not_have_other_pk_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField(primary_key=True)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"The model cannot have more than one field with "
"'primary_key=True'.",
obj=Foo,
id="models.E026",
),
],
)
def test_composite_pk_cannot_include_nullable_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField(null=True)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'id' cannot be included in the composite primary key.",
hint="'id' field may not set 'null=True'.",
obj=Foo,
id="models.E042",
),
],
)
def test_composite_pk_can_include_fk_name(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.SmallIntegerField()
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(Bar.check(databases=self.databases), [])
def test_composite_pk_cannot_include_same_field(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo", "foo_id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.SmallIntegerField()
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(
Bar.check(databases=self.databases),
[
checks.Error(
"'foo_id' cannot be included in the composite primary key.",
hint="'foo_id' and 'foo' are the same fields.",
obj=Bar,
id="models.E042",
),
],
)
def test_composite_pk_cannot_include_composite_pk_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("id", "pk")
id = models.SmallIntegerField()
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'pk' cannot be included in the composite primary key.",
hint="'pk' field has no column.",
obj=Foo,
id="models.E042",
),
],
)
def test_composite_pk_cannot_include_db_column(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo", "bar")
foo = models.SmallIntegerField(db_column="foo_id")
bar = models.SmallIntegerField(db_column="bar_id")
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo_id", "bar_id")
foo = models.SmallIntegerField(db_column="foo_id")
bar = models.SmallIntegerField(db_column="bar_id")
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(
Bar.check(databases=self.databases),
[
checks.Error(
"'foo_id' cannot be included in the composite primary key.",
hint="'foo_id' is not a valid field.",
obj=Bar,
id="models.E042",
),
checks.Error(
"'bar_id' cannot be included in the composite primary key.",
hint="'bar_id' is not a valid field.",
obj=Bar,
id="models.E042",
),
],
)
def test_foreign_object_can_refer_composite_pk(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.IntegerField()
class Baz(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.IntegerField()
bar_id = models.IntegerField()
bar = models.ForeignObject(
Bar,
on_delete=models.CASCADE,
from_fields=("foo_id", "bar_id"),
to_fields=("foo_id", "id"),
)
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(Bar.check(databases=self.databases), [])
self.assertEqual(Baz.check(databases=self.databases), [])
def test_composite_pk_must_be_named_pk(self):
class Foo(models.Model):
primary_key = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField()
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'CompositePrimaryKey' must be named 'pk'.",
obj=Foo._meta.get_field("primary_key"),
id="fields.E013",
),
],
)
def test_composite_pk_cannot_include_generated_field(self):
is_oracle = connection.vendor == "oracle"
class Foo(models.Model):
pk = models.CompositePrimaryKey("id", "foo")
id = models.IntegerField()
foo = models.GeneratedField(
expression=F("id"),
output_field=models.IntegerField(),
db_persist=not is_oracle,
)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'foo' cannot be included in the composite primary key.",
hint="'foo' field is a generated field.",
obj=Foo,
id="models.E042",
),
],
)

View File

@ -0,0 +1,138 @@
from django.test import TestCase
from .models import Tenant, User
class CompositePKCreateTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant = Tenant.objects.create()
cls.user = User.objects.create(
tenant=cls.tenant,
id=1,
email="user0001@example.com",
)
def test_create_user(self):
test_cases = (
{"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"},
{"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"},
{"pk": (self.tenant.id, 7424), "email": "user7424@example.com"},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user = User(**fields)
obj = User.objects.create(**fields)
self.assertEqual(obj.tenant_id, self.tenant.id)
self.assertEqual(obj.id, user.id)
self.assertEqual(obj.pk, (self.tenant.id, user.id))
self.assertEqual(obj.email, user.email)
self.assertEqual(count + 1, User.objects.count())
def test_save_user(self):
test_cases = (
{"tenant": self.tenant, "id": 9241, "email": "user9241@example.com"},
{"tenant_id": self.tenant.id, "id": 5132, "email": "user5132@example.com"},
{"pk": (self.tenant.id, 3014), "email": "user3014@example.com"},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user = User(**fields)
self.assertIsNotNone(user.id)
self.assertIsNotNone(user.email)
user.save()
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.tenant, self.tenant)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.email, fields["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())
def test_bulk_create_users(self):
objs = [
User(tenant=self.tenant, id=8291, email="user8291@example.com"),
User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"),
User(pk=(self.tenant.id, 8214), email="user8214@example.com"),
]
obj_1, obj_2, obj_3 = User.objects.bulk_create(objs)
self.assertEqual(obj_1.tenant_id, self.tenant.id)
self.assertEqual(obj_1.id, 8291)
self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id))
self.assertEqual(obj_1.email, "user8291@example.com")
self.assertEqual(obj_2.tenant_id, self.tenant.id)
self.assertEqual(obj_2.id, 4021)
self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id))
self.assertEqual(obj_2.email, "user4021@example.com")
self.assertEqual(obj_3.tenant_id, self.tenant.id)
self.assertEqual(obj_3.id, 8214)
self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id))
self.assertEqual(obj_3.email, "user8214@example.com")
def test_get_or_create_user(self):
test_cases = (
{
"pk": (self.tenant.id, 8314),
"defaults": {"email": "user8314@example.com"},
},
{
"tenant": self.tenant,
"id": 3142,
"defaults": {"email": "user3142@example.com"},
},
{
"tenant_id": self.tenant.id,
"id": 4218,
"defaults": {"email": "user4218@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.get_or_create(**fields)
self.assertIs(created, True)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())
def test_update_or_create_user(self):
test_cases = (
{
"pk": (self.tenant.id, 2931),
"defaults": {"email": "user2931@example.com"},
},
{
"tenant": self.tenant,
"id": 6428,
"defaults": {"email": "user6428@example.com"},
},
{
"tenant_id": self.tenant.id,
"id": 5278,
"defaults": {"email": "user5278@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.update_or_create(**fields)
self.assertIs(created, True)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())

View File

@ -0,0 +1,83 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKDeleteTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_2,
id=2,
email="user0002@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_2)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
def test_delete_tenant_by_pk(self):
result = Tenant.objects.filter(pk=self.tenant_1.pk).delete()
self.assertEqual(
result,
(
3,
{
"composite_pk.Comment": 1,
"composite_pk.User": 1,
"composite_pk.Tenant": 1,
},
),
)
self.assertIs(Tenant.objects.filter(pk=self.tenant_1.pk).exists(), False)
self.assertIs(Tenant.objects.filter(pk=self.tenant_2.pk).exists(), True)
self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
def test_delete_user_by_pk(self):
result = User.objects.filter(pk=self.user_1.pk).delete()
self.assertEqual(
result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
)
self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
def test_delete_comments_by_user(self):
result = Comment.objects.filter(user=self.user_2).delete()
self.assertEqual(result, (2, {"composite_pk.Comment": 2}))
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), False)
def test_delete_without_pk(self):
msg = (
"Comment object can't be deleted because its pk attribute is set "
"to None."
)
with self.assertRaisesMessage(ValueError, msg):
Comment().delete()
with self.assertRaisesMessage(ValueError, msg):
Comment(tenant_id=1).delete()
with self.assertRaisesMessage(ValueError, msg):
Comment(id=1).delete()

View File

@ -0,0 +1,412 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKFilterTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.tenant_3 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.user_4 = User.objects.create(
tenant=cls.tenant_3,
id=4,
email="user0004@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
cls.comment_5 = Comment.objects.create(id=5, user=cls.user_1)
def test_filter_and_count_user_by_pk(self):
test_cases = (
({"pk": self.user_1.pk}, 1),
({"pk": self.user_2.pk}, 1),
({"pk": self.user_3.pk}, 1),
({"pk": (self.tenant_1.id, self.user_1.id)}, 1),
({"pk": (self.tenant_1.id, self.user_2.id)}, 1),
({"pk": (self.tenant_2.id, self.user_3.id)}, 1),
({"pk": (self.tenant_1.id, self.user_3.id)}, 0),
({"pk": (self.tenant_2.id, self.user_1.id)}, 0),
({"pk": (self.tenant_2.id, self.user_2.id)}, 0),
)
for lookup, count in test_cases:
with self.subTest(lookup=lookup, count=count):
self.assertEqual(User.objects.filter(**lookup).count(), count)
def test_order_comments_by_pk_asc(self):
self.assertSequenceEqual(
Comment.objects.order_by("pk"),
(
self.comment_1, # (1, 1)
self.comment_2, # (1, 2)
self.comment_3, # (1, 3)
self.comment_5, # (1, 5)
self.comment_4, # (2, 4)
),
)
def test_order_comments_by_pk_desc(self):
self.assertSequenceEqual(
Comment.objects.order_by("-pk"),
(
self.comment_4, # (2, 4)
self.comment_5, # (1, 5)
self.comment_3, # (1, 3)
self.comment_2, # (1, 2)
self.comment_1, # (1, 1)
),
)
def test_filter_comments_by_pk_gt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c11, (c12, c13, c15, c24)),
(c12, (c13, c15, c24)),
(c13, (c15, c24)),
(c15, (c24,)),
(c24, ()),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__gt=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_gte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c11, (c11, c12, c13, c15, c24)),
(c12, (c12, c13, c15, c24)),
(c13, (c13, c15, c24)),
(c15, (c15, c24)),
(c24, (c24,)),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__gte=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_lt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c24, (c11, c12, c13, c15)),
(c15, (c11, c12, c13)),
(c13, (c11, c12)),
(c12, (c11,)),
(c11, ()),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__lt=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_lte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c24, (c11, c12, c13, c15, c24)),
(c15, (c11, c12, c13, c15)),
(c13, (c11, c12, c13)),
(c12, (c11, c12)),
(c11, (c11,)),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__lte=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_in(self):
test_cases = (
(),
(self.comment_1,),
(self.comment_1, self.comment_4),
)
for objs in test_cases:
with self.subTest(objs=objs):
pks = [obj.pk for obj in objs]
self.assertSequenceEqual(
Comment.objects.filter(pk__in=pks).order_by("pk"), objs
)
def test_filter_comments_by_user_and_order_by_pk_asc(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1).order_by("pk"),
(self.comment_1, self.comment_2, self.comment_5),
)
def test_filter_comments_by_user_and_order_by_pk_desc(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1).order_by("-pk"),
(self.comment_5, self.comment_2, self.comment_1),
)
def test_filter_comments_by_user_and_exclude_by_pk(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1)
.exclude(pk=self.comment_1.pk)
.order_by("pk"),
(self.comment_2, self.comment_5),
)
def test_filter_comments_by_user_and_contains(self):
self.assertIs(
Comment.objects.filter(user=self.user_1).contains(self.comment_1), True
)
def test_filter_users_by_comments_in(self):
c1, c2, c3, c4, c5 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
((), ()),
((c1,), (u1,)),
((c1, c2), (u1, u1)),
((c1, c2, c3), (u1, u1, u2)),
((c1, c2, c3, c4), (u1, u1, u2, u3)),
((c1, c2, c3, c4, c5), (u1, u1, u1, u2, u3)),
)
for comments, users in test_cases:
with self.subTest(comments=comments, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__in=comments).order_by("pk"), users
)
def test_filter_users_by_comments_lt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2 = (
self.user_1,
self.user_2,
)
test_cases = (
(c11, ()),
(c12, (u1,)),
(c13, (u1, u1)),
(c15, (u1, u1, u2)),
(c24, (u1, u1, u1, u2)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__lt=comment).order_by("pk"), users
)
def test_filter_users_by_comments_lte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1,)),
(c12, (u1, u1)),
(c13, (u1, u1, u2)),
(c15, (u1, u1, u1, u2)),
(c24, (u1, u1, u1, u2, u3)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__lte=comment).order_by("pk"), users
)
def test_filter_users_by_comments_gt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1, u1, u2, u3)),
(c12, (u1, u2, u3)),
(c13, (u1, u3)),
(c15, (u3,)),
(c24, ()),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__gt=comment).order_by("pk"), users
)
def test_filter_users_by_comments_gte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1, u1, u1, u2, u3)),
(c12, (u1, u1, u2, u3)),
(c13, (u1, u2, u3)),
(c15, (u1, u3)),
(c24, (u3,)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__gte=comment).order_by("pk"), users
)
def test_filter_users_by_comments_exact(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1,)),
(c12, (u1,)),
(c13, (u2,)),
(c15, (u1,)),
(c24, (u3,)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments=comment).order_by("pk"), users
)
def test_filter_users_by_comments_isnull(self):
u1, u2, u3, u4 = (
self.user_1,
self.user_2,
self.user_3,
self.user_4,
)
with self.subTest("comments__isnull=True"):
self.assertSequenceEqual(
User.objects.filter(comments__isnull=True).order_by("pk"),
(u4,),
)
with self.subTest("comments__isnull=False"):
self.assertSequenceEqual(
User.objects.filter(comments__isnull=False).order_by("pk"),
(u1, u1, u1, u2, u3),
)
def test_filter_comments_by_pk_isnull(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
with self.subTest("pk__isnull=True"):
self.assertSequenceEqual(
Comment.objects.filter(pk__isnull=True).order_by("pk"),
(),
)
with self.subTest("pk__isnull=False"):
self.assertSequenceEqual(
Comment.objects.filter(pk__isnull=False).order_by("pk"),
(c11, c12, c13, c15, c24),
)
def test_filter_users_by_comments_subquery(self):
subquery = Comment.objects.filter(id=3).only("pk")
queryset = User.objects.filter(comments__in=subquery)
self.assertSequenceEqual(queryset, (self.user_2,))

View File

@ -0,0 +1,126 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKGetTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
def test_get_user(self):
test_cases = (
{"pk": self.user_1.pk},
{"pk": (self.tenant_1.id, self.user_1.id)},
{"id": self.user_1.id},
)
for lookup in test_cases:
with self.subTest(lookup=lookup):
self.assertEqual(User.objects.get(**lookup), self.user_1)
def test_get_comment(self):
test_cases = (
{"pk": self.comment_1.pk},
{"pk": (self.tenant_1.id, self.comment_1.id)},
{"id": self.comment_1.id},
{"user": self.user_1},
{"user_id": self.user_1.id},
{"user__id": self.user_1.id},
{"user__pk": self.user_1.pk},
{"tenant": self.tenant_1},
{"tenant_id": self.tenant_1.id},
{"tenant__id": self.tenant_1.id},
{"tenant__pk": self.tenant_1.pk},
)
for lookup in test_cases:
with self.subTest(lookup=lookup):
self.assertEqual(Comment.objects.get(**lookup), self.comment_1)
def test_get_or_create_user(self):
test_cases = (
{
"pk": self.user_1.pk,
"defaults": {"email": "user9201@example.com"},
},
{
"pk": (self.tenant_1.id, self.user_1.id),
"defaults": {"email": "user9201@example.com"},
},
{
"tenant": self.tenant_1,
"id": self.user_1.id,
"defaults": {"email": "user3512@example.com"},
},
{
"tenant_id": self.tenant_1.id,
"id": self.user_1.id,
"defaults": {"email": "user8239@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.get_or_create(**fields)
self.assertIs(created, False)
self.assertEqual(user.id, self.user_1.id)
self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
self.assertEqual(user.tenant_id, self.tenant_1.id)
self.assertEqual(user.email, self.user_1.email)
self.assertEqual(count, User.objects.count())
def test_lookup_errors(self):
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
m_tuple_collection = (
"'in' lookup of 'pk' must be a collection of tuples or lists"
)
m_2_elements_each = "'in' lookup of 'pk' must have 2 elements each"
test_cases = (
({"pk": 1}, m_tuple % "exact"),
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__exact": 1}, m_tuple % "exact"),
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__in": 1}, m_tuple % "in"),
({"pk__in": (1, 2, 3)}, m_tuple_collection),
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
({"pk__gt": 1}, m_tuple % "gt"),
({"pk__gt": (1, 2, 3)}, m_2_elements % "gt"),
({"pk__gte": 1}, m_tuple % "gte"),
({"pk__gte": (1, 2, 3)}, m_2_elements % "gte"),
({"pk__lt": 1}, m_tuple % "lt"),
({"pk__lt": (1, 2, 3)}, m_2_elements % "lt"),
({"pk__lte": 1}, m_tuple % "lte"),
({"pk__lte": (1, 2, 3)}, m_2_elements % "lte"),
)
for kwargs, message in test_cases:
with (
self.subTest(kwargs=kwargs),
self.assertRaisesMessage(ValueError, message),
):
Comment.objects.get(**kwargs)
def test_get_user_by_comments(self):
self.assertEqual(User.objects.get(comments=self.comment_1), self.user_1)

View File

@ -0,0 +1,153 @@
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.test import TestCase
from .models import Comment, Tenant, Token, User
class CompositePKModelsTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
def test_fields(self):
# tenant_1
self.assertSequenceEqual(
self.tenant_1.user_set.order_by("pk"),
[self.user_1, self.user_2],
)
self.assertSequenceEqual(
self.tenant_1.comments.order_by("pk"),
[self.comment_1, self.comment_2, self.comment_3],
)
# tenant_2
self.assertSequenceEqual(self.tenant_2.user_set.order_by("pk"), [self.user_3])
self.assertSequenceEqual(
self.tenant_2.comments.order_by("pk"), [self.comment_4]
)
# user_1
self.assertEqual(self.user_1.id, 1)
self.assertEqual(self.user_1.tenant_id, self.tenant_1.id)
self.assertEqual(self.user_1.tenant, self.tenant_1)
self.assertEqual(self.user_1.pk, (self.tenant_1.id, self.user_1.id))
self.assertSequenceEqual(
self.user_1.comments.order_by("pk"), [self.comment_1, self.comment_2]
)
# user_2
self.assertEqual(self.user_2.id, 2)
self.assertEqual(self.user_2.tenant_id, self.tenant_1.id)
self.assertEqual(self.user_2.tenant, self.tenant_1)
self.assertEqual(self.user_2.pk, (self.tenant_1.id, self.user_2.id))
self.assertSequenceEqual(self.user_2.comments.order_by("pk"), [self.comment_3])
# comment_1
self.assertEqual(self.comment_1.id, 1)
self.assertEqual(self.comment_1.user_id, self.user_1.id)
self.assertEqual(self.comment_1.user, self.user_1)
self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id)
self.assertEqual(self.comment_1.tenant, self.tenant_1)
self.assertEqual(self.comment_1.pk, (self.tenant_1.id, self.user_1.id))
def test_full_clean_success(self):
test_cases = (
# 1, 1234, {}
({"tenant": self.tenant_1, "id": 1234}, {}),
({"tenant_id": self.tenant_1.id, "id": 1234}, {}),
({"pk": (self.tenant_1.id, 1234)}, {}),
# 1, 1, {"id"}
({"tenant": self.tenant_1, "id": 1}, {"id"}),
({"tenant_id": self.tenant_1.id, "id": 1}, {"id"}),
({"pk": (self.tenant_1.id, 1)}, {"id"}),
# 1, 1, {"tenant", "id"}
({"tenant": self.tenant_1, "id": 1}, {"tenant", "id"}),
({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant", "id"}),
({"pk": (self.tenant_1.id, 1)}, {"tenant", "id"}),
)
for kwargs, exclude in test_cases:
with self.subTest(kwargs):
kwargs["email"] = "user0004@example.com"
User(**kwargs).full_clean(exclude=exclude)
def test_full_clean_failure(self):
e_tenant_and_id = "User with this Tenant and Id already exists."
e_id = "User with this Id already exists."
test_cases = (
# 1, 1, {}
({"tenant": self.tenant_1, "id": 1}, {}, (e_tenant_and_id, e_id)),
({"tenant_id": self.tenant_1.id, "id": 1}, {}, (e_tenant_and_id, e_id)),
({"pk": (self.tenant_1.id, 1)}, {}, (e_tenant_and_id, e_id)),
# 2, 1, {}
({"tenant": self.tenant_2, "id": 1}, {}, (e_id,)),
({"tenant_id": self.tenant_2.id, "id": 1}, {}, (e_id,)),
({"pk": (self.tenant_2.id, 1)}, {}, (e_id,)),
# 1, 1, {"tenant"}
({"tenant": self.tenant_1, "id": 1}, {"tenant"}, (e_id,)),
({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant"}, (e_id,)),
({"pk": (self.tenant_1.id, 1)}, {"tenant"}, (e_id,)),
)
for kwargs, exclude, messages in test_cases:
with self.subTest(kwargs):
with self.assertRaises(ValidationError) as ctx:
kwargs["email"] = "user0004@example.com"
User(**kwargs).full_clean(exclude=exclude)
self.assertSequenceEqual(ctx.exception.messages, messages)
def test_field_conflicts(self):
test_cases = (
({"pk": (1, 1), "id": 2}, (1, 1)),
({"id": 2, "pk": (1, 1)}, (1, 1)),
({"pk": (1, 1), "tenant_id": 2}, (1, 1)),
({"tenant_id": 2, "pk": (1, 1)}, (1, 1)),
({"pk": (2, 2), "tenant_id": 3, "id": 4}, (2, 2)),
({"tenant_id": 3, "id": 4, "pk": (2, 2)}, (2, 2)),
)
for kwargs, pk in test_cases:
with self.subTest(kwargs=kwargs):
user = User(**kwargs)
self.assertEqual(user.pk, pk)
def test_validate_unique(self):
user = User.objects.get(pk=self.user_1.pk)
user.id = None
with self.assertRaises(ValidationError) as ctx:
user.validate_unique()
self.assertSequenceEqual(
ctx.exception.messages, ("User with this Email already exists.",)
)
def test_permissions(self):
token = ContentType.objects.get_for_model(Token)
user = ContentType.objects.get_for_model(User)
comment = ContentType.objects.get_for_model(Comment)
self.assertEqual(4, token.permission_set.count())
self.assertEqual(4, user.permission_set.count())
self.assertEqual(4, comment.permission_set.count())

View File

@ -0,0 +1,134 @@
from django.db.models.query_utils import PathInfo
from django.db.models.sql import Query
from django.test import TestCase
from .models import Comment, Tenant, User
class NamesToPathTests(TestCase):
def test_id(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["id"], User._meta)
self.assertEqual(path, [])
self.assertEqual(final_field, User._meta.get_field("id"))
self.assertEqual(targets, (User._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_pk(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["pk"], User._meta)
self.assertEqual(path, [])
self.assertEqual(final_field, User._meta.get_field("pk"))
self.assertEqual(targets, (User._meta.get_field("pk"),))
self.assertEqual(rest, [])
def test_tenant_id(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(
["tenant", "id"], User._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=User._meta,
to_opts=Tenant._meta,
target_fields=(Tenant._meta.get_field("id"),),
join_field=User._meta.get_field("tenant"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, Tenant._meta.get_field("id"))
self.assertEqual(targets, (Tenant._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_user_id(self):
query = Query(Comment)
path, final_field, targets, rest = query.names_to_path(
["user", "id"], Comment._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=Comment._meta,
to_opts=User._meta,
target_fields=(
User._meta.get_field("tenant"),
User._meta.get_field("id"),
),
join_field=Comment._meta.get_field("user"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, User._meta.get_field("id"))
self.assertEqual(targets, (User._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_user_tenant_id(self):
query = Query(Comment)
path, final_field, targets, rest = query.names_to_path(
["user", "tenant", "id"], Comment._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=Comment._meta,
to_opts=User._meta,
target_fields=(
User._meta.get_field("tenant"),
User._meta.get_field("id"),
),
join_field=Comment._meta.get_field("user"),
m2m=False,
direct=True,
filtered_relation=None,
),
PathInfo(
from_opts=User._meta,
to_opts=Tenant._meta,
target_fields=(Tenant._meta.get_field("id"),),
join_field=User._meta.get_field("tenant"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, Tenant._meta.get_field("id"))
self.assertEqual(targets, (Tenant._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_comments(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["comments"], User._meta)
self.assertEqual(
path,
[
PathInfo(
from_opts=User._meta,
to_opts=Comment._meta,
target_fields=(Comment._meta.get_field("pk"),),
join_field=User._meta.get_field("comments"),
m2m=True,
direct=False,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, User._meta.get_field("comments"))
self.assertEqual(targets, (Comment._meta.get_field("pk"),))
self.assertEqual(rest, [])

View File

@ -0,0 +1,135 @@
from django.test import TestCase
from .models import Comment, Tenant, Token, User
class CompositePKUpdateTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create(name="A")
cls.tenant_2 = Tenant.objects.create(name="B")
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.token_1 = Token.objects.create(id=1, tenant=cls.tenant_1)
cls.token_2 = Token.objects.create(id=2, tenant=cls.tenant_2)
cls.token_3 = Token.objects.create(id=3, tenant=cls.tenant_1)
cls.token_4 = Token.objects.create(id=4, tenant=cls.tenant_2)
def test_update_user(self):
email = "user9315@example.com"
result = User.objects.filter(pk=self.user_1.pk).update(email=email)
self.assertEqual(result, 1)
user = User.objects.get(pk=self.user_1.pk)
self.assertEqual(user.email, email)
def test_save_user(self):
count = User.objects.count()
email = "user9314@example.com"
user = User.objects.get(pk=self.user_1.pk)
user.email = email
user.save()
user.refresh_from_db()
self.assertEqual(user.email, email)
user = User.objects.get(pk=self.user_1.pk)
self.assertEqual(user.email, email)
self.assertEqual(count, User.objects.count())
def test_bulk_update_comments(self):
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
comment_3 = Comment.objects.get(pk=self.comment_3.pk)
comment_1.text = "foo"
comment_2.text = "bar"
comment_3.text = "baz"
result = Comment.objects.bulk_update(
[comment_1, comment_2, comment_3], ["text"]
)
self.assertEqual(result, 3)
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
comment_3 = Comment.objects.get(pk=self.comment_3.pk)
self.assertEqual(comment_1.text, "foo")
self.assertEqual(comment_2.text, "bar")
self.assertEqual(comment_3.text, "baz")
def test_update_or_create_user(self):
test_cases = (
{
"pk": self.user_1.pk,
"defaults": {"email": "user3914@example.com"},
},
{
"pk": (self.tenant_1.id, self.user_1.id),
"defaults": {"email": "user9375@example.com"},
},
{
"tenant": self.tenant_1,
"id": self.user_1.id,
"defaults": {"email": "user3517@example.com"},
},
{
"tenant_id": self.tenant_1.id,
"id": self.user_1.id,
"defaults": {"email": "user8391@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.update_or_create(**fields)
self.assertIs(created, False)
self.assertEqual(user.id, self.user_1.id)
self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
self.assertEqual(user.tenant_id, self.tenant_1.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(count, User.objects.count())
def test_update_comment_by_user_email(self):
result = Comment.objects.filter(user__email=self.user_1.email).update(
text="foo"
)
self.assertEqual(result, 2)
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
self.assertEqual(comment_1.text, "foo")
self.assertEqual(comment_2.text, "foo")
def test_update_token_by_tenant_name(self):
result = Token.objects.filter(tenant__name="A").update(secret="bar")
self.assertEqual(result, 2)
token_1 = Token.objects.get(pk=self.token_1.pk)
self.assertEqual(token_1.secret, "bar")
token_3 = Token.objects.get(pk=self.token_3.pk)
self.assertEqual(token_3.secret, "bar")
def test_cant_update_to_unsaved_object(self):
msg = (
"Unsaved model instance <User: User object ((None, None))> cannot be used "
"in an ORM query."
)
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.update(user=User())

View File

@ -0,0 +1,212 @@
from collections import namedtuple
from uuid import UUID
from django.test import TestCase
from .models import Post, Tenant, User
class CompositePKValuesTests(TestCase):
USER_1_EMAIL = "user0001@example.com"
USER_2_EMAIL = "user0002@example.com"
USER_3_EMAIL = "user0003@example.com"
POST_1_ID = "77777777-7777-7777-7777-777777777777"
POST_2_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
POST_3_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1, id=1, email=cls.USER_1_EMAIL
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1, id=2, email=cls.USER_2_EMAIL
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2, id=3, email=cls.USER_3_EMAIL
)
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_1_ID)
cls.post_2 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_2_ID)
cls.post_3 = Post.objects.create(tenant=cls.tenant_2, id=cls.POST_3_ID)
def test_values_list(self):
with self.subTest('User.objects.values_list("pk")'):
self.assertSequenceEqual(
User.objects.values_list("pk").order_by("pk"),
(
(self.user_1.pk,),
(self.user_2.pk,),
(self.user_3.pk,),
),
)
with self.subTest('User.objects.values_list("pk", "email")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "email").order_by("pk"),
(
(self.user_1.pk, self.USER_1_EMAIL),
(self.user_2.pk, self.USER_2_EMAIL),
(self.user_3.pk, self.USER_3_EMAIL),
),
)
with self.subTest('User.objects.values_list("pk", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.id),
(self.user_2.pk, self.user_2.id),
(self.user_3.pk, self.user_3.id),
),
)
with self.subTest('User.objects.values_list("pk", "tenant_id", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "tenant_id", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.tenant_id, self.user_1.id),
(self.user_2.pk, self.user_2.tenant_id, self.user_2.id),
(self.user_3.pk, self.user_3.tenant_id, self.user_3.id),
),
)
with self.subTest('User.objects.values_list("pk", flat=True)'):
self.assertSequenceEqual(
User.objects.values_list("pk", flat=True).order_by("pk"),
(
self.user_1.pk,
self.user_2.pk,
self.user_3.pk,
),
)
with self.subTest('Post.objects.values_list("pk", flat=True)'):
self.assertSequenceEqual(
Post.objects.values_list("pk", flat=True).order_by("pk"),
(
(self.tenant_1.id, UUID(self.POST_1_ID)),
(self.tenant_1.id, UUID(self.POST_2_ID)),
(self.tenant_2.id, UUID(self.POST_3_ID)),
),
)
with self.subTest('Post.objects.values_list("pk")'):
self.assertSequenceEqual(
Post.objects.values_list("pk").order_by("pk"),
(
((self.tenant_1.id, UUID(self.POST_1_ID)),),
((self.tenant_1.id, UUID(self.POST_2_ID)),),
((self.tenant_2.id, UUID(self.POST_3_ID)),),
),
)
with self.subTest('Post.objects.values_list("pk", "id")'):
self.assertSequenceEqual(
Post.objects.values_list("pk", "id").order_by("pk"),
(
((self.tenant_1.id, UUID(self.POST_1_ID)), UUID(self.POST_1_ID)),
((self.tenant_1.id, UUID(self.POST_2_ID)), UUID(self.POST_2_ID)),
((self.tenant_2.id, UUID(self.POST_3_ID)), UUID(self.POST_3_ID)),
),
)
with self.subTest('Post.objects.values_list("id", "pk")'):
self.assertSequenceEqual(
Post.objects.values_list("id", "pk").order_by("pk"),
(
(UUID(self.POST_1_ID), (self.tenant_1.id, UUID(self.POST_1_ID))),
(UUID(self.POST_2_ID), (self.tenant_1.id, UUID(self.POST_2_ID))),
(UUID(self.POST_3_ID), (self.tenant_2.id, UUID(self.POST_3_ID))),
),
)
with self.subTest('User.objects.values_list("pk", named=True)'):
Row = namedtuple("Row", ["pk"])
self.assertSequenceEqual(
User.objects.values_list("pk", named=True).order_by("pk"),
(
Row(pk=self.user_1.pk),
Row(pk=self.user_2.pk),
Row(pk=self.user_3.pk),
),
)
with self.subTest('User.objects.values_list("pk", "pk")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "pk").order_by("pk"),
(
(self.user_1.pk,),
(self.user_2.pk,),
(self.user_3.pk,),
),
)
with self.subTest('User.objects.values_list("pk", "id", "pk", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "id", "pk", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.id),
(self.user_2.pk, self.user_2.id),
(self.user_3.pk, self.user_3.id),
),
)
def test_values(self):
with self.subTest('User.objects.values("pk")'):
self.assertSequenceEqual(
User.objects.values("pk").order_by("pk"),
(
{"pk": self.user_1.pk},
{"pk": self.user_2.pk},
{"pk": self.user_3.pk},
),
)
with self.subTest('User.objects.values("pk", "email")'):
self.assertSequenceEqual(
User.objects.values("pk", "email").order_by("pk"),
(
{"pk": self.user_1.pk, "email": self.USER_1_EMAIL},
{"pk": self.user_2.pk, "email": self.USER_2_EMAIL},
{"pk": self.user_3.pk, "email": self.USER_3_EMAIL},
),
)
with self.subTest('User.objects.values("pk", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "id").order_by("pk"),
(
{"pk": self.user_1.pk, "id": self.user_1.id},
{"pk": self.user_2.pk, "id": self.user_2.id},
{"pk": self.user_3.pk, "id": self.user_3.id},
),
)
with self.subTest('User.objects.values("pk", "tenant_id", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "tenant_id", "id").order_by("pk"),
(
{
"pk": self.user_1.pk,
"tenant_id": self.user_1.tenant_id,
"id": self.user_1.id,
},
{
"pk": self.user_2.pk,
"tenant_id": self.user_2.tenant_id,
"id": self.user_2.id,
},
{
"pk": self.user_3.pk,
"tenant_id": self.user_3.tenant_id,
"id": self.user_3.id,
},
),
)
with self.subTest('User.objects.values("pk", "pk")'):
self.assertSequenceEqual(
User.objects.values("pk", "pk").order_by("pk"),
(
{"pk": self.user_1.pk},
{"pk": self.user_2.pk},
{"pk": self.user_3.pk},
),
)
with self.subTest('User.objects.values("pk", "id", "pk", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "id", "pk", "id").order_by("pk"),
(
{"pk": self.user_1.pk, "id": self.user_1.id},
{"pk": self.user_2.pk, "id": self.user_2.id},
{"pk": self.user_3.pk, "id": self.user_3.id},
),
)

345
tests/composite_pk/tests.py Normal file
View File

@ -0,0 +1,345 @@
import json
import unittest
from uuid import UUID
import yaml
from django import forms
from django.core import serializers
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection
from django.db.models import CompositePrimaryKey
from django.forms import modelform_factory
from django.test import TestCase
from .models import Comment, Post, Tenant, User
class CommentForm(forms.ModelForm):
class Meta:
model = Comment
fields = "__all__"
class CompositePKTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant = Tenant.objects.create()
cls.user = User.objects.create(
tenant=cls.tenant,
id=1,
email="user0001@example.com",
)
cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
@staticmethod
def get_constraints(table):
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
def test_pk_updated_if_field_updated(self):
user = User.objects.get(pk=self.user.pk)
self.assertEqual(user.pk, (self.tenant.id, self.user.id))
self.assertIs(user._is_pk_set(), True)
user.tenant_id = 9831
self.assertEqual(user.pk, (9831, self.user.id))
self.assertIs(user._is_pk_set(), True)
user.id = 4321
self.assertEqual(user.pk, (9831, 4321))
self.assertIs(user._is_pk_set(), True)
user.pk = (9132, 3521)
self.assertEqual(user.tenant_id, 9132)
self.assertEqual(user.id, 3521)
self.assertIs(user._is_pk_set(), True)
user.id = None
self.assertEqual(user.pk, (9132, None))
self.assertEqual(user.tenant_id, 9132)
self.assertIsNone(user.id)
self.assertIs(user._is_pk_set(), False)
def test_hash(self):
self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
msg = "Model instances without primary key value are unhashable"
with self.assertRaisesMessage(TypeError, msg):
hash(User())
with self.assertRaisesMessage(TypeError, msg):
hash(User(tenant_id=1))
with self.assertRaisesMessage(TypeError, msg):
hash(User(id=1))
def test_pk_must_be_list_or_tuple(self):
user = User.objects.get(pk=self.user.pk)
test_cases = [
"foo",
1000,
3.14,
True,
False,
]
for pk in test_cases:
with self.assertRaisesMessage(
ValueError, "'pk' must be a list or a tuple."
):
user.pk = pk
def test_pk_must_have_2_elements(self):
user = User.objects.get(pk=self.user.pk)
test_cases = [
(),
[],
(1000,),
[1000],
(1, 2, 3),
[1, 2, 3],
]
for pk in test_cases:
with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
user.pk = pk
def test_composite_pk_in_fields(self):
user_fields = {f.name for f in User._meta.get_fields()}
self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
comment_fields = {f.name for f in Comment._meta.get_fields()}
self.assertEqual(
comment_fields,
{"pk", "tenant", "id", "user_id", "user", "text"},
)
def test_pk_field(self):
pk = User._meta.get_field("pk")
self.assertIsInstance(pk, CompositePrimaryKey)
self.assertIs(User._meta.pk, pk)
def test_error_on_user_pk_conflict(self):
with self.assertRaises(IntegrityError):
User.objects.create(tenant=self.tenant, id=self.user.id)
def test_error_on_comment_pk_conflict(self):
with self.assertRaises(IntegrityError):
Comment.objects.create(tenant=self.tenant, id=self.comment.id)
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
def test_get_constraints_postgresql(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["composite_pk_user_pkey"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["composite_pk_comment_pkey"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
def test_get_constraints_sqlite(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["__primary__"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["__primary__"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
def test_get_constraints_mysql(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["PRIMARY"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["PRIMARY"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
def test_get_constraints_oracle(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = next(c for c in user_constraints.values() if c["primary_key"])
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertEqual(user_pk["primary_key"], 1)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertEqual(comment_pk["primary_key"], 1)
def test_in_bulk(self):
"""
Test the .in_bulk() method of composite_pk models.
"""
result = Comment.objects.in_bulk()
self.assertEqual(result, {self.comment.pk: self.comment})
result = Comment.objects.in_bulk([self.comment.pk])
self.assertEqual(result, {self.comment.pk: self.comment})
def test_iterator(self):
"""
Test the .iterator() method of composite_pk models.
"""
result = list(Comment.objects.iterator())
self.assertEqual(result, [self.comment])
def test_query(self):
users = User.objects.values_list("pk").order_by("pk")
self.assertNotIn('AS "pk"', str(users.query))
def test_only(self):
users = User.objects.only("pk")
self.assertSequenceEqual(users, (self.user,))
user = users[0]
with self.assertNumQueries(0):
self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
self.assertEqual(user.tenant_id, self.user.tenant_id)
self.assertEqual(user.id, self.user.id)
with self.assertNumQueries(1):
self.assertEqual(user.email, self.user.email)
def test_model_forms(self):
fields = ["tenant", "id", "user_id", "text"]
self.assertEqual(list(CommentForm.base_fields), fields)
form = modelform_factory(Comment, fields="__all__")
self.assertEqual(list(form().fields), fields)
with self.assertRaisesMessage(
FieldError, "Unknown field(s) (pk) specified for Comment"
):
self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
class CompositePKFixturesTests(TestCase):
fixtures = ["tenant"]
def test_objects(self):
tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
self.assertEqual(tenant_1.id, 1)
self.assertEqual(tenant_1.name, "Tenant 1")
self.assertEqual(tenant_2.id, 2)
self.assertEqual(tenant_2.name, "Tenant 2")
self.assertEqual(tenant_3.id, 3)
self.assertEqual(tenant_3.name, "Tenant 3")
user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
self.assertEqual(user_1.id, 1)
self.assertEqual(user_1.tenant_id, 1)
self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
self.assertEqual(user_1.email, "user0001@example.com")
self.assertEqual(user_2.id, 2)
self.assertEqual(user_2.tenant_id, 1)
self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
self.assertEqual(user_2.email, "user0002@example.com")
self.assertEqual(user_3.id, 3)
self.assertEqual(user_3.tenant_id, 2)
self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
self.assertEqual(user_3.email, "user0003@example.com")
self.assertEqual(user_4.id, 4)
self.assertEqual(user_4.tenant_id, 2)
self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
self.assertEqual(user_4.email, "user0004@example.com")
post_1, post_2 = Post.objects.order_by("pk")
self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
self.assertEqual(post_1.tenant_id, 2)
self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
self.assertEqual(post_2.tenant_id, 2)
self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
def test_serialize_user_json(self):
users = User.objects.filter(pk=(1, 1))
result = serializers.serialize("json", users)
self.assertEqual(
json.loads(result),
[
{
"model": "composite_pk.user",
"pk": [1, 1],
"fields": {
"email": "user0001@example.com",
"id": 1,
"tenant": 1,
},
}
],
)
def test_serialize_user_jsonl(self):
users = User.objects.filter(pk=(1, 2))
result = serializers.serialize("jsonl", users)
self.assertEqual(
json.loads(result),
{
"model": "composite_pk.user",
"pk": [1, 2],
"fields": {
"email": "user0002@example.com",
"id": 2,
"tenant": 1,
},
},
)
def test_serialize_user_yaml(self):
users = User.objects.filter(pk=(2, 3))
result = serializers.serialize("yaml", users)
self.assertEqual(
yaml.safe_load(result),
[
{
"model": "composite_pk.user",
"pk": [2, 3],
"fields": {
"email": "user0003@example.com",
"id": 3,
"tenant": 2,
},
},
],
)
def test_serialize_user_python(self):
users = User.objects.filter(pk=(2, 4))
result = serializers.serialize("python", users)
self.assertEqual(
result,
[
{
"model": "composite_pk.user",
"pk": [2, 4],
"fields": {
"email": "user0004@example.com",
"id": 4,
"tenant": 2,
},
},
],
)
def test_serialize_post_uuid(self):
posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
result = serializers.serialize("json", posts)
self.assertEqual(
json.loads(result),
[
{
"model": "composite_pk.post",
"pk": [2, "11111111-1111-1111-1111-111111111111"],
"fields": {
"id": "11111111-1111-1111-1111-111111111111",
"tenant": 2,
},
},
],
)

View File

@ -5059,6 +5059,95 @@ class AutodetectorTests(BaseAutodetectorTests):
self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"]) self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
self.assertOperationAttributes(changes, "testapp", 0, 0, name="Book") self.assertOperationAttributes(changes, "testapp", 0, 0, name="Book")
@mock.patch(
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition"
)
def test_add_composite_pk(self, mocked_ask_method):
before = [
ModelState(
"app",
"foo",
[
("id", models.AutoField(primary_key=True)),
],
),
]
after = [
ModelState(
"app",
"foo",
[
("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
("id", models.IntegerField()),
],
),
]
changes = self.get_changes(before, after)
self.assertEqual(mocked_ask_method.call_count, 0)
self.assertNumberMigrations(changes, "app", 1)
self.assertOperationTypes(changes, "app", 0, ["AddField", "AlterField"])
self.assertOperationAttributes(
changes,
"app",
0,
0,
name="pk",
model_name="foo",
preserve_default=True,
)
self.assertOperationAttributes(
changes,
"app",
0,
1,
name="id",
model_name="foo",
preserve_default=True,
)
def test_remove_composite_pk(self):
before = [
ModelState(
"app",
"foo",
[
("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
("id", models.IntegerField()),
],
),
]
after = [
ModelState(
"app",
"foo",
[
("id", models.AutoField(primary_key=True)),
],
),
]
changes = self.get_changes(before, after)
self.assertNumberMigrations(changes, "app", 1)
self.assertOperationTypes(changes, "app", 0, ["RemoveField", "AlterField"])
self.assertOperationAttributes(
changes,
"app",
0,
0,
name="pk",
model_name="foo",
)
self.assertOperationAttributes(
changes,
"app",
0,
1,
name="id",
model_name="foo",
preserve_default=True,
)
class MigrationSuggestNameTests(SimpleTestCase): class MigrationSuggestNameTests(SimpleTestCase):
def test_no_operations(self): def test_no_operations(self):

View File

@ -6287,6 +6287,61 @@ class OperationTests(OperationTestBase):
self.assertEqual(pony_new.generated, 1) self.assertEqual(pony_new.generated, 1)
self.assertEqual(pony_new.static, 2) self.assertEqual(pony_new.static, 2)
def test_composite_pk_operations(self):
app_label = "test_d8d90af6"
project_state = self.set_up_test_model(app_label)
operation_1 = migrations.AddField(
"Pony", "pk", models.CompositePrimaryKey("id", "pink")
)
operation_2 = migrations.AlterField("Pony", "id", models.IntegerField())
operation_3 = migrations.RemoveField("Pony", "pk")
table_name = f"{app_label}_pony"
# 1. Add field (pk).
new_state = project_state.clone()
operation_1.state_forwards(app_label, new_state)
with connection.schema_editor() as editor:
operation_1.database_forwards(app_label, editor, project_state, new_state)
self.assertColumnNotExists(table_name, "pk")
Pony = new_state.apps.get_model(app_label, "pony")
obj_1 = Pony.objects.create(weight=1)
msg = (
f"obj_1={obj_1}, "
f"obj_1.id={obj_1.id}, "
f"obj_1.pink={obj_1.pink}, "
f"obj_1.pk={obj_1.pk}, "
f"Pony._meta.pk={repr(Pony._meta.pk)}, "
f"Pony._meta.get_field('id')={repr(Pony._meta.get_field('id'))}"
)
self.assertEqual(obj_1.pink, 3, msg)
self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink), msg)
# 2. Alter field (id -> IntegerField()).
project_state, new_state = new_state, new_state.clone()
operation_2.state_forwards(app_label, new_state)
with connection.schema_editor() as editor:
operation_2.database_forwards(app_label, editor, project_state, new_state)
Pony = new_state.apps.get_model(app_label, "pony")
obj_1 = Pony.objects.get(id=obj_1.id)
self.assertEqual(obj_1.pink, 3)
self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink))
obj_2 = Pony.objects.create(id=2, weight=2)
self.assertEqual(obj_2.id, 2)
self.assertEqual(obj_2.pink, 3)
self.assertEqual(obj_2.pk, (obj_2.id, obj_2.pink))
# 3. Remove field (pk).
project_state, new_state = new_state, new_state.clone()
operation_3.state_forwards(app_label, new_state)
with connection.schema_editor() as editor:
operation_3.database_forwards(app_label, editor, project_state, new_state)
Pony = new_state.apps.get_model(app_label, "pony")
obj_1 = Pony.objects.get(id=obj_1.id)
self.assertEqual(obj_1.pk, obj_1.id)
obj_2 = Pony.objects.get(id=obj_2.id)
self.assertEqual(obj_2.id, 2)
self.assertEqual(obj_2.pk, obj_2.id)
class SwappableOperationTests(OperationTestBase): class SwappableOperationTests(OperationTestBase):
""" """

View File

@ -1206,6 +1206,28 @@ class StateTests(SimpleTestCase):
choices_field = Author._meta.get_field("choice") choices_field = Author._meta.get_field("choice")
self.assertEqual(list(choices_field.choices), choices) self.assertEqual(list(choices_field.choices), choices)
def test_composite_pk_state(self):
new_apps = Apps(["migrations"])
class Foo(models.Model):
pk = models.CompositePrimaryKey("account_id", "id")
account_id = models.SmallIntegerField()
id = models.SmallIntegerField()
class Meta:
app_label = "migrations"
apps = new_apps
project_state = ProjectState.from_apps(new_apps)
model_state = project_state.models["migrations", "foo"]
self.assertEqual(len(model_state.options), 2)
self.assertEqual(model_state.options["constraints"], [])
self.assertEqual(model_state.options["indexes"], [])
self.assertEqual(len(model_state.fields), 3)
self.assertIn("pk", model_state.fields)
self.assertIn("account_id", model_state.fields)
self.assertIn("id", model_state.fields)
class StateRelationsTests(SimpleTestCase): class StateRelationsTests(SimpleTestCase):
def get_base_project_state(self): def get_base_project_state(self):

View File

@ -1138,3 +1138,22 @@ class WriterTests(SimpleTestCase):
ValueError, "'TestModel1' must inherit from 'BaseSerializer'." ValueError, "'TestModel1' must inherit from 'BaseSerializer'."
): ):
MigrationWriter.register_serializer(complex, TestModel1) MigrationWriter.register_serializer(complex, TestModel1)
def test_composite_pk_import(self):
migration = type(
"Migration",
(migrations.Migration,),
{
"operations": [
migrations.AddField(
"foo",
"bar",
models.CompositePrimaryKey("foo_id", "bar_id"),
),
],
},
)
writer = MigrationWriter(migration)
output = writer.as_string()
self.assertEqual(output.count("import"), 1)
self.assertIn("from django.db import migrations, models", output)