mirror of
https://github.com/django/django.git
synced 2025-06-11 06:29:13 +00:00
Fixed #35941 -- Added composite GenericForeignKey support.
This commit is contained in:
parent
1860a1afc9
commit
b641b6a90e
@ -25,6 +25,20 @@ from django.utils.deprecation import RemovedInDjango60Warning
|
|||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_pk(obj):
|
||||||
|
opts = obj._meta
|
||||||
|
if opts.is_composite_pk:
|
||||||
|
return opts.pk.to_json(obj.pk)
|
||||||
|
return obj.pk
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_pk(pk, ct):
|
||||||
|
opts = ct.model_class()._meta
|
||||||
|
if opts.is_composite_pk:
|
||||||
|
return opts.pk.from_json(pk)
|
||||||
|
return pk
|
||||||
|
|
||||||
|
|
||||||
class GenericForeignKey(FieldCacheMixin, Field):
|
class GenericForeignKey(FieldCacheMixin, Field):
|
||||||
"""
|
"""
|
||||||
Provide a generic many-to-one relation through the ``content_type`` and
|
Provide a generic many-to-one relation through the ``content_type`` and
|
||||||
@ -195,7 +209,9 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
|||||||
if ct_id is not None:
|
if ct_id is not None:
|
||||||
fk_val = getattr(instance, self.fk_field)
|
fk_val = getattr(instance, self.fk_field)
|
||||||
if fk_val is not None:
|
if fk_val is not None:
|
||||||
fk_dict[ct_id].add(fk_val)
|
ct = self.get_content_type(id=ct_id)
|
||||||
|
pk_val = deserialize_pk(fk_val, ct)
|
||||||
|
fk_dict[ct_id].add(pk_val)
|
||||||
instance_dict[ct_id] = instance
|
instance_dict[ct_id] = instance
|
||||||
|
|
||||||
ret_val = []
|
ret_val = []
|
||||||
@ -225,7 +241,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
ret_val,
|
ret_val,
|
||||||
lambda obj: (obj.pk, obj.__class__),
|
lambda obj: (serialize_pk(obj), obj.__class__),
|
||||||
gfk_key,
|
gfk_key,
|
||||||
True,
|
True,
|
||||||
self.name,
|
self.name,
|
||||||
@ -242,15 +258,15 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
|||||||
# use ContentType.objects.get_for_id(), which has a global cache.
|
# use ContentType.objects.get_for_id(), which has a global cache.
|
||||||
f = self.model._meta.get_field(self.ct_field)
|
f = self.model._meta.get_field(self.ct_field)
|
||||||
ct_id = getattr(instance, f.attname, None)
|
ct_id = getattr(instance, f.attname, None)
|
||||||
pk_val = getattr(instance, self.fk_field)
|
fk_val = getattr(instance, self.fk_field)
|
||||||
|
|
||||||
rel_obj = self.get_cached_value(instance, default=None)
|
rel_obj = self.get_cached_value(instance, default=None)
|
||||||
if rel_obj is None and self.is_cached(instance):
|
if rel_obj is None and self.is_cached(instance):
|
||||||
return rel_obj
|
return rel_obj
|
||||||
if rel_obj is not None:
|
if rel_obj is not None:
|
||||||
ct_match = (
|
ct = self.get_content_type(obj=rel_obj, using=instance._state.db)
|
||||||
ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id
|
ct_match = ct_id == ct.id
|
||||||
)
|
pk_val = deserialize_pk(fk_val, ct)
|
||||||
pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk
|
pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk
|
||||||
if pk_match:
|
if pk_match:
|
||||||
return rel_obj
|
return rel_obj
|
||||||
@ -258,6 +274,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
|||||||
rel_obj = None
|
rel_obj = None
|
||||||
if ct_id is not None:
|
if ct_id is not None:
|
||||||
ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
||||||
|
pk_val = deserialize_pk(fk_val, ct)
|
||||||
try:
|
try:
|
||||||
rel_obj = ct.get_object_for_this_type(
|
rel_obj = ct.get_object_for_this_type(
|
||||||
using=instance._state.db, pk=pk_val
|
using=instance._state.db, pk=pk_val
|
||||||
@ -272,7 +289,7 @@ class GenericForeignKey(FieldCacheMixin, Field):
|
|||||||
fk = None
|
fk = None
|
||||||
if value is not None:
|
if value is not None:
|
||||||
ct = self.get_content_type(obj=value)
|
ct = self.get_content_type(obj=value)
|
||||||
fk = value.pk
|
fk = serialize_pk(value)
|
||||||
|
|
||||||
setattr(instance, self.ct_field, ct)
|
setattr(instance, self.ct_field, ct)
|
||||||
setattr(instance, self.fk_field, fk)
|
setattr(instance, self.fk_field, fk)
|
||||||
@ -541,7 +558,8 @@ class GenericRelation(ForeignObject):
|
|||||||
% self.content_type_field_name: ContentType.objects.db_manager(using)
|
% self.content_type_field_name: ContentType.objects.db_manager(using)
|
||||||
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
|
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
|
||||||
.pk,
|
.pk,
|
||||||
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs],
|
"%s__in"
|
||||||
|
% self.object_id_field_name: [serialize_pk(obj) for obj in objs],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -589,7 +607,7 @@ def create_generic_related_manager(superclass, rel):
|
|||||||
self.content_type_field_name = rel.field.content_type_field_name
|
self.content_type_field_name = rel.field.content_type_field_name
|
||||||
self.object_id_field_name = rel.field.object_id_field_name
|
self.object_id_field_name = rel.field.object_id_field_name
|
||||||
self.prefetch_cache_name = rel.field.attname
|
self.prefetch_cache_name = rel.field.attname
|
||||||
self.pk_val = instance.pk
|
self.pk_val = serialize_pk(instance)
|
||||||
|
|
||||||
self.core_filters = {
|
self.core_filters = {
|
||||||
"%s__pk" % self.content_type_field_name: self.content_type.id,
|
"%s__pk" % self.content_type_field_name: self.content_type.id,
|
||||||
|
@ -804,3 +804,9 @@ class BaseDatabaseOperations:
|
|||||||
rhs_expr = Col(rhs_table, rhs_field)
|
rhs_expr = Col(rhs_table, rhs_field)
|
||||||
|
|
||||||
return lhs_expr, rhs_expr
|
return lhs_expr, rhs_expr
|
||||||
|
|
||||||
|
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"subclasses of BaseDatabaseOperations may require a "
|
||||||
|
"prepare_join_composite_pk_on_json_array() method"
|
||||||
|
)
|
||||||
|
@ -3,8 +3,10 @@ import uuid
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||||
from django.db.backends.utils import split_tzname_delta
|
from django.db.backends.utils import split_tzname_delta
|
||||||
from django.db.models import Exists, ExpressionWrapper, Lookup
|
from django.db.models import Exists, ExpressionWrapper, Func, Lookup, UUIDField, Value
|
||||||
from django.db.models.constants import OnConflict
|
from django.db.models.constants import OnConflict
|
||||||
|
from django.db.models.fields.json import KeyTextTransform
|
||||||
|
from django.db.models.functions import Cast
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
from django.utils.regex_helper import _lazy_re_compile
|
from django.utils.regex_helper import _lazy_re_compile
|
||||||
@ -453,3 +455,21 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||||||
update_fields,
|
update_fields,
|
||||||
unique_fields,
|
unique_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||||
|
# e.g. `a`.`id` = CAST((`b`.`object_id` ->> '$[0]') AS signed integer)
|
||||||
|
json_array = rhs
|
||||||
|
json_element = KeyTextTransform(index, json_array)
|
||||||
|
|
||||||
|
if isinstance(lhs.field, UUIDField):
|
||||||
|
json_element = Func(
|
||||||
|
json_element,
|
||||||
|
Value("-"),
|
||||||
|
Value(""),
|
||||||
|
function="REPLACE",
|
||||||
|
output_field=UUIDField(),
|
||||||
|
)
|
||||||
|
if json_element.field != lhs.field:
|
||||||
|
json_element = Cast(json_element, lhs.field)
|
||||||
|
|
||||||
|
return lhs, json_element
|
||||||
|
@ -6,8 +6,18 @@ from django.conf import settings
|
|||||||
from django.db import DatabaseError, NotSupportedError
|
from django.db import DatabaseError, NotSupportedError
|
||||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||||
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
|
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
|
||||||
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
|
from django.db.models import (
|
||||||
from django.db.models.expressions import RawSQL
|
AutoField,
|
||||||
|
DateTimeField,
|
||||||
|
Exists,
|
||||||
|
ExpressionWrapper,
|
||||||
|
JSONField,
|
||||||
|
Lookup,
|
||||||
|
UUIDField,
|
||||||
|
)
|
||||||
|
from django.db.models.expressions import Func, RawSQL, Value
|
||||||
|
from django.db.models.fields.json import KeyTextTransform
|
||||||
|
from django.db.models.functions import Cast
|
||||||
from django.db.models.sql.where import WhereNode
|
from django.db.models.sql.where import WhereNode
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.encoding import force_bytes, force_str
|
from django.utils.encoding import force_bytes, force_str
|
||||||
@ -726,3 +736,27 @@ END;
|
|||||||
if isinstance(expression, RawSQL) and expression.conditional:
|
if isinstance(expression, RawSQL) and expression.conditional:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||||
|
json_array = Cast(rhs, JSONField())
|
||||||
|
json_element = KeyTextTransform(index, json_array)
|
||||||
|
|
||||||
|
if isinstance(lhs.field, UUIDField):
|
||||||
|
json_element = Func(
|
||||||
|
json_element,
|
||||||
|
Value("-"),
|
||||||
|
Value(""),
|
||||||
|
function="REPLACE",
|
||||||
|
output_field=UUIDField(),
|
||||||
|
)
|
||||||
|
if isinstance(lhs.field, DateTimeField):
|
||||||
|
json_element = Func(
|
||||||
|
json_element,
|
||||||
|
Value('YYYY-MM-DD"T"HH24:MI:SS'),
|
||||||
|
function="TO_TIMESTAMP",
|
||||||
|
output_field=DateTimeField(),
|
||||||
|
)
|
||||||
|
if json_element.field != lhs.field:
|
||||||
|
json_element = Cast(json_element, lhs.field)
|
||||||
|
|
||||||
|
return lhs, json_element
|
||||||
|
@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
|
|||||||
)
|
)
|
||||||
from django.db.backends.utils import split_tzname_delta
|
from django.db.backends.utils import split_tzname_delta
|
||||||
from django.db.models.constants import OnConflict
|
from django.db.models.constants import OnConflict
|
||||||
|
from django.db.models.fields.json import JSONField, KeyTextTransform
|
||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast
|
||||||
from django.utils.regex_helper import _lazy_re_compile
|
from django.utils.regex_helper import _lazy_re_compile
|
||||||
|
|
||||||
@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||||||
rhs_expr = Cast(rhs_expr, lhs_field)
|
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||||
|
|
||||||
return lhs_expr, rhs_expr
|
return lhs_expr, rhs_expr
|
||||||
|
|
||||||
|
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||||
|
# e.g. `a`.`id` = ((`b`.`object_id`)::jsonb ->> 0)::integer
|
||||||
|
json_array = Cast(rhs, JSONField())
|
||||||
|
json_element = KeyTextTransform(index, json_array)
|
||||||
|
|
||||||
|
if json_element.field != lhs.field:
|
||||||
|
json_element = Cast(json_element, lhs.field)
|
||||||
|
|
||||||
|
return lhs, json_element
|
||||||
|
@ -8,12 +8,15 @@ from django.conf import settings
|
|||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import DatabaseError, NotSupportedError, models
|
from django.db import DatabaseError, NotSupportedError, models
|
||||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||||
|
from django.db.models import DateTimeField, JSONField, UUIDField
|
||||||
from django.db.models.constants import OnConflict
|
from django.db.models.constants import OnConflict
|
||||||
from django.db.models.expressions import Col
|
from django.db.models.expressions import Col, Func, Value
|
||||||
|
from django.db.models.functions import Cast
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
from ...models.fields.json import KeyTextTransform
|
||||||
from .base import Database
|
from .base import Database
|
||||||
|
|
||||||
|
|
||||||
@ -431,3 +434,24 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||||||
|
|
||||||
def force_group_by(self):
|
def force_group_by(self):
|
||||||
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []
|
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []
|
||||||
|
|
||||||
|
def prepare_join_composite_pk_on_json_array(self, lhs, rhs, index):
|
||||||
|
json_array = Cast(rhs, JSONField())
|
||||||
|
json_element = KeyTextTransform(index, json_array)
|
||||||
|
|
||||||
|
if isinstance(lhs.field, UUIDField):
|
||||||
|
json_element = Func(
|
||||||
|
json_element,
|
||||||
|
Value("-"),
|
||||||
|
Value(""),
|
||||||
|
function="REPLACE",
|
||||||
|
output_field=UUIDField(),
|
||||||
|
)
|
||||||
|
if json_element.field != lhs.field:
|
||||||
|
json_element = Cast(json_element, lhs.field)
|
||||||
|
if isinstance(lhs.field, DateTimeField):
|
||||||
|
# If joining DateTimeField in SQLite,
|
||||||
|
# make sure strftime is applied to the left side as well.
|
||||||
|
lhs = Cast(lhs, DateTimeField())
|
||||||
|
|
||||||
|
return lhs, json_element
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from django.core import checks
|
from django.core import checks
|
||||||
from django.db.models import NOT_PROVIDED, Field
|
from django.db.models import NOT_PROVIDED, Field
|
||||||
from django.db.models.expressions import ColPairs
|
from django.db.models.expressions import ColPairs
|
||||||
@ -128,6 +130,40 @@ class CompositePrimaryKey(Field):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def to_json(self, pk):
|
||||||
|
from django.core.serializers.json import DjangoJSONEncoder
|
||||||
|
|
||||||
|
fields_len = len(self.fields)
|
||||||
|
pk_len = len(pk)
|
||||||
|
if fields_len != pk_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} has {fields_len} fields "
|
||||||
|
f"but it tried to serialize {pk_len}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps(
|
||||||
|
[value for value in pk],
|
||||||
|
cls=DjangoJSONEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_json(self, s):
|
||||||
|
values = json.loads(s)
|
||||||
|
|
||||||
|
fields_len = len(self.fields)
|
||||||
|
values_len = len(values)
|
||||||
|
if fields_len != values_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} has {fields_len} fields "
|
||||||
|
f"but it tried to deserialize {values_len}. "
|
||||||
|
f"Did you change the {self.__class__.__name__} fields "
|
||||||
|
"and forgot to update the related GenericForeignKey "
|
||||||
|
'"object_id" fields?'
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(
|
||||||
|
field.to_python(value) for (field, value) in zip(self.fields, values)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CompositePrimaryKey.register_lookup(TupleExact)
|
CompositePrimaryKey.register_lookup(TupleExact)
|
||||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from django.core.exceptions import EmptyResultSet
|
from django.core.exceptions import EmptyResultSet
|
||||||
from django.db.models import Field
|
from django.db.models import Field
|
||||||
@ -211,9 +212,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
|||||||
|
|
||||||
|
|
||||||
class TupleIn(TupleLookupMixin, In):
|
class TupleIn(TupleLookupMixin, In):
|
||||||
|
def check_rhs_is_iterable(self):
|
||||||
|
if not isinstance(self.rhs, Iterable):
|
||||||
|
lhs_str = self.get_lhs_str()
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.lookup_name!r} lookup of {lhs_str} must be an iterable"
|
||||||
|
)
|
||||||
|
|
||||||
def get_prep_lookup(self):
|
def get_prep_lookup(self):
|
||||||
if self.rhs_is_direct_value():
|
if self.rhs_is_direct_value():
|
||||||
self.check_rhs_is_tuple_or_list()
|
self.check_rhs_is_iterable()
|
||||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||||
self.check_rhs_elements_length_equals_lhs_length()
|
self.check_rhs_elements_length_equals_lhs_length()
|
||||||
else:
|
else:
|
||||||
|
@ -6,6 +6,7 @@ the SQL domain.
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from django.core.exceptions import FullResultSet
|
from django.core.exceptions import FullResultSet
|
||||||
|
from django.db import models
|
||||||
from django.db.models.sql.constants import INNER, LOUTER
|
from django.db.models.sql.constants import INNER, LOUTER
|
||||||
from django.utils.deprecation import RemovedInDjango60Warning
|
from django.utils.deprecation import RemovedInDjango60Warning
|
||||||
|
|
||||||
@ -105,6 +106,39 @@ class Join:
|
|||||||
# the branch for strings.
|
# the branch for strings.
|
||||||
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
|
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
|
||||||
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
|
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
|
||||||
|
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||||
|
elif isinstance(lhs, (models.CharField, models.TextField)) and isinstance(
|
||||||
|
rhs, models.CompositePrimaryKey
|
||||||
|
):
|
||||||
|
lhs_col = lhs.get_col(self.parent_alias)
|
||||||
|
for index, field in enumerate(rhs):
|
||||||
|
rhs_col = field.get_col(self.table_alias)
|
||||||
|
lhs_expr, rhs_expr = (
|
||||||
|
connection.ops.prepare_join_composite_pk_on_json_array(
|
||||||
|
rhs_col, lhs_col, index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lhs_sql, lhs_params = compiler.compile(lhs_expr)
|
||||||
|
rhs_sql, rhs_params = compiler.compile(rhs_expr)
|
||||||
|
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
|
||||||
|
params.extend(lhs_params)
|
||||||
|
params.extend(rhs_params)
|
||||||
|
elif isinstance(lhs, models.CompositePrimaryKey) and isinstance(
|
||||||
|
rhs, (models.CharField, models.TextField)
|
||||||
|
):
|
||||||
|
rhs_col = rhs.get_col(self.table_alias)
|
||||||
|
for index, field in enumerate(lhs):
|
||||||
|
lhs_col = field.get_col(self.parent_alias)
|
||||||
|
lhs_expr, rhs_expr = (
|
||||||
|
connection.ops.prepare_join_composite_pk_on_json_array(
|
||||||
|
lhs_col, rhs_col, index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lhs_sql, lhs_params = compiler.compile(lhs_expr)
|
||||||
|
rhs_sql, rhs_params = compiler.compile(rhs_expr)
|
||||||
|
join_conditions.append(f"{lhs_sql} = {rhs_sql}")
|
||||||
|
params.extend(lhs_params)
|
||||||
|
params.extend(rhs_params)
|
||||||
else:
|
else:
|
||||||
lhs, rhs = connection.ops.prepare_join_on_clause(
|
lhs, rhs = connection.ops.prepare_join_on_clause(
|
||||||
self.parent_alias, lhs, self.table_alias, rhs
|
self.parent_alias, lhs, self.table_alias, rhs
|
||||||
@ -113,7 +147,7 @@ class Join:
|
|||||||
lhs_full_name = lhs_sql % lhs_params
|
lhs_full_name = lhs_sql % lhs_params
|
||||||
rhs_sql, rhs_params = compiler.compile(rhs)
|
rhs_sql, rhs_params = compiler.compile(rhs)
|
||||||
rhs_full_name = rhs_sql % rhs_params
|
rhs_full_name = rhs_sql % rhs_params
|
||||||
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||||
|
|
||||||
# Add a single condition inside parentheses for whatever
|
# Add a single condition inside parentheses for whatever
|
||||||
# get_extra_restriction() returns.
|
# get_extra_restriction() returns.
|
||||||
|
@ -62,8 +62,7 @@ A composite primary key can also be filtered by a ``tuple``:
|
|||||||
1
|
1
|
||||||
|
|
||||||
We're still working on composite primary key support for
|
We're still working on composite primary key support for
|
||||||
:ref:`relational fields <cpk-and-relations>`, including
|
:ref:`relational fields <cpk-and-relations>`, and the Django admin. Models with composite
|
||||||
:class:`.GenericForeignKey` fields, and the Django admin. Models with composite
|
|
||||||
primary keys cannot be registered in the Django admin at this time. You can
|
primary keys cannot be registered in the Django admin at this time. You can
|
||||||
expect to see this in future releases.
|
expect to see this in future releases.
|
||||||
|
|
||||||
@ -96,9 +95,8 @@ operation.
|
|||||||
Composite primary keys and relations
|
Composite primary keys and relations
|
||||||
====================================
|
====================================
|
||||||
|
|
||||||
:ref:`Relationship fields <relationship-fields>`, including
|
:ref:`Relationship fields <relationship-fields>` do not support composite
|
||||||
:ref:`generic relations <generic-relations>` do not support composite primary
|
primary keys.
|
||||||
keys.
|
|
||||||
|
|
||||||
For example, given the ``OrderLineItem`` model, the following is not
|
For example, given the ``OrderLineItem`` model, the following is not
|
||||||
supported::
|
supported::
|
||||||
@ -131,6 +129,80 @@ database.
|
|||||||
``ForeignObject`` is an internal API. This means it is not covered by our
|
``ForeignObject`` is an internal API. This means it is not covered by our
|
||||||
:ref:`deprecation policy <internal-release-deprecation-policy>`.
|
:ref:`deprecation policy <internal-release-deprecation-policy>`.
|
||||||
|
|
||||||
|
Generic foreign keys & relations
|
||||||
|
================================
|
||||||
|
|
||||||
|
If a :class:`~django.contrib.contenttypes.fields.GenericForeignKey`\'s
|
||||||
|
"object_id" field is a :class:`~django.db.models.CharField` or a
|
||||||
|
:class:`~django.db.models.TextField`, Django will automatically serialize the
|
||||||
|
related ``CompositePrimaryKey`` into a JSON array::
|
||||||
|
|
||||||
|
class Tag:
|
||||||
|
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
|
||||||
|
object_id = models.CharField(max_length=255)
|
||||||
|
content_object = GenericForeignKey("content_type", "object_id")
|
||||||
|
|
||||||
|
This means ``CompositePrimaryKey`` is backwards-compatible with any
|
||||||
|
``GenericForeignKey`` that's using a ``CharField`` or ``TextField`` "object_id".
|
||||||
|
|
||||||
|
.. code-block:: pycon
|
||||||
|
|
||||||
|
>>> item.pk
|
||||||
|
(2, "B142C")
|
||||||
|
>>> tag = Tag(content_object=item)
|
||||||
|
>>> tag.object_id
|
||||||
|
'[2, "B142C"]'
|
||||||
|
|
||||||
|
.. warning:: pk changes are not supported
|
||||||
|
|
||||||
|
Django doesn't automatically update "object_id" if "content_object.pk"
|
||||||
|
changes. If you change an object's ``pk`` in any way, make sure the related
|
||||||
|
generic foreign keys are also updated.
|
||||||
|
|
||||||
|
:class:`~django.contrib.contenttypes.fields.GenericRelation`\s on models with
|
||||||
|
a ``CompositePrimaryKey`` are also supported. When filtering on a reverse
|
||||||
|
relationship, Django performs ``JOIN``\s using the appropriate backend-specific
|
||||||
|
JSON functions::
|
||||||
|
|
||||||
|
class OrderLineItem(models.Model):
|
||||||
|
...
|
||||||
|
tags = GenericRelation("Tag", related_query_name="items")
|
||||||
|
|
||||||
|
This means ``CompositePrimaryKey``\s are backwards-compatible with
|
||||||
|
``GenericRelation``\s the same way ``GenericForeignKey``\s are.
|
||||||
|
|
||||||
|
.. code-block:: pycon
|
||||||
|
|
||||||
|
>>> item.tags.all() # OK
|
||||||
|
>>> Tag.objects.filter(items__quantity__gte=2) # OK
|
||||||
|
|
||||||
|
.. warning:: JOINs on JSON functions can be slow
|
||||||
|
|
||||||
|
While performing ``JOIN``\s on JSON functions is convenient, it can lead to
|
||||||
|
performance issues. Exercise caution when filtering on the reverse
|
||||||
|
relationship of a ``GenericRelation``.
|
||||||
|
|
||||||
|
.. warning:: MariaDB support
|
||||||
|
|
||||||
|
At the time of writing this, MariaDB's JSON_UNQUOTE function cannot process
|
||||||
|
surrogate pairs (e.g. emojis). If you're planning to use JOINs on JSON
|
||||||
|
functions on MariaDB, make sure the primary key doesn't contain any emojis.
|
||||||
|
|
||||||
|
.. admonition:: Supported fields
|
||||||
|
|
||||||
|
``JOIN``s on JSON functions may not work if the ``CompositePrimaryKey``
|
||||||
|
has unsupported fields.
|
||||||
|
|
||||||
|
The supported fields are:
|
||||||
|
|
||||||
|
- ``SmallIntegerField``
|
||||||
|
- ``IntegerField``
|
||||||
|
- ``BigIntegerField``
|
||||||
|
- ``CharField``
|
||||||
|
- ``UUIDField``
|
||||||
|
- ``DateField``
|
||||||
|
- ``DateTimeField``
|
||||||
|
|
||||||
Composite primary keys and database functions
|
Composite primary keys and database functions
|
||||||
=============================================
|
=============================================
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from .tenant import Comment, Post, Tenant, Token, User
|
from .generic import Dummy
|
||||||
|
from .tenant import Comment, Post, Tag, Tenant, Token, User
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Comment",
|
"Comment",
|
||||||
|
"Dummy",
|
||||||
"Post",
|
"Post",
|
||||||
|
"Tag",
|
||||||
"Tenant",
|
"Tenant",
|
||||||
"Token",
|
"Token",
|
||||||
"User",
|
"User",
|
||||||
|
24
tests/composite_pk/models/generic.py
Normal file
24
tests/composite_pk/models/generic.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from django.contrib.contenttypes.fields import GenericRelation
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
|
class Dummy(models.Model):
|
||||||
|
pk = models.CompositePrimaryKey(
|
||||||
|
"small_integer",
|
||||||
|
"integer",
|
||||||
|
"big_integer",
|
||||||
|
"datetime",
|
||||||
|
"date",
|
||||||
|
"uuid",
|
||||||
|
"char",
|
||||||
|
)
|
||||||
|
|
||||||
|
small_integer = models.SmallIntegerField()
|
||||||
|
integer = models.IntegerField()
|
||||||
|
big_integer = models.BigIntegerField()
|
||||||
|
datetime = models.DateTimeField()
|
||||||
|
date = models.DateField()
|
||||||
|
uuid = models.UUIDField()
|
||||||
|
char = models.CharField(max_length=5)
|
||||||
|
|
||||||
|
tags = GenericRelation("Tag", related_query_name="dummy")
|
@ -1,3 +1,5 @@
|
|||||||
|
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
||||||
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
@ -48,3 +50,13 @@ class Post(models.Model):
|
|||||||
pk = models.CompositePrimaryKey("tenant_id", "id")
|
pk = models.CompositePrimaryKey("tenant_id", "id")
|
||||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
||||||
id = models.UUIDField()
|
id = models.UUIDField()
|
||||||
|
tags = GenericRelation("Tag", related_query_name="post")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(models.Model):
|
||||||
|
name = models.CharField(max_length=10)
|
||||||
|
content_type = models.ForeignKey(
|
||||||
|
ContentType, on_delete=models.CASCADE, related_name="composite_pk_tags"
|
||||||
|
)
|
||||||
|
object_id = models.CharField(max_length=255)
|
||||||
|
content_object = GenericForeignKey("content_type", "object_id")
|
||||||
|
314
tests/composite_pk/test_generic.py
Normal file
314
tests/composite_pk/test_generic.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
|
from unittest import skipIf
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from django.contrib.contenttypes.models import ContentType
|
||||||
|
from django.contrib.contenttypes.prefetch import GenericPrefetch
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.models import Count
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
|
from .models import Comment, Dummy, Post, Tag, Tenant, User
|
||||||
|
|
||||||
|
|
||||||
|
class CompositePKGenericTests(TestCase):
|
||||||
|
POST_1_ID = "e1516ac0-4469-4306-b0ac-2c435e677aa4"
|
||||||
|
DUMMY_1_UUID = "81598d5a-fa03-4800-bae4-823ca12824d5"
|
||||||
|
DUMMY_2_UUID = "ac567f4c-9b6e-4770-86a6-1abf071a2958"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.dummy_1 = Dummy.objects.create(
|
||||||
|
small_integer=32767,
|
||||||
|
integer=2147483647,
|
||||||
|
big_integer=9223372036854775807,
|
||||||
|
datetime=datetime(2024, 11, 30, 6, 26, 1),
|
||||||
|
date=date(2024, 11, 30),
|
||||||
|
uuid=UUID(cls.DUMMY_1_UUID),
|
||||||
|
char="薑戈",
|
||||||
|
)
|
||||||
|
cls.dummy_2 = Dummy.objects.create(
|
||||||
|
small_integer=-32768,
|
||||||
|
integer=-2147483648,
|
||||||
|
big_integer=-9223372036854775808,
|
||||||
|
datetime=datetime(2024, 12, 8, 1, 2, 3),
|
||||||
|
date=date(2024, 12, 8),
|
||||||
|
uuid=UUID(cls.DUMMY_2_UUID),
|
||||||
|
char="😊",
|
||||||
|
)
|
||||||
|
cls.tenant_1 = Tenant.objects.create()
|
||||||
|
cls.tenant_2 = Tenant.objects.create()
|
||||||
|
cls.user_1 = User.objects.create(
|
||||||
|
tenant=cls.tenant_1,
|
||||||
|
id=1,
|
||||||
|
email="user0001@example.com",
|
||||||
|
)
|
||||||
|
cls.user_2 = User.objects.create(
|
||||||
|
tenant=cls.tenant_1,
|
||||||
|
id=2,
|
||||||
|
email="user0002@example.com",
|
||||||
|
)
|
||||||
|
cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
|
||||||
|
cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
|
||||||
|
cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID))
|
||||||
|
cls.comment_1_tag = Tag.objects.create(
|
||||||
|
name="comment_1", content_object=cls.comment_1
|
||||||
|
)
|
||||||
|
cls.post_1_tag = Tag.objects.create(name="post_1", content_object=cls.post_1)
|
||||||
|
cls.dummy_1_tag = Tag.objects.create(name="dummy_1", content_object=cls.dummy_1)
|
||||||
|
cls.dummy_2_tag = Tag.objects.create(name="dummy_2", content_object=cls.dummy_2)
|
||||||
|
cls.comment_ct = ContentType.objects.get_for_model(Comment)
|
||||||
|
cls.post_ct = ContentType.objects.get_for_model(Post)
|
||||||
|
cls.dummy_ct = ContentType.objects.get_for_model(Dummy)
|
||||||
|
cls.post_1_fk = f'[{cls.tenant_1.id}, "{cls.POST_1_ID}"]'
|
||||||
|
cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]"
|
||||||
|
cls.dummy_1_fk = (
|
||||||
|
'[32767, 2147483647, 9223372036854775807, "2024-11-30T06:26:01", '
|
||||||
|
f'"2024-11-30", "{cls.DUMMY_1_UUID}", "\\u8591\\u6208"]'
|
||||||
|
)
|
||||||
|
cls.dummy_2_fk = (
|
||||||
|
'[-32768, -2147483648, -9223372036854775808, "2024-12-08T01:02:03", '
|
||||||
|
f'"2024-12-08", "{cls.DUMMY_2_UUID}", "\\ud83d\\ude0a"]'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fields(self):
|
||||||
|
self.assertEqual(self.comment_1_tag.content_type, self.comment_ct)
|
||||||
|
self.assertEqual(self.comment_1_tag.object_id, self.comment_1_fk)
|
||||||
|
self.assertEqual(self.comment_1_tag.content_object, self.comment_1)
|
||||||
|
|
||||||
|
self.assertEqual(self.post_1_tag.content_type, self.post_ct)
|
||||||
|
self.assertEqual(self.post_1_tag.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(self.post_1_tag.content_object, self.post_1)
|
||||||
|
|
||||||
|
self.assertEqual(self.dummy_1_tag.content_type, self.dummy_ct)
|
||||||
|
self.assertEqual(self.dummy_1_tag.object_id, self.dummy_1_fk)
|
||||||
|
self.assertEqual(self.dummy_1_tag.content_object, self.dummy_1)
|
||||||
|
|
||||||
|
self.assertEqual(self.dummy_2_tag.content_type, self.dummy_ct)
|
||||||
|
self.assertEqual(self.dummy_2_tag.object_id, self.dummy_2_fk)
|
||||||
|
self.assertEqual(self.dummy_2_tag.content_object, self.dummy_2)
|
||||||
|
|
||||||
|
self.assertSequenceEqual(self.post_1.tags.all(), (self.post_1_tag,))
|
||||||
|
self.assertSequenceEqual(self.dummy_1.tags.all(), (self.dummy_1_tag,))
|
||||||
|
self.assertSequenceEqual(self.dummy_2.tags.all(), (self.dummy_2_tag,))
|
||||||
|
|
||||||
|
def test_fields_before_save(self):
|
||||||
|
comment = Comment(pk=(1, 2))
|
||||||
|
tag = Tag(content_object=comment)
|
||||||
|
self.assertEqual(tag.object_id, "[1, 2]")
|
||||||
|
comment.id = 3
|
||||||
|
self.assertEqual(tag.object_id, "[1, 2]")
|
||||||
|
|
||||||
|
def test_cascade_delete_if_generic_relation(self):
|
||||||
|
Post.objects.get(pk=self.post_1.pk).delete()
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||||
|
|
||||||
|
Dummy.objects.get(pk=self.dummy_1.pk).delete()
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||||
|
|
||||||
|
def test_no_cascade_delete_if_no_generic_relation(self):
|
||||||
|
Comment.objects.get(pk=self.comment_1.pk).delete()
|
||||||
|
comment_1_tag = Tag.objects.get(pk=self.comment_1_tag.pk)
|
||||||
|
self.assertIsNone(comment_1_tag.content_object)
|
||||||
|
|
||||||
|
def test_tags_clear(self):
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
post_1.tags.clear()
|
||||||
|
self.assertEqual(post_1.tags.count(), 0)
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||||
|
|
||||||
|
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
|
||||||
|
dummy_1.tags.clear()
|
||||||
|
self.assertEqual(dummy_1.tags.count(), 0)
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||||
|
|
||||||
|
def test_tags_remove(self):
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
post_1.tags.remove(self.post_1_tag)
|
||||||
|
self.assertEqual(post_1.tags.count(), 0)
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||||
|
|
||||||
|
dummy_1 = Dummy.objects.get(pk=self.dummy_1.pk)
|
||||||
|
dummy_1.tags.remove(self.dummy_1_tag)
|
||||||
|
self.assertEqual(dummy_1.tags.count(), 0)
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.dummy_1_tag.pk).exists())
|
||||||
|
|
||||||
|
def test_tags_create(self):
|
||||||
|
tag_count = Tag.objects.count()
|
||||||
|
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
post_1.tags.create(name="foo")
|
||||||
|
self.assertEqual(post_1.tags.count(), 2)
|
||||||
|
self.assertEqual(Tag.objects.count(), tag_count + 1)
|
||||||
|
|
||||||
|
tag = Tag.objects.get(name="foo")
|
||||||
|
self.assertEqual(tag.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag.content_object, post_1)
|
||||||
|
|
||||||
|
def test_tags_add(self):
|
||||||
|
tag_count = Tag.objects.count()
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
|
||||||
|
tag_1 = Tag(name="foo")
|
||||||
|
post_1.tags.add(tag_1, bulk=False)
|
||||||
|
self.assertEqual(post_1.tags.count(), 2)
|
||||||
|
self.assertEqual(Tag.objects.count(), tag_count + 1)
|
||||||
|
|
||||||
|
tag_1 = Tag.objects.get(name="foo")
|
||||||
|
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_1.content_object, post_1)
|
||||||
|
|
||||||
|
tag_2 = Tag.objects.create(name="bar", content_object=self.comment_2)
|
||||||
|
post_1.tags.add(tag_2)
|
||||||
|
self.assertEqual(post_1.tags.count(), 3)
|
||||||
|
self.assertEqual(Tag.objects.count(), tag_count + 2)
|
||||||
|
|
||||||
|
tag_2 = Tag.objects.get(name="bar")
|
||||||
|
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_2.content_object, post_1)
|
||||||
|
|
||||||
|
def test_tags_set(self):
|
||||||
|
tag_count = Tag.objects.count()
|
||||||
|
comment_1_tag = Tag.objects.get(name=self.comment_1_tag.name)
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
post_1.tags.set([comment_1_tag])
|
||||||
|
self.assertEqual(post_1.tags.count(), 1)
|
||||||
|
self.assertEqual(Tag.objects.count(), tag_count - 1)
|
||||||
|
self.assertFalse(Tag.objects.filter(pk=self.post_1_tag.pk).exists())
|
||||||
|
|
||||||
|
def test_tags_get_or_create(self):
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
|
||||||
|
tag_1, created = post_1.tags.get_or_create(name=self.post_1_tag.name)
|
||||||
|
self.assertFalse(created)
|
||||||
|
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
|
||||||
|
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_1.content_object, post_1)
|
||||||
|
|
||||||
|
tag_2, created = post_1.tags.get_or_create(name="foo")
|
||||||
|
self.assertTrue(created)
|
||||||
|
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_2.content_object, post_1)
|
||||||
|
|
||||||
|
def test_tags_update_or_create(self):
|
||||||
|
post_1 = Post.objects.get(pk=self.post_1.pk)
|
||||||
|
|
||||||
|
tag_1, created = post_1.tags.update_or_create(
|
||||||
|
name=self.post_1_tag.name, defaults={"name": "foo"}
|
||||||
|
)
|
||||||
|
self.assertFalse(created)
|
||||||
|
self.assertEqual(tag_1.pk, self.post_1_tag.pk)
|
||||||
|
self.assertEqual(tag_1.name, "foo")
|
||||||
|
self.assertEqual(tag_1.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_1.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_1.content_object, post_1)
|
||||||
|
|
||||||
|
tag_2, created = post_1.tags.update_or_create(name="bar")
|
||||||
|
self.assertTrue(created)
|
||||||
|
self.assertEqual(tag_2.content_type, self.post_ct)
|
||||||
|
self.assertEqual(tag_2.object_id, self.post_1_fk)
|
||||||
|
self.assertEqual(tag_2.content_object, post_1)
|
||||||
|
|
||||||
|
def test_filter_by_related_query_name(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
Tag.objects.filter(post__id=self.post_1.id), (self.post_1_tag,)
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
Tag.objects.filter(dummy__big_integer=self.dummy_1.big_integer),
|
||||||
|
(self.dummy_1_tag,),
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipIf(
|
||||||
|
connection.vendor == "mysql" and connection.mysql_is_mariadb,
|
||||||
|
"MariaDB's JSON_UNQUOTE doesn't support surrogate pairs "
|
||||||
|
"(https://jira.mariadb.org/browse/MDEV-21124)",
|
||||||
|
)
|
||||||
|
def test_aggregate(self):
|
||||||
|
with self.subTest("Post"):
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
self.assertEqual(
|
||||||
|
Post.objects.aggregate(Count("tags")),
|
||||||
|
{"tags__count": 1},
|
||||||
|
ctx[-1]["sql"],
|
||||||
|
)
|
||||||
|
with self.subTest("Dummy"):
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
self.assertEqual(
|
||||||
|
Dummy.objects.aggregate(Count("tags")),
|
||||||
|
{"tags__count": 2},
|
||||||
|
ctx[-1]["sql"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generic_prefetch(self):
|
||||||
|
tags = Tag.objects.prefetch_related(
|
||||||
|
GenericPrefetch(
|
||||||
|
"content_object",
|
||||||
|
[
|
||||||
|
Post.objects.all(),
|
||||||
|
Comment.objects.all(),
|
||||||
|
Dummy.objects.all(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
).order_by("pk")
|
||||||
|
|
||||||
|
self.assertEqual(len(tags), 4)
|
||||||
|
self.assertEqual(tags[0], self.comment_1_tag)
|
||||||
|
self.assertEqual(tags[1], self.post_1_tag)
|
||||||
|
self.assertEqual(tags[2], self.dummy_1_tag)
|
||||||
|
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
self.assertEqual(tags[0].content_object, self.comment_1)
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
self.assertEqual(tags[1].content_object, self.post_1)
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
self.assertEqual(tags[2].content_object, self.dummy_1)
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
self.assertEqual(tags[3].content_object, self.dummy_2)
|
||||||
|
|
||||||
|
def test_to_json(self):
|
||||||
|
field = Post._meta.pk
|
||||||
|
self.assertEqual(
|
||||||
|
field.to_json((1, "004bc28c-a085-44ca-a823-747dcf4ddfd3")),
|
||||||
|
'[1, "004bc28c-a085-44ca-a823-747dcf4ddfd3"]',
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
field.to_json((2, UUID("a5a626e5-d34a-4197-94ea-83547029d6ed"))),
|
||||||
|
'[2, "a5a626e5-d34a-4197-94ea-83547029d6ed"]',
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_to_json_length_mismatch(self):
|
||||||
|
field = Post._meta.pk
|
||||||
|
msg = "CompositePrimaryKey has 2 fields but it tried to serialize %s."
|
||||||
|
with self.assertRaisesMessage(ValueError, msg % 1):
|
||||||
|
self.assertIsNone(field.to_json((1,)))
|
||||||
|
with self.assertRaisesMessage(ValueError, msg % 3):
|
||||||
|
self.assertIsNone(field.to_json((1, 2, 3)))
|
||||||
|
|
||||||
|
def test_from_json(self):
|
||||||
|
field = Post._meta.pk
|
||||||
|
self.assertEqual(
|
||||||
|
field.from_json('[1, "cb6b99bc-66b0-497f-bafe-193b90af6296"]'),
|
||||||
|
(1, UUID("cb6b99bc-66b0-497f-bafe-193b90af6296")),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
field.from_json('[2, "1d5c63dda5264a12a2ec929d51e04430"]'),
|
||||||
|
(2, UUID("1d5c63dd-a526-4a12-a2ec-929d51e04430")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_json_length_mismatch(self):
|
||||||
|
field = Post._meta.pk
|
||||||
|
msg = (
|
||||||
|
"CompositePrimaryKey has 2 fields but it tried to deserialize %s. "
|
||||||
|
"Did you change the CompositePrimaryKey fields and forgot to "
|
||||||
|
'update the related GenericForeignKey "object_id" fields?'
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(ValueError, msg % 1):
|
||||||
|
self.assertIsNone(field.from_json("[1]"))
|
||||||
|
with self.assertRaisesMessage(ValueError, msg % 3):
|
||||||
|
self.assertIsNone(field.from_json("[1, 2, 3]"))
|
@ -92,6 +92,7 @@ class CompositePKGetTests(TestCase):
|
|||||||
|
|
||||||
def test_lookup_errors(self):
|
def test_lookup_errors(self):
|
||||||
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
|
m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
|
||||||
|
m_iterable = "'%s' lookup of 'pk' must be an iterable"
|
||||||
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
|
m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
|
||||||
m_tuple_collection = (
|
m_tuple_collection = (
|
||||||
"'in' lookup of 'pk' must be a collection of tuples or lists"
|
"'in' lookup of 'pk' must be a collection of tuples or lists"
|
||||||
@ -102,7 +103,7 @@ class CompositePKGetTests(TestCase):
|
|||||||
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
|
({"pk": (1, 2, 3)}, m_2_elements % "exact"),
|
||||||
({"pk__exact": 1}, m_tuple % "exact"),
|
({"pk__exact": 1}, m_tuple % "exact"),
|
||||||
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
|
({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
|
||||||
({"pk__in": 1}, m_tuple % "in"),
|
({"pk__in": 1}, m_iterable % "in"),
|
||||||
({"pk__in": (1, 2, 3)}, m_tuple_collection),
|
({"pk__in": (1, 2, 3)}, m_tuple_collection),
|
||||||
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
|
({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
|
||||||
({"pk__gt": 1}, m_tuple % "gt"),
|
({"pk__gt": 1}, m_tuple % "gt"),
|
||||||
|
@ -474,7 +474,6 @@ class TupleLookupsTests(TestCase):
|
|||||||
TupleGreaterThanOrEqual,
|
TupleGreaterThanOrEqual,
|
||||||
TupleLessThan,
|
TupleLessThan,
|
||||||
TupleLessThanOrEqual,
|
TupleLessThanOrEqual,
|
||||||
TupleIn,
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
@ -496,6 +495,24 @@ class TupleLookupsTests(TestCase):
|
|||||||
):
|
):
|
||||||
lookup_cls((F("customer_code"), F("company_code")), rhs)
|
lookup_cls((F("customer_code"), F("company_code")), rhs)
|
||||||
|
|
||||||
|
def test_tuple_in_lookup_rhs_must_be_iterable(self):
|
||||||
|
test_cases = (
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for rhs in test_cases:
|
||||||
|
with self.subTest(lookup_name="in", rhs=rhs):
|
||||||
|
with self.assertRaisesMessage(
|
||||||
|
ValueError,
|
||||||
|
"'in' lookup of ('customer_code', 'company_code') "
|
||||||
|
"must be an iterable",
|
||||||
|
):
|
||||||
|
TupleIn((F("customer_code"), F("company_code")), rhs)
|
||||||
|
|
||||||
def test_tuple_lookup_rhs_must_have_2_elements(self):
|
def test_tuple_lookup_rhs_must_have_2_elements(self):
|
||||||
test_cases = itertools.product(
|
test_cases = itertools.product(
|
||||||
(
|
(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user