mirror of
https://github.com/django/django.git
synced 2025-10-29 16:46:11 +00:00
Added a "depth" argument to select_related() to control how many "levels" of relations select_related() is willing to follow (refs #3275).
Also added unit tests for select_related(). git-svn-id: http://code.djangoproject.com/svn/django/trunk@4645 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
@@ -84,6 +84,7 @@ class QuerySet(object):
|
||||
self._filters = Q()
|
||||
self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
|
||||
self._select_related = False # Whether to fill cache for related objects.
|
||||
self._max_related_depth = 0 # Maximum "depth" for select_related
|
||||
self._distinct = False # Whether the query should use SELECT DISTINCT.
|
||||
self._select = {} # Dictionary of attname -> SQL.
|
||||
self._where = [] # List of extra WHERE clauses to use.
|
||||
@@ -186,7 +187,8 @@ class QuerySet(object):
|
||||
raise StopIteration
|
||||
for row in rows:
|
||||
if fill_cache:
|
||||
obj, index_end = get_cached_row(self.model, row, 0)
|
||||
obj, index_end = get_cached_row(klass=self.model, row=row,
|
||||
index_start=0, max_depth=self._max_related_depth)
|
||||
else:
|
||||
obj = self.model(*row[:index_end])
|
||||
for i, k in enumerate(extra_select):
|
||||
@@ -394,9 +396,9 @@ class QuerySet(object):
|
||||
else:
|
||||
return self._filter_or_exclude(None, **filter_obj)
|
||||
|
||||
def select_related(self, true_or_false=True):
|
||||
def select_related(self, true_or_false=True, depth=0):
|
||||
"Returns a new QuerySet instance with '_select_related' modified."
|
||||
return self._clone(_select_related=true_or_false)
|
||||
return self._clone(_select_related=true_or_false, _max_related_depth=depth)
|
||||
|
||||
def order_by(self, *field_names):
|
||||
"Returns a new QuerySet instance with the ordering changed."
|
||||
@@ -430,6 +432,7 @@ class QuerySet(object):
|
||||
c._filters = self._filters
|
||||
c._order_by = self._order_by
|
||||
c._select_related = self._select_related
|
||||
c._max_related_depth = self._max_related_depth
|
||||
c._distinct = self._distinct
|
||||
c._select = self._select.copy()
|
||||
c._where = self._where[:]
|
||||
@@ -483,7 +486,10 @@ class QuerySet(object):
|
||||
|
||||
# Add additional tables and WHERE clauses based on select_related.
|
||||
if self._select_related:
|
||||
fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
|
||||
fill_table_cache(opts, select, tables, where,
|
||||
old_prefix=opts.db_table,
|
||||
cache_tables_seen=[opts.db_table],
|
||||
max_depth=self._max_related_depth)
|
||||
|
||||
# Add any additional SELECTs.
|
||||
if self._select:
|
||||
@@ -728,21 +734,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
|
||||
return backend.get_fulltext_search_sql(table_prefix + field_name)
|
||||
raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
|
||||
|
||||
def get_cached_row(klass, row, index_start):
|
||||
"Helper function that recursively returns an object with cache filled"
|
||||
def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
|
||||
"""Helper function that recursively returns an object with cache filled"""
|
||||
|
||||
# If we've got a max_depth set and we've exceeded that depth, bail now.
|
||||
if max_depth and cur_depth > max_depth:
|
||||
return None
|
||||
|
||||
index_end = index_start + len(klass._meta.fields)
|
||||
obj = klass(*row[index_start:index_end])
|
||||
for f in klass._meta.fields:
|
||||
if f.rel and not f.null:
|
||||
rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
|
||||
setattr(obj, f.get_cache_name(), rel_obj)
|
||||
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
|
||||
if cached_row:
|
||||
rel_obj, index_end = cached_row
|
||||
setattr(obj, f.get_cache_name(), rel_obj)
|
||||
return obj, index_end
|
||||
|
||||
def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
|
||||
def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
|
||||
"""
|
||||
Helper function that recursively populates the select, tables and where (in
|
||||
place) for select_related queries.
|
||||
"""
|
||||
|
||||
# If we've got a max_depth set and we've exceeded that depth, bail now.
|
||||
if max_depth and cur_depth > max_depth:
|
||||
return None
|
||||
|
||||
qn = backend.quote_name
|
||||
for f in opts.fields:
|
||||
if f.rel and not f.null:
|
||||
@@ -757,7 +775,7 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen)
|
||||
where.append('%s.%s = %s.%s' % \
|
||||
(qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
|
||||
select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
|
||||
fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
|
||||
fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
|
||||
|
||||
def parse_lookup(kwarg_items, opts):
|
||||
# Helper function that handles converting API kwargs
|
||||
|
||||
Reference in New Issue
Block a user