1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #373 -- Added Model._is_pk_set() abstraction to check if a Model's PK is set.

This commit is contained in:
Csirmaz Bendegúz 2024-09-10 04:46:50 +08:00 committed by GitHub
parent cdbd31960e
commit 5865ff5adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 67 additions and 27 deletions

View File

@ -31,7 +31,7 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
+ self.ct_fk_field.name + self.ct_fk_field.name
) )
self.save_as_new = save_as_new self.save_as_new = save_as_new
if self.instance is None or self.instance.pk is None: if self.instance is None or not self.instance._is_pk_set():
qs = self.model._default_manager.none() qs = self.model._default_manager.none()
else: else:
if queryset is None: if queryset is None:

View File

@ -198,7 +198,7 @@ class ExclusionConstraint(BaseConstraint):
lookups.append(lookup) lookups.append(lookup)
queryset = queryset.filter(*lookups) queryset = queryset.filter(*lookups)
model_class_pk = instance._get_pk_val(model._meta) model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None: if not instance._state.adding and instance._is_pk_set(model._meta):
queryset = queryset.exclude(pk=model_class_pk) queryset = queryset.exclude(pk=model_class_pk)
if not self.condition: if not self.condition:
if queryset.exists(): if queryset.exists():

View File

@ -601,7 +601,7 @@ class Model(AltersData, metaclass=ModelBase):
return my_pk == other.pk return my_pk == other.pk
def __hash__(self): def __hash__(self):
if self.pk is None: if not self._is_pk_set():
raise TypeError("Model instances without primary key value are unhashable") raise TypeError("Model instances without primary key value are unhashable")
return hash(self.pk) return hash(self.pk)
@ -662,6 +662,9 @@ class Model(AltersData, metaclass=ModelBase):
pk = property(_get_pk_val, _set_pk_val) pk = property(_get_pk_val, _set_pk_val)
def _is_pk_set(self, meta=None):
return self._get_pk_val(meta) is not None
def get_deferred_fields(self): def get_deferred_fields(self):
""" """
Return a set containing names of deferred fields for this instance. Return a set containing names of deferred fields for this instance.
@ -1094,11 +1097,10 @@ class Model(AltersData, metaclass=ModelBase):
if f.name in update_fields or f.attname in update_fields if f.name in update_fields or f.attname in update_fields
] ]
pk_val = self._get_pk_val(meta) if not self._is_pk_set(meta):
if pk_val is None:
pk_val = meta.pk.get_pk_value_on_save(self) pk_val = meta.pk.get_pk_value_on_save(self)
setattr(self, meta.pk.attname, pk_val) setattr(self, meta.pk.attname, pk_val)
pk_set = pk_val is not None pk_set = self._is_pk_set(meta)
if not pk_set and (force_update or update_fields): if not pk_set and (force_update or update_fields):
raise ValueError("Cannot force an update in save() with no primary key.") raise ValueError("Cannot force an update in save() with no primary key.")
updated = False updated = False
@ -1126,6 +1128,7 @@ class Model(AltersData, metaclass=ModelBase):
for f in non_pks_non_generated for f in non_pks_non_generated
] ]
forced_update = update_fields or force_update forced_update = update_fields or force_update
pk_val = self._get_pk_val(meta)
updated = self._do_update( updated = self._do_update(
base_qs, using, pk_val, values, update_fields, forced_update base_qs, using, pk_val, values, update_fields, forced_update
) )
@ -1226,7 +1229,7 @@ class Model(AltersData, metaclass=ModelBase):
# database to raise an IntegrityError if applicable. If # database to raise an IntegrityError if applicable. If
# constraints aren't supported by the database, there's the # constraints aren't supported by the database, there's the
# unavoidable risk of data corruption. # unavoidable risk of data corruption.
if obj.pk is None: if not obj._is_pk_set():
# Remove the object from a related instance cache. # Remove the object from a related instance cache.
if not field.remote_field.multiple: if not field.remote_field.multiple:
field.remote_field.delete_cached_value(obj) field.remote_field.delete_cached_value(obj)
@ -1254,14 +1257,14 @@ class Model(AltersData, metaclass=ModelBase):
and hasattr(field, "fk_field") and hasattr(field, "fk_field")
): ):
obj = field.get_cached_value(self, default=None) obj = field.get_cached_value(self, default=None)
if obj and obj.pk is None: if obj and not obj._is_pk_set():
raise ValueError( raise ValueError(
f"{operation_name}() prohibited to prevent data loss due to " f"{operation_name}() prohibited to prevent data loss due to "
f"unsaved related object '{field.name}'." f"unsaved related object '{field.name}'."
) )
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
if self.pk is None: if not self._is_pk_set():
raise ValueError( raise ValueError(
"%s object can't be deleted because its %s attribute is set " "%s object can't be deleted because its %s attribute is set "
"to None." % (self._meta.object_name, self._meta.pk.attname) "to None." % (self._meta.object_name, self._meta.pk.attname)
@ -1367,7 +1370,7 @@ class Model(AltersData, metaclass=ModelBase):
return field_map return field_map
def prepare_database_save(self, field): def prepare_database_save(self, field):
if self.pk is None: if not self._is_pk_set():
raise ValueError( raise ValueError(
"Unsaved model instance %r cannot be used in an ORM query." % self "Unsaved model instance %r cannot be used in an ORM query." % self
) )
@ -1497,7 +1500,7 @@ class Model(AltersData, metaclass=ModelBase):
# allows single model to have effectively multiple primary keys. # allows single model to have effectively multiple primary keys.
# Refs #17615. # Refs #17615.
model_class_pk = self._get_pk_val(model_class._meta) model_class_pk = self._get_pk_val(model_class._meta)
if not self._state.adding and model_class_pk is not None: if not self._state.adding and self._is_pk_set(model_class._meta):
qs = qs.exclude(pk=model_class_pk) qs = qs.exclude(pk=model_class_pk)
if qs.exists(): if qs.exists():
if len(unique_check) == 1: if len(unique_check) == 1:
@ -1532,7 +1535,7 @@ class Model(AltersData, metaclass=ModelBase):
qs = model_class._default_manager.filter(**lookup_kwargs) qs = model_class._default_manager.filter(**lookup_kwargs)
# Exclude the current object from the query if we are editing an # Exclude the current object from the query if we are editing an
# instance (as opposed to creating a new one) # instance (as opposed to creating a new one)
if not self._state.adding and self.pk is not None: if not self._state.adding and self._is_pk_set():
qs = qs.exclude(pk=self.pk) qs = qs.exclude(pk=self.pk)
if qs.exists(): if qs.exists():

View File

@ -686,7 +686,7 @@ class UniqueConstraint(BaseConstraint):
filters.append(condition) filters.append(condition)
queryset = queryset.filter(*filters) queryset = queryset.filter(*filters)
model_class_pk = instance._get_pk_val(model._meta) model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and model_class_pk is not None: if not instance._state.adding and instance._is_pk_set(model._meta):
queryset = queryset.exclude(pk=model_class_pk) queryset = queryset.exclude(pk=model_class_pk)
if not self.condition: if not self.condition:
if queryset.exists(): if queryset.exists():

View File

@ -1969,7 +1969,7 @@ class ManyToManyField(RelatedField):
pass pass
def value_from_object(self, obj): def value_from_object(self, obj):
return [] if obj.pk is None else list(getattr(obj, self.attname).all()) return list(getattr(obj, self.attname).all()) if obj._is_pk_set() else []
def save_form_data(self, instance, data): def save_form_data(self, instance, data):
getattr(instance, self.attname).set(data) getattr(instance, self.attname).set(data)

View File

@ -511,8 +511,7 @@ class ReverseOneToOneDescriptor:
try: try:
rel_obj = self.related.get_cached_value(instance) rel_obj = self.related.get_cached_value(instance)
except KeyError: except KeyError:
related_pk = instance.pk if not instance._is_pk_set():
if related_pk is None:
rel_obj = None rel_obj = None
else: else:
filter_args = self.related.field.get_forward_related_filter(instance) filter_args = self.related.field.get_forward_related_filter(instance)
@ -753,7 +752,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
# Even if this relation is not to pk, we require still pk value. # Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB, # The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that. # although having a pk value isn't a guarantee of that.
if self.instance.pk is None: if not self.instance._is_pk_set():
raise ValueError( raise ValueError(
f"{self.instance.__class__.__name__!r} instance needs to have a " f"{self.instance.__class__.__name__!r} instance needs to have a "
f"primary key value before this relationship can be used." f"primary key value before this relationship can be used."
@ -1081,7 +1080,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# Even if this relation is not to pk, we require still pk value. # Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB, # The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that. # although having a pk value isn't a guarantee of that.
if instance.pk is None: if not instance._is_pk_set():
raise ValueError( raise ValueError(
"%r instance needs to have a primary key value before " "%r instance needs to have a primary key value before "
"a many-to-many relationship can be used." "a many-to-many relationship can be used."

View File

@ -16,7 +16,7 @@ def get_normalized_value(value, lhs):
from django.db.models import Model from django.db.models import Model
if isinstance(value, Model): if isinstance(value, Model):
if value.pk is None: if not value._is_pk_set():
raise ValueError("Model instances passed to related filters must be saved.") raise ValueError("Model instances passed to related filters must be saved.")
value_list = [] value_list = []
sources = lhs.output_field.path_infos[-1].target_fields sources = lhs.output_field.path_infos[-1].target_fields

View File

@ -668,7 +668,7 @@ class QuerySet(AltersData):
connection = connections[self.db] connection = connections[self.db]
for obj in objs: for obj in objs:
if obj.pk is None: if not obj._is_pk_set():
# Populate new PK values. # Populate new PK values.
obj.pk = obj._meta.pk.get_pk_value_on_save(obj) obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
if not connection.features.supports_default_keyword_in_bulk_insert: if not connection.features.supports_default_keyword_in_bulk_insert:
@ -794,7 +794,7 @@ class QuerySet(AltersData):
objs = list(objs) objs = list(objs)
self._prepare_for_bulk_create(objs) self._prepare_for_bulk_create(objs)
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) objs_without_pk, objs_with_pk = partition(lambda o: o._is_pk_set(), objs)
if objs_with_pk: if objs_with_pk:
returned_columns = self._batched_insert( returned_columns = self._batched_insert(
objs_with_pk, objs_with_pk,
@ -862,7 +862,7 @@ class QuerySet(AltersData):
if not fields: if not fields:
raise ValueError("Field names must be given to bulk_update().") raise ValueError("Field names must be given to bulk_update().")
objs = tuple(objs) objs = tuple(objs)
if any(obj.pk is None for obj in objs): if not all(obj._is_pk_set() for obj in objs):
raise ValueError("All bulk_update() objects must have a primary key set.") raise ValueError("All bulk_update() objects must have a primary key set.")
fields = [self.model._meta.get_field(name) for name in fields] fields = [self.model._meta.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields): if any(not f.concrete or f.many_to_many for f in fields):
@ -1289,7 +1289,7 @@ class QuerySet(AltersData):
return False return False
except AttributeError: except AttributeError:
raise TypeError("'obj' must be a model instance.") raise TypeError("'obj' must be a model instance.")
if obj.pk is None: if not obj._is_pk_set():
raise ValueError("QuerySet.contains() cannot be used on unsaved objects.") raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
if self._result_cache is not None: if self._result_cache is not None:
return obj in self._result_cache return obj in self._result_cache

View File

@ -220,7 +220,7 @@ class DeferredAttribute:
# might be able to reuse the already loaded value. Refs #18343. # might be able to reuse the already loaded value. Refs #18343.
val = self._check_parent_chain(instance) val = self._check_parent_chain(instance)
if val is None: if val is None:
if instance.pk is None and self.field.generated: if not instance._is_pk_set() and self.field.generated:
raise AttributeError( raise AttributeError(
"Cannot read a generated field from an unsaved model." "Cannot read a generated field from an unsaved model."
) )

View File

@ -935,7 +935,7 @@ class BaseModelFormSet(BaseFormSet, AltersData):
# 1. The object is an unexpected empty model, created by invalid # 1. The object is an unexpected empty model, created by invalid
# POST data such as an object outside the formset's queryset. # POST data such as an object outside the formset's queryset.
# 2. The object was already deleted from the database. # 2. The object was already deleted from the database.
if obj.pk is None: if not obj._is_pk_set():
continue continue
if form in forms_to_delete: if form in forms_to_delete:
self.deleted_objects.append(obj) self.deleted_objects.append(obj)
@ -1103,7 +1103,7 @@ class BaseInlineFormSet(BaseModelFormSet):
self.save_as_new = save_as_new self.save_as_new = save_as_new
if queryset is None: if queryset is None:
queryset = self.model._default_manager queryset = self.model._default_manager
if self.instance.pk is not None: if self.instance._is_pk_set():
qs = queryset.filter(**{self.fk.name: self.instance}) qs = queryset.filter(**{self.fk.name: self.instance})
else: else:
qs = queryset.none() qs = queryset.none()

View File

@ -973,3 +973,14 @@ Other attributes
since they are yet to be saved. Instances fetched from a ``QuerySet`` since they are yet to be saved. Instances fetched from a ``QuerySet``
will have ``adding=False`` and ``db`` set to the alias of the associated will have ``adding=False`` and ``db`` set to the alias of the associated
database. database.
``_is_pk_set()``
----------------
.. method:: Model._is_pk_set()
.. versionadded:: 5.2
The ``_is_pk_set()`` method returns whether the model instance's ``pk`` is set.
It abstracts the model's primary key definition, ensuring consistent behavior
regardless of the specific ``pk`` configuration.

View File

@ -289,7 +289,9 @@ Database backend API
This section describes changes that may be needed in third-party database This section describes changes that may be needed in third-party database
backends. backends.
* ... * The new :meth:`Model._is_pk_set() <django.db.models.Model._is_pk_set>` method
allows checking if a Model instance's primary key is defined.
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
------------------------- -------------------------

View File

@ -661,6 +661,31 @@ class ModelTest(TestCase):
headline__startswith="Area", headline__startswith="Area",
) )
def test_is_pk_unset(self):
cases = [
Article(),
Article(id=None),
]
for case in cases:
with self.subTest(case=case):
self.assertIs(case._is_pk_set(), False)
def test_is_pk_set(self):
def new_instance():
a = Article(pub_date=datetime.today())
a.save()
return a
cases = [
Article(id=1),
Article(id=0),
Article.objects.create(pub_date=datetime.today()),
new_instance(),
]
for case in cases:
with self.subTest(case=case):
self.assertIs(case._is_pk_set(), True)
class ModelLookupTest(TestCase): class ModelLookupTest(TestCase):
@classmethod @classmethod