1
0
mirror of https://github.com/django/django.git synced 2025-07-05 18:29:11 +00:00

queryset-refactor: Fixed default (no fields) case of select_related() to work with model inheritance. Refs #6761.

git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7240 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-03-14 11:45:52 +00:00
parent d91479a287
commit cf2da4689a
2 changed files with 54 additions and 25 deletions

View File

@ -598,8 +598,10 @@ def QNot(q):
def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
requested=None):
"""Helper function that recursively returns an object with cache filled"""
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
return None
@ -608,17 +610,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
index_end = index_start + len(klass._meta.fields)
obj = klass(*row[index_start:index_end])
for f in klass._meta.fields:
if f.rel and ((not restricted and not f.null) or
(restricted and f.name in requested)):
if restricted:
next = requested[f.name]
else:
next = None
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
cur_depth+1, next)
if cached_row:
rel_obj, index_end = cached_row
setattr(obj, f.get_cache_name(), rel_obj)
if (not f.rel or (not restricted and f.null) or
(restricted and f.name not in requested) or f.rel.parent_link):
continue
if restricted:
next = requested[f.name]
else:
next = None
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
cur_depth+1, next)
if cached_row:
rel_obj, index_end = cached_row
setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
def delete_objects(seen_objs):

View File

@ -44,7 +44,7 @@ class Query(object):
self.connection = connection
self.alias_map = {} # Maps alias to table name
self.table_map = {} # Maps table names to list of aliases.
self.rev_join_map = {} # Reverse of join_map.
self.rev_join_map = {} # Reverse of join_map. (FIXME: Update comment)
self.quote_cache = {}
self.default_cols = True
self.default_ordering = True
@ -63,7 +63,7 @@ class Query(object):
self.distinct = False
self.select_related = False
# Arbitrary maximum limit for select_related to prevent infinite
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
self.max_depth = 5
@ -361,14 +361,15 @@ class Query(object):
if hasattr(col, 'alias'):
aliases.append(col.alias)
elif self.default_cols:
table_alias = self.tables[0]
root_pk = self.model._meta.pk.column
seen = {None: table_alias}
for field, model in self.model._meta.get_fields_with_model():
if model not in seen:
seen[model] = self.join((table_alias, model._meta.db_table,
root_pk, model._meta.pk.column))
result.append('%s.%s' % (qn(seen[model]), qn(field.column)))
#table_alias = self.tables[0]
#root_pk = self.model._meta.pk.column
#seen = {None: table_alias}
#for field, model in self.model._meta.get_fields_with_model():
# if model not in seen:
# seen[model] = self.join((table_alias, model._meta.db_table,
# root_pk, model._meta.pk.column))
# result.append('%s.%s' % (qn(seen[model]), qn(field.column)))
result = self.get_default_columns(lambda x, y: "%s.%s" % (qn(x), qn(y)))
aliases = result[:]
result.extend(['(%s) AS %s' % (col, alias)
@ -378,6 +379,31 @@ class Query(object):
self._select_aliases = dict.fromkeys(aliases)
return result
def get_default_columns(self, combine_func=None):
"""
Computes the default columns for selecting every field in the base
model. Returns a list of default (alias, column) pairs suitable for
direct inclusion as the select columns. The 'combine_func' can be
passed in to change the returned data set to a list of some other
structure.
"""
# Note: We allow 'combine_func' here because this method is called a
# lot. The extra overhead from returning a list and then transforming
# it in get_columns() hurt performance in a measurable way.
result = []
table_alias = self.tables[0]
root_pk = self.model._meta.pk.column
seen = {None: table_alias}
for field, model in self.model._meta.get_fields_with_model():
if model not in seen:
seen[model] = self.join((table_alias, model._meta.db_table,
root_pk, model._meta.pk.column))
if combine_func:
result.append(combine_func(seen[model], field.column))
else:
result.append((seen[model], field.column))
return result
def get_from_clause(self):
"""
Returns a list of strings that are joined together to go after the
@ -744,7 +770,7 @@ class Query(object):
if not opts:
opts = self.get_meta()
root_alias = self.tables[0]
self.select.extend([(root_alias, f.column) for f in opts.fields])
self.select.extend(self.get_default_columns())
if not used:
used = []
@ -759,7 +785,7 @@ class Query(object):
for f in opts.fields:
if (not f.rel or (restricted and f.name not in requested) or
(not restricted and f.null)):
(not restricted and f.null) or f.rel.parent_link):
continue
table = f.rel.to._meta.db_table
alias = self.join((root_alias, table, f.column,