diff --git a/django/db/models/base.py b/django/db/models/base.py index 6c59c0038a..ee7f4fe9c9 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -458,6 +458,16 @@ class Model(six.with_metaclass(ModelBase)): super(Model, self).__init__() signals.post_init.send(sender=self.__class__, instance=self) + @classmethod + def from_db(cls, db, field_names, values): + if cls._deferred: + new = cls(**dict(zip(field_names, values))) + else: + new = cls(*values) + new._state.adding = False + new._state.db = db + return new + def __repr__(self): try: u = six.text_type(self) diff --git a/django/db/models/query.py b/django/db/models/query.py index edb2cd88fb..57d9f6144a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -241,8 +241,7 @@ class QuerySet(object): aggregate_select = list(self.query.aggregate_select) only_load = self.query.get_loaded_field_names() - if not fill_cache: - fields = self.model._meta.concrete_fields + fields = self.model._meta.concrete_fields load_fields = [] # If only/defer clauses have been specified, @@ -260,9 +259,6 @@ class QuerySet(object): # Therefore, we need to load all fields from this model load_fields.append(field.name) - index_start = len(extra_select) - aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields) - skip = None if load_fields and not fill_cache: # Some fields have been deferred, so we have to initialize @@ -275,30 +271,25 @@ class QuerySet(object): else: init_list.append(field.attname) model_cls = deferred_class_factory(self.model, skip) + else: + model_cls = self.model + init_list = [f.attname for f in fields] # Cache db and model outside the loop db = self.db - model = self.model compiler = self.query.get_compiler(using=db) + index_start = len(extra_select) + aggregate_start = index_start + len(init_list) + if fill_cache: - klass_info = get_klass_info(model, max_depth=max_depth, + klass_info = get_klass_info(model_cls, max_depth=max_depth, requested=requested, only_load=only_load) for row in compiler.results_iter(): if fill_cache: obj, _ = get_cached_row(row, index_start, db, klass_info, offset=len(aggregate_select)) else: - # Omit aggregates in object creation. - row_data = row[index_start:aggregate_start] - if skip: - obj = model_cls(**dict(zip(init_list, row_data))) - else: - obj = model(*row_data) - - # Store the source database of the object - obj._state.db = db - # This object came from the database; it's not being added. - obj._state.adding = False + obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start]) if extra_select: for i, k in enumerate(extra_select): @@ -1417,6 +1408,21 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx +def reorder_for_init(model, field_names, values): + """ + Reorders given field names and values for those fields + to be in the same order as model.__init__() expects to find them. + """ + new_names, new_values = [], [] + for f in model._meta.concrete_fields: + if f.attname not in field_names: + continue + new_names.append(f.attname) + new_values.append(values[field_names.index(f.attname)]) + assert len(new_names) == len(field_names) + return new_names, new_values + + def get_cached_row(row, index_start, using, klass_info, offset=0, parent_data=()): """ @@ -1451,18 +1457,19 @@ def get_cached_row(row, index_start, using, klass_info, offset=0, fields[pk_idx] == '')): obj = None elif field_names: - fields = list(fields) + values = list(fields) + parent_values = [] + parent_field_names = [] for rel_field, value in parent_data: - field_names.append(rel_field.attname) - fields.append(value) - obj = klass(**dict(zip(field_names, fields))) + parent_field_names.append(rel_field.attname) + parent_values.append(value) + field_names, values = reorder_for_init( + klass, parent_field_names + field_names, + parent_values + values) + obj = klass.from_db(using, field_names, values) else: - obj = klass(*fields) - # If an object was retrieved, set the database state. - if obj: - obj._state.db = using - obj._state.adding = False - + field_names = [f.attname for f in klass._meta.concrete_fields] + obj = klass.from_db(using, field_names, fields) # Instantiate related fields index_end = index_start + field_count + offset # Iterate over each related object, populating any @@ -1494,7 +1501,7 @@ def get_cached_row(row, index_start, using, klass_info, offset=0, parent_data.append((rel_field, getattr(obj, rel_field.attname))) # Recursively retrieve the data for the related object cached_row = get_cached_row(row, index_end, using, klass_info, - parent_data=parent_data) + parent_data=parent_data) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: @@ -1534,15 +1541,18 @@ class RawQuerySet(object): self.params = params or () self.translations = translations or {} - def __iter__(self): - # Mapping of attrnames to row column positions. Used for constructing - # the model using kwargs, needed when not all model's fields are present - # in the query. - model_init_field_names = {} - # A list of tuples of (column name, column position). Used for - # annotation fields. - annotation_fields = [] + def resolve_model_init_order(self): + """ + Resolve the init field names and value positions + """ + model_init_names = [f.attname for f in self.model._meta.fields + if f.attname in self.columns] + annotation_fields = [(column, pos) for pos, column in enumerate(self.columns) + if column not in self.model_fields] + model_init_order = [self.columns.index(fname) for fname in model_init_names] + return model_init_names, model_init_order, annotation_fields + def __iter__(self): # Cache some things for performance reasons outside the loop. db = self.db compiler = connections[db].ops.compiler('SQLCompiler')( @@ -1553,18 +1563,12 @@ class RawQuerySet(object): query = iter(self.query) try: - # Find out which columns are model's fields, and which ones should be - # annotated to the model. - for pos, column in enumerate(self.columns): - if column in self.model_fields: - model_init_field_names[self.model_fields[column].attname] = pos - else: - annotation_fields.append((column, pos)) + model_init_names, model_init_pos, annotation_fields = self.resolve_model_init_order() # Find out which model's fields are not present in the query. skip = set() for field in self.model._meta.fields: - if field.attname not in model_init_field_names: + if field.attname not in model_init_names: skip.add(field.attname) if skip: if self.model._meta.pk.attname in skip: @@ -1572,34 +1576,17 @@ class RawQuerySet(object): model_cls = deferred_class_factory(self.model, skip) else: model_cls = self.model - # All model's fields are present in the query. So, it is possible - # to use *args based model instantiation. For each field of the model, - # record the query column position matching that field. - model_init_field_pos = [] - for field in self.model._meta.fields: - model_init_field_pos.append(model_init_field_names[field.attname]) if need_resolv_columns: fields = [self.model_fields.get(c, None) for c in self.columns] - # Begin looping through the query values. for values in query: if need_resolv_columns: values = compiler.resolve_columns(values, fields) # Associate fields to values - if skip: - model_init_kwargs = {} - for attname, pos in six.iteritems(model_init_field_names): - model_init_kwargs[attname] = values[pos] - instance = model_cls(**model_init_kwargs) - else: - model_init_args = [values[pos] for pos in model_init_field_pos] - instance = model_cls(*model_init_args) + model_init_values = [values[pos] for pos in model_init_pos] + instance = model_cls.from_db(db, model_init_names, model_init_values) if annotation_fields: for column, pos in annotation_fields: setattr(instance, column, values[pos]) - - instance._state.db = db - instance._state.adding = False - yield instance finally: # Done iterating the Query. If it has its own cursor, close it. diff --git a/docs/ref/models/instances.txt b/docs/ref/models/instances.txt index 519da0cf3d..984f64d47e 100644 --- a/docs/ref/models/instances.txt +++ b/docs/ref/models/instances.txt @@ -62,6 +62,60 @@ that, you need to :meth:`~Model.save()`. book = Book.objects.create_book("Pride and Prejudice") +Customizing model loading +------------------------- + +.. classmethod:: Model.from_db(db, field_names, values) + +.. versionadded:: 1.8 + +The ``from_db()`` method can be used to customize model instance creation +when loading from the database. + +The ``db`` argument contains the database alias for the database the model +is loaded from, ``field_names`` contains the names of all loaded fields, and +``values`` contains the loaded values for each field in ``field_names``. The +``field_names`` are in the same order as the ``values``, so it is possible to +use ``cls(**(zip(field_names, values)))`` to instantiate the object. If all +of the model's fields are present, then ``values`` are guaranteed to be in +the order ``__init__()`` expects them. That is, the instance can be created +by ``cls(*values)``. It is possible to check if all fields are present by +consulting ``cls._deferred`` - if ``False``, then all fields have been loaded +from the database. + +In addition to creating the new model, the ``from_db()`` method must set the +``adding`` and ``db`` flags in the new instance's ``_state`` attribute. + +Below is an example showing how torecord the initial values of fields that +are loaded from the database:: + + @classmethod + def from_db(cls, db, field_names, values): + # default implementation of from_db() (could be replaced + # with super()) + if cls._deferred: + instance = cls(**zip(field_names, values)) + else: + instance = cls(*values) + instance._state.adding = False + instance._state.db = db + # customization to store the original field values on the instance + instance._loaded_values = zip(field_names, values) + return instance + + def save(self, *args, **kwargs): + # Check how the current values differ from ._loaded_values. For example, + # prevent changing the creator_id of the model. (This example doesn't + # support cases where 'creator_id' is deferred). + if not self._state.adding and ( + self.creator_id != self._loaded_values['creator_id']): + raise ValueError("Updating the value of creator isn't allowed") + super(...).save(*args, **kwargs) + +The example above shows a full ``from_db()`` implementation to clarify how that +is done. In this case it would of course be possible to just use ``super()`` call +in the ``from_db()`` method. + .. _validating-objects: Validating objects diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 793c781e2f..1eafb3337e 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -193,6 +193,10 @@ Models when these objects are unpickled in a different version than the one in which they were pickled. +* Added :meth:`Model.from_db() ` which + Django uses whenever objects are loaded using the ORM. The method allows + customizing model loading behavior. + Signals ^^^^^^^