mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Refs #28305 -- Consolidated field referencing detection in migrations.
This moves all the field referencing resolution methods to shared functions instead of duplicating efforts amongst state_forwards and references methods.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							734fde7714
						
					
				
				
					commit
					f5ede1cb6d
				
			| @@ -3,9 +3,7 @@ from django.db.models import NOT_PROVIDED | ||||
| from django.utils.functional import cached_property | ||||
|  | ||||
| from .base import Operation | ||||
| from .utils import ( | ||||
|     field_references_model, is_referenced_by_foreign_key, resolve_relation, | ||||
| ) | ||||
| from .utils import field_is_referenced, field_references, get_references | ||||
|  | ||||
|  | ||||
| class FieldOperation(Operation): | ||||
| @@ -33,9 +31,9 @@ class FieldOperation(Operation): | ||||
|         if name_lower == self.model_name_lower: | ||||
|             return True | ||||
|         if self.field: | ||||
|             return field_references_model( | ||||
|             return bool(field_references( | ||||
|                 (app_label, self.model_name_lower), self.field, (app_label, name_lower) | ||||
|             ) | ||||
|             )) | ||||
|         return False | ||||
|  | ||||
|     def references_field(self, model_name, name, app_label): | ||||
| @@ -47,20 +45,14 @@ class FieldOperation(Operation): | ||||
|             elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: | ||||
|                 return True | ||||
|         # Check if this operation remotely references the field. | ||||
|         if self.field: | ||||
|             model_tuple = (app_label, model_name_lower) | ||||
|             remote_field = self.field.remote_field | ||||
|             if remote_field: | ||||
|                 if (resolve_relation(remote_field.model, app_label, self.model_name_lower) == model_tuple and | ||||
|                         (not hasattr(self.field, 'to_fields') or | ||||
|                             name in self.field.to_fields or None in self.field.to_fields)): | ||||
|                     return True | ||||
|                 through = getattr(remote_field, 'through', None) | ||||
|                 if (through and resolve_relation(through, app_label, self.model_name_lower) == model_tuple and | ||||
|                         (getattr(remote_field, 'through_fields', None) is None or | ||||
|                             name in remote_field.through_fields)): | ||||
|                     return True | ||||
|         return False | ||||
|         if self.field is None: | ||||
|             return False | ||||
|         return bool(field_references( | ||||
|             (app_label, self.model_name_lower), | ||||
|             self.field, | ||||
|             (app_label, model_name_lower), | ||||
|             name, | ||||
|         )) | ||||
|  | ||||
|     def reduce(self, operation, app_label): | ||||
|         return ( | ||||
| @@ -236,7 +228,9 @@ class AlterField(FieldOperation): | ||||
|         # not referenced by a foreign key. | ||||
|         delay = ( | ||||
|             not field.is_relation and | ||||
|             not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name) | ||||
|             not field_is_referenced( | ||||
|                 state, (app_label, self.model_name_lower), (self.name, field), | ||||
|             ) | ||||
|         ) | ||||
|         state.reload_model(app_label, self.model_name_lower, delay=delay) | ||||
|  | ||||
| @@ -305,17 +299,11 @@ class RenameField(FieldOperation): | ||||
|         model_state = state.models[app_label, self.model_name_lower] | ||||
|         # Rename the field | ||||
|         fields = model_state.fields | ||||
|         found = False | ||||
|         found = None | ||||
|         for index, (name, field) in enumerate(fields): | ||||
|             if not found and name == self.old_name: | ||||
|                 fields[index] = (self.new_name, field) | ||||
|                 found = True | ||||
|                 # Delay rendering of relationships if it's not a relational | ||||
|                 # field and not referenced by a foreign key. | ||||
|                 delay = ( | ||||
|                     not field.is_relation and | ||||
|                     not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name) | ||||
|                 ) | ||||
|                 found = field | ||||
|             # Fix from_fields to refer to the new field. | ||||
|             from_fields = getattr(field, 'from_fields', None) | ||||
|             if from_fields: | ||||
| @@ -323,7 +311,7 @@ class RenameField(FieldOperation): | ||||
|                     self.new_name if from_field_name == self.old_name else from_field_name | ||||
|                     for from_field_name in from_fields | ||||
|                 ]) | ||||
|         if not found: | ||||
|         if found is None: | ||||
|             raise FieldDoesNotExist( | ||||
|                 "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) | ||||
|             ) | ||||
| @@ -336,23 +324,21 @@ class RenameField(FieldOperation): | ||||
|                     for together in options[option] | ||||
|                 ] | ||||
|         # Fix to_fields to refer to the new field. | ||||
|         model_tuple = app_label, self.model_name_lower | ||||
|         for (model_app_label, model_name), model_state in state.models.items(): | ||||
|             for index, (name, field) in enumerate(model_state.fields): | ||||
|                 remote_field = field.remote_field | ||||
|                 if remote_field: | ||||
|                     remote_model_tuple = resolve_relation( | ||||
|                         remote_field.model, model_app_label, model_name | ||||
|                     ) | ||||
|                     if remote_model_tuple == model_tuple: | ||||
|                         if getattr(remote_field, 'field_name', None) == self.old_name: | ||||
|                             remote_field.field_name = self.new_name | ||||
|                         to_fields = getattr(field, 'to_fields', None) | ||||
|                         if to_fields: | ||||
|                             field.to_fields = tuple([ | ||||
|                                 self.new_name if to_field_name == self.old_name else to_field_name | ||||
|                                 for to_field_name in to_fields | ||||
|                             ]) | ||||
|         delay = True | ||||
|         references = get_references( | ||||
|             state, (app_label, self.model_name_lower), (self.old_name, found), | ||||
|         ) | ||||
|         for *_, field, reference in references: | ||||
|             delay = False | ||||
|             if reference.to: | ||||
|                 remote_field, to_fields = reference.to | ||||
|                 if getattr(remote_field, 'field_name', None) == self.old_name: | ||||
|                     remote_field.field_name = self.new_name | ||||
|                 if to_fields: | ||||
|                     field.to_fields = tuple([ | ||||
|                         self.new_name if to_field_name == self.old_name else to_field_name | ||||
|                         for to_field_name in to_fields | ||||
|                     ]) | ||||
|         state.reload_model(app_label, self.model_name_lower, delay=delay) | ||||
|  | ||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||
|   | ||||
| @@ -7,7 +7,7 @@ from django.utils.functional import cached_property | ||||
| from .fields import ( | ||||
|     AddField, AlterField, FieldOperation, RemoveField, RenameField, | ||||
| ) | ||||
| from .utils import field_references_model, resolve_relation | ||||
| from .utils import field_references, get_references, resolve_relation | ||||
|  | ||||
|  | ||||
| def _check_for_duplicates(arg_name, objs): | ||||
| @@ -113,7 +113,7 @@ class CreateModel(ModelOperation): | ||||
|  | ||||
|         # Check we have no FKs/M2Ms with it | ||||
|         for _name, field in self.fields: | ||||
|             if field_references_model((app_label, self.name_lower), field, reference_model_tuple): | ||||
|             if field_references((app_label, self.name_lower), field, reference_model_tuple): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
| @@ -309,33 +309,19 @@ class RenameModel(ModelOperation): | ||||
|         # Repoint all fields pointing to the old model to the new one. | ||||
|         old_model_tuple = (app_label, self.old_name_lower) | ||||
|         new_remote_model = '%s.%s' % (app_label, self.new_name) | ||||
|         to_reload = [] | ||||
|         for (model_app_label, model_name), model_state in state.models.items(): | ||||
|             model_changed = False | ||||
|             for index, (name, field) in enumerate(model_state.fields): | ||||
|                 changed_field = None | ||||
|                 remote_field = field.remote_field | ||||
|                 if remote_field: | ||||
|                     remote_model_tuple = resolve_relation( | ||||
|                         remote_field.model, model_app_label, model_name | ||||
|                     ) | ||||
|                     if remote_model_tuple == old_model_tuple: | ||||
|                         changed_field = field.clone() | ||||
|                         changed_field.remote_field.model = new_remote_model | ||||
|                     through_model = getattr(remote_field, 'through', None) | ||||
|                     if through_model: | ||||
|                         through_model_tuple = resolve_relation( | ||||
|                             through_model, model_app_label, model_name | ||||
|                         ) | ||||
|                         if through_model_tuple == old_model_tuple: | ||||
|                             if changed_field is None: | ||||
|                                 changed_field = field.clone() | ||||
|                             changed_field.remote_field.through = new_remote_model | ||||
|                 if changed_field: | ||||
|                     model_state.fields[index] = name, changed_field | ||||
|                     model_changed = True | ||||
|             if model_changed: | ||||
|                 to_reload.append((model_app_label, model_name)) | ||||
|         to_reload = set() | ||||
|         for model_state, index, name, field, reference in get_references(state, old_model_tuple): | ||||
|             changed_field = None | ||||
|             if reference.to: | ||||
|                 changed_field = field.clone() | ||||
|                 changed_field.remote_field.model = new_remote_model | ||||
|             if reference.through: | ||||
|                 if changed_field is None: | ||||
|                     changed_field = field.clone() | ||||
|                 changed_field.remote_field.through = new_remote_model | ||||
|             if changed_field: | ||||
|                 model_state.fields[index] = name, changed_field | ||||
|                 to_reload.add((model_state.app_label, model_state.name_lower)) | ||||
|         # Reload models related to old model before removing the old model. | ||||
|         state.reload_models(to_reload, delay=True) | ||||
|         # Remove the old model. | ||||
|   | ||||
| @@ -1,17 +1,8 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT | ||||
|  | ||||
|  | ||||
| def is_referenced_by_foreign_key(state, model_name_lower, field, field_name): | ||||
|     for state_app_label, state_model in state.models: | ||||
|         for _, f in state.models[state_app_label, state_model].fields: | ||||
|             if (f.related_model and | ||||
|                     '%s.%s' % (state_app_label, model_name_lower) == f.related_model.lower() and | ||||
|                     hasattr(f, 'to_fields')): | ||||
|                 if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields: | ||||
|                     return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def resolve_relation(model, app_label=None, model_name=None): | ||||
|     """ | ||||
|     Turn a model class or model reference string and return a model tuple. | ||||
| @@ -38,13 +29,73 @@ def resolve_relation(model, app_label=None, model_name=None): | ||||
|     return model._meta.app_label, model._meta.model_name | ||||
|  | ||||
|  | ||||
| def field_references_model(model_tuple, field, reference_model_tuple): | ||||
|     """Return whether or not field references reference_model_tuple.""" | ||||
| FieldReference = namedtuple('FieldReference', 'to through') | ||||
|  | ||||
|  | ||||
| def field_references( | ||||
|     model_tuple, | ||||
|     field, | ||||
|     reference_model_tuple, | ||||
|     reference_field_name=None, | ||||
|     reference_field=None, | ||||
| ): | ||||
|     """ | ||||
|     Return either False or a FieldReference if `field` references provided | ||||
|     context. | ||||
|  | ||||
|     False positives can be returned if `reference_field_name` is provided | ||||
|     without `reference_field` because of the introspection limitation it | ||||
|     incurs. This should not be an issue when this function is used to determine | ||||
|     whether or not an optimization can take place. | ||||
|     """ | ||||
|     remote_field = field.remote_field | ||||
|     if remote_field: | ||||
|         if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: | ||||
|             return True | ||||
|         through = getattr(remote_field, 'through', None) | ||||
|         if through and resolve_relation(through, *model_tuple) == reference_model_tuple: | ||||
|             return True | ||||
|     return False | ||||
|     if not remote_field: | ||||
|         return False | ||||
|     references_to = None | ||||
|     references_through = None | ||||
|     if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: | ||||
|         to_fields = getattr(field, 'to_fields', None) | ||||
|         if ( | ||||
|             reference_field_name is None or | ||||
|             # Unspecified to_field(s). | ||||
|             to_fields is None or | ||||
|             # Reference to primary key. | ||||
|             (None in to_fields and (reference_field is None or reference_field.primary_key)) or | ||||
|             # Reference to field. | ||||
|             reference_field_name in to_fields | ||||
|         ): | ||||
|             references_to = (remote_field, to_fields) | ||||
|     through = getattr(remote_field, 'through', None) | ||||
|     if through and resolve_relation(through, *model_tuple) == reference_model_tuple: | ||||
|         through_fields = remote_field.through_fields | ||||
|         if ( | ||||
|             reference_field_name is None or | ||||
|             # Unspecified through_fields. | ||||
|             through_fields is None or | ||||
|             # Reference to field. | ||||
|             reference_field_name in through_fields | ||||
|         ): | ||||
|             references_through = (remote_field, through_fields) | ||||
|     if not (references_to or references_through): | ||||
|         return False | ||||
|     return FieldReference(references_to, references_through) | ||||
|  | ||||
|  | ||||
| def get_references(state, model_tuple, field_tuple=()): | ||||
|     """ | ||||
|     Generator of (model_state, index, name, field, reference) referencing | ||||
|     provided context. | ||||
|  | ||||
|     If field_tuple is provided only references to this particular field of | ||||
|     model_tuple will be generated. | ||||
|     """ | ||||
|     for state_model_tuple, model_state in state.models.items(): | ||||
|         for index, (name, field) in enumerate(model_state.fields): | ||||
|             reference = field_references(state_model_tuple, field, model_tuple, *field_tuple) | ||||
|             if reference: | ||||
|                 yield model_state, index, name, field, reference | ||||
|  | ||||
|  | ||||
| def field_is_referenced(state, model_tuple, field_tuple): | ||||
|     """Return whether `field_tuple` is referenced by any state models.""" | ||||
|     return next(get_references(state, model_tuple, field_tuple), None) is not None | ||||
|   | ||||
		Reference in New Issue
	
	Block a user