1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #35941 -- Added composite GenericForeignKey support.

This commit is contained in:
Bendeguz Csirmaz 2024-04-07 10:32:16 +08:00
parent 1860a1afc9
commit b641b6a90e
16 changed files with 657 additions and 23 deletions

View File

@ -25,6 +25,20 @@ from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.functional import cached_property
def serialize_pk(obj):
opts = obj._meta
if opts.is_composite_pk:
return opts.pk.to_json(obj.pk)
return obj.pk
def deserialize_pk(pk, ct):
opts = ct.model_class()._meta
if opts.is_composite_pk:
return opts.pk.from_json(pk)
return pk
class GenericForeignKey(FieldCacheMixin, Field):
"""
Provide a generic many-to-one relation through the ``content_type`` and
@ -195,7 +209,9 @@ class GenericForeignKey(FieldCacheMixin, Field):
if ct_id is not None:
fk_val = getattr(instance, self.fk_field)
if fk_val is not None:
fk_dict[ct_id].add(fk_val)
ct = self.get_content_type(id=ct_id)
pk_val = deserialize_pk(fk_val, ct)
fk_dict[ct_id].add(pk_val)
instance_dict[ct_id] = instance
ret_val = []
@ -225,7 +241,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
return (
ret_val,
lambda obj: (obj.pk, obj.__class__),
lambda obj: (serialize_pk(obj), obj.__class__),
gfk_key,
True,
self.name,
@ -242,15 +258,15 @@ class GenericForeignKey(FieldCacheMixin, Field):
# use ContentType.objects.get_for_id(), which has a global cache.
f = self.model._meta.get_field(self.ct_field)
ct_id = getattr(instance, f.attname, None)
pk_val = getattr(instance, self.fk_field)
fk_val = getattr(instance, self.fk_field)
rel_obj = self.get_cached_value(instance, default=None)
if rel_obj is None and self.is_cached(instance):
return rel_obj
if rel_obj is not None:
ct_match = (
ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id
)
ct = self.get_content_type(obj=rel_obj, using=instance._state.db)
ct_match = ct_id == ct.id
pk_val = deserialize_pk(fk_val, ct)
pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk
if pk_match:
return rel_obj
@ -258,6 +274,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
rel_obj = None
if ct_id is not None:
ct = self.get_content_type(id=ct_id, using=instance._state.db)
pk_val = deserialize_pk(fk_val, ct)
try:
rel_obj = ct.get_object_for_this_type(
using=instance._state.db, pk=pk_val
@ -272,7 +289,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
fk = None
if value is not None:
ct = self.get_content_type(obj=value)
fk = value.pk
fk = serialize_pk(value)
setattr(instance, self.ct_field, ct)
setattr(instance, self.fk_field, fk)
@ -541,7 +558,8 @@ class GenericRelation(ForeignObject):
% self.content_type_field_name: ContentType.objects.db_manager(using)
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
.pk,
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs],
"%s__in"
% self.object_id_field_name: [serialize_pk(obj) for obj in objs],
}
)
@ -589,7 +607,7 @@ def create_generic_related_manager(superclass, rel):
self.content_type_field_name = rel.field.content_type_field_name
self.object_id_field_name = rel.field.object_id_field_name
self.prefetch_cache_name = rel.field.attname
self.pk_val = instance.pk
self.pk_val = serialize_pk(instance)
self.core_filters = {
"%s__pk" % self.content_type_field_name: self.content_type.id,

View File

@ -804,3 +804,9 @@ class BaseDatabaseOperations:
rhs_expr = Col(rhs_table, rhs_field)
return lhs_expr, rhs_expr
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a "
"prepare_join_composite_pk_on_json_array() method"
)

View File

@ -3,8 +3,10 @@ import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models import Exists, ExpressionWrapper, Func, Lookup, UUIDField, Value
from django.db.models.constants import OnConflict
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import Cast
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
@ -453,3 +455,21 @@ class DatabaseOperations(BaseDatabaseOperations):
update_fields,
unique_fields,
)
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
# e.g. `a`.`id` = CAST((`b`.`object_id` ->> '$[0]') AS signed integer)
json_array = rhs
json_element = KeyTextTransform(index, json_array)
if isinstance(lhs.field, UUIDField):
json_element = Func(
json_element,
Value("-"),
Value(""),
function="REPLACE",
output_field=UUIDField(),
)
if json_element.field != lhs.field:
json_element = Cast(json_element, lhs.field)
return lhs, json_element

View File

@ -6,8 +6,18 @@ from django.conf import settings
from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models.expressions import RawSQL
from django.db.models import (
AutoField,
DateTimeField,
Exists,
ExpressionWrapper,
JSONField,
Lookup,
UUIDField,
)
from django.db.models.expressions import Func, RawSQL, Value
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import Cast
from django.db.models.sql.where import WhereNode
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str
@ -726,3 +736,27 @@ END;
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
json_array = Cast(rhs, JSONField())
json_element = KeyTextTransform(index, json_array)
if isinstance(lhs.field, UUIDField):
json_element = Func(
json_element,
Value("-"),
Value(""),
function="REPLACE",
output_field=UUIDField(),
)
if isinstance(lhs.field, DateTimeField):
json_element = Func(
json_element,
Value('YYYY-MM-DD"T"HH24:MI:SS'),
function="TO_TIMESTAMP",
output_field=DateTimeField(),
)
if json_element.field != lhs.field:
json_element = Cast(json_element, lhs.field)
return lhs, json_element

View File

@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
from django.db.models.fields.json import JSONField, KeyTextTransform
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile
@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
# e.g. `a`.`id` = ((`b`.`object_id`)::jsonb ->> 0)::integer
json_array = Cast(rhs, JSONField())
json_element = KeyTextTransform(index, json_array)
if json_element.field != lhs.field:
json_element = Cast(json_element, lhs.field)
return lhs, json_element

View File

@ -8,12 +8,15 @@ from django.conf import settings
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import DateTimeField, JSONField, UUIDField
from django.db.models.constants import OnConflict
from django.db.models.expressions import Col
from django.db.models.expressions import Col, Func, Value
from django.db.models.functions import Cast
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import cached_property
from ...models.fields.json import KeyTextTransform
from .base import Database
@ -431,3 +434,24 @@ class DatabaseOperations(BaseDatabaseOperations):
def force_group_by(self):
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
json_array = Cast(rhs, JSONField())
json_element = KeyTextTransform(index, json_array)
if isinstance(lhs.field, UUIDField):
json_element = Func(
json_element,
Value("-"),
Value(""),
function="REPLACE",
output_field=UUIDField(),
)
if json_element.field != lhs.field:
json_element = Cast(json_element, lhs.field)
if isinstance(lhs.field, DateTimeField):
# If joining DateTimeField in SQLite,
# make sure strftime is applied to the left side as well.
lhs = Cast(lhs, DateTimeField())
return lhs, json_element

View File

@ -1,3 +1,5 @@
import json
from django.core import checks
from django.db.models import NOT_PROVIDED, Field
from django.db.models.expressions import ColPairs
@ -128,6 +130,40 @@ class CompositePrimaryKey(Field):
)
]
def to_json(self, pk):
from django.core.serializers.json import DjangoJSONEncoder
fields_len = len(self.fields)
pk_len = len(pk)
if fields_len != pk_len:
raise ValueError(
f"{self.__class__.__name__} has {fields_len} fields "
f"but it tried to serialize {pk_len}."
)
return json.dumps(
[value for value in pk],
cls=DjangoJSONEncoder,
)
def from_json(self, s):
values = json.loads(s)
fields_len = len(self.fields)
values_len = len(values)
if fields_len != values_len:
raise ValueError(
f"{self.__class__.__name__} has {fields_len} fields "
f"but it tried to deserialize {values_len}. "
f"Did you change the {self.__class__.__name__} fields "
"and forgot to update the related GenericForeignKey "
'"object_id" fields?'
)
return tuple(
field.to_python(value) for (field, value) in zip(self.fields, values)
)
CompositePrimaryKey.register_lookup(TupleExact)
CompositePrimaryKey.register_lookup(TupleGreaterThan)

View File

@ -1,4 +1,5 @@
import itertools
from collections.abc import Iterable
from django.core.exceptions import EmptyResultSet
from django.db.models import Field
@ -211,9 +212,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
class TupleIn(TupleLookupMixin, In):
def check_rhs_is_iterable(self):
if not isinstance(self.rhs, Iterable):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} must be an iterable"
)
def get_prep_lookup(self):
if self.rhs_is_direct_value():
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_iterable()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length()
else:

View File

@ -6,6 +6,7 @@ the SQL domain.
import warnings
from django.core.exceptions import FullResultSet
from django.db import models
from django.db.models.sql.constants import INNER, LOUTER
from django.utils.deprecation import RemovedInDjango60Warning
@ -105,6 +106,39 @@ class Join:
# the branch for strings.
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
elif isinstance(lhs, (models.CharField, models.TextField)) and isinstance(
rhs, models.CompositePrimaryKey
):
lhs_col = lhs.get_col(self.parent_alias)
for index, field in enumerate(rhs):
rhs_col = field.get_col(self.table_alias)
lhs_expr, rhs_expr = (
connection.ops.prepare_join_composite_pk_on_json_array(
rhs_col, lhs_col, index
)
)
lhs_sql, lhs_params = compiler.compile(lhs_expr)
rhs_sql, rhs_params = compiler.compile(rhs_expr)
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
params.extend(lhs_params)
params.extend(rhs_params)
elif isinstance(lhs, models.CompositePrimaryKey) and isinstance(
rhs, (models.CharField, models.TextField)
):
rhs_col = rhs.get_col(self.table_alias)
for index, field in enumerate(lhs):
lhs_col = field.get_col(self.parent_alias)
lhs_expr, rhs_expr = (
connection.ops.prepare_join_composite_pk_on_json_array(
lhs_col, rhs_col, index
)
)
lhs_sql, lhs_params = compiler.compile(lhs_expr)
rhs_sql, rhs_params = compiler.compile(rhs_expr)
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
params.extend(lhs_params)
params.extend(rhs_params)
else:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, self.table_alias, rhs
@ -113,7 +147,7 @@ class Join:
lhs_full_name = lhs_sql % lhs_params
rhs_sql, rhs_params = compiler.compile(rhs)
rhs_full_name = rhs_sql % rhs_params
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.

View File

@ -62,8 +62,7 @@ A composite primary key can also be filtered by a ``tuple``:
1
We're still working on composite primary key support for
:ref:`relational fields <cpk-and-relations>`, including
:class:`.GenericForeignKey` fields, and the Django admin. Models with composite
:ref:`relational fields <cpk-and-relations>`, and the Django admin. Models with composite
primary keys cannot be registered in the Django admin at this time. You can
expect to see this in future releases.
@ -96,9 +95,8 @@ operation.
Composite primary keys and relations
====================================
:ref:`Relationship fields <relationship-fields>`, including
:ref:`generic relations <generic-relations>` do not support composite primary
keys.
:ref:`Relationship fields <relationship-fields>` do not support composite
primary keys.
For example, given the ``OrderLineItem`` model, the following is not
supported::
@ -131,6 +129,80 @@ database.
``ForeignObject`` is an internal API. This means it is not covered by our
:ref:`deprecation policy <internal-release-deprecation-policy>`.
Generic foreign keys & relations
================================
If a :class:`~django.contrib.contenttypes.fields.GenericForeignKey`\'s
"object_id" field is a :class:`~django.db.models.CharField` or a
:class:`~django.db.models.TextField`, Django will automatically serialize the
related ``CompositePrimaryKey`` into a JSON array::
class Tag:
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.CharField(max_length=255)
content_object = GenericForeignKey("content_type", "object_id")
This means ``CompositePrimaryKey`` is backwards-compatible with any
``GenericForeignKey`` that's using a ``CharField`` or ``TextField`` "object_id".
.. code-block:: pycon
>>> item.pk
(2, "B142C")
>>> tag = Tag(content_object=item)
>>> tag.object_id
'[2, "B142C"]'
.. warning:: pk changes are not supported
Django doesn't automatically update "object_id" if "content_object.pk"
changes. If you change an object's ``pk`` in any way, make sure the related
generic foreign keys are also updated.
:class:`~django.contrib.contenttypes.fields.GenericRelation`\s on models with
a ``CompositePrimaryKey`` are also supported. When filtering on a reverse
relationship, Django performs ``JOIN``\s using the appropriate backend-specific
JSON functions::
class OrderLineItem(models.Model):
...
tags = GenericRelation("Tag", related_query_name="items")
This means ``CompositePrimaryKey``\s are backwards-compatible with
``GenericRelation``\s the same way ``GenericForeignKey``\s are.
.. code-block:: pycon
>>> item.tags.all() # OK
>>> Tag.objects.filter(items__quantity__gte=2) # OK
.. warning:: JOINs on JSON functions can be slow
While performing ``JOIN``\s on JSON functions is convenient, it can lead to
performance issues. Exercise caution when filtering on the reverse
relationship of a ``GenericRelation``.
.. warning:: MariaDB support
At the time of writing this, MariaDB's JSON_UNQUOTE function cannot process
surrogate pairs (e.g. emojis). If you're planning to use JOINs on JSON
functions on MariaDB, make sure the primary key doesn't contain any emojis.
.. admonition:: Supported fields
``JOIN``s on JSON functions may not work if the ``CompositePrimaryKey``
has unsupported fields.
The supported fields are:
- ``SmallIntegerField``
- ``IntegerField``
- ``BigIntegerField``
- ``CharField``
- ``UUIDField``
- ``DateField``
- ``DateTimeField``
Composite primary keys and database functions
=============================================

View File

@ -1,8 +1,11 @@
from .tenant import Comment, Post, Tenant, Token, User
from .generic import Dummy
from .tenant import Comment, Post, Tag, Tenant, Token, User
__all__ = [
"Comment",
"Dummy",
"Post",
"Tag",
"Tenant",
"Token",
"User",

View File

@ -0,0 +1,24 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
class Dummy(models.Model):
pk = models.CompositePrimaryKey(
"small_integer",
"integer",
"big_integer",
"datetime",
"date",
"uuid",
"char",
)
small_integer = models.SmallIntegerField()
integer = models.IntegerField()
big_integer = models.BigIntegerField()
datetime = models.DateTimeField()
date = models.DateField()
uuid = models.UUIDField()
char = models.CharField(max_length=5)
tags = GenericRelation("Tag", related_query_name="dummy")

View File

@ -1,3 +1,5 @@
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
@ -48,3 +50,13 @@ class Post(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.UUIDField()
tags = GenericRelation("Tag", related_query_name="post")
class Tag(models.Model):
name = models.CharField(max_length=10)
content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="composite_pk_tags"
)
object_id = models.CharField(max_length=255)
content_object = GenericForeignKey("content_type", "object_id")

View File

@ -0,0 +1,314 @@
from datetime import date, datetime
from unittest import skipIf
from uuid import UUID
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.db import connection
from django.db.models import Count
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from .models import Comment, Dummy, Post, Tag, Tenant, User
class CompositePKGenericTests(TestCase):
POST_1_ID = "e1516ac0-4469-4306-b0ac-2c435e677aa4"
DUMMY_1_UUID = "81598d5a-fa03-4800-bae4-823ca12824d5"
DUMMY_2_UUID = "ac567f4c-9b6e-4770-86a6-1abf071a2958"
@classmethod
def setUpTestData(cls):
cls.dummy_1 = Dummy.objects.create(
small_integer=32767,
integer=2147483647,
big_integer=9223372036854775807,
datetime=datetime(2024, 11, 30, 6, 26, 1),
date=date(2024, 11, 30),
uuid=UUID(cls.DUMMY_1_UUID),
char="薑戈",
)
cls.dummy_2 = Dummy.objects.create(
small_integer=-32768,
integer=-2147483648,
big_integer=-9223372036854775808,
datetime=datetime(2024, 12, 8, 1, 2, 3),
date=date(2024, 12, 8),
uuid=UUID(cls.DUMMY_2_UUID),
char="😊",
)
cls.tenant_1 = Tenant.objects.create()
cls.tenant_2 = Tenant.objects.create()
cls.user_1 = User.objects.create(
tenant=cls.tenant_1,
id=1,
email="user0001@example.com",
)
cls.user_2 = User.objects.create(
tenant=cls.tenant_1,
id=2,
email="user0002@example.com",
)
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID))
cls.comment_1_tag = Tag.objects.create(
name="comment_1", content_object=cls.comment_1
)
cls.post_1_tag = Tag.objects.create(name="post_1", content_object=cls.post_1)
cls.dummy_1_tag = Tag.objects.create(name="dummy_1", content_object=cls.dummy_1)
cls.dummy_2_tag = Tag.objects.create(name="dummy_2", content_object=cls.dummy_2)
cls.comment_ct = ContentType.objects.get_for_model(Comment)
cls.post_ct = ContentType.objects.get_for_model(Post)
cls.dummy_ct = ContentType.objects.get_for_model(Dummy)
cls.post_1_fk = f'[{cls.tenant_1.id}, "{cls.POST_1_ID}"]'
cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]"
cls.dummy_1_fk = (
'[32767, 2147483647, 9223372036854775807, "2024-11-30T06:26:01", '
f'"2024-11-30", "{cls.DUMMY_1_UUID}", "\\u8591\\u6208"]'
)
cls.dummy_2_fk = (
'[-32768, -2147483648, -9223372036854775808, "2024-12-08T01:02:03", '
f'"2024-12-08", "{cls.DUMMY_2_UUID}", "\\ud83d\\ude0a"]'
)
def test_fields(self):
self.assertEqual(self.comment_1_tag.content_type, self.comment_ct)
self.assertEqual(self.comment_1_tag.object_id, self.comment_1_fk)
self.assertEqual(self.comment_1_tag.content_object, self.comment_1)
self.assertEqual(self.post_1_tag.content_type, self.post_ct)
self.assertEqual(self.post_1_tag.object_id, self.post_1_fk)
self.assertEqual(self.post_1_tag.content_object, self.post_1)
self.assertEqual(self.dummy_1_tag.content_type, self.dummy_ct)
self.assertEqual(self.dummy_1_tag.object_id, self.dummy_1_fk)
self.assertEqual(self.dummy_1_tag.content_object, self.dummy_1)
self.assertEqual(self.dummy_2_tag.content_type, self.dummy_ct)
self.assertEqual(self.dummy_2_tag.object_id, self.dummy_2_fk)
self.assertEqual(self.dummy_2_tag.content_object, self.dummy_2)
self.assertSequenceEqual(self.post_1.tags.all(), (self.post_1_tag,))
self.assertSequenceEqual(self.dummy_1.tags.all(), (self.dummy_1_tag,))
self.assertSequenceEqual(self.dummy_2.tags.all(), (self.dummy_2_tag,))
def test_fields_before_save(self):
comment = Comment(pk=(1, 2))
tag = Tag(content_object=comment)
self.assertEqual(tag.object_id, "[1, 2]")
comment.id = 3
self.assertEqual(tag.object_id, "[1, 2]")
def test_cascade_delete_if_generic_relation(self):
Post.objects.get(pk=self.post_1.pk).delete()
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
Dummy.objects.get(pk=self.dummy_1.pk).delete()
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_no_cascade_delete_if_no_generic_relation(self):
Comment.objects.get(pk=self.comment_1.pk).delete()
comment_1_tag = Tag.objects.get(pk=self.comment_1_tag.pk)
self.assertIsNone(comment_1_tag.content_object)
def test_tags_clear(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.clear()
self.assertEqual(post_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
dummy_1.tags.clear()
self.assertEqual(dummy_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_tags_remove(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.remove(self.post_1_tag)
self.assertEqual(post_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
dummy_1.tags.remove(self.dummy_1_tag)
self.assertEqual(dummy_1.tags.count(), 0)
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
def test_tags_create(self):
tag_count = Tag.objects.count()
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.create(name="foo")
self.assertEqual(post_1.tags.count(), 2)
self.assertEqual(Tag.objects.count(), tag_count + 1)
tag = Tag.objects.get(name="foo")
self.assertEqual(tag.content_type, self.post_ct)
self.assertEqual(tag.object_id, self.post_1_fk)
self.assertEqual(tag.content_object, post_1)
def test_tags_add(self):
tag_count = Tag.objects.count()
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1 = Tag(name="foo")
post_1.tags.add(tag_1, bulk=False)
self.assertEqual(post_1.tags.count(), 2)
self.assertEqual(Tag.objects.count(), tag_count + 1)
tag_1 = Tag.objects.get(name="foo")
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2 = Tag.objects.create(name="bar", content_object=self.comment_2)
post_1.tags.add(tag_2)
self.assertEqual(post_1.tags.count(), 3)
self.assertEqual(Tag.objects.count(), tag_count + 2)
tag_2 = Tag.objects.get(name="bar")
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_tags_set(self):
tag_count = Tag.objects.count()
comment_1_tag = Tag.objects.get(name=self.comment_1_tag.name)
post_1 = Post.objects.get(pk=self.post_1.pk)
post_1.tags.set([comment_1_tag])
self.assertEqual(post_1.tags.count(), 1)
self.assertEqual(Tag.objects.count(), tag_count - 1)
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
def test_tags_get_or_create(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1, created = post_1.tags.get_or_create(name=self.post_1_tag.name)
self.assertFalse(created)
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2, created = post_1.tags.get_or_create(name="foo")
self.assertTrue(created)
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_tags_update_or_create(self):
post_1 = Post.objects.get(pk=self.post_1.pk)
tag_1, created = post_1.tags.update_or_create(
name=self.post_1_tag.name, defaults={"name": "foo"}
)
self.assertFalse(created)
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
self.assertEqual(tag_1.name, "foo")
self.assertEqual(tag_1.content_type, self.post_ct)
self.assertEqual(tag_1.object_id, self.post_1_fk)
self.assertEqual(tag_1.content_object, post_1)
tag_2, created = post_1.tags.update_or_create(name="bar")
self.assertTrue(created)
self.assertEqual(tag_2.content_type, self.post_ct)
self.assertEqual(tag_2.object_id, self.post_1_fk)
self.assertEqual(tag_2.content_object, post_1)
def test_filter_by_related_query_name(self):
self.assertSequenceEqual(
Tag.objects.filter(post__id=self.post_1.id), (self.post_1_tag,)
)
self.assertSequenceEqual(
Tag.objects.filter(dummy__big_integer=self.dummy_1.big_integer),
(self.dummy_1_tag,),
)
@skipIf(
connection.vendor == "mysql" and connection.mysql_is_mariadb,
"MariaDB's JSON_UNQUOTE doesn't support surrogate pairs "
"(https://jira.mariadb.org/browse/MDEV-21124)",
)
def test_aggregate(self):
with self.subTest("Post"):
with CaptureQueriesContext(connection) as ctx:
self.assertEqual(
Post.objects.aggregate(Count("tags")),
{"tags__count": 1},
ctx[-1]["sql"],
)
with self.subTest("Dummy"):
with CaptureQueriesContext(connection) as ctx:
self.assertEqual(
Dummy.objects.aggregate(Count("tags")),
{"tags__count": 2},
ctx[-1]["sql"],
)
def test_generic_prefetch(self):
tags = Tag.objects.prefetch_related(
GenericPrefetch(
"content_object",
[
Post.objects.all(),
Comment.objects.all(),
Dummy.objects.all(),
],
)
).order_by("pk")
self.assertEqual(len(tags), 4)
self.assertEqual(tags[0], self.comment_1_tag)
self.assertEqual(tags[1], self.post_1_tag)
self.assertEqual(tags[2], self.dummy_1_tag)
with self.assertNumQueries(0):
self.assertEqual(tags[0].content_object, self.comment_1)
with self.assertNumQueries(0):
self.assertEqual(tags[1].content_object, self.post_1)
with self.assertNumQueries(0):
self.assertEqual(tags[2].content_object, self.dummy_1)
with self.assertNumQueries(0):
self.assertEqual(tags[3].content_object, self.dummy_2)
def test_to_json(self):
field = Post._meta.pk
self.assertEqual(
field.to_json((1, "004bc28c-a085-44ca-a823-747dcf4ddfd3")),
'[1, "004bc28c-a085-44ca-a823-747dcf4ddfd3"]',
)
self.assertEqual(
field.to_json((2, UUID("a5a626e5-d34a-4197-94ea-83547029d6ed"))),
'[2, "a5a626e5-d34a-4197-94ea-83547029d6ed"]',
)
def test_to_json_length_mismatch(self):
field = Post._meta.pk
msg = "CompositePrimaryKey has 2 fields but it tried to serialize %s."
with self.assertRaisesMessage(ValueError, msg % 1):
self.assertIsNone(field.to_json((1,)))
with self.assertRaisesMessage(ValueError, msg % 3):
self.assertIsNone(field.to_json((1, 2, 3)))
def test_from_json(self):
field = Post._meta.pk
self.assertEqual(
field.from_json('[1, "cb6b99bc-66b0-497f-bafe-193b90af6296"]'),
(1, UUID("cb6b99bc-66b0-497f-bafe-193b90af6296")),
)
self.assertEqual(
field.from_json('[2, "1d5c63dda5264a12a2ec929d51e04430"]'),
(2, UUID("1d5c63dd-a526-4a12-a2ec-929d51e04430")),
)
def test_from_json_length_mismatch(self):
field = Post._meta.pk
msg = (
"CompositePrimaryKey has 2 fields but it tried to deserialize %s. "
"Did you change the CompositePrimaryKey fields and forgot to "
'update the related GenericForeignKey "object_id" fields?'
)
with self.assertRaisesMessage(ValueError, msg % 1):
self.assertIsNone(field.from_json("[1]"))
with self.assertRaisesMessage(ValueError, msg % 3):
self.assertIsNone(field.from_json("[1, 2, 3]"))

View File

@ -92,6 +92,7 @@ class CompositePKGetTests(TestCase):
def test_lookup_errors(self):
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
m_iterable = "'%s' lookup of 'pk' must be an iterable"
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
m_tuple_collection = (
"'in' lookup of 'pk' must be a collection of tuples or lists"
@ -102,7 +103,7 @@ class CompositePKGetTests(TestCase):
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__exact": 1}, m_tuple % "exact"),
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
({"pk__in": 1}, m_tuple % "in"),
({"pk__in": 1}, m_iterable % "in"),
({"pk__in": (1, 2, 3)}, m_tuple_collection),
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
({"pk__gt": 1}, m_tuple % "gt"),

View File

@ -474,7 +474,6 @@ class TupleLookupsTests(TestCase):
TupleGreaterThanOrEqual,
TupleLessThan,
TupleLessThanOrEqual,
TupleIn,
),
(
0,
@ -496,6 +495,24 @@ class TupleLookupsTests(TestCase):
):
lookup_cls((F("customer_code"), F("company_code")), rhs)
def test_tuple_in_lookup_rhs_must_be_iterable(self):
test_cases = (
0,
1,
None,
True,
False,
)
for rhs in test_cases:
with self.subTest(lookup_name="in", rhs=rhs):
with self.assertRaisesMessage(
ValueError,
"'in' lookup of ('customer_code', 'company_code') "
"must be an iterable",
):
TupleIn((F("customer_code"), F("company_code")), rhs)
def test_tuple_lookup_rhs_must_have_2_elements(self):
test_cases = itertools.product(
(