mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.
Thanks Abhijeet Viswa for the report and initial patch.
This commit is contained in:
parent
c33eb6dcd0
commit
0107e3d105
@ -953,6 +953,21 @@ class SQLCompiler:
|
|||||||
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
|
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
|
||||||
the query.
|
the query.
|
||||||
"""
|
"""
|
||||||
|
def _get_parent_klass_info(klass_info):
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
'model': parent_model,
|
||||||
|
'field': parent_link,
|
||||||
|
'reverse': False,
|
||||||
|
'select_fields': [
|
||||||
|
select_index
|
||||||
|
for select_index in klass_info['select_fields']
|
||||||
|
if self.select[select_index][0].target.model == parent_model
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for parent_model, parent_link in klass_info['model']._meta.parents.items()
|
||||||
|
)
|
||||||
|
|
||||||
def _get_field_choices():
|
def _get_field_choices():
|
||||||
"""Yield all allowed field paths in breadth-first search order."""
|
"""Yield all allowed field paths in breadth-first search order."""
|
||||||
queue = collections.deque([(None, self.klass_info)])
|
queue = collections.deque([(None, self.klass_info)])
|
||||||
@ -967,6 +982,10 @@ class SQLCompiler:
|
|||||||
field = field.remote_field
|
field = field.remote_field
|
||||||
path = parent_path + [field.name]
|
path = parent_path + [field.name]
|
||||||
yield LOOKUP_SEP.join(path)
|
yield LOOKUP_SEP.join(path)
|
||||||
|
queue.extend(
|
||||||
|
(path, klass_info)
|
||||||
|
for klass_info in _get_parent_klass_info(klass_info)
|
||||||
|
)
|
||||||
queue.extend(
|
queue.extend(
|
||||||
(path, klass_info)
|
(path, klass_info)
|
||||||
for klass_info in klass_info.get('related_klass_infos', [])
|
for klass_info in klass_info.get('related_klass_infos', [])
|
||||||
@ -974,28 +993,42 @@ class SQLCompiler:
|
|||||||
result = []
|
result = []
|
||||||
invalid_names = []
|
invalid_names = []
|
||||||
for name in self.query.select_for_update_of:
|
for name in self.query.select_for_update_of:
|
||||||
parts = [] if name == 'self' else name.split(LOOKUP_SEP)
|
|
||||||
klass_info = self.klass_info
|
klass_info = self.klass_info
|
||||||
for part in parts:
|
if name == 'self':
|
||||||
for related_klass_info in klass_info.get('related_klass_infos', []):
|
# Find the first selected column from a base model. If it
|
||||||
field = related_klass_info['field']
|
# doesn't exist, don't lock a base model.
|
||||||
if related_klass_info['reverse']:
|
for select_index in klass_info['select_fields']:
|
||||||
field = field.remote_field
|
if self.select[select_index][0].target.model == klass_info['model']:
|
||||||
if field.name == part:
|
col = self.select[select_index][0]
|
||||||
klass_info = related_klass_info
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
klass_info = None
|
col = None
|
||||||
break
|
|
||||||
if klass_info is None:
|
|
||||||
invalid_names.append(name)
|
|
||||||
continue
|
|
||||||
select_index = klass_info['select_fields'][0]
|
|
||||||
col = self.select[select_index][0]
|
|
||||||
if self.connection.features.select_for_update_of_column:
|
|
||||||
result.append(self.compile(col)[0])
|
|
||||||
else:
|
else:
|
||||||
result.append(self.quote_name_unless_alias(col.alias))
|
for part in name.split(LOOKUP_SEP):
|
||||||
|
klass_infos = (
|
||||||
|
*klass_info.get('related_klass_infos', []),
|
||||||
|
*_get_parent_klass_info(klass_info),
|
||||||
|
)
|
||||||
|
for related_klass_info in klass_infos:
|
||||||
|
field = related_klass_info['field']
|
||||||
|
if related_klass_info['reverse']:
|
||||||
|
field = field.remote_field
|
||||||
|
if field.name == part:
|
||||||
|
klass_info = related_klass_info
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
klass_info = None
|
||||||
|
break
|
||||||
|
if klass_info is None:
|
||||||
|
invalid_names.append(name)
|
||||||
|
continue
|
||||||
|
select_index = klass_info['select_fields'][0]
|
||||||
|
col = self.select[select_index][0]
|
||||||
|
if col is not None:
|
||||||
|
if self.connection.features.select_for_update_of_column:
|
||||||
|
result.append(self.compile(col)[0])
|
||||||
|
else:
|
||||||
|
result.append(self.quote_name_unless_alias(col.alias))
|
||||||
if invalid_names:
|
if invalid_names:
|
||||||
raise FieldError(
|
raise FieldError(
|
||||||
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
|
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
|
||||||
|
@ -1692,6 +1692,14 @@ specify the related objects you want to lock in ``select_for_update(of=(...))``
|
|||||||
using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
|
using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
|
||||||
to refer to the queryset's model.
|
to refer to the queryset's model.
|
||||||
|
|
||||||
|
.. admonition:: Lock parents models in ``select_for_update(of=(...))``
|
||||||
|
|
||||||
|
If you want to lock parents models when using :ref:`multi-table inheritance
|
||||||
|
<multi-table-inheritance>`, you must specify parent link fields (by default
|
||||||
|
``<parent_model_name>_ptr``) in the ``of`` argument. For example::
|
||||||
|
|
||||||
|
Restaurant.objects.select_for_update(of=('self', 'place_ptr'))
|
||||||
|
|
||||||
You can't use ``select_for_update()`` on nullable relations::
|
You can't use ``select_for_update()`` on nullable relations::
|
||||||
|
|
||||||
>>> Person.objects.select_related('hometown').select_for_update()
|
>>> Person.objects.select_related('hometown').select_for_update()
|
||||||
|
@ -17,3 +17,9 @@ Bugfixes
|
|||||||
* Fixed a regression in Django 2.2.1 that caused a crash when migrating
|
* Fixed a regression in Django 2.2.1 that caused a crash when migrating
|
||||||
permissions for proxy models with a multiple database setup if the
|
permissions for proxy models with a multiple database setup if the
|
||||||
``default`` entry was empty (:ticket:`31021`).
|
``default`` entry was empty (:ticket:`31021`).
|
||||||
|
|
||||||
|
* Fixed a data loss possibility in the
|
||||||
|
:meth:`~django.db.models.query.QuerySet.select_for_update()`. When using
|
||||||
|
``'self'`` in the ``of`` argument with :ref:`multi-table inheritance
|
||||||
|
<multi-table-inheritance>`, a parent model was locked instead of the
|
||||||
|
queryset's model (:ticket:`30953`).
|
||||||
|
@ -5,11 +5,20 @@ class Country(models.Model):
|
|||||||
name = models.CharField(max_length=30)
|
name = models.CharField(max_length=30)
|
||||||
|
|
||||||
|
|
||||||
|
class EUCountry(Country):
|
||||||
|
join_date = models.DateField()
|
||||||
|
|
||||||
|
|
||||||
class City(models.Model):
|
class City(models.Model):
|
||||||
name = models.CharField(max_length=30)
|
name = models.CharField(max_length=30)
|
||||||
country = models.ForeignKey(Country, models.CASCADE)
|
country = models.ForeignKey(Country, models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class EUCity(models.Model):
|
||||||
|
name = models.CharField(max_length=30)
|
||||||
|
country = models.ForeignKey(EUCountry, models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Person(models.Model):
|
class Person(models.Model):
|
||||||
name = models.CharField(max_length=30)
|
name = models.CharField(max_length=30)
|
||||||
born = models.ForeignKey(City, models.CASCADE, related_name='+')
|
born = models.ForeignKey(City, models.CASCADE, related_name='+')
|
||||||
|
@ -15,7 +15,7 @@ from django.test import (
|
|||||||
)
|
)
|
||||||
from django.test.utils import CaptureQueriesContext
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
from .models import City, Country, Person, PersonProfile
|
from .models import City, Country, EUCity, EUCountry, Person, PersonProfile
|
||||||
|
|
||||||
|
|
||||||
class SelectForUpdateTests(TransactionTestCase):
|
class SelectForUpdateTests(TransactionTestCase):
|
||||||
@ -119,6 +119,47 @@ class SelectForUpdateTests(TransactionTestCase):
|
|||||||
expected = [connection.ops.quote_name(value) for value in expected]
|
expected = [connection.ops.quote_name(value) for value in expected]
|
||||||
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
|
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('has_select_for_update_of')
|
||||||
|
def test_for_update_sql_model_inheritance_generated_of(self):
|
||||||
|
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
|
||||||
|
list(EUCountry.objects.select_for_update(of=('self',)))
|
||||||
|
if connection.features.select_for_update_of_column:
|
||||||
|
expected = ['select_for_update_eucountry"."country_ptr_id']
|
||||||
|
else:
|
||||||
|
expected = ['select_for_update_eucountry']
|
||||||
|
expected = [connection.ops.quote_name(value) for value in expected]
|
||||||
|
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('has_select_for_update_of')
|
||||||
|
def test_for_update_sql_model_inheritance_ptr_generated_of(self):
|
||||||
|
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
|
||||||
|
list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',)))
|
||||||
|
if connection.features.select_for_update_of_column:
|
||||||
|
expected = [
|
||||||
|
'select_for_update_eucountry"."country_ptr_id',
|
||||||
|
'select_for_update_country"."id',
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
expected = ['select_for_update_eucountry', 'select_for_update_country']
|
||||||
|
expected = [connection.ops.quote_name(value) for value in expected]
|
||||||
|
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('has_select_for_update_of')
|
||||||
|
def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
|
||||||
|
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
|
||||||
|
list(EUCity.objects.select_related('country').select_for_update(
|
||||||
|
of=('self', 'country__country_ptr',),
|
||||||
|
))
|
||||||
|
if connection.features.select_for_update_of_column:
|
||||||
|
expected = [
|
||||||
|
'select_for_update_eucity"."id',
|
||||||
|
'select_for_update_country"."id',
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
expected = ['select_for_update_eucity', 'select_for_update_country']
|
||||||
|
expected = [connection.ops.quote_name(value) for value in expected]
|
||||||
|
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
|
||||||
|
|
||||||
@skipUnlessDBFeature('has_select_for_update_of')
|
@skipUnlessDBFeature('has_select_for_update_of')
|
||||||
def test_for_update_of_followed_by_values(self):
|
def test_for_update_of_followed_by_values(self):
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
@ -257,6 +298,25 @@ class SelectForUpdateTests(TransactionTestCase):
|
|||||||
'born', 'profile',
|
'born', 'profile',
|
||||||
).exclude(profile=None).select_for_update(of=(name,)).get()
|
).exclude(profile=None).select_for_update(of=(name,)).get()
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
|
||||||
|
def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
|
||||||
|
msg = (
|
||||||
|
'Invalid field name(s) given in select_for_update(of=(...)): '
|
||||||
|
'name. Only relational fields followed in the query are allowed. '
|
||||||
|
'Choices are: self, %s.'
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(
|
||||||
|
FieldError,
|
||||||
|
msg % 'country, country__country_ptr',
|
||||||
|
):
|
||||||
|
with transaction.atomic():
|
||||||
|
EUCity.objects.select_related(
|
||||||
|
'country',
|
||||||
|
).select_for_update(of=('name',)).get()
|
||||||
|
with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
|
||||||
|
with transaction.atomic():
|
||||||
|
EUCountry.objects.select_for_update(of=('name',)).get()
|
||||||
|
|
||||||
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
|
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
|
||||||
def test_reverse_one_to_one_of_arguments(self):
|
def test_reverse_one_to_one_of_arguments(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user