1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Fixed #373 -- Added CompositePrimaryKey.

Thanks Lily Foote and Simon Charette for reviews and mentoring
this Google Summer of Code 2024 project.

Co-authored-by: Simon Charette <charette.s@gmail.com>
Co-authored-by: Lily Foote <code@lilyf.org>
This commit is contained in:
Bendeguz Csirmaz
2024-04-07 10:32:16 +08:00
committed by Sarah Boyce
parent 86661f2449
commit 978aae4334
43 changed files with 3078 additions and 29 deletions

View File

View File

@@ -0,0 +1,75 @@
[
{
"pk": 1,
"model": "composite_pk.tenant",
"fields": {
"id": 1,
"name": "Tenant 1"
}
},
{
"pk": 2,
"model": "composite_pk.tenant",
"fields": {
"id": 2,
"name": "Tenant 2"
}
},
{
"pk": 3,
"model": "composite_pk.tenant",
"fields": {
"id": 3,
"name": "Tenant 3"
}
},
{
"pk": [1, 1],
"model": "composite_pk.user",
"fields": {
"tenant_id": 1,
"id": 1,
"email": "user0001@example.com"
}
},
{
"pk": [1, 2],
"model": "composite_pk.user",
"fields": {
"tenant_id": 1,
"id": 2,
"email": "user0002@example.com"
}
},
{
"pk": [2, 3],
"model": "composite_pk.user",
"fields": {
"email": "user0003@example.com"
}
},
{
"model": "composite_pk.user",
"fields": {
"tenant_id": 2,
"id": 4,
"email": "user0004@example.com"
}
},
{
"pk": [2, "11111111-1111-1111-1111-111111111111"],
"model": "composite_pk.post",
"fields": {
"tenant_id": 2,
"id": "11111111-1111-1111-1111-111111111111"
}
},
{
"pk": [2, "ffffffff-ffff-ffff-ffff-ffffffffffff"],
"model": "composite_pk.post",
"fields": {
"tenant_id": 2,
"id": "ffffffff-ffff-ffff-ffff-ffffffffffff"
}
}
]

View File

@@ -0,0 +1,9 @@
from .tenant import Comment, Post, Tenant, Token, User
__all__ = [
"Comment",
"Post",
"Tenant",
"Token",
"User",
]

View File

@@ -0,0 +1,50 @@
from django.db import models
class Tenant(models.Model):
name = models.CharField(max_length=10, default="", blank=True)
class Token(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, related_name="tokens")
id = models.SmallIntegerField()
secret = models.CharField(max_length=10, default="", blank=True)
class BaseModel(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.SmallIntegerField(unique=True)
class Meta:
abstract = True
class User(BaseModel):
email = models.EmailField(unique=True)
class Comment(models.Model):
pk = models.CompositePrimaryKey("tenant", "id")
tenant = models.ForeignKey(
Tenant,
on_delete=models.CASCADE,
related_name="comments",
)
id = models.SmallIntegerField(unique=True, db_column="comment_id")
user_id = models.SmallIntegerField()
user = models.ForeignObject(
User,
on_delete=models.CASCADE,
from_fields=("tenant_id", "user_id"),
to_fields=("tenant_id", "id"),
related_name="comments",
)
text = models.TextField(default="", blank=True)
class Post(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.UUIDField()

View File

@@ -0,0 +1,139 @@
from django.db import NotSupportedError
from django.db.models import Count, Q
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKAggregateTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_2, text="foo")
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1, text="bar")
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_1, text="foobar")
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3, text="foobarbaz")
cls.comment_5 = Comment.objects.create(id=5, user=cls.user_3, text="barbaz")
cls.comment_6 = Comment.objects.create(id=6, user=cls.user_3, text="baz")
def test_users_annotated_with_comments_id_count(self):
user_1, user_2, user_3 = User.objects.annotate(Count("comments__id")).order_by(
"pk"
)
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__id__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__id__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__id__count, 3)
def test_users_annotated_with_aliased_comments_id_count(self):
user_1, user_2, user_3 = User.objects.annotate(
comments_count=Count("comments__id")
).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments_count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments_count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments_count, 3)
def test_users_annotated_with_comments_count(self):
user_1, user_2, user_3 = User.objects.annotate(Count("comments")).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__count, 3)
def test_users_annotated_with_comments_count_filter(self):
user_1, user_2, user_3 = User.objects.annotate(
comments__count=Count(
"comments", filter=Q(pk__in=[self.user_1.pk, self.user_2.pk])
)
).order_by("pk")
self.assertEqual(user_1, self.user_1)
self.assertEqual(user_1.comments__count, 2)
self.assertEqual(user_2, self.user_2)
self.assertEqual(user_2.comments__count, 1)
self.assertEqual(user_3, self.user_3)
self.assertEqual(user_3.comments__count, 0)
def test_count_distinct_not_supported(self):
with self.assertRaisesMessage(
NotSupportedError, "COUNT(DISTINCT) doesn't support composite primary keys"
):
self.assertIsNone(
User.objects.annotate(comments__count=Count("comments", distinct=True))
)
def test_user_values_annotated_with_comments_id_count(self):
self.assertSequenceEqual(
User.objects.values("pk").annotate(Count("comments__id")).order_by("pk"),
(
{"pk": self.user_1.pk, "comments__id__count": 2},
{"pk": self.user_2.pk, "comments__id__count": 1},
{"pk": self.user_3.pk, "comments__id__count": 3},
),
)
def test_user_values_annotated_with_filtered_comments_id_count(self):
self.assertSequenceEqual(
User.objects.values("pk")
.annotate(
comments_count=Count(
"comments__id",
filter=Q(comments__text__icontains="foo"),
)
)
.order_by("pk"),
(
{"pk": self.user_1.pk, "comments_count": 1},
{"pk": self.user_2.pk, "comments_count": 1},
{"pk": self.user_3.pk, "comments_count": 1},
),
)
def test_filter_and_count_users_by_comments_fields(self):
users = User.objects.filter(comments__id__gt=2).order_by("pk")
self.assertEqual(users.count(), 4)
self.assertSequenceEqual(
users, (self.user_1, self.user_3, self.user_3, self.user_3)
)
users = User.objects.filter(comments__text__icontains="foo").order_by("pk")
self.assertEqual(users.count(), 3)
self.assertSequenceEqual(users, (self.user_1, self.user_2, self.user_3))
users = User.objects.filter(comments__text__icontains="baz").order_by("pk")
self.assertEqual(users.count(), 3)
self.assertSequenceEqual(users, (self.user_3, self.user_3, self.user_3))
def test_order_by_comments_id_count(self):
self.assertSequenceEqual(
User.objects.annotate(comments_count=Count("comments__id")).order_by(
"-comments_count"
),
(self.user_3, self.user_1, self.user_2),
)

View File

@@ -0,0 +1,242 @@
from django.core import checks
from django.db import connection, models
from django.db.models import F
from django.test import TestCase
from django.test.utils import isolate_apps
@isolate_apps("composite_pk")
class CompositePKChecksTests(TestCase):
maxDiff = None
def test_composite_pk_must_be_unique_strings(self):
test_cases = (
(),
(0,),
(1,),
("id", False),
("id", "id"),
(("id",),),
)
for i, args in enumerate(test_cases):
with (
self.subTest(args=args),
self.assertRaisesMessage(
ValueError, "CompositePrimaryKey args must be unique strings."
),
):
models.CompositePrimaryKey(*args)
def test_composite_pk_must_include_at_least_2_fields(self):
expected_message = "CompositePrimaryKey must include at least two fields."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("id")
def test_composite_pk_cannot_have_a_default(self):
expected_message = "CompositePrimaryKey cannot have a default."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", default=(1, 1))
def test_composite_pk_cannot_have_a_database_default(self):
expected_message = "CompositePrimaryKey cannot have a database default."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", db_default=models.F("id"))
def test_composite_pk_cannot_be_editable(self):
expected_message = "CompositePrimaryKey cannot be editable."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", editable=True)
def test_composite_pk_must_be_a_primary_key(self):
expected_message = "CompositePrimaryKey must be a primary key."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", primary_key=False)
def test_composite_pk_must_be_blank(self):
expected_message = "CompositePrimaryKey must be blank."
with self.assertRaisesMessage(ValueError, expected_message):
models.CompositePrimaryKey("tenant_id", "id", blank=False)
def test_composite_pk_must_not_have_other_pk_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField(primary_key=True)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"The model cannot have more than one field with "
"'primary_key=True'.",
obj=Foo,
id="models.E026",
),
],
)
def test_composite_pk_cannot_include_nullable_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField(null=True)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'id' cannot be included in the composite primary key.",
hint="'id' field may not set 'null=True'.",
obj=Foo,
id="models.E042",
),
],
)
def test_composite_pk_can_include_fk_name(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.SmallIntegerField()
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(Bar.check(databases=self.databases), [])
def test_composite_pk_cannot_include_same_field(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo", "foo_id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.SmallIntegerField()
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(
Bar.check(databases=self.databases),
[
checks.Error(
"'foo_id' cannot be included in the composite primary key.",
hint="'foo_id' and 'foo' are the same fields.",
obj=Bar,
id="models.E042",
),
],
)
def test_composite_pk_cannot_include_composite_pk_field(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("id", "pk")
id = models.SmallIntegerField()
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'pk' cannot be included in the composite primary key.",
hint="'pk' field has no column.",
obj=Foo,
id="models.E042",
),
],
)
def test_composite_pk_cannot_include_db_column(self):
class Foo(models.Model):
pk = models.CompositePrimaryKey("foo", "bar")
foo = models.SmallIntegerField(db_column="foo_id")
bar = models.SmallIntegerField(db_column="bar_id")
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo_id", "bar_id")
foo = models.SmallIntegerField(db_column="foo_id")
bar = models.SmallIntegerField(db_column="bar_id")
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(
Bar.check(databases=self.databases),
[
checks.Error(
"'foo_id' cannot be included in the composite primary key.",
hint="'foo_id' is not a valid field.",
obj=Bar,
id="models.E042",
),
checks.Error(
"'bar_id' cannot be included in the composite primary key.",
hint="'bar_id' is not a valid field.",
obj=Bar,
id="models.E042",
),
],
)
def test_foreign_object_can_refer_composite_pk(self):
class Foo(models.Model):
pass
class Bar(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.IntegerField()
class Baz(models.Model):
pk = models.CompositePrimaryKey("foo_id", "id")
foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
id = models.IntegerField()
bar_id = models.IntegerField()
bar = models.ForeignObject(
Bar,
on_delete=models.CASCADE,
from_fields=("foo_id", "bar_id"),
to_fields=("foo_id", "id"),
)
self.assertEqual(Foo.check(databases=self.databases), [])
self.assertEqual(Bar.check(databases=self.databases), [])
self.assertEqual(Baz.check(databases=self.databases), [])
def test_composite_pk_must_be_named_pk(self):
class Foo(models.Model):
primary_key = models.CompositePrimaryKey("foo_id", "id")
foo_id = models.IntegerField()
id = models.IntegerField()
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'CompositePrimaryKey' must be named 'pk'.",
obj=Foo._meta.get_field("primary_key"),
id="fields.E013",
),
],
)
def test_composite_pk_cannot_include_generated_field(self):
is_oracle = connection.vendor == "oracle"
class Foo(models.Model):
pk = models.CompositePrimaryKey("id", "foo")
id = models.IntegerField()
foo = models.GeneratedField(
expression=F("id"),
output_field=models.IntegerField(),
db_persist=not is_oracle,
)
self.assertEqual(
Foo.check(databases=self.databases),
[
checks.Error(
"'foo' cannot be included in the composite primary key.",
hint="'foo' field is a generated field.",
obj=Foo,
id="models.E042",
),
],
)

View File

@@ -0,0 +1,138 @@
from django.test import TestCase
from .models import Tenant, User
class CompositePKCreateTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant = Tenant.objects.create()
cls.user = User.objects.create(
tenant=cls.tenant,
id=1,
email="user0001@example.com",
)
def test_create_user(self):
test_cases = (
{"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"},
{"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"},
{"pk": (self.tenant.id, 7424), "email": "user7424@example.com"},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user = User(**fields)
obj = User.objects.create(**fields)
self.assertEqual(obj.tenant_id, self.tenant.id)
self.assertEqual(obj.id, user.id)
self.assertEqual(obj.pk, (self.tenant.id, user.id))
self.assertEqual(obj.email, user.email)
self.assertEqual(count + 1, User.objects.count())
def test_save_user(self):
test_cases = (
{"tenant": self.tenant, "id": 9241, "email": "user9241@example.com"},
{"tenant_id": self.tenant.id, "id": 5132, "email": "user5132@example.com"},
{"pk": (self.tenant.id, 3014), "email": "user3014@example.com"},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user = User(**fields)
self.assertIsNotNone(user.id)
self.assertIsNotNone(user.email)
user.save()
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.tenant, self.tenant)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.email, fields["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())
def test_bulk_create_users(self):
objs = [
User(tenant=self.tenant, id=8291, email="user8291@example.com"),
User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"),
User(pk=(self.tenant.id, 8214), email="user8214@example.com"),
]
obj_1, obj_2, obj_3 = User.objects.bulk_create(objs)
self.assertEqual(obj_1.tenant_id, self.tenant.id)
self.assertEqual(obj_1.id, 8291)
self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id))
self.assertEqual(obj_1.email, "user8291@example.com")
self.assertEqual(obj_2.tenant_id, self.tenant.id)
self.assertEqual(obj_2.id, 4021)
self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id))
self.assertEqual(obj_2.email, "user4021@example.com")
self.assertEqual(obj_3.tenant_id, self.tenant.id)
self.assertEqual(obj_3.id, 8214)
self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id))
self.assertEqual(obj_3.email, "user8214@example.com")
def test_get_or_create_user(self):
test_cases = (
{
"pk": (self.tenant.id, 8314),
"defaults": {"email": "user8314@example.com"},
},
{
"tenant": self.tenant,
"id": 3142,
"defaults": {"email": "user3142@example.com"},
},
{
"tenant_id": self.tenant.id,
"id": 4218,
"defaults": {"email": "user4218@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.get_or_create(**fields)
self.assertIs(created, True)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())
def test_update_or_create_user(self):
test_cases = (
{
"pk": (self.tenant.id, 2931),
"defaults": {"email": "user2931@example.com"},
},
{
"tenant": self.tenant,
"id": 6428,
"defaults": {"email": "user6428@example.com"},
},
{
"tenant_id": self.tenant.id,
"id": 5278,
"defaults": {"email": "user5278@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.update_or_create(**fields)
self.assertIs(created, True)
self.assertIsNotNone(user.id)
self.assertEqual(user.pk, (self.tenant.id, user.id))
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(user.email, f"user{user.id}@example.com")
self.assertEqual(count + 1, User.objects.count())

View File

@@ -0,0 +1,83 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKDeleteTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_2,
id=2,
email="user0002@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_2)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
def test_delete_tenant_by_pk(self):
result = Tenant.objects.filter(pk=self.tenant_1.pk).delete()
self.assertEqual(
result,
(
3,
{
"composite_pk.Comment": 1,
"composite_pk.User": 1,
"composite_pk.Tenant": 1,
},
),
)
self.assertIs(Tenant.objects.filter(pk=self.tenant_1.pk).exists(), False)
self.assertIs(Tenant.objects.filter(pk=self.tenant_2.pk).exists(), True)
self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
def test_delete_user_by_pk(self):
result = User.objects.filter(pk=self.user_1.pk).delete()
self.assertEqual(
result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
)
self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
def test_delete_comments_by_user(self):
result = Comment.objects.filter(user=self.user_2).delete()
self.assertEqual(result, (2, {"composite_pk.Comment": 2}))
self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), True)
self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), False)
self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), False)
def test_delete_without_pk(self):
msg = (
"Comment object can't be deleted because its pk attribute is set "
"to None."
)
with self.assertRaisesMessage(ValueError, msg):
Comment().delete()
with self.assertRaisesMessage(ValueError, msg):
Comment(tenant_id=1).delete()
with self.assertRaisesMessage(ValueError, msg):
Comment(id=1).delete()

View File

@@ -0,0 +1,412 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKFilterTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.tenant_3 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.user_4 = User.objects.create(
tenant=cls.tenant_3,
id=4,
email="user0004@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
cls.comment_5 = Comment.objects.create(id=5, user=cls.user_1)
def test_filter_and_count_user_by_pk(self):
test_cases = (
({"pk": self.user_1.pk}, 1),
({"pk": self.user_2.pk}, 1),
({"pk": self.user_3.pk}, 1),
({"pk": (self.tenant_1.id, self.user_1.id)}, 1),
({"pk": (self.tenant_1.id, self.user_2.id)}, 1),
({"pk": (self.tenant_2.id, self.user_3.id)}, 1),
({"pk": (self.tenant_1.id, self.user_3.id)}, 0),
({"pk": (self.tenant_2.id, self.user_1.id)}, 0),
({"pk": (self.tenant_2.id, self.user_2.id)}, 0),
)
for lookup, count in test_cases:
with self.subTest(lookup=lookup, count=count):
self.assertEqual(User.objects.filter(**lookup).count(), count)
def test_order_comments_by_pk_asc(self):
self.assertSequenceEqual(
Comment.objects.order_by("pk"),
(
self.comment_1, # (1, 1)
self.comment_2, # (1, 2)
self.comment_3, # (1, 3)
self.comment_5, # (1, 5)
self.comment_4, # (2, 4)
),
)
def test_order_comments_by_pk_desc(self):
self.assertSequenceEqual(
Comment.objects.order_by("-pk"),
(
self.comment_4, # (2, 4)
self.comment_5, # (1, 5)
self.comment_3, # (1, 3)
self.comment_2, # (1, 2)
self.comment_1, # (1, 1)
),
)
def test_filter_comments_by_pk_gt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c11, (c12, c13, c15, c24)),
(c12, (c13, c15, c24)),
(c13, (c15, c24)),
(c15, (c24,)),
(c24, ()),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__gt=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_gte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c11, (c11, c12, c13, c15, c24)),
(c12, (c12, c13, c15, c24)),
(c13, (c13, c15, c24)),
(c15, (c15, c24)),
(c24, (c24,)),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__gte=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_lt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c24, (c11, c12, c13, c15)),
(c15, (c11, c12, c13)),
(c13, (c11, c12)),
(c12, (c11,)),
(c11, ()),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__lt=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_lte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
test_cases = (
(c24, (c11, c12, c13, c15, c24)),
(c15, (c11, c12, c13, c15)),
(c13, (c11, c12, c13)),
(c12, (c11, c12)),
(c11, (c11,)),
)
for obj, objs in test_cases:
with self.subTest(obj=obj, objs=objs):
self.assertSequenceEqual(
Comment.objects.filter(pk__lte=obj.pk).order_by("pk"), objs
)
def test_filter_comments_by_pk_in(self):
test_cases = (
(),
(self.comment_1,),
(self.comment_1, self.comment_4),
)
for objs in test_cases:
with self.subTest(objs=objs):
pks = [obj.pk for obj in objs]
self.assertSequenceEqual(
Comment.objects.filter(pk__in=pks).order_by("pk"), objs
)
def test_filter_comments_by_user_and_order_by_pk_asc(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1).order_by("pk"),
(self.comment_1, self.comment_2, self.comment_5),
)
def test_filter_comments_by_user_and_order_by_pk_desc(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1).order_by("-pk"),
(self.comment_5, self.comment_2, self.comment_1),
)
def test_filter_comments_by_user_and_exclude_by_pk(self):
self.assertSequenceEqual(
Comment.objects.filter(user=self.user_1)
.exclude(pk=self.comment_1.pk)
.order_by("pk"),
(self.comment_2, self.comment_5),
)
def test_filter_comments_by_user_and_contains(self):
self.assertIs(
Comment.objects.filter(user=self.user_1).contains(self.comment_1), True
)
def test_filter_users_by_comments_in(self):
c1, c2, c3, c4, c5 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
((), ()),
((c1,), (u1,)),
((c1, c2), (u1, u1)),
((c1, c2, c3), (u1, u1, u2)),
((c1, c2, c3, c4), (u1, u1, u2, u3)),
((c1, c2, c3, c4, c5), (u1, u1, u1, u2, u3)),
)
for comments, users in test_cases:
with self.subTest(comments=comments, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__in=comments).order_by("pk"), users
)
def test_filter_users_by_comments_lt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2 = (
self.user_1,
self.user_2,
)
test_cases = (
(c11, ()),
(c12, (u1,)),
(c13, (u1, u1)),
(c15, (u1, u1, u2)),
(c24, (u1, u1, u1, u2)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__lt=comment).order_by("pk"), users
)
def test_filter_users_by_comments_lte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1,)),
(c12, (u1, u1)),
(c13, (u1, u1, u2)),
(c15, (u1, u1, u1, u2)),
(c24, (u1, u1, u1, u2, u3)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__lte=comment).order_by("pk"), users
)
def test_filter_users_by_comments_gt(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1, u1, u2, u3)),
(c12, (u1, u2, u3)),
(c13, (u1, u3)),
(c15, (u3,)),
(c24, ()),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__gt=comment).order_by("pk"), users
)
def test_filter_users_by_comments_gte(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1, u1, u1, u2, u3)),
(c12, (u1, u1, u2, u3)),
(c13, (u1, u2, u3)),
(c15, (u1, u3)),
(c24, (u3,)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments__gte=comment).order_by("pk"), users
)
def test_filter_users_by_comments_exact(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
u1, u2, u3 = (
self.user_1,
self.user_2,
self.user_3,
)
test_cases = (
(c11, (u1,)),
(c12, (u1,)),
(c13, (u2,)),
(c15, (u1,)),
(c24, (u3,)),
)
for comment, users in test_cases:
with self.subTest(comment=comment, users=users):
self.assertSequenceEqual(
User.objects.filter(comments=comment).order_by("pk"), users
)
def test_filter_users_by_comments_isnull(self):
u1, u2, u3, u4 = (
self.user_1,
self.user_2,
self.user_3,
self.user_4,
)
with self.subTest("comments__isnull=True"):
self.assertSequenceEqual(
User.objects.filter(comments__isnull=True).order_by("pk"),
(u4,),
)
with self.subTest("comments__isnull=False"):
self.assertSequenceEqual(
User.objects.filter(comments__isnull=False).order_by("pk"),
(u1, u1, u1, u2, u3),
)
def test_filter_comments_by_pk_isnull(self):
c11, c12, c13, c24, c15 = (
self.comment_1,
self.comment_2,
self.comment_3,
self.comment_4,
self.comment_5,
)
with self.subTest("pk__isnull=True"):
self.assertSequenceEqual(
Comment.objects.filter(pk__isnull=True).order_by("pk"),
(),
)
with self.subTest("pk__isnull=False"):
self.assertSequenceEqual(
Comment.objects.filter(pk__isnull=False).order_by("pk"),
(c11, c12, c13, c15, c24),
)
def test_filter_users_by_comments_subquery(self):
subquery = Comment.objects.filter(id=3).only("pk")
queryset = User.objects.filter(comments__in=subquery)
self.assertSequenceEqual(queryset, (self.user_2,))

View File

@@ -0,0 +1,126 @@
from django.test import TestCase
from .models import Comment, Tenant, User
class CompositePKGetTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
def test_get_user(self):
test_cases = (
{"pk": self.user_1.pk},
{"pk": (self.tenant_1.id, self.user_1.id)},
{"id": self.user_1.id},
)
for lookup in test_cases:
with self.subTest(lookup=lookup):
self.assertEqual(User.objects.get(**lookup), self.user_1)
def test_get_comment(self):
test_cases = (
{"pk": self.comment_1.pk},
{"pk": (self.tenant_1.id, self.comment_1.id)},
{"id": self.comment_1.id},
{"user": self.user_1},
{"user_id": self.user_1.id},
{"user__id": self.user_1.id},
{"user__pk": self.user_1.pk},
{"tenant": self.tenant_1},
{"tenant_id": self.tenant_1.id},
{"tenant__id": self.tenant_1.id},
{"tenant__pk": self.tenant_1.pk},
)
for lookup in test_cases:
with self.subTest(lookup=lookup):
self.assertEqual(Comment.objects.get(**lookup), self.comment_1)
def test_get_or_create_user(self):
test_cases = (
{
"pk": self.user_1.pk,
"defaults": {"email": "user9201@example.com"},
},
{
"pk": (self.tenant_1.id, self.user_1.id),
"defaults": {"email": "user9201@example.com"},
},
{
"tenant": self.tenant_1,
"id": self.user_1.id,
"defaults": {"email": "user3512@example.com"},
},
{
"tenant_id": self.tenant_1.id,
"id": self.user_1.id,
"defaults": {"email": "user8239@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.get_or_create(**fields)
self.assertIs(created, False)
self.assertEqual(user.id, self.user_1.id)
self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
self.assertEqual(user.tenant_id, self.tenant_1.id)
self.assertEqual(user.email, self.user_1.email)
self.assertEqual(count, User.objects.count())
def test_lookup_errors(self):
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
m_tuple_collection = (
"'in' lookup of 'pk' must be a collection of tuples or lists"
)
m_2_elements_each = "'in' lookup of 'pk' must have 2 elements each"
test_cases = (
({"pk": 1}, m_tuple % "exact"),
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__exact": 1}, m_tuple % "exact"),
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__in": 1}, m_tuple % "in"),
({"pk__in": (1, 2, 3)}, m_tuple_collection),
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
({"pk__gt": 1}, m_tuple % "gt"),
({"pk__gt": (1, 2, 3)}, m_2_elements % "gt"),
({"pk__gte": 1}, m_tuple % "gte"),
({"pk__gte": (1, 2, 3)}, m_2_elements % "gte"),
({"pk__lt": 1}, m_tuple % "lt"),
({"pk__lt": (1, 2, 3)}, m_2_elements % "lt"),
({"pk__lte": 1}, m_tuple % "lte"),
({"pk__lte": (1, 2, 3)}, m_2_elements % "lte"),
)
for kwargs, message in test_cases:
with (
self.subTest(kwargs=kwargs),
self.assertRaisesMessage(ValueError, message),
):
Comment.objects.get(**kwargs)
def test_get_user_by_comments(self):
self.assertEqual(User.objects.get(comments=self.comment_1), self.user_1)

View File

@@ -0,0 +1,153 @@
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.test import TestCase
from .models import Comment, Tenant, Token, User
class CompositePKModelsTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
def test_fields(self):
# tenant_1
self.assertSequenceEqual(
self.tenant_1.user_set.order_by("pk"),
[self.user_1, self.user_2],
)
self.assertSequenceEqual(
self.tenant_1.comments.order_by("pk"),
[self.comment_1, self.comment_2, self.comment_3],
)
# tenant_2
self.assertSequenceEqual(self.tenant_2.user_set.order_by("pk"), [self.user_3])
self.assertSequenceEqual(
self.tenant_2.comments.order_by("pk"), [self.comment_4]
)
# user_1
self.assertEqual(self.user_1.id, 1)
self.assertEqual(self.user_1.tenant_id, self.tenant_1.id)
self.assertEqual(self.user_1.tenant, self.tenant_1)
self.assertEqual(self.user_1.pk, (self.tenant_1.id, self.user_1.id))
self.assertSequenceEqual(
self.user_1.comments.order_by("pk"), [self.comment_1, self.comment_2]
)
# user_2
self.assertEqual(self.user_2.id, 2)
self.assertEqual(self.user_2.tenant_id, self.tenant_1.id)
self.assertEqual(self.user_2.tenant, self.tenant_1)
self.assertEqual(self.user_2.pk, (self.tenant_1.id, self.user_2.id))
self.assertSequenceEqual(self.user_2.comments.order_by("pk"), [self.comment_3])
# comment_1
self.assertEqual(self.comment_1.id, 1)
self.assertEqual(self.comment_1.user_id, self.user_1.id)
self.assertEqual(self.comment_1.user, self.user_1)
self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id)
self.assertEqual(self.comment_1.tenant, self.tenant_1)
self.assertEqual(self.comment_1.pk, (self.tenant_1.id, self.user_1.id))
def test_full_clean_success(self):
test_cases = (
# 1, 1234, {}
({"tenant": self.tenant_1, "id": 1234}, {}),
({"tenant_id": self.tenant_1.id, "id": 1234}, {}),
({"pk": (self.tenant_1.id, 1234)}, {}),
# 1, 1, {"id"}
({"tenant": self.tenant_1, "id": 1}, {"id"}),
({"tenant_id": self.tenant_1.id, "id": 1}, {"id"}),
({"pk": (self.tenant_1.id, 1)}, {"id"}),
# 1, 1, {"tenant", "id"}
({"tenant": self.tenant_1, "id": 1}, {"tenant", "id"}),
({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant", "id"}),
({"pk": (self.tenant_1.id, 1)}, {"tenant", "id"}),
)
for kwargs, exclude in test_cases:
with self.subTest(kwargs):
kwargs["email"] = "user0004@example.com"
User(**kwargs).full_clean(exclude=exclude)
def test_full_clean_failure(self):
e_tenant_and_id = "User with this Tenant and Id already exists."
e_id = "User with this Id already exists."
test_cases = (
# 1, 1, {}
({"tenant": self.tenant_1, "id": 1}, {}, (e_tenant_and_id, e_id)),
({"tenant_id": self.tenant_1.id, "id": 1}, {}, (e_tenant_and_id, e_id)),
({"pk": (self.tenant_1.id, 1)}, {}, (e_tenant_and_id, e_id)),
# 2, 1, {}
({"tenant": self.tenant_2, "id": 1}, {}, (e_id,)),
({"tenant_id": self.tenant_2.id, "id": 1}, {}, (e_id,)),
({"pk": (self.tenant_2.id, 1)}, {}, (e_id,)),
# 1, 1, {"tenant"}
({"tenant": self.tenant_1, "id": 1}, {"tenant"}, (e_id,)),
({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant"}, (e_id,)),
({"pk": (self.tenant_1.id, 1)}, {"tenant"}, (e_id,)),
)
for kwargs, exclude, messages in test_cases:
with self.subTest(kwargs):
with self.assertRaises(ValidationError) as ctx:
kwargs["email"] = "user0004@example.com"
User(**kwargs).full_clean(exclude=exclude)
self.assertSequenceEqual(ctx.exception.messages, messages)
def test_field_conflicts(self):
test_cases = (
({"pk": (1, 1), "id": 2}, (1, 1)),
({"id": 2, "pk": (1, 1)}, (1, 1)),
({"pk": (1, 1), "tenant_id": 2}, (1, 1)),
({"tenant_id": 2, "pk": (1, 1)}, (1, 1)),
({"pk": (2, 2), "tenant_id": 3, "id": 4}, (2, 2)),
({"tenant_id": 3, "id": 4, "pk": (2, 2)}, (2, 2)),
)
for kwargs, pk in test_cases:
with self.subTest(kwargs=kwargs):
user = User(**kwargs)
self.assertEqual(user.pk, pk)
def test_validate_unique(self):
user = User.objects.get(pk=self.user_1.pk)
user.id = None
with self.assertRaises(ValidationError) as ctx:
user.validate_unique()
self.assertSequenceEqual(
ctx.exception.messages, ("User with this Email already exists.",)
)
def test_permissions(self):
token = ContentType.objects.get_for_model(Token)
user = ContentType.objects.get_for_model(User)
comment = ContentType.objects.get_for_model(Comment)
self.assertEqual(4, token.permission_set.count())
self.assertEqual(4, user.permission_set.count())
self.assertEqual(4, comment.permission_set.count())

View File

@@ -0,0 +1,134 @@
from django.db.models.query_utils import PathInfo
from django.db.models.sql import Query
from django.test import TestCase
from .models import Comment, Tenant, User
class NamesToPathTests(TestCase):
def test_id(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["id"], User._meta)
self.assertEqual(path, [])
self.assertEqual(final_field, User._meta.get_field("id"))
self.assertEqual(targets, (User._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_pk(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["pk"], User._meta)
self.assertEqual(path, [])
self.assertEqual(final_field, User._meta.get_field("pk"))
self.assertEqual(targets, (User._meta.get_field("pk"),))
self.assertEqual(rest, [])
def test_tenant_id(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(
["tenant", "id"], User._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=User._meta,
to_opts=Tenant._meta,
target_fields=(Tenant._meta.get_field("id"),),
join_field=User._meta.get_field("tenant"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, Tenant._meta.get_field("id"))
self.assertEqual(targets, (Tenant._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_user_id(self):
query = Query(Comment)
path, final_field, targets, rest = query.names_to_path(
["user", "id"], Comment._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=Comment._meta,
to_opts=User._meta,
target_fields=(
User._meta.get_field("tenant"),
User._meta.get_field("id"),
),
join_field=Comment._meta.get_field("user"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, User._meta.get_field("id"))
self.assertEqual(targets, (User._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_user_tenant_id(self):
query = Query(Comment)
path, final_field, targets, rest = query.names_to_path(
["user", "tenant", "id"], Comment._meta
)
self.assertEqual(
path,
[
PathInfo(
from_opts=Comment._meta,
to_opts=User._meta,
target_fields=(
User._meta.get_field("tenant"),
User._meta.get_field("id"),
),
join_field=Comment._meta.get_field("user"),
m2m=False,
direct=True,
filtered_relation=None,
),
PathInfo(
from_opts=User._meta,
to_opts=Tenant._meta,
target_fields=(Tenant._meta.get_field("id"),),
join_field=User._meta.get_field("tenant"),
m2m=False,
direct=True,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, Tenant._meta.get_field("id"))
self.assertEqual(targets, (Tenant._meta.get_field("id"),))
self.assertEqual(rest, [])
def test_comments(self):
query = Query(User)
path, final_field, targets, rest = query.names_to_path(["comments"], User._meta)
self.assertEqual(
path,
[
PathInfo(
from_opts=User._meta,
to_opts=Comment._meta,
target_fields=(Comment._meta.get_field("pk"),),
join_field=User._meta.get_field("comments"),
m2m=True,
direct=False,
filtered_relation=None,
),
],
)
self.assertEqual(final_field, User._meta.get_field("comments"))
self.assertEqual(targets, (Comment._meta.get_field("pk"),))
self.assertEqual(rest, [])

View File

@@ -0,0 +1,135 @@
from django.test import TestCase
from .models import Comment, Tenant, Token, User
class CompositePKUpdateTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant_1 = Tenant.objects.create(name="A")
cls.tenant_2 = Tenant.objects.create(name="B")
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2,
id=3,
email="user0003@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
cls.token_1 = Token.objects.create(id=1, tenant=cls.tenant_1)
cls.token_2 = Token.objects.create(id=2, tenant=cls.tenant_2)
cls.token_3 = Token.objects.create(id=3, tenant=cls.tenant_1)
cls.token_4 = Token.objects.create(id=4, tenant=cls.tenant_2)
def test_update_user(self):
email = "user9315@example.com"
result = User.objects.filter(pk=self.user_1.pk).update(email=email)
self.assertEqual(result, 1)
user = User.objects.get(pk=self.user_1.pk)
self.assertEqual(user.email, email)
def test_save_user(self):
count = User.objects.count()
email = "user9314@example.com"
user = User.objects.get(pk=self.user_1.pk)
user.email = email
user.save()
user.refresh_from_db()
self.assertEqual(user.email, email)
user = User.objects.get(pk=self.user_1.pk)
self.assertEqual(user.email, email)
self.assertEqual(count, User.objects.count())
def test_bulk_update_comments(self):
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
comment_3 = Comment.objects.get(pk=self.comment_3.pk)
comment_1.text = "foo"
comment_2.text = "bar"
comment_3.text = "baz"
result = Comment.objects.bulk_update(
[comment_1, comment_2, comment_3], ["text"]
)
self.assertEqual(result, 3)
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
comment_3 = Comment.objects.get(pk=self.comment_3.pk)
self.assertEqual(comment_1.text, "foo")
self.assertEqual(comment_2.text, "bar")
self.assertEqual(comment_3.text, "baz")
def test_update_or_create_user(self):
test_cases = (
{
"pk": self.user_1.pk,
"defaults": {"email": "user3914@example.com"},
},
{
"pk": (self.tenant_1.id, self.user_1.id),
"defaults": {"email": "user9375@example.com"},
},
{
"tenant": self.tenant_1,
"id": self.user_1.id,
"defaults": {"email": "user3517@example.com"},
},
{
"tenant_id": self.tenant_1.id,
"id": self.user_1.id,
"defaults": {"email": "user8391@example.com"},
},
)
for fields in test_cases:
with self.subTest(fields=fields):
count = User.objects.count()
user, created = User.objects.update_or_create(**fields)
self.assertIs(created, False)
self.assertEqual(user.id, self.user_1.id)
self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
self.assertEqual(user.tenant_id, self.tenant_1.id)
self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(count, User.objects.count())
def test_update_comment_by_user_email(self):
result = Comment.objects.filter(user__email=self.user_1.email).update(
text="foo"
)
self.assertEqual(result, 2)
comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk)
self.assertEqual(comment_1.text, "foo")
self.assertEqual(comment_2.text, "foo")
def test_update_token_by_tenant_name(self):
result = Token.objects.filter(tenant__name="A").update(secret="bar")
self.assertEqual(result, 2)
token_1 = Token.objects.get(pk=self.token_1.pk)
self.assertEqual(token_1.secret, "bar")
token_3 = Token.objects.get(pk=self.token_3.pk)
self.assertEqual(token_3.secret, "bar")
def test_cant_update_to_unsaved_object(self):
msg = (
"Unsaved model instance <User: User object ((None, None))> cannot be used "
"in an ORM query."
)
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.update(user=User())

View File

@@ -0,0 +1,212 @@
from collections import namedtuple
from uuid import UUID
from django.test import TestCase
from .models import Post, Tenant, User
class CompositePKValuesTests(TestCase):
USER_1_EMAIL = "user0001@example.com"
USER_2_EMAIL = "user0002@example.com"
USER_3_EMAIL = "user0003@example.com"
POST_1_ID = "77777777-7777-7777-7777-777777777777"
POST_2_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
POST_3_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1, id=1, email=cls.USER_1_EMAIL
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1, id=2, email=cls.USER_2_EMAIL
)
cls.user_3 = User.objects.create(
tenant=cls.tenant_2, id=3, email=cls.USER_3_EMAIL
)
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_1_ID)
cls.post_2 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_2_ID)
cls.post_3 = Post.objects.create(tenant=cls.tenant_2, id=cls.POST_3_ID)
def test_values_list(self):
with self.subTest('User.objects.values_list("pk")'):
self.assertSequenceEqual(
User.objects.values_list("pk").order_by("pk"),
(
(self.user_1.pk,),
(self.user_2.pk,),
(self.user_3.pk,),
),
)
with self.subTest('User.objects.values_list("pk", "email")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "email").order_by("pk"),
(
(self.user_1.pk, self.USER_1_EMAIL),
(self.user_2.pk, self.USER_2_EMAIL),
(self.user_3.pk, self.USER_3_EMAIL),
),
)
with self.subTest('User.objects.values_list("pk", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.id),
(self.user_2.pk, self.user_2.id),
(self.user_3.pk, self.user_3.id),
),
)
with self.subTest('User.objects.values_list("pk", "tenant_id", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "tenant_id", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.tenant_id, self.user_1.id),
(self.user_2.pk, self.user_2.tenant_id, self.user_2.id),
(self.user_3.pk, self.user_3.tenant_id, self.user_3.id),
),
)
with self.subTest('User.objects.values_list("pk", flat=True)'):
self.assertSequenceEqual(
User.objects.values_list("pk", flat=True).order_by("pk"),
(
self.user_1.pk,
self.user_2.pk,
self.user_3.pk,
),
)
with self.subTest('Post.objects.values_list("pk", flat=True)'):
self.assertSequenceEqual(
Post.objects.values_list("pk", flat=True).order_by("pk"),
(
(self.tenant_1.id, UUID(self.POST_1_ID)),
(self.tenant_1.id, UUID(self.POST_2_ID)),
(self.tenant_2.id, UUID(self.POST_3_ID)),
),
)
with self.subTest('Post.objects.values_list("pk")'):
self.assertSequenceEqual(
Post.objects.values_list("pk").order_by("pk"),
(
((self.tenant_1.id, UUID(self.POST_1_ID)),),
((self.tenant_1.id, UUID(self.POST_2_ID)),),
((self.tenant_2.id, UUID(self.POST_3_ID)),),
),
)
with self.subTest('Post.objects.values_list("pk", "id")'):
self.assertSequenceEqual(
Post.objects.values_list("pk", "id").order_by("pk"),
(
((self.tenant_1.id, UUID(self.POST_1_ID)), UUID(self.POST_1_ID)),
((self.tenant_1.id, UUID(self.POST_2_ID)), UUID(self.POST_2_ID)),
((self.tenant_2.id, UUID(self.POST_3_ID)), UUID(self.POST_3_ID)),
),
)
with self.subTest('Post.objects.values_list("id", "pk")'):
self.assertSequenceEqual(
Post.objects.values_list("id", "pk").order_by("pk"),
(
(UUID(self.POST_1_ID), (self.tenant_1.id, UUID(self.POST_1_ID))),
(UUID(self.POST_2_ID), (self.tenant_1.id, UUID(self.POST_2_ID))),
(UUID(self.POST_3_ID), (self.tenant_2.id, UUID(self.POST_3_ID))),
),
)
with self.subTest('User.objects.values_list("pk", named=True)'):
Row = namedtuple("Row", ["pk"])
self.assertSequenceEqual(
User.objects.values_list("pk", named=True).order_by("pk"),
(
Row(pk=self.user_1.pk),
Row(pk=self.user_2.pk),
Row(pk=self.user_3.pk),
),
)
with self.subTest('User.objects.values_list("pk", "pk")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "pk").order_by("pk"),
(
(self.user_1.pk,),
(self.user_2.pk,),
(self.user_3.pk,),
),
)
with self.subTest('User.objects.values_list("pk", "id", "pk", "id")'):
self.assertSequenceEqual(
User.objects.values_list("pk", "id", "pk", "id").order_by("pk"),
(
(self.user_1.pk, self.user_1.id),
(self.user_2.pk, self.user_2.id),
(self.user_3.pk, self.user_3.id),
),
)
def test_values(self):
with self.subTest('User.objects.values("pk")'):
self.assertSequenceEqual(
User.objects.values("pk").order_by("pk"),
(
{"pk": self.user_1.pk},
{"pk": self.user_2.pk},
{"pk": self.user_3.pk},
),
)
with self.subTest('User.objects.values("pk", "email")'):
self.assertSequenceEqual(
User.objects.values("pk", "email").order_by("pk"),
(
{"pk": self.user_1.pk, "email": self.USER_1_EMAIL},
{"pk": self.user_2.pk, "email": self.USER_2_EMAIL},
{"pk": self.user_3.pk, "email": self.USER_3_EMAIL},
),
)
with self.subTest('User.objects.values("pk", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "id").order_by("pk"),
(
{"pk": self.user_1.pk, "id": self.user_1.id},
{"pk": self.user_2.pk, "id": self.user_2.id},
{"pk": self.user_3.pk, "id": self.user_3.id},
),
)
with self.subTest('User.objects.values("pk", "tenant_id", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "tenant_id", "id").order_by("pk"),
(
{
"pk": self.user_1.pk,
"tenant_id": self.user_1.tenant_id,
"id": self.user_1.id,
},
{
"pk": self.user_2.pk,
"tenant_id": self.user_2.tenant_id,
"id": self.user_2.id,
},
{
"pk": self.user_3.pk,
"tenant_id": self.user_3.tenant_id,
"id": self.user_3.id,
},
),
)
with self.subTest('User.objects.values("pk", "pk")'):
self.assertSequenceEqual(
User.objects.values("pk", "pk").order_by("pk"),
(
{"pk": self.user_1.pk},
{"pk": self.user_2.pk},
{"pk": self.user_3.pk},
),
)
with self.subTest('User.objects.values("pk", "id", "pk", "id")'):
self.assertSequenceEqual(
User.objects.values("pk", "id", "pk", "id").order_by("pk"),
(
{"pk": self.user_1.pk, "id": self.user_1.id},
{"pk": self.user_2.pk, "id": self.user_2.id},
{"pk": self.user_3.pk, "id": self.user_3.id},
),
)

345
tests/composite_pk/tests.py Normal file
View File

@@ -0,0 +1,345 @@
import json
import unittest
from uuid import UUID
import yaml
from django import forms
from django.core import serializers
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection
from django.db.models import CompositePrimaryKey
from django.forms import modelform_factory
from django.test import TestCase
from .models import Comment, Post, Tenant, User
class CommentForm(forms.ModelForm):
class Meta:
model = Comment
fields = "__all__"
class CompositePKTests(TestCase):
maxDiff = None
@classmethod
def setUpTestData(cls):
cls.tenant = Tenant.objects.create()
cls.user = User.objects.create(
tenant=cls.tenant,
id=1,
email="user0001@example.com",
)
cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
@staticmethod
def get_constraints(table):
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
def test_pk_updated_if_field_updated(self):
user = User.objects.get(pk=self.user.pk)
self.assertEqual(user.pk, (self.tenant.id, self.user.id))
self.assertIs(user._is_pk_set(), True)
user.tenant_id = 9831
self.assertEqual(user.pk, (9831, self.user.id))
self.assertIs(user._is_pk_set(), True)
user.id = 4321
self.assertEqual(user.pk, (9831, 4321))
self.assertIs(user._is_pk_set(), True)
user.pk = (9132, 3521)
self.assertEqual(user.tenant_id, 9132)
self.assertEqual(user.id, 3521)
self.assertIs(user._is_pk_set(), True)
user.id = None
self.assertEqual(user.pk, (9132, None))
self.assertEqual(user.tenant_id, 9132)
self.assertIsNone(user.id)
self.assertIs(user._is_pk_set(), False)
def test_hash(self):
self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
msg = "Model instances without primary key value are unhashable"
with self.assertRaisesMessage(TypeError, msg):
hash(User())
with self.assertRaisesMessage(TypeError, msg):
hash(User(tenant_id=1))
with self.assertRaisesMessage(TypeError, msg):
hash(User(id=1))
def test_pk_must_be_list_or_tuple(self):
user = User.objects.get(pk=self.user.pk)
test_cases = [
"foo",
1000,
3.14,
True,
False,
]
for pk in test_cases:
with self.assertRaisesMessage(
ValueError, "'pk' must be a list or a tuple."
):
user.pk = pk
def test_pk_must_have_2_elements(self):
user = User.objects.get(pk=self.user.pk)
test_cases = [
(),
[],
(1000,),
[1000],
(1, 2, 3),
[1, 2, 3],
]
for pk in test_cases:
with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
user.pk = pk
def test_composite_pk_in_fields(self):
user_fields = {f.name for f in User._meta.get_fields()}
self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
comment_fields = {f.name for f in Comment._meta.get_fields()}
self.assertEqual(
comment_fields,
{"pk", "tenant", "id", "user_id", "user", "text"},
)
def test_pk_field(self):
pk = User._meta.get_field("pk")
self.assertIsInstance(pk, CompositePrimaryKey)
self.assertIs(User._meta.pk, pk)
def test_error_on_user_pk_conflict(self):
with self.assertRaises(IntegrityError):
User.objects.create(tenant=self.tenant, id=self.user.id)
def test_error_on_comment_pk_conflict(self):
with self.assertRaises(IntegrityError):
Comment.objects.create(tenant=self.tenant, id=self.comment.id)
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
def test_get_constraints_postgresql(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["composite_pk_user_pkey"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["composite_pk_comment_pkey"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
def test_get_constraints_sqlite(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["__primary__"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["__primary__"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
def test_get_constraints_mysql(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = user_constraints["PRIMARY"]
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertIs(user_pk["primary_key"], True)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = comment_constraints["PRIMARY"]
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertIs(comment_pk["primary_key"], True)
@unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
def test_get_constraints_oracle(self):
user_constraints = self.get_constraints(User._meta.db_table)
user_pk = next(c for c in user_constraints.values() if c["primary_key"])
self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
self.assertEqual(user_pk["primary_key"], 1)
comment_constraints = self.get_constraints(Comment._meta.db_table)
comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
self.assertEqual(comment_pk["primary_key"], 1)
def test_in_bulk(self):
"""
Test the .in_bulk() method of composite_pk models.
"""
result = Comment.objects.in_bulk()
self.assertEqual(result, {self.comment.pk: self.comment})
result = Comment.objects.in_bulk([self.comment.pk])
self.assertEqual(result, {self.comment.pk: self.comment})
def test_iterator(self):
"""
Test the .iterator() method of composite_pk models.
"""
result = list(Comment.objects.iterator())
self.assertEqual(result, [self.comment])
def test_query(self):
users = User.objects.values_list("pk").order_by("pk")
self.assertNotIn('AS "pk"', str(users.query))
def test_only(self):
users = User.objects.only("pk")
self.assertSequenceEqual(users, (self.user,))
user = users[0]
with self.assertNumQueries(0):
self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
self.assertEqual(user.tenant_id, self.user.tenant_id)
self.assertEqual(user.id, self.user.id)
with self.assertNumQueries(1):
self.assertEqual(user.email, self.user.email)
def test_model_forms(self):
fields = ["tenant", "id", "user_id", "text"]
self.assertEqual(list(CommentForm.base_fields), fields)
form = modelform_factory(Comment, fields="__all__")
self.assertEqual(list(form().fields), fields)
with self.assertRaisesMessage(
FieldError, "Unknown field(s) (pk) specified for Comment"
):
self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
class CompositePKFixturesTests(TestCase):
fixtures = ["tenant"]
def test_objects(self):
tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
self.assertEqual(tenant_1.id, 1)
self.assertEqual(tenant_1.name, "Tenant 1")
self.assertEqual(tenant_2.id, 2)
self.assertEqual(tenant_2.name, "Tenant 2")
self.assertEqual(tenant_3.id, 3)
self.assertEqual(tenant_3.name, "Tenant 3")
user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
self.assertEqual(user_1.id, 1)
self.assertEqual(user_1.tenant_id, 1)
self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
self.assertEqual(user_1.email, "user0001@example.com")
self.assertEqual(user_2.id, 2)
self.assertEqual(user_2.tenant_id, 1)
self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
self.assertEqual(user_2.email, "user0002@example.com")
self.assertEqual(user_3.id, 3)
self.assertEqual(user_3.tenant_id, 2)
self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
self.assertEqual(user_3.email, "user0003@example.com")
self.assertEqual(user_4.id, 4)
self.assertEqual(user_4.tenant_id, 2)
self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
self.assertEqual(user_4.email, "user0004@example.com")
post_1, post_2 = Post.objects.order_by("pk")
self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
self.assertEqual(post_1.tenant_id, 2)
self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
self.assertEqual(post_2.tenant_id, 2)
self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
def test_serialize_user_json(self):
users = User.objects.filter(pk=(1, 1))
result = serializers.serialize("json", users)
self.assertEqual(
json.loads(result),
[
{
"model": "composite_pk.user",
"pk": [1, 1],
"fields": {
"email": "user0001@example.com",
"id": 1,
"tenant": 1,
},
}
],
)
def test_serialize_user_jsonl(self):
users = User.objects.filter(pk=(1, 2))
result = serializers.serialize("jsonl", users)
self.assertEqual(
json.loads(result),
{
"model": "composite_pk.user",
"pk": [1, 2],
"fields": {
"email": "user0002@example.com",
"id": 2,
"tenant": 1,
},
},
)
def test_serialize_user_yaml(self):
users = User.objects.filter(pk=(2, 3))
result = serializers.serialize("yaml", users)
self.assertEqual(
yaml.safe_load(result),
[
{
"model": "composite_pk.user",
"pk": [2, 3],
"fields": {
"email": "user0003@example.com",
"id": 3,
"tenant": 2,
},
},
],
)
def test_serialize_user_python(self):
users = User.objects.filter(pk=(2, 4))
result = serializers.serialize("python", users)
self.assertEqual(
result,
[
{
"model": "composite_pk.user",
"pk": [2, 4],
"fields": {
"email": "user0004@example.com",
"id": 4,
"tenant": 2,
},
},
],
)
def test_serialize_post_uuid(self):
posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
result = serializers.serialize("json", posts)
self.assertEqual(
json.loads(result),
[
{
"model": "composite_pk.post",
"pk": [2, "11111111-1111-1111-1111-111111111111"],
"fields": {
"id": "11111111-1111-1111-1111-111111111111",
"tenant": 2,
},
},
],
)