1
0
mirror of https://github.com/django/django.git synced 2025-07-06 02:39:12 +00:00

queryset-refactor: Fixed default (no fields) case of select_related() to work with model inheritance. Refs #6761.

git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7240 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-03-14 11:45:52 +00:00
parent d91479a287
commit cf2da4689a
2 changed files with 54 additions and 25 deletions

View File

@ -598,8 +598,10 @@ def QNot(q):
def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
requested=None): requested=None):
"""Helper function that recursively returns an object with cache filled""" """
Helper function that recursively returns an object with the specified
related attributes already populated.
"""
if max_depth and requested is None and cur_depth > max_depth: if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now. # We've recursed deeply enough; stop now.
return None return None
@ -608,17 +610,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
index_end = index_start + len(klass._meta.fields) index_end = index_start + len(klass._meta.fields)
obj = klass(*row[index_start:index_end]) obj = klass(*row[index_start:index_end])
for f in klass._meta.fields: for f in klass._meta.fields:
if f.rel and ((not restricted and not f.null) or if (not f.rel or (not restricted and f.null) or
(restricted and f.name in requested)): (restricted and f.name not in requested) or f.rel.parent_link):
if restricted: continue
next = requested[f.name] if restricted:
else: next = requested[f.name]
next = None else:
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, next = None
cur_depth+1, next) cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
if cached_row: cur_depth+1, next)
rel_obj, index_end = cached_row if cached_row:
setattr(obj, f.get_cache_name(), rel_obj) rel_obj, index_end = cached_row
setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end return obj, index_end
def delete_objects(seen_objs): def delete_objects(seen_objs):

View File

@ -44,7 +44,7 @@ class Query(object):
self.connection = connection self.connection = connection
self.alias_map = {} # Maps alias to table name self.alias_map = {} # Maps alias to table name
self.table_map = {} # Maps table names to list of aliases. self.table_map = {} # Maps table names to list of aliases.
self.rev_join_map = {} # Reverse of join_map. self.rev_join_map = {} # Reverse of join_map. (FIXME: Update comment)
self.quote_cache = {} self.quote_cache = {}
self.default_cols = True self.default_cols = True
self.default_ordering = True self.default_ordering = True
@ -63,7 +63,7 @@ class Query(object):
self.distinct = False self.distinct = False
self.select_related = False self.select_related = False
# Arbitrary maximum limit for select_related to prevent infinite # Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related(). # recursion. Can be changed by the depth parameter to select_related().
self.max_depth = 5 self.max_depth = 5
@ -361,14 +361,15 @@ class Query(object):
if hasattr(col, 'alias'): if hasattr(col, 'alias'):
aliases.append(col.alias) aliases.append(col.alias)
elif self.default_cols: elif self.default_cols:
table_alias = self.tables[0] #table_alias = self.tables[0]
root_pk = self.model._meta.pk.column #root_pk = self.model._meta.pk.column
seen = {None: table_alias} #seen = {None: table_alias}
for field, model in self.model._meta.get_fields_with_model(): #for field, model in self.model._meta.get_fields_with_model():
if model not in seen: # if model not in seen:
seen[model] = self.join((table_alias, model._meta.db_table, # seen[model] = self.join((table_alias, model._meta.db_table,
root_pk, model._meta.pk.column)) # root_pk, model._meta.pk.column))
result.append('%s.%s' % (qn(seen[model]), qn(field.column))) # result.append('%s.%s' % (qn(seen[model]), qn(field.column)))
result = self.get_default_columns(lambda x, y: "%s.%s" % (qn(x), qn(y)))
aliases = result[:] aliases = result[:]
result.extend(['(%s) AS %s' % (col, alias) result.extend(['(%s) AS %s' % (col, alias)
@ -378,6 +379,31 @@ class Query(object):
self._select_aliases = dict.fromkeys(aliases) self._select_aliases = dict.fromkeys(aliases)
return result return result
def get_default_columns(self, combine_func=None):
"""
Computes the default columns for selecting every field in the base
model. Returns a list of default (alias, column) pairs suitable for
direct inclusion as the select columns. The 'combine_func' can be
passed in to change the returned data set to a list of some other
structure.
"""
# Note: We allow 'combine_func' here because this method is called a
# lot. The extra overhead from returning a list and then transforming
# it in get_columns() hurt performance in a measurable way.
result = []
table_alias = self.tables[0]
root_pk = self.model._meta.pk.column
seen = {None: table_alias}
for field, model in self.model._meta.get_fields_with_model():
if model not in seen:
seen[model] = self.join((table_alias, model._meta.db_table,
root_pk, model._meta.pk.column))
if combine_func:
result.append(combine_func(seen[model], field.column))
else:
result.append((seen[model], field.column))
return result
def get_from_clause(self): def get_from_clause(self):
""" """
Returns a list of strings that are joined together to go after the Returns a list of strings that are joined together to go after the
@ -744,7 +770,7 @@ class Query(object):
if not opts: if not opts:
opts = self.get_meta() opts = self.get_meta()
root_alias = self.tables[0] root_alias = self.tables[0]
self.select.extend([(root_alias, f.column) for f in opts.fields]) self.select.extend(self.get_default_columns())
if not used: if not used:
used = [] used = []
@ -759,7 +785,7 @@ class Query(object):
for f in opts.fields: for f in opts.fields:
if (not f.rel or (restricted and f.name not in requested) or if (not f.rel or (restricted and f.name not in requested) or
(not restricted and f.null)): (not restricted and f.null) or f.rel.parent_link):
continue continue
table = f.rel.to._meta.db_table table = f.rel.to._meta.db_table
alias = self.join((root_alias, table, f.column, alias = self.join((root_alias, table, f.column,