1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #28010 -- Allowed reverse related fields in SELECT FOR UPDATE .. OF.

Thanks Adam Chidlow for polishing the patch.
This commit is contained in:
Ran Benita 2017-10-17 11:28:00 +08:00 committed by Tim Graham
parent 56b364bacc
commit 03049fb8d9
5 changed files with 61 additions and 6 deletions

View File

@ -783,6 +783,7 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': f.remote_field.model, 'model': f.remote_field.model,
'field': f, 'field': f,
'reverse': False,
'local_setter': f.set_cached_value, 'local_setter': f.set_cached_value,
'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None, 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
'from_parent': False, 'from_parent': False,
@ -821,6 +822,7 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': model, 'model': model,
'field': f, 'field': f,
'reverse': True,
'local_setter': f.remote_field.set_cached_value, 'local_setter': f.remote_field.set_cached_value,
'remote_setter': f.set_cached_value, 'remote_setter': f.set_cached_value,
'from_parent': from_parent, 'from_parent': from_parent,
@ -858,6 +860,7 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': model, 'model': model,
'field': f, 'field': f,
'reverse': True,
'local_setter': local_setter, 'local_setter': local_setter,
'remote_setter': remote_setter, 'remote_setter': remote_setter,
'from_parent': from_parent, 'from_parent': from_parent,
@ -905,7 +908,10 @@ class SQLCompiler:
path = [] path = []
yield 'self' yield 'self'
else: else:
path = parent_path + [klass_info['field'].name] field = klass_info['field']
if klass_info['reverse']:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path) yield LOOKUP_SEP.join(path)
queue.extend( queue.extend(
(path, klass_info) (path, klass_info)
@ -918,7 +924,10 @@ class SQLCompiler:
klass_info = self.klass_info klass_info = self.klass_info
for part in parts: for part in parts:
for related_klass_info in klass_info.get('related_klass_infos', []): for related_klass_info in klass_info.get('related_klass_infos', []):
if related_klass_info['field'].name == part: field = related_klass_info['field']
if related_klass_info['reverse']:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info klass_info = related_klass_info
break break
else: else:

View File

@ -1628,6 +1628,19 @@ specify the related objects you want to lock in ``select_for_update(of=(...))``
using the same fields syntax as :meth:`select_related`. Use the value ``'self'`` using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
to refer to the queryset's model. to refer to the queryset's model.
You can't use ``select_for_update()`` on nullable relations::
>>> Person.objects.select_related('hometown').select_for_update()
Traceback (most recent call last):
...
django.db.utils.NotSupportedError: FOR UPDATE cannot be applied to the nullable side of an outer join
To avoid that restriction, you can exclude null objects if you don't care about
them::
>>> Person.objects.select_related('hometown').select_for_update().exclude(hometown=None)
<QuerySet [<Person: ...)>, ...]>
Currently, the ``postgresql``, ``oracle``, and ``mysql`` database Currently, the ``postgresql``, ``oracle``, and ``mysql`` database
backends support ``select_for_update()``. However, MySQL doesn't support the backends support ``select_for_update()``. However, MySQL doesn't support the
``nowait``, ``skip_locked``, and ``of`` arguments. ``nowait``, ``skip_locked``, and ``of`` arguments.

View File

@ -1,4 +1,4 @@
from django.db import connection from django.db import connection, transaction
from django.db.models import Case, Count, F, FilteredRelation, Q, When from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipUnlessDBFeature
@ -62,6 +62,20 @@ class FilteredRelationTests(TestCase):
(self.book4, self.author1), (self.book4, self.author1),
], lambda x: (x, x.author_join)) ], lambda x: (x, x.author_join))
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_select_related_foreign_key_for_update_of(self):
with transaction.atomic():
qs = Book.objects.annotate(
author_join=FilteredRelation('author'),
).select_related('author_join').select_for_update(of=('self',)).order_by('pk')
with self.assertNumQueries(1):
self.assertQuerysetEqual(qs, [
(self.book1, self.author1),
(self.book2, self.author2),
(self.book3, self.author2),
(self.book4, self.author1),
], lambda x: (x, x.author_join))
def test_without_join(self): def test_without_join(self):
self.assertSequenceEqual( self.assertSequenceEqual(
Author.objects.annotate( Author.objects.annotate(

View File

@ -14,3 +14,7 @@ class Person(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
born = models.ForeignKey(City, models.CASCADE, related_name='+') born = models.ForeignKey(City, models.CASCADE, related_name='+')
died = models.ForeignKey(City, models.CASCADE, related_name='+') died = models.ForeignKey(City, models.CASCADE, related_name='+')
class PersonProfile(models.Model):
person = models.OneToOneField(Person, models.CASCADE, related_name='profile')

View File

@ -15,7 +15,7 @@ from django.test import (
) )
from django.test.utils import CaptureQueriesContext from django.test.utils import CaptureQueriesContext
from .models import City, Country, Person from .models import City, Country, Person, PersonProfile
class SelectForUpdateTests(TransactionTestCase): class SelectForUpdateTests(TransactionTestCase):
@ -30,6 +30,7 @@ class SelectForUpdateTests(TransactionTestCase):
self.city1 = City.objects.create(name='Liberchies', country=self.country1) self.city1 = City.objects.create(name='Liberchies', country=self.country1)
self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2) self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2) self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
self.person_profile = PersonProfile.objects.create(person=self.person)
# We need another database connection in transaction to test that one # We need another database connection in transaction to test that one
# connection issuing a SELECT ... FOR UPDATE will block. # connection issuing a SELECT ... FOR UPDATE will block.
@ -225,13 +226,27 @@ class SelectForUpdateTests(TransactionTestCase):
msg = ( msg = (
'Invalid field name(s) given in select_for_update(of=(...)): %s. ' 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. ' 'Only relational fields followed in the query are allowed. '
'Choices are: self, born.' 'Choices are: self, born, profile.'
) )
for name in ['born__country', 'died', 'died__country']: for name in ['born__country', 'died', 'died__country']:
with self.subTest(name=name): with self.subTest(name=name):
with self.assertRaisesMessage(FieldError, msg % name): with self.assertRaisesMessage(FieldError, msg % name):
with transaction.atomic(): with transaction.atomic():
Person.objects.select_related('born').select_for_update(of=(name,)).get() Person.objects.select_related(
'born', 'profile',
).exclude(profile=None).select_for_update(of=(name,)).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_reverse_one_to_one_of_arguments(self):
"""
Reverse OneToOneFields may be included in of=(...) as long as NULLs
are excluded because LEFT JOIN isn't allowed in SELECT FOR UPDATE.
"""
with transaction.atomic():
person = Person.objects.select_related(
'profile',
).exclude(profile=None).select_for_update(of=('profile',)).get()
self.assertEqual(person.profile, self.person_profile)
@skipUnlessDBFeature('has_select_for_update') @skipUnlessDBFeature('has_select_for_update')
def test_for_update_after_from(self): def test_for_update_after_from(self):