From cf2da4689a569c9b436e2f4a1c2c0b2295c821c6 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Fri, 14 Mar 2008 11:45:52 +0000 Subject: [PATCH] queryset-refactor: Fixed default (no fields) case of select_related() to work with model inheritance. Refs #6761. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7240 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 29 +++++++++++--------- django/db/models/sql/query.py | 50 ++++++++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 09badef5aa..2e62b85602 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -598,8 +598,10 @@ def QNot(q): def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, requested=None): - """Helper function that recursively returns an object with cache filled""" - + """ + Helper function that recursively returns an object with the specified + related attributes already populated. + """ if max_depth and requested is None and cur_depth > max_depth: # We've recursed deeply enough; stop now. return None @@ -608,17 +610,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, index_end = index_start + len(klass._meta.fields) obj = klass(*row[index_start:index_end]) for f in klass._meta.fields: - if f.rel and ((not restricted and not f.null) or - (restricted and f.name in requested)): - if restricted: - next = requested[f.name] - else: - next = None - cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, - cur_depth+1, next) - if cached_row: - rel_obj, index_end = cached_row - setattr(obj, f.get_cache_name(), rel_obj) + if (not f.rel or (not restricted and f.null) or + (restricted and f.name not in requested) or f.rel.parent_link): + continue + if restricted: + next = requested[f.name] + else: + next = None + cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, + cur_depth+1, next) + if cached_row: + rel_obj, index_end = cached_row + setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end def delete_objects(seen_objs): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dd6365282b..77f3519516 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -44,7 +44,7 @@ class Query(object): self.connection = connection self.alias_map = {} # Maps alias to table name self.table_map = {} # Maps table names to list of aliases. - self.rev_join_map = {} # Reverse of join_map. + self.rev_join_map = {} # Reverse of join_map. (FIXME: Update comment) self.quote_cache = {} self.default_cols = True self.default_ordering = True @@ -63,7 +63,7 @@ class Query(object): self.distinct = False self.select_related = False - # Arbitrary maximum limit for select_related to prevent infinite + # Arbitrary maximum limit for select_related. Prevents infinite # recursion. Can be changed by the depth parameter to select_related(). self.max_depth = 5 @@ -361,14 +361,15 @@ class Query(object): if hasattr(col, 'alias'): aliases.append(col.alias) elif self.default_cols: - table_alias = self.tables[0] - root_pk = self.model._meta.pk.column - seen = {None: table_alias} - for field, model in self.model._meta.get_fields_with_model(): - if model not in seen: - seen[model] = self.join((table_alias, model._meta.db_table, - root_pk, model._meta.pk.column)) - result.append('%s.%s' % (qn(seen[model]), qn(field.column))) + #table_alias = self.tables[0] + #root_pk = self.model._meta.pk.column + #seen = {None: table_alias} + #for field, model in self.model._meta.get_fields_with_model(): + # if model not in seen: + # seen[model] = self.join((table_alias, model._meta.db_table, + # root_pk, model._meta.pk.column)) + # result.append('%s.%s' % (qn(seen[model]), qn(field.column))) + result = self.get_default_columns(lambda x, y: "%s.%s" % (qn(x), qn(y))) aliases = result[:] result.extend(['(%s) AS %s' % (col, alias) @@ -378,6 +379,31 @@ class Query(object): self._select_aliases = dict.fromkeys(aliases) return result + def get_default_columns(self, combine_func=None): + """ + Computes the default columns for selecting every field in the base + model. Returns a list of default (alias, column) pairs suitable for + direct inclusion as the select columns. The 'combine_func' can be + passed in to change the returned data set to a list of some other + structure. + """ + # Note: We allow 'combine_func' here because this method is called a + # lot. The extra overhead from returning a list and then transforming + # it in get_columns() hurt performance in a measurable way. + result = [] + table_alias = self.tables[0] + root_pk = self.model._meta.pk.column + seen = {None: table_alias} + for field, model in self.model._meta.get_fields_with_model(): + if model not in seen: + seen[model] = self.join((table_alias, model._meta.db_table, + root_pk, model._meta.pk.column)) + if combine_func: + result.append(combine_func(seen[model], field.column)) + else: + result.append((seen[model], field.column)) + return result + def get_from_clause(self): """ Returns a list of strings that are joined together to go after the @@ -744,7 +770,7 @@ class Query(object): if not opts: opts = self.get_meta() root_alias = self.tables[0] - self.select.extend([(root_alias, f.column) for f in opts.fields]) + self.select.extend(self.get_default_columns()) if not used: used = [] @@ -759,7 +785,7 @@ class Query(object): for f in opts.fields: if (not f.rel or (restricted and f.name not in requested) or - (not restricted and f.null)): + (not restricted and f.null) or f.rel.parent_link): continue table = f.rel.to._meta.db_table alias = self.join((root_alias, table, f.column,