mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #35941 -- Added composite GenericForeignKey support.
This commit is contained in:
parent
1860a1afc9
commit
b641b6a90e
@ -25,6 +25,20 @@ from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
def serialize_pk(obj):
|
||||
opts = obj._meta
|
||||
if opts.is_composite_pk:
|
||||
return opts.pk.to_json(obj.pk)
|
||||
return obj.pk
|
||||
|
||||
|
||||
def deserialize_pk(pk, ct):
|
||||
opts = ct.model_class()._meta
|
||||
if opts.is_composite_pk:
|
||||
return opts.pk.from_json(pk)
|
||||
return pk
|
||||
|
||||
|
||||
class GenericForeignKey(FieldCacheMixin, Field):
|
||||
"""
|
||||
Provide a generic many-to-one relation through the ``content_type`` and
|
||||
@ -195,7 +209,9 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
||||
if ct_id is not None:
|
||||
fk_val = getattr(instance, self.fk_field)
|
||||
if fk_val is not None:
|
||||
fk_dict[ct_id].add(fk_val)
|
||||
ct = self.get_content_type(id=ct_id)
|
||||
pk_val = deserialize_pk(fk_val, ct)
|
||||
fk_dict[ct_id].add(pk_val)
|
||||
instance_dict[ct_id] = instance
|
||||
|
||||
ret_val = []
|
||||
@ -225,7 +241,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
||||
|
||||
return (
|
||||
ret_val,
|
||||
lambda obj: (obj.pk, obj.__class__),
|
||||
lambda obj: (serialize_pk(obj), obj.__class__),
|
||||
gfk_key,
|
||||
True,
|
||||
self.name,
|
||||
@ -242,15 +258,15 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
||||
# use ContentType.objects.get_for_id(), which has a global cache.
|
||||
f = self.model._meta.get_field(self.ct_field)
|
||||
ct_id = getattr(instance, f.attname, None)
|
||||
pk_val = getattr(instance, self.fk_field)
|
||||
fk_val = getattr(instance, self.fk_field)
|
||||
|
||||
rel_obj = self.get_cached_value(instance, default=None)
|
||||
if rel_obj is None and self.is_cached(instance):
|
||||
return rel_obj
|
||||
if rel_obj is not None:
|
||||
ct_match = (
|
||||
ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id
|
||||
)
|
||||
ct = self.get_content_type(obj=rel_obj, using=instance._state.db)
|
||||
ct_match = ct_id == ct.id
|
||||
pk_val = deserialize_pk(fk_val, ct)
|
||||
pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk
|
||||
if pk_match:
|
||||
return rel_obj
|
||||
@ -258,6 +274,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
||||
rel_obj = None
|
||||
if ct_id is not None:
|
||||
ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
||||
pk_val = deserialize_pk(fk_val, ct)
|
||||
try:
|
||||
rel_obj = ct.get_object_for_this_type(
|
||||
using=instance._state.db, pk=pk_val
|
||||
@ -272,7 +289,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
||||
fk = None
|
||||
if value is not None:
|
||||
ct = self.get_content_type(obj=value)
|
||||
fk = value.pk
|
||||
fk = serialize_pk(value)
|
||||
|
||||
setattr(instance, self.ct_field, ct)
|
||||
setattr(instance, self.fk_field, fk)
|
||||
@ -541,7 +558,8 @@ class GenericRelation(ForeignObject):
|
||||
% self.content_type_field_name: ContentType.objects.db_manager(using)
|
||||
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
|
||||
.pk,
|
||||
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs],
|
||||
"%s__in"
|
||||
% self.object_id_field_name: [serialize_pk(obj) for obj in objs],
|
||||
}
|
||||
)
|
||||
|
||||
@ -589,7 +607,7 @@ def create_generic_related_manager(superclass, rel):
|
||||
self.content_type_field_name = rel.field.content_type_field_name
|
||||
self.object_id_field_name = rel.field.object_id_field_name
|
||||
self.prefetch_cache_name = rel.field.attname
|
||||
self.pk_val = instance.pk
|
||||
self.pk_val = serialize_pk(instance)
|
||||
|
||||
self.core_filters = {
|
||||
"%s__pk" % self.content_type_field_name: self.content_type.id,
|
||||
|
@ -804,3 +804,9 @@ class BaseDatabaseOperations:
|
||||
rhs_expr = Col(rhs_table, rhs_field)
|
||||
|
||||
return lhs_expr, rhs_expr
|
||||
|
||||
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"prepare_join_composite_pk_on_json_array() method"
|
||||
)
|
||||
|
@ -3,8 +3,10 @@ import uuid
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models import Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models import Exists, ExpressionWrapper, Func, Lookup, UUIDField, Value
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.fields.json import KeyTextTransform
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
@ -453,3 +455,21 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
|
||||
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||
# e.g. `a`.`id` = CAST((`b`.`object_id` ->> '$[0]') AS signed integer)
|
||||
json_array = rhs
|
||||
json_element = KeyTextTransform(index, json_array)
|
||||
|
||||
if isinstance(lhs.field, UUIDField):
|
||||
json_element = Func(
|
||||
json_element,
|
||||
Value("-"),
|
||||
Value(""),
|
||||
function="REPLACE",
|
||||
output_field=UUIDField(),
|
||||
)
|
||||
if json_element.field != lhs.field:
|
||||
json_element = Cast(json_element, lhs.field)
|
||||
|
||||
return lhs, json_element
|
||||
|
@ -6,8 +6,18 @@ from django.conf import settings
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
|
||||
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models import (
|
||||
AutoField,
|
||||
DateTimeField,
|
||||
Exists,
|
||||
ExpressionWrapper,
|
||||
JSONField,
|
||||
Lookup,
|
||||
UUIDField,
|
||||
)
|
||||
from django.db.models.expressions import Func, RawSQL, Value
|
||||
from django.db.models.fields.json import KeyTextTransform
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.sql.where import WhereNode
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
@ -726,3 +736,27 @@ END;
|
||||
if isinstance(expression, RawSQL) and expression.conditional:
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||
json_array = Cast(rhs, JSONField())
|
||||
json_element = KeyTextTransform(index, json_array)
|
||||
|
||||
if isinstance(lhs.field, UUIDField):
|
||||
json_element = Func(
|
||||
json_element,
|
||||
Value("-"),
|
||||
Value(""),
|
||||
function="REPLACE",
|
||||
output_field=UUIDField(),
|
||||
)
|
||||
if isinstance(lhs.field, DateTimeField):
|
||||
json_element = Func(
|
||||
json_element,
|
||||
Value('YYYY-MM-DD"T"HH24:MI:SS'),
|
||||
function="TO_TIMESTAMP",
|
||||
output_field=DateTimeField(),
|
||||
)
|
||||
if json_element.field != lhs.field:
|
||||
json_element = Cast(json_element, lhs.field)
|
||||
|
||||
return lhs, json_element
|
||||
|
@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
|
||||
)
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.fields.json import JSONField, KeyTextTransform
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||
|
||||
return lhs_expr, rhs_expr
|
||||
|
||||
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||
# e.g. `a`.`id` = ((`b`.`object_id`)::jsonb ->> 0)::integer
|
||||
json_array = Cast(rhs, JSONField())
|
||||
json_element = KeyTextTransform(index, json_array)
|
||||
|
||||
if json_element.field != lhs.field:
|
||||
json_element = Cast(json_element, lhs.field)
|
||||
|
||||
return lhs, json_element
|
||||
|
@ -8,12 +8,15 @@ from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, NotSupportedError, models
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models import DateTimeField, JSONField, UUIDField
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.expressions import Col
|
||||
from django.db.models.expressions import Col, Func, Value
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from ...models.fields.json import KeyTextTransform
|
||||
from .base import Database
|
||||
|
||||
|
||||
@ -431,3 +434,24 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
|
||||
def force_group_by(self):
|
||||
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []
|
||||
|
||||
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||
json_array = Cast(rhs, JSONField())
|
||||
json_element = KeyTextTransform(index, json_array)
|
||||
|
||||
if isinstance(lhs.field, UUIDField):
|
||||
json_element = Func(
|
||||
json_element,
|
||||
Value("-"),
|
||||
Value(""),
|
||||
function="REPLACE",
|
||||
output_field=UUIDField(),
|
||||
)
|
||||
if json_element.field != lhs.field:
|
||||
json_element = Cast(json_element, lhs.field)
|
||||
if isinstance(lhs.field, DateTimeField):
|
||||
# If joining DateTimeField in SQLite,
|
||||
# make sure strftime is applied to the left side as well.
|
||||
lhs = Cast(lhs, DateTimeField())
|
||||
|
||||
return lhs, json_element
|
||||
|
@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from django.core import checks
|
||||
from django.db.models import NOT_PROVIDED, Field
|
||||
from django.db.models.expressions import ColPairs
|
||||
@ -128,6 +130,40 @@ class CompositePrimaryKey(Field):
|
||||
)
|
||||
]
|
||||
|
||||
def to_json(self, pk):
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
|
||||
fields_len = len(self.fields)
|
||||
pk_len = len(pk)
|
||||
if fields_len != pk_len:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} has {fields_len} fields "
|
||||
f"but it tried to serialize {pk_len}."
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
[value for value in pk],
|
||||
cls=DjangoJSONEncoder,
|
||||
)
|
||||
|
||||
def from_json(self, s):
|
||||
values = json.loads(s)
|
||||
|
||||
fields_len = len(self.fields)
|
||||
values_len = len(values)
|
||||
if fields_len != values_len:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} has {fields_len} fields "
|
||||
f"but it tried to deserialize {values_len}. "
|
||||
f"Did you change the {self.__class__.__name__} fields "
|
||||
"and forgot to update the related GenericForeignKey "
|
||||
'"object_id" fields?'
|
||||
)
|
||||
|
||||
return tuple(
|
||||
field.to_python(value) for (field, value) in zip(self.fields, values)
|
||||
)
|
||||
|
||||
|
||||
CompositePrimaryKey.register_lookup(TupleExact)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models import Field
|
||||
@ -211,9 +212,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
||||
|
||||
|
||||
class TupleIn(TupleLookupMixin, In):
|
||||
def check_rhs_is_iterable(self):
|
||||
if not isinstance(self.rhs, Iterable):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must be an iterable"
|
||||
)
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_is_iterable()
|
||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||
self.check_rhs_elements_length_equals_lhs_length()
|
||||
else:
|
||||
|
@ -6,6 +6,7 @@ the SQL domain.
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import FullResultSet
|
||||
from django.db import models
|
||||
from django.db.models.sql.constants import INNER, LOUTER
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
|
||||
@ -105,6 +106,39 @@ class Join:
|
||||
# the branch for strings.
|
||||
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
|
||||
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
|
||||
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||
elif isinstance(lhs, (models.CharField, models.TextField)) and isinstance(
|
||||
rhs, models.CompositePrimaryKey
|
||||
):
|
||||
lhs_col = lhs.get_col(self.parent_alias)
|
||||
for index, field in enumerate(rhs):
|
||||
rhs_col = field.get_col(self.table_alias)
|
||||
lhs_expr, rhs_expr = (
|
||||
connection.ops.prepare_join_composite_pk_on_json_array(
|
||||
rhs_col, lhs_col, index
|
||||
)
|
||||
)
|
||||
lhs_sql, lhs_params = compiler.compile(lhs_expr)
|
||||
rhs_sql, rhs_params = compiler.compile(rhs_expr)
|
||||
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
elif isinstance(lhs, models.CompositePrimaryKey) and isinstance(
|
||||
rhs, (models.CharField, models.TextField)
|
||||
):
|
||||
rhs_col = rhs.get_col(self.table_alias)
|
||||
for index, field in enumerate(lhs):
|
||||
lhs_col = field.get_col(self.parent_alias)
|
||||
lhs_expr, rhs_expr = (
|
||||
connection.ops.prepare_join_composite_pk_on_json_array(
|
||||
lhs_col, rhs_col, index
|
||||
)
|
||||
)
|
||||
lhs_sql, lhs_params = compiler.compile(lhs_expr)
|
||||
rhs_sql, rhs_params = compiler.compile(rhs_expr)
|
||||
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
else:
|
||||
lhs, rhs = connection.ops.prepare_join_on_clause(
|
||||
self.parent_alias, lhs, self.table_alias, rhs
|
||||
@ -113,7 +147,7 @@ class Join:
|
||||
lhs_full_name = lhs_sql % lhs_params
|
||||
rhs_sql, rhs_params = compiler.compile(rhs)
|
||||
rhs_full_name = rhs_sql % rhs_params
|
||||
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||
|
||||
# Add a single condition inside parentheses for whatever
|
||||
# get_extra_restriction() returns.
|
||||
|
@ -62,8 +62,7 @@ A composite primary key can also be filtered by a ``tuple``:
|
||||
1
|
||||
|
||||
We're still working on composite primary key support for
|
||||
:ref:`relational fields <cpk-and-relations>`, including
|
||||
:class:`.GenericForeignKey` fields, and the Django admin. Models with composite
|
||||
:ref:`relational fields <cpk-and-relations>`, and the Django admin. Models with composite
|
||||
primary keys cannot be registered in the Django admin at this time. You can
|
||||
expect to see this in future releases.
|
||||
|
||||
@ -96,9 +95,8 @@ operation.
|
||||
Composite primary keys and relations
|
||||
====================================
|
||||
|
||||
:ref:`Relationship fields <relationship-fields>`, including
|
||||
:ref:`generic relations <generic-relations>` do not support composite primary
|
||||
keys.
|
||||
:ref:`Relationship fields <relationship-fields>` do not support composite
|
||||
primary keys.
|
||||
|
||||
For example, given the ``OrderLineItem`` model, the following is not
|
||||
supported::
|
||||
@ -131,6 +129,80 @@ database.
|
||||
``ForeignObject`` is an internal API. This means it is not covered by our
|
||||
:ref:`deprecation policy <internal-release-deprecation-policy>`.
|
||||
|
||||
Generic foreign keys & relations
|
||||
================================
|
||||
|
||||
If a :class:`~django.contrib.contenttypes.fields.GenericForeignKey`\'s
|
||||
"object_id" field is a :class:`~django.db.models.CharField` or a
|
||||
:class:`~django.db.models.TextField`, Django will automatically serialize the
|
||||
related ``CompositePrimaryKey`` into a JSON array::
|
||||
|
||||
class Tag:
|
||||
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
|
||||
object_id = models.CharField(max_length=255)
|
||||
content_object = GenericForeignKey("content_type", "object_id")
|
||||
|
||||
This means ``CompositePrimaryKey`` is backwards-compatible with any
|
||||
``GenericForeignKey`` that's using a ``CharField`` or ``TextField`` "object_id".
|
||||
|
||||
.. code-block:: pycon
|
||||
|
||||
>>> item.pk
|
||||
(2, "B142C")
|
||||
>>> tag = Tag(content_object=item)
|
||||
>>> tag.object_id
|
||||
'[2, "B142C"]'
|
||||
|
||||
.. warning:: pk changes are not supported
|
||||
|
||||
Django doesn't automatically update "object_id" if "content_object.pk"
|
||||
changes. If you change an object's ``pk`` in any way, make sure the related
|
||||
generic foreign keys are also updated.
|
||||
|
||||
:class:`~django.contrib.contenttypes.fields.GenericRelation`\s on models with
|
||||
a ``CompositePrimaryKey`` are also supported. When filtering on a reverse
|
||||
relationship, Django performs ``JOIN``\s using the appropriate backend-specific
|
||||
JSON functions::
|
||||
|
||||
class OrderLineItem(models.Model):
|
||||
...
|
||||
tags = GenericRelation("Tag", related_query_name="items")
|
||||
|
||||
This means ``CompositePrimaryKey``\s are backwards-compatible with
|
||||
``GenericRelation``\s the same way ``GenericForeignKey``\s are.
|
||||
|
||||
.. code-block:: pycon
|
||||
|
||||
>>> item.tags.all() # OK
|
||||
>>> Tag.objects.filter(items__quantity__gte=2) # OK
|
||||
|
||||
.. warning:: JOINs on JSON functions can be slow
|
||||
|
||||
While performing ``JOIN``\s on JSON functions is convenient, it can lead to
|
||||
performance issues. Exercise caution when filtering on the reverse
|
||||
relationship of a ``GenericRelation``.
|
||||
|
||||
.. warning:: MariaDB support
|
||||
|
||||
At the time of writing this, MariaDB's JSON_UNQUOTE function cannot process
|
||||
surrogate pairs (e.g. emojis). If you're planning to use JOINs on JSON
|
||||
functions on MariaDB, make sure the primary key doesn't contain any emojis.
|
||||
|
||||
.. admonition:: Supported fields
|
||||
|
||||
``JOIN``s on JSON functions may not work if the ``CompositePrimaryKey``
|
||||
has unsupported fields.
|
||||
|
||||
The supported fields are:
|
||||
|
||||
- ``SmallIntegerField``
|
||||
- ``IntegerField``
|
||||
- ``BigIntegerField``
|
||||
- ``CharField``
|
||||
- ``UUIDField``
|
||||
- ``DateField``
|
||||
- ``DateTimeField``
|
||||
|
||||
Composite primary keys and database functions
|
||||
=============================================
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
from .tenant import Comment, Post, Tenant, Token, User
|
||||
from .generic import Dummy
|
||||
from .tenant import Comment, Post, Tag, Tenant, Token, User
|
||||
|
||||
__all__ = [
|
||||
"Comment",
|
||||
"Dummy",
|
||||
"Post",
|
||||
"Tag",
|
||||
"Tenant",
|
||||
"Token",
|
||||
"User",
|
||||
|
24
tests/composite_pk/models/generic.py
Normal file
24
tests/composite_pk/models/generic.py
Normal file
@ -0,0 +1,24 @@
|
||||
from django.contrib.contenttypes.fields import GenericRelation
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Dummy(models.Model):
|
||||
pk = models.CompositePrimaryKey(
|
||||
"small_integer",
|
||||
"integer",
|
||||
"big_integer",
|
||||
"datetime",
|
||||
"date",
|
||||
"uuid",
|
||||
"char",
|
||||
)
|
||||
|
||||
small_integer = models.SmallIntegerField()
|
||||
integer = models.IntegerField()
|
||||
big_integer = models.BigIntegerField()
|
||||
datetime = models.DateTimeField()
|
||||
date = models.DateField()
|
||||
uuid = models.UUIDField()
|
||||
char = models.CharField(max_length=5)
|
||||
|
||||
tags = GenericRelation("Tag", related_query_name="dummy")
|
@ -1,3 +1,5 @@
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
|
||||
|
||||
@ -48,3 +50,13 @@ class Post(models.Model):
|
||||
pk = models.CompositePrimaryKey("tenant_id", "id")
|
||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
||||
id = models.UUIDField()
|
||||
tags = GenericRelation("Tag", related_query_name="post")
|
||||
|
||||
|
||||
class Tag(models.Model):
|
||||
name = models.CharField(max_length=10)
|
||||
content_type = models.ForeignKey(
|
||||
ContentType, on_delete=models.CASCADE, related_name="composite_pk_tags"
|
||||
)
|
||||
object_id = models.CharField(max_length=255)
|
||||
content_object = GenericForeignKey("content_type", "object_id")
|
||||
|
314
tests/composite_pk/test_generic.py
Normal file
314
tests/composite_pk/test_generic.py
Normal file
@ -0,0 +1,314 @@
|
||||
from datetime import date, datetime
|
||||
from unittest import skipIf
|
||||
from uuid import UUID
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.contrib.contenttypes.prefetch import GenericPrefetch
|
||||
from django.db import connection
|
||||
from django.db.models import Count
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from .models import Comment, Dummy, Post, Tag, Tenant, User
|
||||
|
||||
|
||||
class CompositePKGenericTests(TestCase):
|
||||
POST_1_ID = "e1516ac0-4469-4306-b0ac-2c435e677aa4"
|
||||
DUMMY_1_UUID = "81598d5a-fa03-4800-bae4-823ca12824d5"
|
||||
DUMMY_2_UUID = "ac567f4c-9b6e-4770-86a6-1abf071a2958"
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.dummy_1 = Dummy.objects.create(
|
||||
small_integer=32767,
|
||||
integer=2147483647,
|
||||
big_integer=9223372036854775807,
|
||||
datetime=datetime(2024, 11, 30, 6, 26, 1),
|
||||
date=date(2024, 11, 30),
|
||||
uuid=UUID(cls.DUMMY_1_UUID),
|
||||
char="薑戈",
|
||||
)
|
||||
cls.dummy_2 = Dummy.objects.create(
|
||||
small_integer=-32768,
|
||||
integer=-2147483648,
|
||||
big_integer=-9223372036854775808,
|
||||
datetime=datetime(2024, 12, 8, 1, 2, 3),
|
||||
date=date(2024, 12, 8),
|
||||
uuid=UUID(cls.DUMMY_2_UUID),
|
||||
char="😊",
|
||||
)
|
||||
cls.tenant_1 = Tenant.objects.create()
|
||||
cls.tenant_2 = Tenant.objects.create()
|
||||
cls.user_1 = User.objects.create(
|
||||
tenant=cls.tenant_1,
|
||||
id=1,
|
||||
email="user0001@example.com",
|
||||
)
|
||||
cls.user_2 = User.objects.create(
|
||||
tenant=cls.tenant_1,
|
||||
id=2,
|
||||
email="user0002@example.com",
|
||||
)
|
||||
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
|
||||
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
|
||||
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID))
|
||||
cls.comment_1_tag = Tag.objects.create(
|
||||
name="comment_1", content_object=cls.comment_1
|
||||
)
|
||||
cls.post_1_tag = Tag.objects.create(name="post_1", content_object=cls.post_1)
|
||||
cls.dummy_1_tag = Tag.objects.create(name="dummy_1", content_object=cls.dummy_1)
|
||||
cls.dummy_2_tag = Tag.objects.create(name="dummy_2", content_object=cls.dummy_2)
|
||||
cls.comment_ct = ContentType.objects.get_for_model(Comment)
|
||||
cls.post_ct = ContentType.objects.get_for_model(Post)
|
||||
cls.dummy_ct = ContentType.objects.get_for_model(Dummy)
|
||||
cls.post_1_fk = f'[{cls.tenant_1.id}, "{cls.POST_1_ID}"]'
|
||||
cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]"
|
||||
cls.dummy_1_fk = (
|
||||
'[32767, 2147483647, 9223372036854775807, "2024-11-30T06:26:01", '
|
||||
f'"2024-11-30", "{cls.DUMMY_1_UUID}", "\\u8591\\u6208"]'
|
||||
)
|
||||
cls.dummy_2_fk = (
|
||||
'[-32768, -2147483648, -9223372036854775808, "2024-12-08T01:02:03", '
|
||||
f'"2024-12-08", "{cls.DUMMY_2_UUID}", "\\ud83d\\ude0a"]'
|
||||
)
|
||||
|
||||
def test_fields(self):
|
||||
self.assertEqual(self.comment_1_tag.content_type, self.comment_ct)
|
||||
self.assertEqual(self.comment_1_tag.object_id, self.comment_1_fk)
|
||||
self.assertEqual(self.comment_1_tag.content_object, self.comment_1)
|
||||
|
||||
self.assertEqual(self.post_1_tag.content_type, self.post_ct)
|
||||
self.assertEqual(self.post_1_tag.object_id, self.post_1_fk)
|
||||
self.assertEqual(self.post_1_tag.content_object, self.post_1)
|
||||
|
||||
self.assertEqual(self.dummy_1_tag.content_type, self.dummy_ct)
|
||||
self.assertEqual(self.dummy_1_tag.object_id, self.dummy_1_fk)
|
||||
self.assertEqual(self.dummy_1_tag.content_object, self.dummy_1)
|
||||
|
||||
self.assertEqual(self.dummy_2_tag.content_type, self.dummy_ct)
|
||||
self.assertEqual(self.dummy_2_tag.object_id, self.dummy_2_fk)
|
||||
self.assertEqual(self.dummy_2_tag.content_object, self.dummy_2)
|
||||
|
||||
self.assertSequenceEqual(self.post_1.tags.all(), (self.post_1_tag,))
|
||||
self.assertSequenceEqual(self.dummy_1.tags.all(), (self.dummy_1_tag,))
|
||||
self.assertSequenceEqual(self.dummy_2.tags.all(), (self.dummy_2_tag,))
|
||||
|
||||
def test_fields_before_save(self):
|
||||
comment = Comment(pk=(1, 2))
|
||||
tag = Tag(content_object=comment)
|
||||
self.assertEqual(tag.object_id, "[1, 2]")
|
||||
comment.id = 3
|
||||
self.assertEqual(tag.object_id, "[1, 2]")
|
||||
|
||||
def test_cascade_delete_if_generic_relation(self):
|
||||
Post.objects.get(pk=self.post_1.pk).delete()
|
||||
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||
|
||||
Dummy.objects.get(pk=self.dummy_1.pk).delete()
|
||||
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||
|
||||
def test_no_cascade_delete_if_no_generic_relation(self):
|
||||
Comment.objects.get(pk=self.comment_1.pk).delete()
|
||||
comment_1_tag = Tag.objects.get(pk=self.comment_1_tag.pk)
|
||||
self.assertIsNone(comment_1_tag.content_object)
|
||||
|
||||
def test_tags_clear(self):
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
post_1.tags.clear()
|
||||
self.assertEqual(post_1.tags.count(), 0)
|
||||
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||
|
||||
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
|
||||
dummy_1.tags.clear()
|
||||
self.assertEqual(dummy_1.tags.count(), 0)
|
||||
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||
|
||||
def test_tags_remove(self):
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
post_1.tags.remove(self.post_1_tag)
|
||||
self.assertEqual(post_1.tags.count(), 0)
|
||||
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||
|
||||
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
|
||||
dummy_1.tags.remove(self.dummy_1_tag)
|
||||
self.assertEqual(dummy_1.tags.count(), 0)
|
||||
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||
|
||||
def test_tags_create(self):
|
||||
tag_count = Tag.objects.count()
|
||||
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
post_1.tags.create(name="foo")
|
||||
self.assertEqual(post_1.tags.count(), 2)
|
||||
self.assertEqual(Tag.objects.count(), tag_count + 1)
|
||||
|
||||
tag = Tag.objects.get(name="foo")
|
||||
self.assertEqual(tag.content_type, self.post_ct)
|
||||
self.assertEqual(tag.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag.content_object, post_1)
|
||||
|
||||
def test_tags_add(self):
|
||||
tag_count = Tag.objects.count()
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
|
||||
tag_1 = Tag(name="foo")
|
||||
post_1.tags.add(tag_1, bulk=False)
|
||||
self.assertEqual(post_1.tags.count(), 2)
|
||||
self.assertEqual(Tag.objects.count(), tag_count + 1)
|
||||
|
||||
tag_1 = Tag.objects.get(name="foo")
|
||||
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_1.content_object, post_1)
|
||||
|
||||
tag_2 = Tag.objects.create(name="bar", content_object=self.comment_2)
|
||||
post_1.tags.add(tag_2)
|
||||
self.assertEqual(post_1.tags.count(), 3)
|
||||
self.assertEqual(Tag.objects.count(), tag_count + 2)
|
||||
|
||||
tag_2 = Tag.objects.get(name="bar")
|
||||
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_2.content_object, post_1)
|
||||
|
||||
def test_tags_set(self):
|
||||
tag_count = Tag.objects.count()
|
||||
comment_1_tag = Tag.objects.get(name=self.comment_1_tag.name)
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
post_1.tags.set([comment_1_tag])
|
||||
self.assertEqual(post_1.tags.count(), 1)
|
||||
self.assertEqual(Tag.objects.count(), tag_count - 1)
|
||||
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||
|
||||
def test_tags_get_or_create(self):
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
|
||||
tag_1, created = post_1.tags.get_or_create(name=self.post_1_tag.name)
|
||||
self.assertFalse(created)
|
||||
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
|
||||
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_1.content_object, post_1)
|
||||
|
||||
tag_2, created = post_1.tags.get_or_create(name="foo")
|
||||
self.assertTrue(created)
|
||||
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_2.content_object, post_1)
|
||||
|
||||
def test_tags_update_or_create(self):
|
||||
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||
|
||||
tag_1, created = post_1.tags.update_or_create(
|
||||
name=self.post_1_tag.name, defaults={"name": "foo"}
|
||||
)
|
||||
self.assertFalse(created)
|
||||
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
|
||||
self.assertEqual(tag_1.name, "foo")
|
||||
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_1.content_object, post_1)
|
||||
|
||||
tag_2, created = post_1.tags.update_or_create(name="bar")
|
||||
self.assertTrue(created)
|
||||
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||
self.assertEqual(tag_2.content_object, post_1)
|
||||
|
||||
def test_filter_by_related_query_name(self):
|
||||
self.assertSequenceEqual(
|
||||
Tag.objects.filter(post__id=self.post_1.id), (self.post_1_tag,)
|
||||
)
|
||||
self.assertSequenceEqual(
|
||||
Tag.objects.filter(dummy__big_integer=self.dummy_1.big_integer),
|
||||
(self.dummy_1_tag,),
|
||||
)
|
||||
|
||||
@skipIf(
|
||||
connection.vendor == "mysql" and connection.mysql_is_mariadb,
|
||||
"MariaDB's JSON_UNQUOTE doesn't support surrogate pairs "
|
||||
"(https://jira.mariadb.org/browse/MDEV-21124)",
|
||||
)
|
||||
def test_aggregate(self):
|
||||
with self.subTest("Post"):
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
self.assertEqual(
|
||||
Post.objects.aggregate(Count("tags")),
|
||||
{"tags__count": 1},
|
||||
ctx[-1]["sql"],
|
||||
)
|
||||
with self.subTest("Dummy"):
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
self.assertEqual(
|
||||
Dummy.objects.aggregate(Count("tags")),
|
||||
{"tags__count": 2},
|
||||
ctx[-1]["sql"],
|
||||
)
|
||||
|
||||
def test_generic_prefetch(self):
|
||||
tags = Tag.objects.prefetch_related(
|
||||
GenericPrefetch(
|
||||
"content_object",
|
||||
[
|
||||
Post.objects.all(),
|
||||
Comment.objects.all(),
|
||||
Dummy.objects.all(),
|
||||
],
|
||||
)
|
||||
).order_by("pk")
|
||||
|
||||
self.assertEqual(len(tags), 4)
|
||||
self.assertEqual(tags[0], self.comment_1_tag)
|
||||
self.assertEqual(tags[1], self.post_1_tag)
|
||||
self.assertEqual(tags[2], self.dummy_1_tag)
|
||||
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(tags[0].content_object, self.comment_1)
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(tags[1].content_object, self.post_1)
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(tags[2].content_object, self.dummy_1)
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(tags[3].content_object, self.dummy_2)
|
||||
|
||||
def test_to_json(self):
|
||||
field = Post._meta.pk
|
||||
self.assertEqual(
|
||||
field.to_json((1, "004bc28c-a085-44ca-a823-747dcf4ddfd3")),
|
||||
'[1, "004bc28c-a085-44ca-a823-747dcf4ddfd3"]',
|
||||
)
|
||||
self.assertEqual(
|
||||
field.to_json((2, UUID("a5a626e5-d34a-4197-94ea-83547029d6ed"))),
|
||||
'[2, "a5a626e5-d34a-4197-94ea-83547029d6ed"]',
|
||||
)
|
||||
|
||||
def test_to_json_length_mismatch(self):
|
||||
field = Post._meta.pk
|
||||
msg = "CompositePrimaryKey has 2 fields but it tried to serialize %s."
|
||||
with self.assertRaisesMessage(ValueError, msg % 1):
|
||||
self.assertIsNone(field.to_json((1,)))
|
||||
with self.assertRaisesMessage(ValueError, msg % 3):
|
||||
self.assertIsNone(field.to_json((1, 2, 3)))
|
||||
|
||||
def test_from_json(self):
|
||||
field = Post._meta.pk
|
||||
self.assertEqual(
|
||||
field.from_json('[1, "cb6b99bc-66b0-497f-bafe-193b90af6296"]'),
|
||||
(1, UUID("cb6b99bc-66b0-497f-bafe-193b90af6296")),
|
||||
)
|
||||
self.assertEqual(
|
||||
field.from_json('[2, "1d5c63dda5264a12a2ec929d51e04430"]'),
|
||||
(2, UUID("1d5c63dd-a526-4a12-a2ec-929d51e04430")),
|
||||
)
|
||||
|
||||
def test_from_json_length_mismatch(self):
|
||||
field = Post._meta.pk
|
||||
msg = (
|
||||
"CompositePrimaryKey has 2 fields but it tried to deserialize %s. "
|
||||
"Did you change the CompositePrimaryKey fields and forgot to "
|
||||
'update the related GenericForeignKey "object_id" fields?'
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg % 1):
|
||||
self.assertIsNone(field.from_json("[1]"))
|
||||
with self.assertRaisesMessage(ValueError, msg % 3):
|
||||
self.assertIsNone(field.from_json("[1, 2, 3]"))
|
@ -92,6 +92,7 @@ class CompositePKGetTests(TestCase):
|
||||
|
||||
def test_lookup_errors(self):
|
||||
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
|
||||
m_iterable = "'%s' lookup of 'pk' must be an iterable"
|
||||
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
|
||||
m_tuple_collection = (
|
||||
"'in' lookup of 'pk' must be a collection of tuples or lists"
|
||||
@ -102,7 +103,7 @@ class CompositePKGetTests(TestCase):
|
||||
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
|
||||
({"pk__exact": 1}, m_tuple % "exact"),
|
||||
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
|
||||
({"pk__in": 1}, m_tuple % "in"),
|
||||
({"pk__in": 1}, m_iterable % "in"),
|
||||
({"pk__in": (1, 2, 3)}, m_tuple_collection),
|
||||
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
|
||||
({"pk__gt": 1}, m_tuple % "gt"),
|
||||
|
@ -474,7 +474,6 @@ class TupleLookupsTests(TestCase):
|
||||
TupleGreaterThanOrEqual,
|
||||
TupleLessThan,
|
||||
TupleLessThanOrEqual,
|
||||
TupleIn,
|
||||
),
|
||||
(
|
||||
0,
|
||||
@ -496,6 +495,24 @@ class TupleLookupsTests(TestCase):
|
||||
):
|
||||
lookup_cls((F("customer_code"), F("company_code")), rhs)
|
||||
|
||||
def test_tuple_in_lookup_rhs_must_be_iterable(self):
|
||||
test_cases = (
|
||||
0,
|
||||
1,
|
||||
None,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
for rhs in test_cases:
|
||||
with self.subTest(lookup_name="in", rhs=rhs):
|
||||
with self.assertRaisesMessage(
|
||||
ValueError,
|
||||
"'in' lookup of ('customer_code', 'company_code') "
|
||||
"must be an iterable",
|
||||
):
|
||||
TupleIn((F("customer_code"), F("company_code")), rhs)
|
||||
|
||||
def test_tuple_lookup_rhs_must_have_2_elements(self):
|
||||
test_cases = itertools.product(
|
||||
(
|
||||
|
Loading…
Reference in New Issue
Block a user