mirror of
https://github.com/django/django.git
synced 2025-10-31 09:41:08 +00:00
Fixed #373 -- Added CompositePrimaryKey.
Thanks Lily Foote and Simon Charette for reviews and mentoring this Google Summer of Code 2024 project. Co-authored-by: Simon Charette <charette.s@gmail.com> Co-authored-by: Lily Foote <code@lilyf.org>
This commit is contained in:
committed by
Sarah Boyce
parent
86661f2449
commit
978aae4334
@@ -113,6 +113,11 @@ class AdminSite:
|
||||
"The model %s is abstract, so it cannot be registered with admin."
|
||||
% model.__name__
|
||||
)
|
||||
if model._meta.is_composite_pk:
|
||||
raise ImproperlyConfigured(
|
||||
"The model %s has a composite primary key, so it cannot be "
|
||||
"registered with admin." % model.__name__
|
||||
)
|
||||
|
||||
if self.is_registered(model):
|
||||
registered_admin = str(self.get_model_admin(model))
|
||||
|
||||
@@ -7,6 +7,7 @@ other serializers.
|
||||
from django.apps import apps
|
||||
from django.core.serializers import base
|
||||
from django.db import DEFAULT_DB_ALIAS, models
|
||||
from django.db.models import CompositePrimaryKey
|
||||
from django.utils.encoding import is_protected_type
|
||||
|
||||
|
||||
@@ -39,6 +40,8 @@ class Serializer(base.Serializer):
|
||||
return data
|
||||
|
||||
def _value_from_field(self, obj, field):
|
||||
if isinstance(field, CompositePrimaryKey):
|
||||
return [self._value_from_field(obj, f) for f in field]
|
||||
value = field.value_from_object(obj)
|
||||
# Protected types (i.e., primitives like None, numbers, dates,
|
||||
# and Decimals) are passed through as is. All other values are
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.db.backends.ddl_references import (
|
||||
)
|
||||
from django.db.backends.utils import names_digest, split_identifier, truncate_name
|
||||
from django.db.models import NOT_PROVIDED, Deferrable, Index
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
from django.db.models.sql import Query
|
||||
from django.db.transaction import TransactionManagementError, atomic
|
||||
from django.utils import timezone
|
||||
@@ -106,6 +107,7 @@ class BaseDatabaseSchemaEditor:
|
||||
sql_check_constraint = "CHECK (%(check)s)"
|
||||
sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
|
||||
sql_pk_constraint = "PRIMARY KEY (%(columns)s)"
|
||||
|
||||
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
|
||||
sql_delete_check = sql_delete_constraint
|
||||
@@ -282,6 +284,11 @@ class BaseDatabaseSchemaEditor:
|
||||
constraint.constraint_sql(model, self)
|
||||
for constraint in model._meta.constraints
|
||||
)
|
||||
|
||||
pk = model._meta.pk
|
||||
if isinstance(pk, CompositePrimaryKey):
|
||||
constraint_sqls.append(self._pk_constraint_sql(pk.columns))
|
||||
|
||||
sql = self.sql_create_table % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"definition": ", ".join(
|
||||
@@ -1999,6 +2006,11 @@ class BaseDatabaseSchemaEditor:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
def _pk_constraint_sql(self, columns):
|
||||
return self.sql_pk_constraint % {
|
||||
"columns": ", ".join(self.quote_name(column) for column in columns)
|
||||
}
|
||||
|
||||
def _delete_primary_key(self, model, strict=False):
|
||||
constraint_names = self._constraint_names(model, primary_key=True)
|
||||
if strict and len(constraint_names) != 1:
|
||||
|
||||
@@ -211,6 +211,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
return create_index
|
||||
|
||||
def _is_identity_column(self, table_name, column_name):
|
||||
if not column_name:
|
||||
return False
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.db import NotSupportedError
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import NOT_PROVIDED, UniqueConstraint
|
||||
from django.db.models import NOT_PROVIDED, CompositePrimaryKey, UniqueConstraint
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
@@ -104,6 +104,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
|
||||
# Since CompositePrimaryKey is not a concrete field (column is None),
|
||||
# it's not copied by default.
|
||||
pk = model._meta.pk
|
||||
if isinstance(pk, CompositePrimaryKey):
|
||||
body[pk.name] = pk.clone()
|
||||
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {
|
||||
@@ -296,6 +303,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
# Special-case implicit M2M tables.
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
self.create_model(field.remote_field.through)
|
||||
elif isinstance(field, CompositePrimaryKey):
|
||||
# If a CompositePrimaryKey field was added, the existing primary key field
|
||||
# had to be altered too, resulting in an AddField, AlterField migration.
|
||||
# The table cannot be re-created on AddField, it would result in a
|
||||
# duplicate primary key error.
|
||||
return
|
||||
elif (
|
||||
# Primary keys and unique fields are not supported in ALTER TABLE
|
||||
# ADD COLUMN.
|
||||
|
||||
@@ -38,6 +38,7 @@ from django.db.models.expressions import (
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
from django.db.models.fields.files import FileField, ImageField
|
||||
from django.db.models.fields.generated import GeneratedField
|
||||
from django.db.models.fields.json import JSONField
|
||||
@@ -82,6 +83,7 @@ __all__ += [
|
||||
"ProtectedError",
|
||||
"RestrictedError",
|
||||
"Case",
|
||||
"CompositePrimaryKey",
|
||||
"Exists",
|
||||
"Expression",
|
||||
"ExpressionList",
|
||||
|
||||
@@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, Func, Star, Value, When
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
@@ -174,6 +175,22 @@ class Count(Aggregate):
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
result = super().resolve_expression(*args, **kwargs)
|
||||
expr = result.source_expressions[0]
|
||||
|
||||
# In case of composite primary keys, count the first column.
|
||||
if isinstance(expr, ColPairs):
|
||||
if self.distinct:
|
||||
raise NotSupportedError(
|
||||
"COUNT(DISTINCT) doesn't support composite primary keys"
|
||||
)
|
||||
|
||||
cols = expr.get_cols()
|
||||
return Count(cols[0], filter=result.filter)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod
|
||||
from itertools import chain
|
||||
|
||||
@@ -30,6 +31,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max,
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.deletion import CASCADE, Collector
|
||||
from django.db.models.expressions import DatabaseDefault
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
from django.db.models.fields.related import (
|
||||
ForeignObjectRel,
|
||||
OneToOneField,
|
||||
@@ -508,7 +510,7 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
for field in fields_iter:
|
||||
is_related_object = False
|
||||
# Virtual field
|
||||
if field.attname not in kwargs and field.column is None or field.generated:
|
||||
if field.column is None or field.generated:
|
||||
continue
|
||||
if kwargs:
|
||||
if isinstance(field.remote_field, ForeignObjectRel):
|
||||
@@ -663,7 +665,11 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
pk = property(_get_pk_val, _set_pk_val)
|
||||
|
||||
def _is_pk_set(self, meta=None):
|
||||
return self._get_pk_val(meta) is not None
|
||||
pk_val = self._get_pk_val(meta)
|
||||
return not (
|
||||
pk_val is None
|
||||
or (isinstance(pk_val, tuple) and any(f is None for f in pk_val))
|
||||
)
|
||||
|
||||
def get_deferred_fields(self):
|
||||
"""
|
||||
@@ -1454,6 +1460,11 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
name = f.name
|
||||
if name in exclude:
|
||||
continue
|
||||
if isinstance(f, CompositePrimaryKey):
|
||||
names = tuple(field.name for field in f.fields)
|
||||
if exclude.isdisjoint(names):
|
||||
unique_checks.append((model_class, names))
|
||||
continue
|
||||
if f.unique:
|
||||
unique_checks.append((model_class, (name,)))
|
||||
if f.unique_for_date and f.unique_for_date not in exclude:
|
||||
@@ -1728,6 +1739,7 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
*cls._check_constraints(databases),
|
||||
*cls._check_default_pk(),
|
||||
*cls._check_db_table_comment(databases),
|
||||
*cls._check_composite_pk(),
|
||||
]
|
||||
|
||||
return errors
|
||||
@@ -1764,6 +1776,63 @@ class Model(AltersData, metaclass=ModelBase):
|
||||
]
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _check_composite_pk(cls):
|
||||
errors = []
|
||||
meta = cls._meta
|
||||
pk = meta.pk
|
||||
|
||||
if not isinstance(pk, CompositePrimaryKey):
|
||||
return errors
|
||||
|
||||
seen_columns = defaultdict(list)
|
||||
|
||||
for field_name in pk.field_names:
|
||||
hint = None
|
||||
|
||||
try:
|
||||
field = meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
field = None
|
||||
|
||||
if not field:
|
||||
hint = f"{field_name!r} is not a valid field."
|
||||
elif not field.column:
|
||||
hint = f"{field_name!r} field has no column."
|
||||
elif field.null:
|
||||
hint = f"{field_name!r} field may not set 'null=True'."
|
||||
elif field.generated:
|
||||
hint = f"{field_name!r} field is a generated field."
|
||||
else:
|
||||
seen_columns[field.column].append(field_name)
|
||||
|
||||
if hint:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{field_name!r} cannot be included in the composite primary "
|
||||
"key.",
|
||||
hint=hint,
|
||||
obj=cls,
|
||||
id="models.E042",
|
||||
)
|
||||
)
|
||||
|
||||
for column, field_names in seen_columns.items():
|
||||
if len(field_names) > 1:
|
||||
field_name, *rest = field_names
|
||||
duplicates = ", ".join(repr(field) for field in rest)
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{duplicates} cannot be included in the composite primary "
|
||||
"key.",
|
||||
hint=f"{duplicates} and {field_name!r} are the same fields.",
|
||||
obj=cls,
|
||||
id="models.E042",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def _check_db_table_comment(cls, databases):
|
||||
if not cls._meta.db_table_comment:
|
||||
|
||||
@@ -656,6 +656,8 @@ class Field(RegisterLookupMixin):
|
||||
path = path.replace("django.db.models.fields.json", "django.db.models")
|
||||
elif path.startswith("django.db.models.fields.proxy"):
|
||||
path = path.replace("django.db.models.fields.proxy", "django.db.models")
|
||||
elif path.startswith("django.db.models.fields.composite"):
|
||||
path = path.replace("django.db.models.fields.composite", "django.db.models")
|
||||
elif path.startswith("django.db.models.fields"):
|
||||
path = path.replace("django.db.models.fields", "django.db.models")
|
||||
# Return basic info - other fields should override this.
|
||||
|
||||
150
django/db/models/fields/composite.py
Normal file
150
django/db/models/fields/composite.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from django.core import checks
|
||||
from django.db.models import NOT_PROVIDED, Field
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields.tuple_lookups import (
|
||||
TupleExact,
|
||||
TupleGreaterThan,
|
||||
TupleGreaterThanOrEqual,
|
||||
TupleIn,
|
||||
TupleIsNull,
|
||||
TupleLessThan,
|
||||
TupleLessThanOrEqual,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class CompositeAttribute:
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
@property
|
||||
def attnames(self):
|
||||
return [field.attname for field in self.field.fields]
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
return tuple(getattr(instance, attname) for attname in self.attnames)
|
||||
|
||||
def __set__(self, instance, values):
|
||||
attnames = self.attnames
|
||||
length = len(attnames)
|
||||
|
||||
if values is None:
|
||||
values = (None,) * length
|
||||
|
||||
if not isinstance(values, (list, tuple)):
|
||||
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
|
||||
if length != len(values):
|
||||
raise ValueError(f"{self.field.name!r} must have {length} elements.")
|
||||
|
||||
for attname, value in zip(attnames, values):
|
||||
setattr(instance, attname, value)
|
||||
|
||||
|
||||
class CompositePrimaryKey(Field):
|
||||
descriptor_class = CompositeAttribute
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if (
|
||||
not args
|
||||
or not all(isinstance(field, str) for field in args)
|
||||
or len(set(args)) != len(args)
|
||||
):
|
||||
raise ValueError("CompositePrimaryKey args must be unique strings.")
|
||||
if len(args) == 1:
|
||||
raise ValueError("CompositePrimaryKey must include at least two fields.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a database default.")
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("CompositePrimaryKey cannot be editable.")
|
||||
if not kwargs.setdefault("primary_key", True):
|
||||
raise ValueError("CompositePrimaryKey must be a primary key.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("CompositePrimaryKey must be blank.")
|
||||
|
||||
self.field_names = args
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
# args is always [] so it can be ignored.
|
||||
name, path, _, kwargs = super().deconstruct()
|
||||
return name, path, self.field_names, kwargs
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
meta = self.model._meta
|
||||
return tuple(meta.get_field(field_name) for field_name in self.field_names)
|
||||
|
||||
@cached_property
|
||||
def columns(self):
|
||||
return tuple(field.column for field in self.fields)
|
||||
|
||||
def contribute_to_class(self, cls, name, private_only=False):
|
||||
super().contribute_to_class(cls, name, private_only=private_only)
|
||||
cls._meta.pk = self
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def get_attname_column(self):
|
||||
return self.get_attname(), None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fields)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.field_names)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias == self.model._meta.db_table and (
|
||||
output_field is None or output_field == self
|
||||
):
|
||||
return self.cached_col
|
||||
|
||||
return ColPairs(alias, self.fields, self.fields, output_field)
|
||||
|
||||
def get_pk_value_on_save(self, instance):
|
||||
values = []
|
||||
|
||||
for field in self.fields:
|
||||
value = field.value_from_object(instance)
|
||||
if value is None:
|
||||
value = field.get_pk_value_on_save(instance)
|
||||
values.append(value)
|
||||
|
||||
return tuple(values)
|
||||
|
||||
def _check_field_name(self):
|
||||
if self.name == "pk":
|
||||
return []
|
||||
return [
|
||||
checks.Error(
|
||||
"'CompositePrimaryKey' must be named 'pk'.",
|
||||
obj=self,
|
||||
id="fields.E013",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
CompositePrimaryKey.register_lookup(TupleExact)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThan)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleIn)
|
||||
CompositePrimaryKey.register_lookup(TupleIsNull)
|
||||
|
||||
|
||||
def unnest(fields):
|
||||
result = []
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, CompositePrimaryKey):
|
||||
result.extend(field.fields)
|
||||
else:
|
||||
result.append(field)
|
||||
|
||||
return result
|
||||
@@ -624,11 +624,21 @@ class ForeignObject(RelatedField):
|
||||
if not has_unique_constraint:
|
||||
foreign_fields = {f.name for f in self.foreign_related_fields}
|
||||
remote_opts = self.remote_field.model._meta
|
||||
has_unique_constraint = any(
|
||||
frozenset(ut) <= foreign_fields for ut in remote_opts.unique_together
|
||||
) or any(
|
||||
frozenset(uc.fields) <= foreign_fields
|
||||
for uc in remote_opts.total_unique_constraints
|
||||
has_unique_constraint = (
|
||||
any(
|
||||
frozenset(ut) <= foreign_fields
|
||||
for ut in remote_opts.unique_together
|
||||
)
|
||||
or any(
|
||||
frozenset(uc.fields) <= foreign_fields
|
||||
for uc in remote_opts.total_unique_constraints
|
||||
)
|
||||
# If the model defines a composite primary key and the foreign key
|
||||
# refers to it, the target is unique.
|
||||
or (
|
||||
frozenset(field.name for field in remote_opts.pk_fields)
|
||||
== foreign_fields
|
||||
)
|
||||
)
|
||||
|
||||
if not has_unique_constraint:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
@@ -19,7 +20,7 @@ def get_normalized_value(value, lhs):
|
||||
if not value._is_pk_set():
|
||||
raise ValueError("Model instances passed to related filters must be saved.")
|
||||
value_list = []
|
||||
sources = lhs.output_field.path_infos[-1].target_fields
|
||||
sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
@@ -30,7 +31,8 @@ def get_normalized_value(value, lhs):
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
return (value.pk,)
|
||||
pk = value.pk
|
||||
return pk if isinstance(pk, tuple) else (pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
|
||||
@@ -250,6 +250,8 @@ class TupleIn(TupleLookupMixin, In):
|
||||
|
||||
def check_rhs_select_length_equals_lhs_length(self):
|
||||
len_rhs = len(self.rhs.select)
|
||||
if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
|
||||
len_rhs = len(self.rhs.select[0])
|
||||
len_lhs = len(self.lhs)
|
||||
if len_rhs != len_lhs:
|
||||
lhs_str = self.get_lhs_str()
|
||||
@@ -304,7 +306,13 @@ class TupleIn(TupleLookupMixin, In):
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
def as_subquery(self, compiler, connection):
|
||||
return compiler.compile(In(self.lhs, self.rhs))
|
||||
lhs = self.lhs
|
||||
rhs = self.rhs
|
||||
if isinstance(lhs, ColPairs):
|
||||
rhs = rhs.clone()
|
||||
rhs.set_values([source.name for source in lhs.sources])
|
||||
lhs = Tuple(lhs)
|
||||
return compiler.compile(In(lhs, rhs))
|
||||
|
||||
|
||||
tuple_lookups = {
|
||||
|
||||
@@ -7,7 +7,14 @@ from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||
from django.core.signals import setting_changed
|
||||
from django.db import connections
|
||||
from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint
|
||||
from django.db.models import (
|
||||
AutoField,
|
||||
CompositePrimaryKey,
|
||||
Manager,
|
||||
OrderWrt,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.query_utils import PathInfo
|
||||
from django.utils.datastructures import ImmutableList, OrderedSet
|
||||
from django.utils.functional import cached_property
|
||||
@@ -973,6 +980,14 @@ class Options:
|
||||
)
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def pk_fields(self):
|
||||
return composite.unnest([self.pk])
|
||||
|
||||
@property
|
||||
def is_composite_pk(self):
|
||||
return isinstance(self.pk, CompositePrimaryKey)
|
||||
|
||||
@cached_property
|
||||
def _property_names(self):
|
||||
"""Return a set of the names of the properties defined on the model."""
|
||||
|
||||
@@ -171,11 +171,14 @@ class RawModelIterable(BaseIterable):
|
||||
"Raw query must include the primary key"
|
||||
)
|
||||
fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
|
||||
converters = compiler.get_converters(
|
||||
[f.get_col(f.model._meta.db_table) if f else None for f in fields]
|
||||
)
|
||||
cols = [f.get_col(f.model._meta.db_table) if f else None for f in fields]
|
||||
converters = compiler.get_converters(cols)
|
||||
if converters:
|
||||
query_iterator = compiler.apply_converters(query_iterator, converters)
|
||||
if compiler.has_composite_fields(cols):
|
||||
query_iterator = compiler.composite_fields_to_tuples(
|
||||
query_iterator, cols
|
||||
)
|
||||
for values in query_iterator:
|
||||
# Associate fields to values
|
||||
model_init_values = [values[pos] for pos in model_init_pos]
|
||||
|
||||
@@ -7,7 +7,9 @@ from itertools import chain
|
||||
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
|
||||
from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
from django.db.models.functions import Cast, Random
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.query_utils import select_related_descend
|
||||
@@ -283,6 +285,9 @@ class SQLCompiler:
|
||||
# Reference to a column.
|
||||
elif isinstance(expression, int):
|
||||
expression = cols[expression]
|
||||
# ColPairs cannot be aliased.
|
||||
if isinstance(expression, ColPairs):
|
||||
alias = None
|
||||
selected.append((alias, expression))
|
||||
|
||||
for select_idx, (alias, expression) in enumerate(selected):
|
||||
@@ -997,6 +1002,7 @@ class SQLCompiler:
|
||||
# alias for a given field. This also includes None -> start_alias to
|
||||
# be used by local fields.
|
||||
seen_models = {None: start_alias}
|
||||
select_mask_fields = set(composite.unnest(select_mask))
|
||||
|
||||
for field in opts.concrete_fields:
|
||||
model = field.model._meta.concrete_model
|
||||
@@ -1017,7 +1023,7 @@ class SQLCompiler:
|
||||
# parent model data is already present in the SELECT clause,
|
||||
# and we want to avoid reloading the same data again.
|
||||
continue
|
||||
if select_mask and field not in select_mask:
|
||||
if select_mask and field not in select_mask_fields:
|
||||
continue
|
||||
alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
|
||||
column = field.get_col(alias)
|
||||
@@ -1110,9 +1116,10 @@ class SQLCompiler:
|
||||
)
|
||||
return results
|
||||
targets, alias, _ = self.query.trim_joins(targets, joins, path)
|
||||
target_fields = composite.unnest(targets)
|
||||
return [
|
||||
(OrderBy(transform_function(t, alias), descending=descending), False)
|
||||
for t in targets
|
||||
for t in target_fields
|
||||
]
|
||||
|
||||
def _setup_joins(self, pieces, opts, alias):
|
||||
@@ -1504,13 +1511,25 @@ class SQLCompiler:
|
||||
return result
|
||||
|
||||
def get_converters(self, expressions):
|
||||
i = 0
|
||||
converters = {}
|
||||
for i, expression in enumerate(expressions):
|
||||
if expression:
|
||||
|
||||
for expression in expressions:
|
||||
if isinstance(expression, ColPairs):
|
||||
cols = expression.get_source_expressions()
|
||||
cols_converters = self.get_converters(cols)
|
||||
for j, (convs, col) in cols_converters.items():
|
||||
converters[i + j] = (convs, col)
|
||||
i += len(expression)
|
||||
elif expression:
|
||||
backend_converters = self.connection.ops.get_db_converters(expression)
|
||||
field_converters = expression.get_db_converters(self.connection)
|
||||
if backend_converters or field_converters:
|
||||
converters[i] = (backend_converters + field_converters, expression)
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return converters
|
||||
|
||||
def apply_converters(self, rows, converters):
|
||||
@@ -1524,6 +1543,24 @@ class SQLCompiler:
|
||||
row[pos] = value
|
||||
yield row
|
||||
|
||||
def has_composite_fields(self, expressions):
|
||||
# Check for composite fields before calling the relatively costly
|
||||
# composite_fields_to_tuples.
|
||||
return any(isinstance(expression, ColPairs) for expression in expressions)
|
||||
|
||||
def composite_fields_to_tuples(self, rows, expressions):
|
||||
col_pair_slices = [
|
||||
slice(i, i + len(expression))
|
||||
for i, expression in enumerate(expressions)
|
||||
if isinstance(expression, ColPairs)
|
||||
]
|
||||
|
||||
for row in map(list, rows):
|
||||
for pos in col_pair_slices:
|
||||
row[pos] = (tuple(row[pos]),)
|
||||
|
||||
yield row
|
||||
|
||||
def results_iter(
|
||||
self,
|
||||
results=None,
|
||||
@@ -1541,8 +1578,10 @@ class SQLCompiler:
|
||||
rows = chain.from_iterable(results)
|
||||
if converters:
|
||||
rows = self.apply_converters(rows, converters)
|
||||
if tuple_expected:
|
||||
rows = map(tuple, rows)
|
||||
if self.has_composite_fields(fields):
|
||||
rows = self.composite_fields_to_tuples(rows, fields)
|
||||
if tuple_expected:
|
||||
rows = map(tuple, rows)
|
||||
return rows
|
||||
|
||||
def has_results(self):
|
||||
@@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler):
|
||||
)
|
||||
]
|
||||
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
|
||||
elif isinstance(opts.pk, CompositePrimaryKey):
|
||||
returning_field = returning_fields[0]
|
||||
cols = [returning_field.get_col(opts.db_table)]
|
||||
rows = [
|
||||
(
|
||||
self.connection.ops.last_insert_id(
|
||||
cursor,
|
||||
opts.db_table,
|
||||
returning_field.column,
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
cols = [opts.pk.get_col(opts.db_table)]
|
||||
rows = [
|
||||
@@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler):
|
||||
]
|
||||
converters = self.get_converters(cols)
|
||||
if converters:
|
||||
rows = list(self.apply_converters(rows, converters))
|
||||
return rows
|
||||
rows = self.apply_converters(rows, converters)
|
||||
if self.has_composite_fields(cols):
|
||||
rows = self.composite_fields_to_tuples(rows, cols)
|
||||
return list(rows)
|
||||
|
||||
|
||||
class SQLDeleteCompiler(SQLCompiler):
|
||||
@@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
query.add_fields(fields)
|
||||
super().pre_sql_setup()
|
||||
|
||||
is_composite_pk = meta.is_composite_pk
|
||||
must_pre_select = (
|
||||
count > 1 and not self.connection.features.update_can_self_select
|
||||
)
|
||||
@@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
idents = []
|
||||
related_ids = collections.defaultdict(list)
|
||||
for rows in query.get_compiler(self.using).execute_sql(MULTI):
|
||||
idents.extend(r[0] for r in rows)
|
||||
pks = [row if is_composite_pk else row[0] for row in rows]
|
||||
idents.extend(pks)
|
||||
for parent, index in related_ids_index:
|
||||
related_ids[parent].extend(r[index] for r in rows)
|
||||
self.query.add_filter("pk__in", idents)
|
||||
|
||||
@@ -627,8 +627,12 @@ class Query(BaseExpression):
|
||||
if result is None:
|
||||
result = empty_set_result
|
||||
else:
|
||||
converters = compiler.get_converters(outer_query.annotation_select.values())
|
||||
result = next(compiler.apply_converters((result,), converters))
|
||||
cols = outer_query.annotation_select.values()
|
||||
converters = compiler.get_converters(cols)
|
||||
rows = compiler.apply_converters((result,), converters)
|
||||
if compiler.has_composite_fields(cols):
|
||||
rows = compiler.composite_fields_to_tuples(rows, cols)
|
||||
result = next(rows)
|
||||
|
||||
return dict(zip(outer_query.annotation_select, result))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user