1
0
mirror of https://github.com/django/django.git synced 2025-06-08 21:19:13 +00:00
django/tests/composite_pk/test_generic.py

315 lines
13 KiB
Python

from datetime import date, datetime
from unittest import skipIf
from uuid import UUID
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.db import connection
from django.db.models import Count
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from .models import Comment, Dummy, Post, Tag, Tenant, User
class CompositePKGenericTests(TestCase):
POST_1_ID = "e1516ac0-4469-4306-b0ac-2c435e677aa4"
DUMMY_1_UUID = "81598d5a-fa03-4800-bae4-823ca12824d5"
DUMMY_2_UUID = "ac567f4c-9b6e-4770-86a6-1abf071a2958"
@classmethod
def setUpTestData(cls):
cls.dummy_1 = Dummy.objects.create(
small_integer=32767,
integer=2147483647,
big_integer=9223372036854775807,
datetime=datetime(2024, 11, 30, 6, 26, 1),
date=date(2024, 11, 30),
uuid=UUID(cls.DUMMY_1_UUID),
char="薑戈",
)
cls.dummy_2 = Dummy.objects.create(
small_integer=-32768,
integer=-2147483648,
big_integer=-9223372036854775808,
datetime=datetime(2024, 12, 8, 1, 2, 3),
date=date(2024, 12, 8),
uuid=UUID(cls.DUMMY_2_UUID),
char="😊",
)
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID))
cls.comment_1_tag = Tag.objects.create(
name="comment_1", content_object=cls.comment_1
)
cls.post_1_tag = Tag.objects.create(name="post_1", content_object=cls.post_1)
cls.dummy_1_tag = Tag.objects.create(name="dummy_1", content_object=cls.dummy_1)
cls.dummy_2_tag = Tag.objects.create(name="dummy_2", content_object=cls.dummy_2)
cls.comment_ct = ContentType.objects.get_for_model(Comment)
cls.post_ct = ContentType.objects.get_for_model(Post)
cls.dummy_ct = ContentType.objects.get_for_model(Dummy)
cls.post_1_fk = f'[{cls.tenant_1.id}, "{cls.POST_1_ID}"]'
cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]"
cls.dummy_1_fk = (
'[32767, 2147483647, 9223372036854775807, "2024-11-30T06:26:01", '
f'"2024-11-30", "{cls.DUMMY_1_UUID}", "\\u8591\\u6208"]'
)
cls.dummy_2_fk = (
'[-32768, -2147483648, -9223372036854775808, "2024-12-08T01:02:03", '
f'"2024-12-08", "{cls.DUMMY_2_UUID}", "\\ud83d\\ude0a"]'
)
def test_fields(self):
self.assertEqual(self.comment_1_tag.content_type, self.comment_ct)
self.assertEqual(self.comment_1_tag.object_id, self.comment_1_fk)
self.assertEqual(self.comment_1_tag.content_object, self.comment_1)
self.assertEqual(self.post_1_tag.content_type, self.post_ct)
self.assertEqual(self.post_1_tag.object_id, self.post_1_fk)
self.assertEqual(self.post_1_tag.content_object, self.post_1)
self.assertEqual(self.dummy_1_tag.content_type, self.dummy_ct)
self.assertEqual(self.dummy_1_tag.object_id, self.dummy_1_fk)
self.assertEqual(self.dummy_1_tag.content_object, self.dummy_1)
self.assertEqual(self.dummy_2_tag.content_type, self.dummy_ct)
self.assertEqual(self.dummy_2_tag.object_id, self.dummy_2_fk)
self.assertEqual(self.dummy_2_tag.content_object, self.dummy_2)
self.assertSequenceEqual(self.post_1.tags.all(), (self.post_1_tag,))
self.assertSequenceEqual(self.dummy_1.tags.all(), (self.dummy_1_tag,))
self.assertSequenceEqual(self.dummy_2.tags.all(), (self.dummy_2_tag,))
def test_fields_before_save(self):
comment = Comment(pk=(1, 2))
tag = Tag(content_object=comment)
self.assertEqual(tag.object_id, "[1, 2]")
comment.id = 3
self.assertEqual(tag.object_id, "[1, 2]")
def test_cascade_delete_if_generic_relation(self):
Post.objects.get(pk=self.post_1.pk).delete()
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
Dummy.objects.get(pk=self.dummy_1.pk).delete()
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_no_cascade_delete_if_no_generic_relation(self):
Comment.objects.get(pk=self.comment_1.pk).delete()
comment_1_tag = Tag.objects.get(pk=self.comment_1_tag.pk)
self.assertIsNone(comment_1_tag.content_object)
def test_tags_clear(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.clear()
self.assertEqual(post_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
dummy_1.tags.clear()
self.assertEqual(dummy_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_tags_remove(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.remove(self.post_1_tag)
self.assertEqual(post_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
dummy_1.tags.remove(self.dummy_1_tag)
self.assertEqual(dummy_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_tags_create(self):
tag_count = Tag.objects.count()
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.create(name="foo")
self.assertEqual(post_1.tags.count(), 2)
self.assertEqual(Tag.objects.count(), tag_count + 1)
tag = Tag.objects.get(name="foo")
self.assertEqual(tag.content_type, self.post_ct)
self.assertEqual(tag.object_id, self.post_1_fk)
self.assertEqual(tag.content_object, post_1)
def test_tags_add(self):
tag_count = Tag.objects.count()
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1 = Tag(name="foo")
post_1.tags.add(tag_1, bulk=False)
self.assertEqual(post_1.tags.count(), 2)
self.assertEqual(Tag.objects.count(), tag_count + 1)
tag_1 = Tag.objects.get(name="foo")
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2 = Tag.objects.create(name="bar", content_object=self.comment_2)
post_1.tags.add(tag_2)
self.assertEqual(post_1.tags.count(), 3)
self.assertEqual(Tag.objects.count(), tag_count + 2)
tag_2 = Tag.objects.get(name="bar")
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_tags_set(self):
tag_count = Tag.objects.count()
comment_1_tag = Tag.objects.get(name=self.comment_1_tag.name)
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.set([comment_1_tag])
self.assertEqual(post_1.tags.count(), 1)
self.assertEqual(Tag.objects.count(), tag_count - 1)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
def test_tags_get_or_create(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1, created = post_1.tags.get_or_create(name=self.post_1_tag.name)
self.assertFalse(created)
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2, created = post_1.tags.get_or_create(name="foo")
self.assertTrue(created)
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_tags_update_or_create(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1, created = post_1.tags.update_or_create(
name=self.post_1_tag.name, defaults={"name": "foo"}
)
self.assertFalse(created)
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
self.assertEqual(tag_1.name, "foo")
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2, created = post_1.tags.update_or_create(name="bar")
self.assertTrue(created)
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_filter_by_related_query_name(self):
self.assertSequenceEqual(
Tag.objects.filter(post__id=self.post_1.id), (self.post_1_tag,)
)
self.assertSequenceEqual(
Tag.objects.filter(dummy__big_integer=self.dummy_1.big_integer),
(self.dummy_1_tag,),
)
@skipIf(
connection.vendor == "mysql" and connection.mysql_is_mariadb,
"MariaDB's JSON_UNQUOTE doesn't support surrogate pairs "
"(https://jira.mariadb.org/browse/MDEV-21124)",
)
def test_aggregate(self):
with self.subTest("Post"):
with CaptureQueriesContext(connection) as ctx:
self.assertEqual(
Post.objects.aggregate(Count("tags")),
{"tags__count": 1},
ctx[-1]["sql"],
)
with self.subTest("Dummy"):
with CaptureQueriesContext(connection) as ctx:
self.assertEqual(
Dummy.objects.aggregate(Count("tags")),
{"tags__count": 2},
ctx[-1]["sql"],
)
def test_generic_prefetch(self):
tags = Tag.objects.prefetch_related(
GenericPrefetch(
"content_object",
[
Post.objects.all(),
Comment.objects.all(),
Dummy.objects.all(),
],
)
).order_by("pk")
self.assertEqual(len(tags), 4)
self.assertEqual(tags[0], self.comment_1_tag)
self.assertEqual(tags[1], self.post_1_tag)
self.assertEqual(tags[2], self.dummy_1_tag)
with self.assertNumQueries(0):
self.assertEqual(tags[0].content_object, self.comment_1)
with self.assertNumQueries(0):
self.assertEqual(tags[1].content_object, self.post_1)
with self.assertNumQueries(0):
self.assertEqual(tags[2].content_object, self.dummy_1)
with self.assertNumQueries(0):
self.assertEqual(tags[3].content_object, self.dummy_2)
def test_to_json(self):
field = Post._meta.pk
self.assertEqual(
field.to_json((1, "004bc28c-a085-44ca-a823-747dcf4ddfd3")),
'[1, "004bc28c-a085-44ca-a823-747dcf4ddfd3"]',
)
self.assertEqual(
field.to_json((2, UUID("a5a626e5-d34a-4197-94ea-83547029d6ed"))),
'[2, "a5a626e5-d34a-4197-94ea-83547029d6ed"]',
)
def test_to_json_length_mismatch(self):
field = Post._meta.pk
msg = "CompositePrimaryKey has 2 fields but it tried to serialize %s."
with self.assertRaisesMessage(ValueError, msg % 1):
self.assertIsNone(field.to_json((1,)))
with self.assertRaisesMessage(ValueError, msg % 3):
self.assertIsNone(field.to_json((1, 2, 3)))
def test_from_json(self):
field = Post._meta.pk
self.assertEqual(
field.from_json('[1, "cb6b99bc-66b0-497f-bafe-193b90af6296"]'),
(1, UUID("cb6b99bc-66b0-497f-bafe-193b90af6296")),
)
self.assertEqual(
field.from_json('[2, "1d5c63dda5264a12a2ec929d51e04430"]'),
(2, UUID("1d5c63dd-a526-4a12-a2ec-929d51e04430")),
)
def test_from_json_length_mismatch(self):
field = Post._meta.pk
msg = (
"CompositePrimaryKey has 2 fields but it tried to deserialize %s. "
"Did you change the CompositePrimaryKey fields and forgot to "
'update the related GenericForeignKey "object_id" fields?'
)
with self.assertRaisesMessage(ValueError, msg % 1):
self.assertIsNone(field.from_json("[1]"))
with self.assertRaisesMessage(ValueError, msg % 3):
self.assertIsNone(field.from_json("[1, 2, 3]"))