From da6570bf082620205737c82cd7deb7185daaf538 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Sun, 17 Feb 2008 18:47:57 +0000 Subject: [PATCH] queryset-refactor: Model inheritance support. This adds both types of model inheritance: abstract base classes (ABCs) and multi-table inheritance. See the documentation and tests / examples for details. Still a few known bugs here, so don't file tickets (I know about them). Not quite ready for prime-time usage, but it mostly works as expected. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7126 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/core/management/sql.py | 16 +- django/core/management/validation.py | 23 +- django/db/models/base.py | 127 +++++---- django/db/models/fields/__init__.py | 26 +- django/db/models/fields/related.py | 7 +- django/db/models/manager.py | 4 +- django/db/models/options.py | 260 +++++++++++++++---- django/db/models/sql/query.py | 26 +- docs/model-api.txt | 99 +++++++ tests/modeltests/model_inheritance/models.py | 157 ++++++++++- 10 files changed, 607 insertions(+), 138 deletions(-) diff --git a/django/core/management/sql.py b/django/core/management/sql.py index 15bffce26b..9e606a745c 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -26,7 +26,7 @@ def django_table_list(only_existing=False): for app in models.get_apps(): for model in models.get_models(app): tables.append(model._meta.db_table) - tables.extend([f.m2m_db_table() for f in model._meta.many_to_many]) + tables.extend([f.m2m_db_table() for f in model._meta.local_many_to_many]) if only_existing: existing = table_list() tables = [t for t in tables if t in existing] @@ -54,12 +54,12 @@ def sequence_list(): for app in apps: for model in models.get_models(app): - for f in model._meta.fields: + for f in model._meta.local_fields: if isinstance(f, models.AutoField): sequence_list.append({'table': model._meta.db_table, 'column': f.column}) break # Only one AutoField is allowed per model, so don't bother continuing. - for f in model._meta.many_to_many: + for f in model._meta.local_many_to_many: sequence_list.append({'table': f.m2m_db_table(), 'column': None}) return sequence_list @@ -147,7 +147,7 @@ def sql_delete(app, style): if cursor and table_name_converter(model._meta.db_table) in table_names: # The table exists, so it needs to be dropped opts = model._meta - for f in opts.fields: + for f in opts.local_fields: if f.rel and f.rel.to not in to_delete: references_to_delete.setdefault(f.rel.to, []).append( (model, f) ) @@ -179,7 +179,7 @@ def sql_delete(app, style): # Output DROP TABLE statements for many-to-many tables. for model in app_models: opts = model._meta - for f in opts.many_to_many: + for f in opts.local_many_to_many: if isinstance(f.rel, generic.GenericRel): continue if cursor and table_name_converter(f.m2m_db_table()) in table_names: @@ -256,7 +256,7 @@ def sql_model_create(model, style, known_models=set()): pending_references = {} qn = connection.ops.quote_name inline_references = connection.features.inline_fk_references - for f in opts.fields: + for f in opts.local_fields: col_type = f.db_type() tablespace = f.db_tablespace or opts.db_tablespace if col_type is None: @@ -351,7 +351,7 @@ def many_to_many_sql_for_model(model, style): final_output = [] qn = connection.ops.quote_name inline_references = connection.features.inline_fk_references - for f in opts.many_to_many: + for f in opts.local_many_to_many: if not isinstance(f.rel, generic.GenericRel): tablespace = f.db_tablespace or opts.db_tablespace if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys: @@ -458,7 +458,7 @@ def sql_indexes_for_model(model, style): output = [] qn = connection.ops.quote_name - for f in model._meta.fields: + for f in model._meta.local_fields: if f.db_index and not ((f.primary_key or f.unique) and connection.features.autoindexes_primary_keys): unique = f.unique and 'UNIQUE ' or '' tablespace = f.db_tablespace or model._meta.db_tablespace diff --git a/django/core/management/validation.py b/django/core/management/validation.py index fc3a7162a1..f5e4f7fc70 100644 --- a/django/core/management/validation.py +++ b/django/core/management/validation.py @@ -32,7 +32,7 @@ def get_validation_errors(outfile, app=None): opts = cls._meta # Do field-specific validation. - for f in opts.fields: + for f in opts.local_fields: if f.name == 'id' and not f.primary_key and opts.pk.name == 'id': e.add(opts, '"%s": You can\'t use "id" as a field name, because each model automatically gets an "id" field if none of the fields have primary_key=True. You need to either remove/rename your "id" field or add primary_key=True to a field.' % f.name) if f.name.endswith('_'): @@ -69,8 +69,8 @@ def get_validation_errors(outfile, app=None): if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255: e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]]))) - # Check to see if the related field will clash with any - # existing fields, m2m fields, m2m related objects or related objects + # Check to see if the related field will clash with any existing + # fields, m2m fields, m2m related objects or related objects if f.rel: if f.rel.to not in models.get_models(): e.add(opts, "'%s' has relation with model %s, which has not been installed" % (f.name, f.rel.to)) @@ -87,7 +87,7 @@ def get_validation_errors(outfile, app=None): e.add(opts, "Accessor for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name: e.add(opts, "Reverse query name for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) - for r in rel_opts.many_to_many: + for r in rel_opts.local_many_to_many: if r.name == rel_name: e.add(opts, "Accessor for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name: @@ -104,9 +104,10 @@ def get_validation_errors(outfile, app=None): if r.get_accessor_name() == rel_query_name: e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name)) - for i, f in enumerate(opts.many_to_many): + for i, f in enumerate(opts.local_many_to_many): # Check to see if the related m2m field will clash with any - # existing fields, m2m fields, m2m related objects or related objects + # existing fields, m2m fields, m2m related objects or related + # objects if f.rel.to not in models.get_models(): e.add(opts, "'%s' has m2m relation with model %s, which has not been installed" % (f.name, f.rel.to)) # it is a string and we could not find the model it refers to @@ -117,17 +118,17 @@ def get_validation_errors(outfile, app=None): rel_opts = f.rel.to._meta rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name() rel_query_name = f.related_query_name() - # If rel_name is none, there is no reverse accessor. - # (This only occurs for symmetrical m2m relations to self). - # If this is the case, there are no clashes to check for this field, as - # there are no reverse descriptors for this field. + # If rel_name is none, there is no reverse accessor (this only + # occurs for symmetrical m2m relations to self). If this is the + # case, there are no clashes to check for this field, as there are + # no reverse descriptors for this field. if rel_name is not None: for r in rel_opts.fields: if r.name == rel_name: e.add(opts, "Accessor for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name: e.add(opts, "Reverse query name for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) - for r in rel_opts.many_to_many: + for r in rel_opts.local_many_to_many: if r.name == rel_name: e.add(opts, "Accessor for m2m field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name: diff --git a/django/db/models/base.py b/django/db/models/base.py index 9236207f94..45bb82c3f4 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1,9 +1,14 @@ +import types +import sys +import os +from itertools import izip + import django.db.models.manipulators import django.db.models.manager from django.core import validators from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned from django.db.models.fields import AutoField, ImageField, FieldDoesNotExist -from django.db.models.fields.related import OneToOneRel, ManyToOneRel +from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneField from django.db.models.query import delete_objects, Q from django.db.models.options import Options, AdminOptions from django.db import connection, transaction @@ -14,10 +19,6 @@ from django.utils.datastructures import SortedDict from django.utils.functional import curry from django.utils.encoding import smart_str, force_unicode, smart_unicode from django.conf import settings -from itertools import izip -import types -import sys -import os class ModelBase(type): "Metaclass for all models" @@ -25,29 +26,46 @@ class ModelBase(type): # If this isn't a subclass of Model, don't do anything special. try: parents = [b for b in bases if issubclass(b, Model)] - if not parents: - return super(ModelBase, cls).__new__(cls, name, bases, attrs) except NameError: # 'Model' isn't defined yet, meaning we're looking at Django's own # Model class, defined below. + parents = [] + if not parents: return super(ModelBase, cls).__new__(cls, name, bases, attrs) # Create the class. new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')}) - new_class.add_to_class('_meta', Options(attrs.pop('Meta', None))) + meta = attrs.pop('Meta', None) + # FIXME: Promote Meta to a newstyle class before attaching it to the + # model. + ## if meta: + ## new_class.Meta = meta + new_class.add_to_class('_meta', Options(meta)) + # FIXME: Need to be smarter here. Exception is an old-style class in + # Python <= 2.4, new-style in Python 2.5+. This construction is only + # really correct for old-style classes. new_class.add_to_class('DoesNotExist', types.ClassType('DoesNotExist', (ObjectDoesNotExist,), {})) - new_class.add_to_class('MultipleObjectsReturned', - types.ClassType('MultipleObjectsReturned', (MultipleObjectsReturned, ), {})) + new_class.add_to_class('MultipleObjectsReturned', types.ClassType('MultipleObjectsReturned', (MultipleObjectsReturned, ), {})) - # Build complete list of parents + # Do the appropriate setup for any model parents. + abstract_parents = [] for base in parents: - # Things without _meta aren't functional models, so they're - # uninteresting parents. - if hasattr(base, '_meta'): - new_class._meta.parents.append(base) - new_class._meta.parents.extend(base._meta.parents) - + if not hasattr(base, '_meta'): + # Things without _meta aren't functional models, so they're + # uninteresting parents. + continue + if not base._meta.abstract: + attr_name = '%s_ptr' % base._meta.module_name + field = OneToOneField(base, name=attr_name, auto_created=True) + new_class.add_to_class(attr_name, field) + new_class._meta.parents[base] = field + else: + abstract_parents.append(base) + if getattr(new_class, '_default_manager', None) is not None: + # We have a parent who set the default manager. We need to override + # this. + new_class._default_manager = None if getattr(new_class._meta, 'app_label', None) is None: # Figure out the app_label by looking one level up. # For 'django.contrib.sites.models', this would be 'sites'. @@ -63,21 +81,26 @@ class ModelBase(type): for obj_name, obj in attrs.items(): new_class.add_to_class(obj_name, obj) - # Add Fields inherited from parents - for parent in new_class._meta.parents: - for field in parent._meta.fields: - # Only add parent fields if they aren't defined for this class. - try: - new_class._meta.get_field(field.name) - except FieldDoesNotExist: - field.contribute_to_class(new_class, field.name) + for parent in abstract_parents: + names = [f.name for f in new_class._meta.local_fields + new_class._meta.many_to_many] + for field in parent._meta.local_fields: + if field.name in names: + raise TypeError('Local field %r in class %r clashes with field of similar name from abstract base class %r' + % (field.name, name, parent.__name__)) + new_class.add_to_class(field.name, field) + + if new_class._meta.abstract: + # Abstract base models can't be instantiated and don't appear in + # the list of models for an app. We do the final setup for them a + # little differently from normal models. + return new_class new_class._prepare() - register_models(new_class._meta.app_label, new_class) + # Because of the way imports happen (recursively), we may or may not be - # the first class for this model to register with the framework. There - # should only be one class for each model, so we must always return the + # the first time this model tries to register with the framework. There + # should only be one class for each model, so we always return the # registered version. return get_model(new_class._meta.app_label, name, False) @@ -113,8 +136,10 @@ class ModelBase(type): class Model(object): __metaclass__ = ModelBase - def _get_pk_val(self): - return getattr(self, self._meta.pk.attname) + def _get_pk_val(self, meta=None): + if not meta: + meta = self._meta + return getattr(self, meta.pk.attname) def _set_pk_val(self, value): return setattr(self, self._meta.pk.attname, value) @@ -207,19 +232,30 @@ class Model(object): raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self) - def save(self, raw=False): - dispatcher.send(signal=signals.pre_save, sender=self.__class__, - instance=self, raw=raw) + def save(self, raw=False, cls=None): + if not cls: + dispatcher.send(signal=signals.pre_save, sender=self.__class__, + instance=self, raw=raw) + cls = self.__class__ + meta = self._meta + signal = True + else: + meta = cls._meta + signal = False - non_pks = [f for f in self._meta.fields if not f.primary_key] + for parent, field in meta.parents.items(): + self.save(raw, parent) + setattr(self, field.attname, self._get_pk_val(parent._meta)) + + non_pks = [f for f in self._meta.local_fields if not f.primary_key] # First, try an UPDATE. If that doesn't update anything, do an INSERT. - pk_val = self._get_pk_val() + pk_val = self._get_pk_val(meta) # Note: the comparison with '' is required for compatibility with # oldforms-style model creation. pk_set = pk_val is not None and smart_unicode(pk_val) != u'' record_exists = True - manager = self.__class__._default_manager + manager = cls._default_manager if pk_set: # Determine whether a record with the primary key already exists. if manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by(): @@ -231,16 +267,16 @@ class Model(object): record_exists = False if not pk_set or not record_exists: if not pk_set: - values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in self._meta.fields if not isinstance(f, AutoField)] + values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)] else: - values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in self._meta.fields] + values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields] - if self._meta.order_with_respect_to: - field = self._meta.order_with_respect_to + if meta.order_with_respect_to: + field = meta.order_with_respect_to values.append(('_order', manager.filter(**{field.name: getattr(self, field.attname)}).count())) record_exists = False - update_pk = bool(self._meta.has_auto_field and not pk_set) + update_pk = bool(meta.has_auto_field and not pk_set) if values: # Create a new record. result = manager._insert(_return_id=update_pk, **dict(values)) @@ -250,12 +286,13 @@ class Model(object): _raw_values=True, pk=connection.ops.pk_default_value()) if update_pk: - setattr(self, self._meta.pk.attname, result) + setattr(self, meta.pk.attname, result) transaction.commit_unless_managed() - # Run any post-save hooks. - dispatcher.send(signal=signals.post_save, sender=self.__class__, - instance=self, created=(not record_exists), raw=raw) + if signal: + # Run any post-save hooks. + dispatcher.send(signal=signals.post_save, sender=self.__class__, + instance=self, created=(not record_exists), raw=raw) save.alters_data = True diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index a3d7d05e16..c2edc71573 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -75,15 +75,19 @@ class Field(object): # database level. empty_strings_allowed = True - # Tracks each time a Field instance is created. Used to retain order. + # These track each time a Field instance is created. Used to retain order. + # The auto_creation_counter is used for fields that Django implicitly + # creates, creation_counter is used for all user-specified fields. creation_counter = 0 + auto_creation_counter = -1 def __init__(self, verbose_name=None, name=None, primary_key=False, - max_length=None, unique=False, blank=False, null=False, db_index=False, - core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True, - prepopulate_from=None, unique_for_date=None, unique_for_month=None, - unique_for_year=None, validator_list=None, choices=None, radio_admin=None, - help_text='', db_column=None, db_tablespace=None): + max_length=None, unique=False, blank=False, null=False, + db_index=False, core=False, rel=None, default=NOT_PROVIDED, + editable=True, serialize=True, prepopulate_from=None, + unique_for_date=None, unique_for_month=None, unique_for_year=None, + validator_list=None, choices=None, radio_admin=None, help_text='', + db_column=None, db_tablespace=None, auto_created=False): self.name = name self.verbose_name = verbose_name self.primary_key = primary_key @@ -109,9 +113,13 @@ class Field(object): # Set db_index to True if the field has a relationship and doesn't explicitly set db_index. self.db_index = db_index - # Increase the creation counter, and save our local copy. - self.creation_counter = Field.creation_counter - Field.creation_counter += 1 + # Adjust the appropriate creation counter, and save our local copy. + if auto_created: + self.creation_counter = Field.auto_creation_counter + Field.auto_creation_counter -= 1 + else: + self.creation_counter = Field.creation_counter + Field.creation_counter += 1 def __cmp__(self, other): # This is needed because bisect does not take a comparison function. diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index a7c4fca2fd..39fe6d794c 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -494,8 +494,9 @@ class OneToOneRel(ManyToOneRel): # ignored here. We accept them as parameters only to match the calling # signature of ManyToOneRel.__init__(). super(OneToOneRel, self).__init__(to, field_name, num_in_admin, - edit_inline, related_name, limit_choices_to, lookup_overrides, - raw_id_admin) + edit_inline=edit_inline, related_name=related_name, + limit_choices_to=limit_choices_to, + lookup_overrides=lookup_overrides, raw_id_admin=raw_id_admin) self.multiple = False class ManyToManyRel(object): @@ -754,7 +755,7 @@ class ManyToManyField(RelatedField, Field): def save_form_data(self, instance, data): setattr(instance, self.attname, data) - + def formfield(self, **kwargs): defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()} defaults.update(kwargs) diff --git a/django/db/models/manager.py b/django/db/models/manager.py index ed420e1333..34811cd324 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -5,7 +5,7 @@ from django.db.models.fields import FieldDoesNotExist def ensure_default_manager(sender): cls = sender - if not hasattr(cls, '_default_manager'): + if not hasattr(cls, '_default_manager') or cls._default_manager is None: # Create the default manager, if needed. try: cls._meta.get_field('objects') @@ -31,7 +31,7 @@ class Manager(object): # TODO: Use weakref because of possible memory leak / circular reference. self.model = model setattr(model, name, ManagerDescriptor(self)) - if not hasattr(model, '_default_manager') or self.creation_counter < model._default_manager.creation_counter: + if not hasattr(model, '_default_manager') or model._default_manager is None or self.creation_counter < model._default_manager.creation_counter: model._default_manager = self ####################### diff --git a/django/db/models/options.py b/django/db/models/options.py index 8d80a0ac8d..8c369f98a6 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -1,3 +1,6 @@ +import re +from bisect import bisect + from django.conf import settings from django.db.models.related import RelatedObject from django.db.models.fields.related import ManyToManyRel @@ -7,19 +10,19 @@ from django.db.models.loading import get_models, app_cache_ready from django.db.models import Manager from django.utils.translation import activate, deactivate_all, get_language, string_concat from django.utils.encoding import force_unicode, smart_str -from bisect import bisect -import re +from django.utils.datastructures import SortedDict # Calculate the verbose_name by converting from InitialCaps to "lowercase with spaces". get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', ' \\1', class_name).lower().strip() DEFAULT_NAMES = ('verbose_name', 'db_table', 'ordering', 'unique_together', 'permissions', 'get_latest_by', - 'order_with_respect_to', 'app_label', 'db_tablespace') + 'order_with_respect_to', 'app_label', 'db_tablespace', + 'abstract') class Options(object): def __init__(self, meta): - self.fields, self.many_to_many = [], [] + self.local_fields, self.local_many_to_many = [], [] self.module_name, self.verbose_name = None, None self.verbose_name_plural = None self.db_table = '' @@ -35,7 +38,8 @@ class Options(object): self.pk = None self.has_auto_field, self.auto_field = False, None self.one_to_one_field = None - self.parents = [] + self.abstract = False + self.parents = SortedDict() def contribute_to_class(self, cls, name): cls._meta = self @@ -82,9 +86,16 @@ class Options(object): self.order_with_respect_to = None if self.pk is None: - auto = AutoField(verbose_name='ID', primary_key=True) - auto.creation_counter = -1 - model.add_to_class('id', auto) + if self.parents: + # Promote the first parent link in lieu of adding yet another + # field. + field = self.parents.value_for_index(0) + field.primary_key = True + self.pk = field + else: + auto = AutoField(verbose_name='ID', primary_key=True, + auto_created=True) + model.add_to_class('id', auto) # If the db_table wasn't provided, use the app_label + module_name. if not self.db_table: @@ -97,13 +108,23 @@ class Options(object): # Move many-to-many related fields from self.fields into # self.many_to_many. if field.rel and isinstance(field.rel, ManyToManyRel): - self.many_to_many.insert(bisect(self.many_to_many, field), field) + self.local_many_to_many.insert(bisect(self.local_many_to_many, field), field) else: - self.fields.insert(bisect(self.fields, field), field) + self.local_fields.insert(bisect(self.local_fields, field), field) if not self.pk and field.primary_key: self.pk = field field.serialize = False + # All of these internal caches need to be updated the next time they + # are used. + # TODO: Do this more neatly. (Also, use less caches!) + if hasattr(self, '_field_cache'): + del self._field_cache + if hasattr(self, '_m2m_cache'): + del self._m2m_cache + if hasattr(self, '_name_map'): + del self._name_map + def __repr__(self): return '' % self.object_name @@ -123,8 +144,76 @@ class Options(object): return raw verbose_name_raw = property(verbose_name_raw) + def _fields(self): + """ + The getter for self.fields. This returns the list of field objects + available to this model (including through parent models). + """ + try: + self._field_cache + except AttributeError: + self._fill_fields_cache() + return self._field_cache.keys() + fields = property(_fields) + + def get_fields_with_model(self): + """ + Returns a list of (field, model) pairs for all fields. The "model" + element is None for fields on the current model. Mostly of use when + constructing queries so that we know which model a field belongs to. + """ + try: + self._field_cache + except AttributeError: + self._fill_fields_cache() + return self._field_cache.items() + + def _fill_fields_cache(self): + cache = SortedDict() + for parent in self.parents: + for field, model in parent._meta.get_fields_with_model(): + if model: + cache[field] = model + else: + cache[field] = parent + for field in self.local_fields: + cache[field] = None + self._field_cache = cache + + def _many_to_many(self): + try: + self._m2m_cache + except AttributeError: + self._fill_m2m_cache() + return self._m2m_cache.keys() + many_to_many = property(_many_to_many) + + def get_m2m_with_model(self): + """ + The many-to-many version of get_fields_with_model(). + """ + try: + self._m2m_cache + except AttributeError: + self._fill_m2m_cache() + return self._m2m_cache.items() + + def _fill_m2m_cache(self): + cache = SortedDict() + for parent in self.parents: + for field, model in parent._meta.get_m2m_with_model(): + if model: + cache[field] = model + else: + cache[field] = parent + for field in self.local_many_to_many: + cache[field] = None + self._m2m_cache = cache + def get_field(self, name, many_to_many=True): - "Returns the requested field by name. Raises FieldDoesNotExist on error." + """ + Returns the requested field by name. Raises FieldDoesNotExist on error. + """ to_search = many_to_many and (self.fields + self.many_to_many) or self.fields for f in to_search: if f.name == name: @@ -133,8 +222,9 @@ class Options(object): def get_field_by_name(self, name, only_direct=False): """ - Returns the (field_object, direct, m2m), where field_object is the - Field instance for the given name, direct is True if the field exists + Returns the (field_object, model, direct, m2m), where field_object is + the Field instance for the given name, model is the model containing + this field (None for local fields), direct is True if the field exists on this model, and m2m is True for many-to-many relations. When 'direct' is False, 'field_object' is the corresponding RelatedObject for this field (since the field doesn't have an instance associated @@ -151,7 +241,7 @@ class Options(object): cache = self.init_name_map() result = cache.get(name) - if not result or (not result[1] and only_direct): + if not result or (only_direct and not result[2]): raise FieldDoesNotExist('%s has no field named %r' % (self.object_name, name)) return result @@ -173,15 +263,16 @@ class Options(object): """ Initialises the field name -> field object mapping. """ - cache = dict([(f.name, (f, True, False)) for f in self.fields]) - for f in self.many_to_many: - cache[f.name] = (f, True, True) - for f in self.get_all_related_many_to_many_objects(): - cache[f.field.related_query_name()] = (f, False, True) - for f in self.get_all_related_objects(): - cache[f.field.related_query_name()] = (f, False, False) + cache = dict([(f.name, (f, m, True, False)) for f, m in + self.get_fields_with_model()]) + for f, model in self.get_m2m_with_model(): + cache[f.name] = (f, model, True, True) + for f, model in self.get_all_related_m2m_objects_with_model(): + cache[f.field.related_query_name()] = (f, model, False, True) + for f, model in self.get_all_related_objects_with_model(): + cache[f.field.related_query_name()] = (f, model, False, False) if self.order_with_respect_to: - cache['_order'] = OrderWrt(), True, False + cache['_order'] = OrderWrt(), None, True, False if app_cache_ready(): self._name_map = cache return cache @@ -195,17 +286,81 @@ class Options(object): def get_delete_permission(self): return 'delete_%s' % self.object_name.lower() - def get_all_related_objects(self): - try: # Try the cache first. - return self._all_related_objects + def get_all_related_objects(self, local_only=False): + try: + self._related_objects_cache except AttributeError: - rel_objs = [] - for klass in get_models(): - for f in klass._meta.fields: - if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta: - rel_objs.append(RelatedObject(f.rel.to, klass, f)) - self._all_related_objects = rel_objs - return rel_objs + self._fill_related_objects_cache() + if local_only: + return [k for k, v in self._related_objects_cache.items() if not v] + return self._related_objects_cache.keys() + + def get_all_related_objects_with_model(self): + """ + Returns a list of (related-object, model) pairs. Similar to + get_fields_with_model(). + """ + try: + self._related_objects_cache + except AttributeError: + self._fill_related_objects_cache() + return self._related_objects_cache.items() + + def _fill_related_objects_cache(self): + cache = SortedDict() + parent_list = self.get_parent_list() + for parent in self.parents: + for obj, model in parent._meta.get_all_related_objects_with_model(): + if obj.field.creation_counter < 0 and obj.model not in parent_list: + continue + if not model: + cache[obj] = parent + else: + cache[obj] = model + for klass in get_models(): + for f in klass._meta.local_fields: + if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta: + cache[RelatedObject(f.rel.to, klass, f)] = None + self._related_objects_cache = cache + + def get_all_related_many_to_many_objects(self, local_only=False): + try: + cache = self._related_many_to_many_cache + except AttributeError: + cache = self._fill_related_many_to_many_cache() + if local_only: + return [k for k, v in cache.items() if not v] + return cache.keys() + + def get_all_related_m2m_objects_with_model(self): + """ + Returns a list of (related-m2m-object, model) pairs. Similar to + get_fields_with_model(). + """ + try: + cache = self._related_many_to_many_cache + except AttributeError: + cache = self._fill_related_many_to_many_cache() + return cache.items() + + def _fill_related_many_to_many_cache(self): + cache = SortedDict() + parent_list = self.get_parent_list() + for parent in self.parents: + for obj, model in parent._meta.get_all_related_m2m_objects_with_model(): + if obj.field.creation_counter < 0 and obj.model not in parent_list: + continue + if not model: + cache[obj] = parent + else: + cache[obj] = model + for klass in get_models(): + for f in klass._meta.local_many_to_many: + if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta: + cache[RelatedObject(f.rel.to, klass, f)] = None + if app_cache_ready(): + self._related_many_to_many_cache = cache + return cache def get_followed_related_objects(self, follow=None): if follow == None: @@ -229,18 +384,35 @@ class Options(object): follow[f.name] = fol return follow - def get_all_related_many_to_many_objects(self): - try: # Try the cache first. - return self._all_related_many_to_many_objects - except AttributeError: - rel_objs = [] - for klass in get_models(): - for f in klass._meta.many_to_many: - if f.rel and not isinstance(f.rel.to, str) and self == f.rel.to._meta: - rel_objs.append(RelatedObject(f.rel.to, klass, f)) - if app_cache_ready(): - self._all_related_many_to_many_objects = rel_objs - return rel_objs + def get_base_chain(self, model): + """ + Returns a list of parent classes leading to 'model' (order from closet + to most distant ancestor). This has to handle the case were 'model' is + a granparent or even more distant relation. + """ + if not self.parents: + return + if model in self.parents: + return [model] + for parent in self.parents: + res = parent._meta.get_base_chain(model) + if res: + res.insert(0, parent) + return res + raise TypeError('%r is not an ancestor of this model' + % model._meta.module_name) + + def get_parent_list(self): + """ + Returns a list of all the ancestor of this model as a list. Useful for + determining if something is an ancestor, regardless of lineage. + """ + # FIXME: Fix model hashing and then use a Set here. + result = [] + for parent in self.parents: + result.append(parent) + result.extend(parent._meta.get_parent_list()) + return result def get_ordered_objects(self): "Returns a list of Options objects that are ordered with respect to this object." diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 7d95e1d603..7b70e4d1b2 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -389,8 +389,13 @@ class Query(object): aliases.append(col.alias) elif self.default_cols: table_alias = self.tables[0] - result = ['%s.%s' % (qn(table_alias), qn(f.column)) - for f in self.model._meta.fields] + root_pk = self.model._meta.pk.column + seen = {None: table_alias} + for field, model in self.model._meta.get_fields_with_model(): + if model not in seen: + seen[model] = self.join((table_alias, model._meta.db_table, + root_pk, model._meta.pk.column)) + result.append('%s.%s' % (qn(seen[model]), qn(field.column))) aliases = result[:] result.extend(['(%s) AS %s' % (col, alias) @@ -742,7 +747,7 @@ class Query(object): opts = self.model._meta alias = self.join((None, opts.db_table, None, None)) - field, target, unused, join_list, = self.setup_joins(parts, opts, + field, target, opts, join_list, = self.setup_joins(parts, opts, alias, (connector == AND)) col = target.column alias = join_list[-1][-1] @@ -850,11 +855,21 @@ class Query(object): name = opts.pk.name try: - field, direct, m2m = opts.get_field_by_name(name) + field, model, direct, m2m = opts.get_field_by_name(name) except FieldDoesNotExist: names = opts.get_all_field_names() raise TypeError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) + if model: + # The field lives on a base class of the current model. + alias_list = [] + for int_model in opts.get_base_chain(model): + lhs_col = opts.parents[int_model].column + opts = int_model._meta + alias = self.join((alias, opts.db_table, lhs_col, + opts.pk.column)) + alias_list.append(alias) + joins.append(alias_list) cached_data = opts._join_cache.get(name) orig_opts = opts @@ -899,6 +914,7 @@ class Query(object): nullable=field.null) joins.append([alias]) else: + # Non-relation fields. target = field break else: @@ -1242,7 +1258,7 @@ class UpdateQuery(Query): def add_update_values(self, values): from django.db.models.base import Model for name, val in values.items(): - field, direct, m2m = self.model._meta.get_field_by_name(name) + field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: # Can only update non-relation fields and foreign keys. raise TypeError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) diff --git a/docs/model-api.txt b/docs/model-api.txt index cccd2ffded..6ca08ae4e5 100644 --- a/docs/model-api.txt +++ b/docs/model-api.txt @@ -2028,6 +2028,105 @@ You can also prevent saving:: .. _database API docs: ../db-api/ +Model inheritance +================= + +Abstract base classes +--------------------- + +Abstract base classes are useful when you want to put some common information +into a number of other models. You write your base class and put +``abstract=True`` in the ``Meta`` class. This model will then not be used to +create any database table. Instead, when it is used as a base class for other +models, its fields will be added to those of the child class. It is an error +to have fields in the abstract base class with the same name as those in the +child (and Django will raise an exception). + +An example:: + + class CommonInfo(models.Model): + name = models.CharField(max_length=100) + age = models.PositiveIntegerField() + + class Meta: + abstract = True + + class Student(CommonInfo): + home_group = models.CharField(max_length=5) + +The ``Student`` model will have three fields: ``name``, ``age`` and +``home_group``. The ``CommonInfo`` model cannot be used as a normal Django +model, since it is an abstract base class. It does not generate a database +table or have a manager or anything like that. + +For many uses, this type of model inheritance will be exactly what you want. +It provides a way to factor out common information at the Python level, whilst +still only creating one database table per child model at the database level. + +Multi-table inheritance +----------------------- + +The second type of model inheritance supported by Django is when each model in +the hierarchy is a model all by itself. Each model corresponds to its own +database table and can be queried and created indvidually. The inheritance +relationship introduces links between the child model and each of its parents +(via an automatically created ``OneToOneField``). For example:: + + class Place(models.Model): + name = models.CharField(max_length=50) + address = models.CharField(max_length=80) + + class Restaurant(Place): + serves_hot_dogs = models.BooleanField() + serves_pizza = models.BooleanField() + +All of the fields of ``Place`` will also be available in ``Restaurant``, +although the data will reside in a different database table. So these are both +possible:: + + >>> Place.objects.filter(name="Bob's Cafe") + >>> Restaurant.objects.filter(name="Bob's Cafe") + +If you have a ``Place`` that is also a ``Restaurant``, you can get from the +``Place`` object to the ``Restaurant`` object by using the lower-case version +of the model name:: + + >>> p = Place.objects.filter(name="Bob's Cafe") + # If Bob's Cafe is a Restaurant object, this will give the child class: + >>> p.restaurant + + +However, if ``p`` in the above example was *not* a ``Restaurant`` (it had been +created directly as a ``Place`` object or was the parent of some other class), +referring to ``p.restaurant`` would give an error. + +Normally you won't need to worry too much about how model inheritance works. +It will behave similarly to Python class inheritance. + +Inheritance and reverse relations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Because multi-table inheritance uses an implicit ``OneToOneField`` to link the +child and the parent, it's possible to move from the parent down to the child, +as in the above example. However, this uses up the name that is the default +``related_name`` value for ``ForeignKey`` and ``ManyToManyField`` relations. +If you are putting those type of relations on a subclass of another model, you +**must** specify the ``related_name`` attribute on each such field. If you +forget, Django will raise an error when you run ``manage.py validate`` or try +to syncdb. + +For example, using the above ``Place`` class again, let's create another +subclass with a ``ManyToManyField``:: + + class Supplier(Place): + # Must specify related_name on all relations. + customers = models.ManyToManyField(Restaurant, + related_name='provider') + +For more information about reverse relations, refer to the `Database API +reference`_ . For now, just remember to run ``manage.py validate`` when +you're writing your models and pay attention to the error messages. + Models across files =================== diff --git a/tests/modeltests/model_inheritance/models.py b/tests/modeltests/model_inheritance/models.py index d9956a5452..940b2ee10c 100644 --- a/tests/modeltests/model_inheritance/models.py +++ b/tests/modeltests/model_inheritance/models.py @@ -1,11 +1,35 @@ """ XX. Model inheritance -Model inheritance isn't yet supported. +Model inheritance exists in two varieties: + - abstract base classes which are a way of specifying common + information inherited by the subclasses. They don't exist as a separate + model. + - non-abstract base classes (the default), which are models in their own + right with their own database tables and everything. Their subclasses + have references back to them, created automatically. + +Both styles are demonstrated here. """ from django.db import models +class CommonInfo(models.Model): + name = models.CharField(max_length=50) + age = models.PositiveIntegerField() + + class Meta: + abstract = True + + def __unicode__(self): + return u'%s %s' % (self.__class__.__name__, self.name) + +class Worker(CommonInfo): + job = models.CharField(max_length=50) + +class Student(CommonInfo): + school_class = models.CharField(max_length=10) + class Place(models.Model): name = models.CharField(max_length=50) address = models.CharField(max_length=80) @@ -26,16 +50,44 @@ class ItalianRestaurant(Restaurant): def __unicode__(self): return u"%s the italian restaurant" % self.name -# XFAIL: Recent changes to model saving mean these now fail catastrophically. -# They'll be re-enabled when the porting is a bit further along. -not__test__ = {'API_TESTS':""" -# Make sure Restaurant has the right fields in the right order. ->>> [f.name for f in Restaurant._meta.fields] -['id', 'name', 'address', 'serves_hot_dogs', 'serves_pizza'] +class Supplier(Place): + customers = models.ManyToManyField(Restaurant, related_name='provider') -# Make sure ItalianRestaurant has the right fields in the right order. ->>> [f.name for f in ItalianRestaurant._meta.fields] -['id', 'name', 'address', 'serves_hot_dogs', 'serves_pizza', 'serves_gnocchi'] + def __unicode__(self): + return u"%s the supplier" % self.name + +class ParkingLot(Place): + main_site = models.ForeignKey(Place, related_name='lot') + + def __unicode__(self): + return u"%s the parking lot" % self.name + +__test__ = {'API_TESTS':""" +# The Student and Worker models both have 'name' and 'age' fields on them and +# inherit the __unicode__() method, just as with normal Python subclassing. +# This is useful if you want to factor out common information for programming +# purposes, but still completely independent separate models at the database +# level. + +>>> w = Worker(name='Fred', age=35, job='Quarry worker') +>>> w.save() +>>> s = Student(name='Pebbles', age=5, school_class='1B') +>>> s.save() +>>> unicode(w) +u'Worker Fred' +>>> unicode(s) +u'Student Pebbles' + +# However, the CommonInfo class cannot be used as a normal model (it doesn't +# exist as a model). +>>> CommonInfo.objects.all() +Traceback (most recent call last): + ... +AttributeError: type object 'CommonInfo' has no attribute 'objects' + +# The Place/Restaurant/ItalianRestaurant models, on the other hand, all exist +# as independent models. However, the subclasses also have transparent access +# to the fields of their ancestors. # Create a couple of Places. >>> p1 = Place(name='Master Shakes', address='666 W. Jersey') @@ -43,7 +95,7 @@ not__test__ = {'API_TESTS':""" >>> p2 = Place(name='Ace Hardware', address='1013 N. Ashland') >>> p2.save() -# Test constructor for Restaurant. +Test constructor for Restaurant. >>> r = Restaurant(name='Demon Dogs', address='944 W. Fullerton', serves_hot_dogs=True, serves_pizza=False) >>> r.save() @@ -51,5 +103,88 @@ not__test__ = {'API_TESTS':""" >>> ir = ItalianRestaurant(name='Ristorante Miron', address='1234 W. Elm', serves_hot_dogs=False, serves_pizza=False, serves_gnocchi=True) >>> ir.save() +# Make sure Restaurant and ItalianRestaurant have the right fields in the right +# order. +>>> [f.name for f in Restaurant._meta.fields] +['id', 'name', 'address', 'place_ptr', 'serves_hot_dogs', 'serves_pizza'] +>>> [f.name for f in ItalianRestaurant._meta.fields] +['id', 'name', 'address', 'place_ptr', 'serves_hot_dogs', 'serves_pizza', 'restaurant_ptr', 'serves_gnocchi'] + +# Even though p.supplier for a Place 'p' (a parent of a Supplier), a Restaurant +# object cannot access that reverse relation, since it's not part of the +# Place-Supplier Hierarchy. +>>> Place.objects.filter(supplier__name='foo') +[] +>>> Restaurant.objects.filter(supplier__name='foo') +Traceback (most recent call last): + ... +TypeError: Cannot resolve keyword 'supplier' into field. Choices are: address, id, italianrestaurant, lot, name, place_ptr, provider, serves_hot_dogs, serves_pizza + +# Parent fields can be used directly in filters on the child model. +>>> Restaurant.objects.filter(name='Demon Dogs') +[] +>>> ItalianRestaurant.objects.filter(address='1234 W. Elm') +[] + +# Filters against the parent model return objects of the parent's type. +>>> Place.objects.filter(name='Demon Dogs') +[] + +# Since the parent and child are linked by an automatically created +# OneToOneField, you can get from the parent to the child by using the child's +# name. +>>> place = Place.objects.get(name='Demon Dogs') +>>> place.restaurant + + +>>> Place.objects.get(name='Ristorante Miron').restaurant.italianrestaurant + +>>> Restaurant.objects.get(name='Ristorante Miron').italianrestaurant + + +# This won't work because the Demon Dogs restaurant is not an Italian +# restaurant. +>>> place.restaurant.italianrestaurant +Traceback (most recent call last): + ... +DoesNotExist: ItalianRestaurant matching query does not exist. + +# Related objects work just as they normally do. + +>>> s1 = Supplier(name="Joe's Chickens", address='123 Sesame St') +>>> s1.save() +>>> s1.customers = [r, ir] +>>> s2 = Supplier(name="Luigi's Pasta", address='456 Sesame St') +>>> s2.save() +>>> s2.customers = [ir] + +# This won't work because the Place we select is not a Restaurant (it's a +# Supplier). +>>> p = Place.objects.get(name="Joe's Chickens") +>>> p.restaurant +Traceback (most recent call last): + ... +DoesNotExist: Restaurant matching query does not exist. + +# But we can descend from p to the Supplier child, as expected. +>>> p.supplier + + +>>> ir.provider.order_by('-name') +[, ] + +>>> Restaurant.objects.filter(provider__name__contains="Chickens") +[, ] +>>> ItalianRestaurant.objects.filter(provider__name__contains="Chickens") +[] + +>>> park1 = ParkingLot(name='Main St', address='111 Main St', main_site=s1) +>>> park1.save() +>>> park2 = ParkingLot(name='Well Lit', address='124 Sesame St', main_site=ir) +>>> park2.save() + +>>> Restaurant.objects.get(lot__name='Well Lit') + + """}