From b641b6a90eb0350d46103c8e05a4f5fb19350295 Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Sun, 7 Apr 2024 10:32:16 +0800 Subject: [PATCH] Fixed #35941 -- Added composite GenericForeignKey support. --- django/contrib/contenttypes/fields.py | 36 ++- django/db/backends/base/operations.py | 6 + django/db/backends/mysql/operations.py | 22 +- django/db/backends/oracle/operations.py | 38 ++- django/db/backends/postgresql/operations.py | 11 + django/db/backends/sqlite3/operations.py | 26 +- django/db/models/fields/composite.py | 36 +++ django/db/models/fields/tuple_lookups.py | 10 +- django/db/models/sql/datastructures.py | 36 ++- docs/topics/composite-primary-key.txt | 82 ++++- tests/composite_pk/models/__init__.py | 5 +- tests/composite_pk/models/generic.py | 24 ++ tests/composite_pk/models/tenant.py | 12 + tests/composite_pk/test_generic.py | 314 ++++++++++++++++++++ tests/composite_pk/test_get.py | 3 +- tests/foreign_object/test_tuple_lookups.py | 19 +- 16 files changed, 657 insertions(+), 23 deletions(-) create mode 100644 tests/composite_pk/models/generic.py create mode 100644 tests/composite_pk/test_generic.py diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index a3e87f6ed4..1d5ca6d906 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -25,6 +25,20 @@ from django.utils.deprecation import RemovedInDjango60Warning from django.utils.functional import cached_property +def serialize_pk(obj): + opts = obj._meta + if opts.is_composite_pk: + return opts.pk.to_json(obj.pk) + return obj.pk + + +def deserialize_pk(pk, ct): + opts = ct.model_class()._meta + if opts.is_composite_pk: + return opts.pk.from_json(pk) + return pk + + class GenericForeignKey(FieldCacheMixin, Field): """ Provide a generic many-to-one relation through the ``content_type`` and @@ -195,7 +209,9 @@ class GenericForeignKey(FieldCacheMixin, Field): if ct_id is not None: fk_val = getattr(instance, self.fk_field) if fk_val is not None: - fk_dict[ct_id].add(fk_val) + ct = self.get_content_type(id=ct_id) + pk_val = deserialize_pk(fk_val, ct) + fk_dict[ct_id].add(pk_val) instance_dict[ct_id] = instance ret_val = [] @@ -225,7 +241,7 @@ class GenericForeignKey(FieldCacheMixin, Field): return ( ret_val, - lambda obj: (obj.pk, obj.__class__), + lambda obj: (serialize_pk(obj), obj.__class__), gfk_key, True, self.name, @@ -242,15 +258,15 @@ class GenericForeignKey(FieldCacheMixin, Field): # use ContentType.objects.get_for_id(), which has a global cache. f = self.model._meta.get_field(self.ct_field) ct_id = getattr(instance, f.attname, None) - pk_val = getattr(instance, self.fk_field) + fk_val = getattr(instance, self.fk_field) rel_obj = self.get_cached_value(instance, default=None) if rel_obj is None and self.is_cached(instance): return rel_obj if rel_obj is not None: - ct_match = ( - ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id - ) + ct = self.get_content_type(obj=rel_obj, using=instance._state.db) + ct_match = ct_id == ct.id + pk_val = deserialize_pk(fk_val, ct) pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk if pk_match: return rel_obj @@ -258,6 +274,7 @@ class GenericForeignKey(FieldCacheMixin, Field): rel_obj = None if ct_id is not None: ct = self.get_content_type(id=ct_id, using=instance._state.db) + pk_val = deserialize_pk(fk_val, ct) try: rel_obj = ct.get_object_for_this_type( using=instance._state.db, pk=pk_val @@ -272,7 +289,7 @@ class GenericForeignKey(FieldCacheMixin, Field): fk = None if value is not None: ct = self.get_content_type(obj=value) - fk = value.pk + fk = serialize_pk(value) setattr(instance, self.ct_field, ct) setattr(instance, self.fk_field, fk) @@ -541,7 +558,8 @@ class GenericRelation(ForeignObject): % self.content_type_field_name: ContentType.objects.db_manager(using) .get_for_model(self.model, for_concrete_model=self.for_concrete_model) .pk, - "%s__in" % self.object_id_field_name: [obj.pk for obj in objs], + "%s__in" + % self.object_id_field_name: [serialize_pk(obj) for obj in objs], } ) @@ -589,7 +607,7 @@ def create_generic_related_manager(superclass, rel): self.content_type_field_name = rel.field.content_type_field_name self.object_id_field_name = rel.field.object_id_field_name self.prefetch_cache_name = rel.field.attname - self.pk_val = instance.pk + self.pk_val = serialize_pk(instance) self.core_filters = { "%s__pk" % self.content_type_field_name: self.content_type.id, diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 60de2d6c79..d06fc933cd 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -804,3 +804,9 @@ class BaseDatabaseOperations: rhs_expr = Col(rhs_table, rhs_field) return lhs_expr, rhs_expr + + def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index): + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a " + "prepare_join_composite_pk_on_json_array() method" + ) diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 9806303539..e26569d995 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -3,8 +3,10 @@ import uuid from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta -from django.db.models import Exists, ExpressionWrapper, Lookup +from django.db.models import Exists, ExpressionWrapper, Func, Lookup, UUIDField, Value from django.db.models.constants import OnConflict +from django.db.models.fields.json import KeyTextTransform +from django.db.models.functions import Cast from django.utils import timezone from django.utils.encoding import force_str from django.utils.regex_helper import _lazy_re_compile @@ -453,3 +455,21 @@ class DatabaseOperations(BaseDatabaseOperations): update_fields, unique_fields, ) + + def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index): + # e.g. `a`.`id` = CAST((`b`.`object_id` ->> '$[0]') AS signed integer) + json_array = rhs + json_element = KeyTextTransform(index, json_array) + + if isinstance(lhs.field, UUIDField): + json_element = Func( + json_element, + Value("-"), + Value(""), + function="REPLACE", + output_field=UUIDField(), + ) + if json_element.field != lhs.field: + json_element = Cast(json_element, lhs.field) + + return lhs, json_element diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 79c6da994e..b572956d69 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -6,8 +6,18 @@ from django.conf import settings from django.db import DatabaseError, NotSupportedError from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name -from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup -from django.db.models.expressions import RawSQL +from django.db.models import ( + AutoField, + DateTimeField, + Exists, + ExpressionWrapper, + JSONField, + Lookup, + UUIDField, +) +from django.db.models.expressions import Func, RawSQL, Value +from django.db.models.fields.json import KeyTextTransform +from django.db.models.functions import Cast from django.db.models.sql.where import WhereNode from django.utils import timezone from django.utils.encoding import force_bytes, force_str @@ -726,3 +736,27 @@ END; if isinstance(expression, RawSQL) and expression.conditional: return True return False + + def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index): + json_array = Cast(rhs, JSONField()) + json_element = KeyTextTransform(index, json_array) + + if isinstance(lhs.field, UUIDField): + json_element = Func( + json_element, + Value("-"), + Value(""), + function="REPLACE", + output_field=UUIDField(), + ) + if isinstance(lhs.field, DateTimeField): + json_element = Func( + json_element, + Value('YYYY-MM-DD"T"HH24:MI:SS'), + function="TO_TIMESTAMP", + output_field=DateTimeField(), + ) + if json_element.field != lhs.field: + json_element = Cast(json_element, lhs.field) + + return lhs, json_element diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 8a0ca36a29..d48964513d 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import ( ) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.db.models.fields.json import JSONField, KeyTextTransform from django.db.models.functions import Cast from django.utils.regex_helper import _lazy_re_compile @@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations): rhs_expr = Cast(rhs_expr, lhs_field) return lhs_expr, rhs_expr + + def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index): + # e.g. `a`.`id` = ((`b`.`object_id`)::jsonb ->> 0)::integer + json_array = Cast(rhs, JSONField()) + json_element = KeyTextTransform(index, json_array) + + if json_element.field != lhs.field: + json_element = Cast(json_element, lhs.field) + + return lhs, json_element diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 0078cc077a..79556bc071 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -8,12 +8,15 @@ from django.conf import settings from django.core.exceptions import FieldError from django.db import DatabaseError, NotSupportedError, models from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.models import DateTimeField, JSONField, UUIDField from django.db.models.constants import OnConflict -from django.db.models.expressions import Col +from django.db.models.expressions import Col, Func, Value +from django.db.models.functions import Cast from django.utils import timezone from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.functional import cached_property +from ...models.fields.json import KeyTextTransform from .base import Database @@ -431,3 +434,24 @@ class DatabaseOperations(BaseDatabaseOperations): def force_group_by(self): return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else [] + + def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index): + json_array = Cast(rhs, JSONField()) + json_element = KeyTextTransform(index, json_array) + + if isinstance(lhs.field, UUIDField): + json_element = Func( + json_element, + Value("-"), + Value(""), + function="REPLACE", + output_field=UUIDField(), + ) + if json_element.field != lhs.field: + json_element = Cast(json_element, lhs.field) + if isinstance(lhs.field, DateTimeField): + # If joining DateTimeField in SQLite, + # make sure strftime is applied to the left side as well. + lhs = Cast(lhs, DateTimeField()) + + return lhs, json_element diff --git a/django/db/models/fields/composite.py b/django/db/models/fields/composite.py index 550a440dcf..097b6102cf 100644 --- a/django/db/models/fields/composite.py +++ b/django/db/models/fields/composite.py @@ -1,3 +1,5 @@ +import json + from django.core import checks from django.db.models import NOT_PROVIDED, Field from django.db.models.expressions import ColPairs @@ -128,6 +130,40 @@ class CompositePrimaryKey(Field): ) ] + def to_json(self, pk): + from django.core.serializers.json import DjangoJSONEncoder + + fields_len = len(self.fields) + pk_len = len(pk) + if fields_len != pk_len: + raise ValueError( + f"{self.__class__.__name__} has {fields_len} fields " + f"but it tried to serialize {pk_len}." + ) + + return json.dumps( + [value for value in pk], + cls=DjangoJSONEncoder, + ) + + def from_json(self, s): + values = json.loads(s) + + fields_len = len(self.fields) + values_len = len(values) + if fields_len != values_len: + raise ValueError( + f"{self.__class__.__name__} has {fields_len} fields " + f"but it tried to deserialize {values_len}. " + f"Did you change the {self.__class__.__name__} fields " + "and forgot to update the related GenericForeignKey " + '"object_id" fields?' + ) + + return tuple( + field.to_python(value) for (field, value) in zip(self.fields, values) + ) + CompositePrimaryKey.register_lookup(TupleExact) CompositePrimaryKey.register_lookup(TupleGreaterThan) diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index e515e971b4..a44f2af1f5 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -1,4 +1,5 @@ import itertools +from collections.abc import Iterable from django.core.exceptions import EmptyResultSet from django.db.models import Field @@ -211,9 +212,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): class TupleIn(TupleLookupMixin, In): + def check_rhs_is_iterable(self): + if not isinstance(self.rhs, Iterable): + lhs_str = self.get_lhs_str() + raise ValueError( + f"{self.lookup_name!r} lookup of {lhs_str} must be an iterable" + ) + def get_prep_lookup(self): if self.rhs_is_direct_value(): - self.check_rhs_is_tuple_or_list() + self.check_rhs_is_iterable() self.check_rhs_is_collection_of_tuples_or_lists() self.check_rhs_elements_length_equals_lhs_length() else: diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 7c0c14a46e..d18272b50d 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -6,6 +6,7 @@ the SQL domain. import warnings from django.core.exceptions import FullResultSet +from django.db import models from django.db.models.sql.constants import INNER, LOUTER from django.utils.deprecation import RemovedInDjango60Warning @@ -105,6 +106,39 @@ class Join: # the branch for strings. lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs)) rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs)) + join_conditions.append(f"{lhs_full_name} = {rhs_full_name}") + elif isinstance(lhs, (models.CharField, models.TextField)) and isinstance( + rhs, models.CompositePrimaryKey + ): + lhs_col = lhs.get_col(self.parent_alias) + for index, field in enumerate(rhs): + rhs_col = field.get_col(self.table_alias) + lhs_expr, rhs_expr = ( + connection.ops.prepare_join_composite_pk_on_json_array( + rhs_col, lhs_col, index + ) + ) + lhs_sql, lhs_params = compiler.compile(lhs_expr) + rhs_sql, rhs_params = compiler.compile(rhs_expr) + join_conditions.append(f"{lhs_sql} = {rhs_sql}") + params.extend(lhs_params) + params.extend(rhs_params) + elif isinstance(lhs, models.CompositePrimaryKey) and isinstance( + rhs, (models.CharField, models.TextField) + ): + rhs_col = rhs.get_col(self.table_alias) + for index, field in enumerate(lhs): + lhs_col = field.get_col(self.parent_alias) + lhs_expr, rhs_expr = ( + connection.ops.prepare_join_composite_pk_on_json_array( + lhs_col, rhs_col, index + ) + ) + lhs_sql, lhs_params = compiler.compile(lhs_expr) + rhs_sql, rhs_params = compiler.compile(rhs_expr) + join_conditions.append(f"{lhs_sql} = {rhs_sql}") + params.extend(lhs_params) + params.extend(rhs_params) else: lhs, rhs = connection.ops.prepare_join_on_clause( self.parent_alias, lhs, self.table_alias, rhs @@ -113,7 +147,7 @@ class Join: lhs_full_name = lhs_sql % lhs_params rhs_sql, rhs_params = compiler.compile(rhs) rhs_full_name = rhs_sql % rhs_params - join_conditions.append(f"{lhs_full_name} = {rhs_full_name}") + join_conditions.append(f"{lhs_full_name} = {rhs_full_name}") # Add a single condition inside parentheses for whatever # get_extra_restriction() returns. diff --git a/docs/topics/composite-primary-key.txt b/docs/topics/composite-primary-key.txt index 9e5234ca9f..a80cf17df7 100644 --- a/docs/topics/composite-primary-key.txt +++ b/docs/topics/composite-primary-key.txt @@ -62,8 +62,7 @@ A composite primary key can also be filtered by a ``tuple``: 1 We're still working on composite primary key support for -:ref:`relational fields `, including -:class:`.GenericForeignKey` fields, and the Django admin. Models with composite +:ref:`relational fields `, and the Django admin. Models with composite primary keys cannot be registered in the Django admin at this time. You can expect to see this in future releases. @@ -96,9 +95,8 @@ operation. Composite primary keys and relations ==================================== -:ref:`Relationship fields `, including -:ref:`generic relations ` do not support composite primary -keys. +:ref:`Relationship fields ` do not support composite +primary keys. For example, given the ``OrderLineItem`` model, the following is not supported:: @@ -131,6 +129,80 @@ database. ``ForeignObject`` is an internal API. This means it is not covered by our :ref:`deprecation policy `. +Generic foreign keys & relations +================================ + +If a :class:`~django.contrib.contenttypes.fields.GenericForeignKey`\'s +"object_id" field is a :class:`~django.db.models.CharField` or a +:class:`~django.db.models.TextField`, Django will automatically serialize the +related ``CompositePrimaryKey`` into a JSON array:: + + class Tag: + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.CharField(max_length=255) + content_object = GenericForeignKey("content_type", "object_id") + +This means ``CompositePrimaryKey`` is backwards-compatible with any +``GenericForeignKey`` that's using a ``CharField`` or ``TextField`` "object_id". + +.. code-block:: pycon + + >>> item.pk + (2, "B142C") + >>> tag = Tag(content_object=item) + >>> tag.object_id + '[2, "B142C"]' + +.. warning:: pk changes are not supported + + Django doesn't automatically update "object_id" if "content_object.pk" + changes. If you change an object's ``pk`` in any way, make sure the related + generic foreign keys are also updated. + +:class:`~django.contrib.contenttypes.fields.GenericRelation`\s on models with +a ``CompositePrimaryKey`` are also supported. When filtering on a reverse +relationship, Django performs ``JOIN``\s using the appropriate backend-specific +JSON functions:: + + class OrderLineItem(models.Model): + ... + tags = GenericRelation("Tag", related_query_name="items") + +This means ``CompositePrimaryKey``\s are backwards-compatible with +``GenericRelation``\s the same way ``GenericForeignKey``\s are. + +.. code-block:: pycon + + >>> item.tags.all() # OK + >>> Tag.objects.filter(items__quantity__gte=2) # OK + +.. warning:: JOINs on JSON functions can be slow + + While performing ``JOIN``\s on JSON functions is convenient, it can lead to + performance issues. Exercise caution when filtering on the reverse + relationship of a ``GenericRelation``. + +.. warning:: MariaDB support + + At the time of writing this, MariaDB's JSON_UNQUOTE function cannot process + surrogate pairs (e.g. emojis). If you're planning to use JOINs on JSON + functions on MariaDB, make sure the primary key doesn't contain any emojis. + +.. admonition:: Supported fields + + ``JOIN``s on JSON functions may not work if the ``CompositePrimaryKey`` + has unsupported fields. + + The supported fields are: + + - ``SmallIntegerField`` + - ``IntegerField`` + - ``BigIntegerField`` + - ``CharField`` + - ``UUIDField`` + - ``DateField`` + - ``DateTimeField`` + Composite primary keys and database functions ============================================= diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py index 35c3943716..63ee55daf7 100644 --- a/tests/composite_pk/models/__init__.py +++ b/tests/composite_pk/models/__init__.py @@ -1,8 +1,11 @@ -from .tenant import Comment, Post, Tenant, Token, User +from .generic import Dummy +from .tenant import Comment, Post, Tag, Tenant, Token, User __all__ = [ "Comment", + "Dummy", "Post", + "Tag", "Tenant", "Token", "User", diff --git a/tests/composite_pk/models/generic.py b/tests/composite_pk/models/generic.py new file mode 100644 index 0000000000..9ace06c7c1 --- /dev/null +++ b/tests/composite_pk/models/generic.py @@ -0,0 +1,24 @@ +from django.contrib.contenttypes.fields import GenericRelation +from django.db import models + + +class Dummy(models.Model): + pk = models.CompositePrimaryKey( + "small_integer", + "integer", + "big_integer", + "datetime", + "date", + "uuid", + "char", + ) + + small_integer = models.SmallIntegerField() + integer = models.IntegerField() + big_integer = models.BigIntegerField() + datetime = models.DateTimeField() + date = models.DateField() + uuid = models.UUIDField() + char = models.CharField(max_length=5) + + tags = GenericRelation("Tag", related_query_name="dummy") diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py index ac0b3d9715..70669ee7cd 100644 --- a/tests/composite_pk/models/tenant.py +++ b/tests/composite_pk/models/tenant.py @@ -1,3 +1,5 @@ +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType from django.db import models @@ -48,3 +50,13 @@ class Post(models.Model): pk = models.CompositePrimaryKey("tenant_id", "id") tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) id = models.UUIDField() + tags = GenericRelation("Tag", related_query_name="post") + + +class Tag(models.Model): + name = models.CharField(max_length=10) + content_type = models.ForeignKey( + ContentType, on_delete=models.CASCADE, related_name="composite_pk_tags" + ) + object_id = models.CharField(max_length=255) + content_object = GenericForeignKey("content_type", "object_id") diff --git a/tests/composite_pk/test_generic.py b/tests/composite_pk/test_generic.py new file mode 100644 index 0000000000..fb872b0c6f --- /dev/null +++ b/tests/composite_pk/test_generic.py @@ -0,0 +1,314 @@ +from datetime import date, datetime +from unittest import skipIf +from uuid import UUID + +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.prefetch import GenericPrefetch +from django.db import connection +from django.db.models import Count +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Comment, Dummy, Post, Tag, Tenant, User + + +class CompositePKGenericTests(TestCase): + POST_1_ID = "e1516ac0-4469-4306-b0ac-2c435e677aa4" + DUMMY_1_UUID = "81598d5a-fa03-4800-bae4-823ca12824d5" + DUMMY_2_UUID = "ac567f4c-9b6e-4770-86a6-1abf071a2958" + + @classmethod + def setUpTestData(cls): + cls.dummy_1 = Dummy.objects.create( + small_integer=32767, + integer=2147483647, + big_integer=9223372036854775807, + datetime=datetime(2024, 11, 30, 6, 26, 1), + date=date(2024, 11, 30), + uuid=UUID(cls.DUMMY_1_UUID), + char="薑戈", + ) + cls.dummy_2 = Dummy.objects.create( + small_integer=-32768, + integer=-2147483648, + big_integer=-9223372036854775808, + datetime=datetime(2024, 12, 8, 1, 2, 3), + date=date(2024, 12, 8), + uuid=UUID(cls.DUMMY_2_UUID), + char="😊", + ) + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID)) + cls.comment_1_tag = Tag.objects.create( + name="comment_1", content_object=cls.comment_1 + ) + cls.post_1_tag = Tag.objects.create(name="post_1", content_object=cls.post_1) + cls.dummy_1_tag = Tag.objects.create(name="dummy_1", content_object=cls.dummy_1) + cls.dummy_2_tag = Tag.objects.create(name="dummy_2", content_object=cls.dummy_2) + cls.comment_ct = ContentType.objects.get_for_model(Comment) + cls.post_ct = ContentType.objects.get_for_model(Post) + cls.dummy_ct = ContentType.objects.get_for_model(Dummy) + cls.post_1_fk = f'[{cls.tenant_1.id}, "{cls.POST_1_ID}"]' + cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]" + cls.dummy_1_fk = ( + '[32767, 2147483647, 9223372036854775807, "2024-11-30T06:26:01", ' + f'"2024-11-30", "{cls.DUMMY_1_UUID}", "\\u8591\\u6208"]' + ) + cls.dummy_2_fk = ( + '[-32768, -2147483648, -9223372036854775808, "2024-12-08T01:02:03", ' + f'"2024-12-08", "{cls.DUMMY_2_UUID}", "\\ud83d\\ude0a"]' + ) + + def test_fields(self): + self.assertEqual(self.comment_1_tag.content_type, self.comment_ct) + self.assertEqual(self.comment_1_tag.object_id, self.comment_1_fk) + self.assertEqual(self.comment_1_tag.content_object, self.comment_1) + + self.assertEqual(self.post_1_tag.content_type, self.post_ct) + self.assertEqual(self.post_1_tag.object_id, self.post_1_fk) + self.assertEqual(self.post_1_tag.content_object, self.post_1) + + self.assertEqual(self.dummy_1_tag.content_type, self.dummy_ct) + self.assertEqual(self.dummy_1_tag.object_id, self.dummy_1_fk) + self.assertEqual(self.dummy_1_tag.content_object, self.dummy_1) + + self.assertEqual(self.dummy_2_tag.content_type, self.dummy_ct) + self.assertEqual(self.dummy_2_tag.object_id, self.dummy_2_fk) + self.assertEqual(self.dummy_2_tag.content_object, self.dummy_2) + + self.assertSequenceEqual(self.post_1.tags.all(), (self.post_1_tag,)) + self.assertSequenceEqual(self.dummy_1.tags.all(), (self.dummy_1_tag,)) + self.assertSequenceEqual(self.dummy_2.tags.all(), (self.dummy_2_tag,)) + + def test_fields_before_save(self): + comment = Comment(pk=(1, 2)) + tag = Tag(content_object=comment) + self.assertEqual(tag.object_id, "[1, 2]") + comment.id = 3 + self.assertEqual(tag.object_id, "[1, 2]") + + def test_cascade_delete_if_generic_relation(self): + Post.objects.get(pk=self.post_1.pk).delete() + self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists()) + + Dummy.objects.get(pk=self.dummy_1.pk).delete() + self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists()) + + def test_no_cascade_delete_if_no_generic_relation(self): + Comment.objects.get(pk=self.comment_1.pk).delete() + comment_1_tag = Tag.objects.get(pk=self.comment_1_tag.pk) + self.assertIsNone(comment_1_tag.content_object) + + def test_tags_clear(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.tags.clear() + self.assertEqual(post_1.tags.count(), 0) + self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists()) + + dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk) + dummy_1.tags.clear() + self.assertEqual(dummy_1.tags.count(), 0) + self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists()) + + def test_tags_remove(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.tags.remove(self.post_1_tag) + self.assertEqual(post_1.tags.count(), 0) + self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists()) + + dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk) + dummy_1.tags.remove(self.dummy_1_tag) + self.assertEqual(dummy_1.tags.count(), 0) + self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists()) + + def test_tags_create(self): + tag_count = Tag.objects.count() + + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.tags.create(name="foo") + self.assertEqual(post_1.tags.count(), 2) + self.assertEqual(Tag.objects.count(), tag_count + 1) + + tag = Tag.objects.get(name="foo") + self.assertEqual(tag.content_type, self.post_ct) + self.assertEqual(tag.object_id, self.post_1_fk) + self.assertEqual(tag.content_object, post_1) + + def test_tags_add(self): + tag_count = Tag.objects.count() + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_1 = Tag(name="foo") + post_1.tags.add(tag_1, bulk=False) + self.assertEqual(post_1.tags.count(), 2) + self.assertEqual(Tag.objects.count(), tag_count + 1) + + tag_1 = Tag.objects.get(name="foo") + self.assertEqual(tag_1.content_type, self.post_ct) + self.assertEqual(tag_1.object_id, self.post_1_fk) + self.assertEqual(tag_1.content_object, post_1) + + tag_2 = Tag.objects.create(name="bar", content_object=self.comment_2) + post_1.tags.add(tag_2) + self.assertEqual(post_1.tags.count(), 3) + self.assertEqual(Tag.objects.count(), tag_count + 2) + + tag_2 = Tag.objects.get(name="bar") + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + def test_tags_set(self): + tag_count = Tag.objects.count() + comment_1_tag = Tag.objects.get(name=self.comment_1_tag.name) + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.tags.set([comment_1_tag]) + self.assertEqual(post_1.tags.count(), 1) + self.assertEqual(Tag.objects.count(), tag_count - 1) + self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists()) + + def test_tags_get_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_1, created = post_1.tags.get_or_create(name=self.post_1_tag.name) + self.assertFalse(created) + self.assertEqual(tag_1.pk, self.post_1_tag.pk) + self.assertEqual(tag_1.content_type, self.post_ct) + self.assertEqual(tag_1.object_id, self.post_1_fk) + self.assertEqual(tag_1.content_object, post_1) + + tag_2, created = post_1.tags.get_or_create(name="foo") + self.assertTrue(created) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + def test_tags_update_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_1, created = post_1.tags.update_or_create( + name=self.post_1_tag.name, defaults={"name": "foo"} + ) + self.assertFalse(created) + self.assertEqual(tag_1.pk, self.post_1_tag.pk) + self.assertEqual(tag_1.name, "foo") + self.assertEqual(tag_1.content_type, self.post_ct) + self.assertEqual(tag_1.object_id, self.post_1_fk) + self.assertEqual(tag_1.content_object, post_1) + + tag_2, created = post_1.tags.update_or_create(name="bar") + self.assertTrue(created) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + def test_filter_by_related_query_name(self): + self.assertSequenceEqual( + Tag.objects.filter(post__id=self.post_1.id), (self.post_1_tag,) + ) + self.assertSequenceEqual( + Tag.objects.filter(dummy__big_integer=self.dummy_1.big_integer), + (self.dummy_1_tag,), + ) + + @skipIf( + connection.vendor == "mysql" and connection.mysql_is_mariadb, + "MariaDB's JSON_UNQUOTE doesn't support surrogate pairs " + "(https://jira.mariadb.org/browse/MDEV-21124)", + ) + def test_aggregate(self): + with self.subTest("Post"): + with CaptureQueriesContext(connection) as ctx: + self.assertEqual( + Post.objects.aggregate(Count("tags")), + {"tags__count": 1}, + ctx[-1]["sql"], + ) + with self.subTest("Dummy"): + with CaptureQueriesContext(connection) as ctx: + self.assertEqual( + Dummy.objects.aggregate(Count("tags")), + {"tags__count": 2}, + ctx[-1]["sql"], + ) + + def test_generic_prefetch(self): + tags = Tag.objects.prefetch_related( + GenericPrefetch( + "content_object", + [ + Post.objects.all(), + Comment.objects.all(), + Dummy.objects.all(), + ], + ) + ).order_by("pk") + + self.assertEqual(len(tags), 4) + self.assertEqual(tags[0], self.comment_1_tag) + self.assertEqual(tags[1], self.post_1_tag) + self.assertEqual(tags[2], self.dummy_1_tag) + + with self.assertNumQueries(0): + self.assertEqual(tags[0].content_object, self.comment_1) + with self.assertNumQueries(0): + self.assertEqual(tags[1].content_object, self.post_1) + with self.assertNumQueries(0): + self.assertEqual(tags[2].content_object, self.dummy_1) + with self.assertNumQueries(0): + self.assertEqual(tags[3].content_object, self.dummy_2) + + def test_to_json(self): + field = Post._meta.pk + self.assertEqual( + field.to_json((1, "004bc28c-a085-44ca-a823-747dcf4ddfd3")), + '[1, "004bc28c-a085-44ca-a823-747dcf4ddfd3"]', + ) + self.assertEqual( + field.to_json((2, UUID("a5a626e5-d34a-4197-94ea-83547029d6ed"))), + '[2, "a5a626e5-d34a-4197-94ea-83547029d6ed"]', + ) + + def test_to_json_length_mismatch(self): + field = Post._meta.pk + msg = "CompositePrimaryKey has 2 fields but it tried to serialize %s." + with self.assertRaisesMessage(ValueError, msg % 1): + self.assertIsNone(field.to_json((1,))) + with self.assertRaisesMessage(ValueError, msg % 3): + self.assertIsNone(field.to_json((1, 2, 3))) + + def test_from_json(self): + field = Post._meta.pk + self.assertEqual( + field.from_json('[1, "cb6b99bc-66b0-497f-bafe-193b90af6296"]'), + (1, UUID("cb6b99bc-66b0-497f-bafe-193b90af6296")), + ) + self.assertEqual( + field.from_json('[2, "1d5c63dda5264a12a2ec929d51e04430"]'), + (2, UUID("1d5c63dd-a526-4a12-a2ec-929d51e04430")), + ) + + def test_from_json_length_mismatch(self): + field = Post._meta.pk + msg = ( + "CompositePrimaryKey has 2 fields but it tried to deserialize %s. " + "Did you change the CompositePrimaryKey fields and forgot to " + 'update the related GenericForeignKey "object_id" fields?' + ) + with self.assertRaisesMessage(ValueError, msg % 1): + self.assertIsNone(field.from_json("[1]")) + with self.assertRaisesMessage(ValueError, msg % 3): + self.assertIsNone(field.from_json("[1, 2, 3]")) diff --git a/tests/composite_pk/test_get.py b/tests/composite_pk/test_get.py index c896ec26ed..45e74cd0b1 100644 --- a/tests/composite_pk/test_get.py +++ b/tests/composite_pk/test_get.py @@ -92,6 +92,7 @@ class CompositePKGetTests(TestCase): def test_lookup_errors(self): m_tuple = "'%s' lookup of 'pk' must be a tuple or a list" + m_iterable = "'%s' lookup of 'pk' must be an iterable" m_2_elements = "'%s' lookup of 'pk' must have 2 elements" m_tuple_collection = ( "'in' lookup of 'pk' must be a collection of tuples or lists" @@ -102,7 +103,7 @@ class CompositePKGetTests(TestCase): ({"pk": (1, 2, 3)}, m_2_elements % "exact"), ({"pk__exact": 1}, m_tuple % "exact"), ({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"), - ({"pk__in": 1}, m_tuple % "in"), + ({"pk__in": 1}, m_iterable % "in"), ({"pk__in": (1, 2, 3)}, m_tuple_collection), ({"pk__in": ((1, 2, 3),)}, m_2_elements_each), ({"pk__gt": 1}, m_tuple % "gt"), diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index 797fea1c8a..1d30f32722 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -474,7 +474,6 @@ class TupleLookupsTests(TestCase): TupleGreaterThanOrEqual, TupleLessThan, TupleLessThanOrEqual, - TupleIn, ), ( 0, @@ -496,6 +495,24 @@ class TupleLookupsTests(TestCase): ): lookup_cls((F("customer_code"), F("company_code")), rhs) + def test_tuple_in_lookup_rhs_must_be_iterable(self): + test_cases = ( + 0, + 1, + None, + True, + False, + ) + + for rhs in test_cases: + with self.subTest(lookup_name="in", rhs=rhs): + with self.assertRaisesMessage( + ValueError, + "'in' lookup of ('customer_code', 'company_code') " + "must be an iterable", + ): + TupleIn((F("customer_code"), F("company_code")), rhs) + def test_tuple_lookup_rhs_must_have_2_elements(self): test_cases = itertools.product( (