mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Refs #28305 -- Consolidated field referencing detection in migrations.
This moves all the field referencing resolution methods to shared functions instead of duplicating efforts amongst state_forwards and references methods.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							734fde7714
						
					
				
				
					commit
					f5ede1cb6d
				
			| @@ -3,9 +3,7 @@ from django.db.models import NOT_PROVIDED | |||||||
| from django.utils.functional import cached_property | from django.utils.functional import cached_property | ||||||
|  |  | ||||||
| from .base import Operation | from .base import Operation | ||||||
| from .utils import ( | from .utils import field_is_referenced, field_references, get_references | ||||||
|     field_references_model, is_referenced_by_foreign_key, resolve_relation, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FieldOperation(Operation): | class FieldOperation(Operation): | ||||||
| @@ -33,9 +31,9 @@ class FieldOperation(Operation): | |||||||
|         if name_lower == self.model_name_lower: |         if name_lower == self.model_name_lower: | ||||||
|             return True |             return True | ||||||
|         if self.field: |         if self.field: | ||||||
|             return field_references_model( |             return bool(field_references( | ||||||
|                 (app_label, self.model_name_lower), self.field, (app_label, name_lower) |                 (app_label, self.model_name_lower), self.field, (app_label, name_lower) | ||||||
|             ) |             )) | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def references_field(self, model_name, name, app_label): |     def references_field(self, model_name, name, app_label): | ||||||
| @@ -47,20 +45,14 @@ class FieldOperation(Operation): | |||||||
|             elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: |             elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: | ||||||
|                 return True |                 return True | ||||||
|         # Check if this operation remotely references the field. |         # Check if this operation remotely references the field. | ||||||
|         if self.field: |         if self.field is None: | ||||||
|             model_tuple = (app_label, model_name_lower) |             return False | ||||||
|             remote_field = self.field.remote_field |         return bool(field_references( | ||||||
|             if remote_field: |             (app_label, self.model_name_lower), | ||||||
|                 if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and |             self.field, | ||||||
|                         (not hasattr(self.field, 'to_fields') or |             (app_label, model_name_lower), | ||||||
|                             name in self.field.to_fields or None in self.field.to_fields)): |             name, | ||||||
|                     return True |         )) | ||||||
|                 through = getattr(remote_field, 'through', None) |  | ||||||
|                 if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and |  | ||||||
|                         (getattr(remote_field, 'through_fields', None) is None or |  | ||||||
|                             name in remote_field.through_fields)): |  | ||||||
|                     return True |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def reduce(self, operation, app_label): |     def reduce(self, operation, app_label): | ||||||
|         return ( |         return ( | ||||||
| @@ -236,7 +228,9 @@ class AlterField(FieldOperation): | |||||||
|         # not referenced by a foreign key. |         # not referenced by a foreign key. | ||||||
|         delay = ( |         delay = ( | ||||||
|             not field.is_relation and |             not field.is_relation and | ||||||
|             not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name) |             not field_is_referenced( | ||||||
|  |                 state, (app_label, self.model_name_lower), (self.name, field), | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|         state.reload_model(app_label, self.model_name_lower, delay=delay) |         state.reload_model(app_label, self.model_name_lower, delay=delay) | ||||||
|  |  | ||||||
| @@ -305,17 +299,11 @@ class RenameField(FieldOperation): | |||||||
|         model_state = state.models[app_label, self.model_name_lower] |         model_state = state.models[app_label, self.model_name_lower] | ||||||
|         # Rename the field |         # Rename the field | ||||||
|         fields = model_state.fields |         fields = model_state.fields | ||||||
|         found = False |         found = None | ||||||
|         for index, (name, field) in enumerate(fields): |         for index, (name, field) in enumerate(fields): | ||||||
|             if not found and name == self.old_name: |             if not found and name == self.old_name: | ||||||
|                 fields[index] = (self.new_name, field) |                 fields[index] = (self.new_name, field) | ||||||
|                 found = True |                 found = field | ||||||
|                 # Delay rendering of relationships if it's not a relational |  | ||||||
|                 # field and not referenced by a foreign key. |  | ||||||
|                 delay = ( |  | ||||||
|                     not field.is_relation and |  | ||||||
|                     not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name) |  | ||||||
|                 ) |  | ||||||
|             # Fix from_fields to refer to the new field. |             # Fix from_fields to refer to the new field. | ||||||
|             from_fields = getattr(field, 'from_fields', None) |             from_fields = getattr(field, 'from_fields', None) | ||||||
|             if from_fields: |             if from_fields: | ||||||
| @@ -323,7 +311,7 @@ class RenameField(FieldOperation): | |||||||
|                     self.new_name if from_field_name == self.old_name else from_field_name |                     self.new_name if from_field_name == self.old_name else from_field_name | ||||||
|                     for from_field_name in from_fields |                     for from_field_name in from_fields | ||||||
|                 ]) |                 ]) | ||||||
|         if not found: |         if found is None: | ||||||
|             raise FieldDoesNotExist( |             raise FieldDoesNotExist( | ||||||
|                 "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) |                 "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) | ||||||
|             ) |             ) | ||||||
| @@ -336,23 +324,21 @@ class RenameField(FieldOperation): | |||||||
|                     for together in options[option] |                     for together in options[option] | ||||||
|                 ] |                 ] | ||||||
|         # Fix to_fields to refer to the new field. |         # Fix to_fields to refer to the new field. | ||||||
|         model_tuple = app_label, self.model_name_lower |         delay = True | ||||||
|         for (model_app_label, model_name), model_state in state.models.items(): |         references = get_references( | ||||||
|             for index, (name, field) in enumerate(model_state.fields): |             state, (app_label, self.model_name_lower), (self.old_name, found), | ||||||
|                 remote_field = field.remote_field |         ) | ||||||
|                 if remote_field: |         for *_, field, reference in references: | ||||||
|                     remote_model_tuple = resolve_relation( |             delay = False | ||||||
|                         remote_field.model, model_app_label, model_name |             if reference.to: | ||||||
|                     ) |                 remote_field, to_fields = reference.to | ||||||
|                     if remote_model_tuple == model_tuple: |                 if getattr(remote_field, 'field_name', None) == self.old_name: | ||||||
|                         if getattr(remote_field, 'field_name', None) == self.old_name: |                     remote_field.field_name = self.new_name | ||||||
|                             remote_field.field_name = self.new_name |                 if to_fields: | ||||||
|                         to_fields = getattr(field, 'to_fields', None) |                     field.to_fields = tuple([ | ||||||
|                         if to_fields: |                         self.new_name if to_field_name == self.old_name else to_field_name | ||||||
|                             field.to_fields = tuple([ |                         for to_field_name in to_fields | ||||||
|                                 self.new_name if to_field_name == self.old_name else to_field_name |                     ]) | ||||||
|                                 for to_field_name in to_fields |  | ||||||
|                             ]) |  | ||||||
|         state.reload_model(app_label, self.model_name_lower, delay=delay) |         state.reload_model(app_label, self.model_name_lower, delay=delay) | ||||||
|  |  | ||||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): |     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ from django.utils.functional import cached_property | |||||||
| from .fields import ( | from .fields import ( | ||||||
|     AddField, AlterField, FieldOperation, RemoveField, RenameField, |     AddField, AlterField, FieldOperation, RemoveField, RenameField, | ||||||
| ) | ) | ||||||
| from .utils import field_references_model, resolve_relation | from .utils import field_references, get_references, resolve_relation | ||||||
|  |  | ||||||
|  |  | ||||||
| def _check_for_duplicates(arg_name, objs): | def _check_for_duplicates(arg_name, objs): | ||||||
| @@ -113,7 +113,7 @@ class CreateModel(ModelOperation): | |||||||
|  |  | ||||||
|         # Check we have no FKs/M2Ms with it |         # Check we have no FKs/M2Ms with it | ||||||
|         for _name, field in self.fields: |         for _name, field in self.fields: | ||||||
|             if field_references_model((app_label, self.name_lower), field, reference_model_tuple): |             if field_references((app_label, self.name_lower), field, reference_model_tuple): | ||||||
|                 return True |                 return True | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
| @@ -309,33 +309,19 @@ class RenameModel(ModelOperation): | |||||||
|         # Repoint all fields pointing to the old model to the new one. |         # Repoint all fields pointing to the old model to the new one. | ||||||
|         old_model_tuple = (app_label, self.old_name_lower) |         old_model_tuple = (app_label, self.old_name_lower) | ||||||
|         new_remote_model = '%s.%s' % (app_label, self.new_name) |         new_remote_model = '%s.%s' % (app_label, self.new_name) | ||||||
|         to_reload = [] |         to_reload = set() | ||||||
|         for (model_app_label, model_name), model_state in state.models.items(): |         for model_state, index, name, field, reference in get_references(state, old_model_tuple): | ||||||
|             model_changed = False |             changed_field = None | ||||||
|             for index, (name, field) in enumerate(model_state.fields): |             if reference.to: | ||||||
|                 changed_field = None |                 changed_field = field.clone() | ||||||
|                 remote_field = field.remote_field |                 changed_field.remote_field.model = new_remote_model | ||||||
|                 if remote_field: |             if reference.through: | ||||||
|                     remote_model_tuple = resolve_relation( |                 if changed_field is None: | ||||||
|                         remote_field.model, model_app_label, model_name |                     changed_field = field.clone() | ||||||
|                     ) |                 changed_field.remote_field.through = new_remote_model | ||||||
|                     if remote_model_tuple == old_model_tuple: |             if changed_field: | ||||||
|                         changed_field = field.clone() |                 model_state.fields[index] = name, changed_field | ||||||
|                         changed_field.remote_field.model = new_remote_model |                 to_reload.add((model_state.app_label, model_state.name_lower)) | ||||||
|                     through_model = getattr(remote_field, 'through', None) |  | ||||||
|                     if through_model: |  | ||||||
|                         through_model_tuple = resolve_relation( |  | ||||||
|                             through_model, model_app_label, model_name |  | ||||||
|                         ) |  | ||||||
|                         if through_model_tuple == old_model_tuple: |  | ||||||
|                             if changed_field is None: |  | ||||||
|                                 changed_field = field.clone() |  | ||||||
|                             changed_field.remote_field.through = new_remote_model |  | ||||||
|                 if changed_field: |  | ||||||
|                     model_state.fields[index] = name, changed_field |  | ||||||
|                     model_changed = True |  | ||||||
|             if model_changed: |  | ||||||
|                 to_reload.append((model_app_label, model_name)) |  | ||||||
|         # Reload models related to old model before removing the old model. |         # Reload models related to old model before removing the old model. | ||||||
|         state.reload_models(to_reload, delay=True) |         state.reload_models(to_reload, delay=True) | ||||||
|         # Remove the old model. |         # Remove the old model. | ||||||
|   | |||||||
| @@ -1,17 +1,8 @@ | |||||||
|  | from collections import namedtuple | ||||||
|  |  | ||||||
| from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT | from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): |  | ||||||
|     for state_app_label, state_model in state.models: |  | ||||||
|         for _, f in state.models[state_app_label, state_model].fields: |  | ||||||
|             if (f.related_model and |  | ||||||
|                     '%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and |  | ||||||
|                     hasattr(f, 'to_fields')): |  | ||||||
|                 if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields: |  | ||||||
|                     return True |  | ||||||
|     return False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def resolve_relation(model, app_label=None, model_name=None): | def resolve_relation(model, app_label=None, model_name=None): | ||||||
|     """ |     """ | ||||||
|     Turn a model class or model reference string and return a model tuple. |     Turn a model class or model reference string and return a model tuple. | ||||||
| @@ -38,13 +29,73 @@ def resolve_relation(model, app_label=None, model_name=None): | |||||||
|     return model._meta.app_label, model._meta.model_name |     return model._meta.app_label, model._meta.model_name | ||||||
|  |  | ||||||
|  |  | ||||||
| def field_references_model(model_tuple, field, reference_model_tuple): | FieldReference = namedtuple('FieldReference', 'to through') | ||||||
|     """Return whether or not field references reference_model_tuple.""" |  | ||||||
|  |  | ||||||
|  | def field_references( | ||||||
|  |     model_tuple, | ||||||
|  |     field, | ||||||
|  |     reference_model_tuple, | ||||||
|  |     reference_field_name=None, | ||||||
|  |     reference_field=None, | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     Return either False or a FieldReference if `field` references provided | ||||||
|  |     context. | ||||||
|  |  | ||||||
|  |     False positives can be returned if `reference_field_name` is provided | ||||||
|  |     without `reference_field` because of the introspection limitation it | ||||||
|  |     incurs. This should not be an issue when this function is used to determine | ||||||
|  |     whether or not an optimization can take place. | ||||||
|  |     """ | ||||||
|     remote_field = field.remote_field |     remote_field = field.remote_field | ||||||
|     if remote_field: |     if not remote_field: | ||||||
|         if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: |         return False | ||||||
|             return True |     references_to = None | ||||||
|         through = getattr(remote_field, 'through', None) |     references_through = None | ||||||
|         if through and resolve_relation(through, *model_tuple) == reference_model_tuple: |     if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: | ||||||
|             return True |         to_fields = getattr(field, 'to_fields', None) | ||||||
|     return False |         if ( | ||||||
|  |             reference_field_name is None or | ||||||
|  |             # Unspecified to_field(s). | ||||||
|  |             to_fields is None or | ||||||
|  |             # Reference to primary key. | ||||||
|  |             (None in to_fields and (reference_field is None or reference_field.primary_key)) or | ||||||
|  |             # Reference to field. | ||||||
|  |             reference_field_name in to_fields | ||||||
|  |         ): | ||||||
|  |             references_to = (remote_field, to_fields) | ||||||
|  |     through = getattr(remote_field, 'through', None) | ||||||
|  |     if through and resolve_relation(through, *model_tuple) == reference_model_tuple: | ||||||
|  |         through_fields = remote_field.through_fields | ||||||
|  |         if ( | ||||||
|  |             reference_field_name is None or | ||||||
|  |             # Unspecified through_fields. | ||||||
|  |             through_fields is None or | ||||||
|  |             # Reference to field. | ||||||
|  |             reference_field_name in through_fields | ||||||
|  |         ): | ||||||
|  |             references_through = (remote_field, through_fields) | ||||||
|  |     if not (references_to or references_through): | ||||||
|  |         return False | ||||||
|  |     return FieldReference(references_to, references_through) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_references(state, model_tuple, field_tuple=()): | ||||||
|  |     """ | ||||||
|  |     Generator of (model_state, index, name, field, reference) referencing | ||||||
|  |     provided context. | ||||||
|  |  | ||||||
|  |     If field_tuple is provided only references to this particular field of | ||||||
|  |     model_tuple will be generated. | ||||||
|  |     """ | ||||||
|  |     for state_model_tuple, model_state in state.models.items(): | ||||||
|  |         for index, (name, field) in enumerate(model_state.fields): | ||||||
|  |             reference = field_references(state_model_tuple, field, model_tuple, *field_tuple) | ||||||
|  |             if reference: | ||||||
|  |                 yield model_state, index, name, field, reference | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def field_is_referenced(state, model_tuple, field_tuple): | ||||||
|  |     """Return whether `field_tuple` is referenced by any state models.""" | ||||||
|  |     return next(get_references(state, model_tuple, field_tuple), None) is not None | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user