1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.

Thanks Abhijeet Viswa for the report and initial patch.
This commit is contained in:
Mariusz Felisiak
2019-12-02 07:57:19 +01:00
committed by GitHub
parent c33eb6dcd0
commit 0107e3d105
5 changed files with 135 additions and 19 deletions

View File

@@ -953,6 +953,21 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
def _get_parent_klass_info(klass_info):
return (
{
'model': parent_model,
'field': parent_link,
'reverse': False,
'select_fields': [
select_index
for select_index in klass_info['select_fields']
if self.select[select_index][0].target.model == parent_model
],
}
for parent_model, parent_link in klass_info['model']._meta.parents.items()
)
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
queue = collections.deque([(None, self.klass_info)])
@@ -967,6 +982,10 @@ class SQLCompiler:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path)
queue.extend(
(path, klass_info)
for klass_info in _get_parent_klass_info(klass_info)
)
queue.extend(
(path, klass_info)
for klass_info in klass_info.get('related_klass_infos', [])
@@ -974,28 +993,42 @@ class SQLCompiler:
result = []
invalid_names = []
for name in self.query.select_for_update_of:
parts = [] if name == 'self' else name.split(LOOKUP_SEP)
klass_info = self.klass_info
for part in parts:
for related_klass_info in klass_info.get('related_klass_infos', []):
field = related_klass_info['field']
if related_klass_info['reverse']:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
if name == 'self':
# Find the first selected column from a base model. If it
# doesn't exist, don't lock a base model.
for select_index in klass_info['select_fields']:
if self.select[select_index][0].target.model == klass_info['model']:
col = self.select[select_index][0]
break
else:
klass_info = None
break
if klass_info is None:
invalid_names.append(name)
continue
select_index = klass_info['select_fields'][0]
col = self.select[select_index][0]
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
col = None
else:
result.append(self.quote_name_unless_alias(col.alias))
for part in name.split(LOOKUP_SEP):
klass_infos = (
*klass_info.get('related_klass_infos', []),
*_get_parent_klass_info(klass_info),
)
for related_klass_info in klass_infos:
field = related_klass_info['field']
if related_klass_info['reverse']:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
break
else:
klass_info = None
break
if klass_info is None:
invalid_names.append(name)
continue
select_index = klass_info['select_fields'][0]
col = self.select[select_index][0]
if col is not None:
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
else:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
'Invalid field name(s) given in select_for_update(of=(...)): %s. '