mirror of
				https://github.com/django/django.git
				synced 2025-10-30 09:06:13 +00:00 
			
		
		
		
	Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object): | ||||
|     # SingleRelatedObjectDescriptor instance. | ||||
|     def __init__(self, related): | ||||
|         self.related = related | ||||
|         self.cache_name = '_%s_cache' % related.get_accessor_name() | ||||
|         self.cache_name = related.get_cache_name() | ||||
|  | ||||
|     def __get__(self, instance, instance_type=None): | ||||
|         if instance is None: | ||||
| @@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object): | ||||
|             # cache. This cache also might not exist if the related object | ||||
|             # hasn't been accessed yet. | ||||
|             if related: | ||||
|                 cache_name = '_%s_cache' % self.field.related.get_accessor_name() | ||||
|                 cache_name = self.field.related.get_cache_name() | ||||
|                 try: | ||||
|                     delattr(related, cache_name) | ||||
|                 except AttributeError: | ||||
|   | ||||
| @@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | ||||
|     """ | ||||
|     Helper function that recursively returns an object with the specified | ||||
|     related attributes already populated. | ||||
|  | ||||
|     This method may be called recursively to populate deep select_related() | ||||
|     clauses. | ||||
|  | ||||
|     Arguments: | ||||
|      * klass - the class to retrieve (and instantiate) | ||||
|      * row - the row of data returned by the database cursor | ||||
|      * index_start - the index of the row at which data for this | ||||
|        object is known to start | ||||
|      * max_depth - the maximum depth to which a select_related() | ||||
|        relationship should be explored. | ||||
|      * cur_depth - the current depth in the select_related() tree. | ||||
|        Used in recursive calls to determin if we should dig deeper. | ||||
|      * requested - A dictionary describing the select_related() tree | ||||
|        that is to be retrieved. keys are field names; values are | ||||
|        dictionaries describing the keys on that related object that | ||||
|        are themselves to be select_related(). | ||||
|      * offset - the number of additional fields that are known to | ||||
|        exist in `row` for `klass`. This usually means the number of | ||||
|        annotated results on `klass`. | ||||
|      * only_load - if the query has had only() or defer() applied, | ||||
|        this is the list of field names that will be returned. If None, | ||||
|        the full field list for `klass` can be assumed. | ||||
|     """ | ||||
|     if max_depth and requested is None and cur_depth > max_depth: | ||||
|         # We've recursed deeply enough; stop now. | ||||
| @@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | ||||
|         # Handle deferred fields. | ||||
|         skip = set() | ||||
|         init_list = [] | ||||
|         pk_val = row[index_start + klass._meta.pk_index()] | ||||
|         # Build the list of fields that *haven't* been requested | ||||
|         for field in klass._meta.fields: | ||||
|             if field.name not in load_fields: | ||||
|                 skip.add(field.name) | ||||
|             else: | ||||
|                 init_list.append(field.attname) | ||||
|         # Retrieve all the requested fields | ||||
|         field_count = len(init_list) | ||||
|         fields = row[index_start : index_start + field_count] | ||||
|         # If all the select_related columns are None, then the related | ||||
|         # object must be non-existent - set the relation to None. | ||||
|         # Otherwise, construct the related object. | ||||
|         if fields == (None,) * field_count: | ||||
|             obj = None | ||||
|         elif skip: | ||||
| @@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | ||||
|         else: | ||||
|             obj = klass(*fields) | ||||
|     else: | ||||
|         # Load all fields on klass | ||||
|         field_count = len(klass._meta.fields) | ||||
|         fields = row[index_start : index_start + field_count] | ||||
|         # If all the select_related columns are None, then the related | ||||
|         # object must be non-existent - set the relation to None. | ||||
|         # Otherwise, construct the related object. | ||||
|         if fields == (None,) * field_count: | ||||
|             obj = None | ||||
|         else: | ||||
|             obj = klass(*fields) | ||||
|  | ||||
|     index_end = index_start + field_count + offset | ||||
|     # Iterate over each related object, populating any | ||||
|     # select_related() fields | ||||
|     for f in klass._meta.fields: | ||||
|         if not select_related_descend(f, restricted, requested): | ||||
|             continue | ||||
| @@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | ||||
|             next = requested[f.name] | ||||
|         else: | ||||
|             next = None | ||||
|         # Recursively retrieve the data for the related object | ||||
|         cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, | ||||
|                 cur_depth+1, next) | ||||
|         # If the recursive descent found an object, populate the | ||||
|         # descriptor caches relevant to the object | ||||
|         if cached_row: | ||||
|             rel_obj, index_end = cached_row | ||||
|             if obj is not None: | ||||
|                 # If the base object exists, populate the | ||||
|                 # descriptor cache | ||||
|                 setattr(obj, f.get_cache_name(), rel_obj) | ||||
|             if f.unique: | ||||
|                 # If the field is unique, populate the | ||||
|                 # reverse descriptor cache on the related object | ||||
|                 setattr(rel_obj, f.related.get_cache_name(), obj) | ||||
|  | ||||
|     # Now do the same, but for reverse related objects. | ||||
|     # Only handle the restricted case - i.e., don't do a depth | ||||
|     # descent into reverse relations unless explicitly requested | ||||
|     if restricted: | ||||
|         related_fields = [ | ||||
|             (o.field, o.model) | ||||
|             for o in klass._meta.get_all_related_objects() | ||||
|             if o.field.unique | ||||
|         ] | ||||
|         for f, model in related_fields: | ||||
|             if not select_related_descend(f, restricted, requested, reverse=True): | ||||
|                 continue | ||||
|             next = requested[f.related_query_name()] | ||||
|             # Recursively retrieve the data for the related object | ||||
|             cached_row = get_cached_row(model, row, index_end, max_depth, | ||||
|                 cur_depth+1, next) | ||||
|             # If the recursive descent found an object, populate the | ||||
|             # descriptor caches relevant to the object | ||||
|             if cached_row: | ||||
|                 rel_obj, index_end = cached_row | ||||
|                 if obj is not None: | ||||
|                     # If the field is unique, populate the | ||||
|                     # reverse descriptor cache | ||||
|                     setattr(obj, f.related.get_cache_name(), rel_obj) | ||||
|                 if rel_obj is not None: | ||||
|                     # If the related object exists, populate | ||||
|                     # the descriptor cache. | ||||
|                     setattr(rel_obj, f.get_cache_name(), obj) | ||||
|  | ||||
|     return obj, index_end | ||||
|  | ||||
| def delete_objects(seen_objs, using): | ||||
|   | ||||
| @@ -197,18 +197,28 @@ class DeferredAttribute(object): | ||||
|         """ | ||||
|         instance.__dict__[self.field_name] = value | ||||
|  | ||||
| def select_related_descend(field, restricted, requested): | ||||
| def select_related_descend(field, restricted, requested, reverse=False): | ||||
|     """ | ||||
|     Returns True if this field should be used to descend deeper for | ||||
|     select_related() purposes. Used by both the query construction code | ||||
|     (sql.query.fill_related_selections()) and the model instance creation code | ||||
|     (query.get_cached_row()). | ||||
|  | ||||
|     Arguments: | ||||
|      * field - the field to be checked | ||||
|      * restricted - a boolean field, indicating if the field list has been | ||||
|        manually restricted using a requested clause) | ||||
|      * requested - The select_related() dictionary. | ||||
|      * reverse - boolean, True if we are checking a reverse select related | ||||
|     """ | ||||
|     if not field.rel: | ||||
|         return False | ||||
|     if field.rel.parent_link: | ||||
|     if field.rel.parent_link and not reverse: | ||||
|         return False | ||||
|     if restricted and field.name not in requested: | ||||
|     if restricted: | ||||
|         if reverse and field.related_query_name() not in requested: | ||||
|             return False | ||||
|         if not reverse and field.name not in requested: | ||||
|             return False | ||||
|     if not restricted and field.null: | ||||
|         return False | ||||
|   | ||||
| @@ -45,3 +45,6 @@ class RelatedObject(object): | ||||
|             return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') | ||||
|         else: | ||||
|             return self.field.rel.related_name or (self.opts.object_name.lower()) | ||||
|  | ||||
|     def get_cache_name(self): | ||||
|         return "_%s_cache" % self.get_accessor_name() | ||||
|   | ||||
| @@ -520,7 +520,7 @@ class SQLCompiler(object): | ||||
|  | ||||
|         # Setup for the case when only particular related fields should be | ||||
|         # included in the related selection. | ||||
|         if requested is None and restricted is not False: | ||||
|         if requested is None: | ||||
|             if isinstance(self.query.select_related, dict): | ||||
|                 requested = self.query.select_related | ||||
|                 restricted = True | ||||
| @@ -600,6 +600,72 @@ class SQLCompiler(object): | ||||
|             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, | ||||
|                     used, next, restricted, new_nullable, dupe_set, avoid) | ||||
|  | ||||
|         if restricted: | ||||
|             related_fields = [ | ||||
|                 (o.field, o.model) | ||||
|                 for o in opts.get_all_related_objects() | ||||
|                 if o.field.unique | ||||
|             ] | ||||
|             for f, model in related_fields: | ||||
|                 if not select_related_descend(f, restricted, requested, reverse=True): | ||||
|                     continue | ||||
|                 # The "avoid" set is aliases we want to avoid just for this | ||||
|                 # particular branch of the recursion. They aren't permanently | ||||
|                 # forbidden from reuse in the related selection tables (which is | ||||
|                 # what "used" specifies). | ||||
|                 avoid = avoid_set.copy() | ||||
|                 dupe_set = orig_dupe_set.copy() | ||||
|                 table = model._meta.db_table | ||||
|  | ||||
|                 int_opts = opts | ||||
|                 alias = root_alias | ||||
|                 alias_chain = [] | ||||
|                 chain = opts.get_base_chain(f.rel.to) | ||||
|                 if chain is not None: | ||||
|                     for int_model in chain: | ||||
|                         # Proxy model have elements in base chain | ||||
|                         # with no parents, assign the new options | ||||
|                         # object and skip to the next base in that | ||||
|                         # case | ||||
|                         if not int_opts.parents[int_model]: | ||||
|                             int_opts = int_model._meta | ||||
|                             continue | ||||
|                         lhs_col = int_opts.parents[int_model].column | ||||
|                         dedupe = lhs_col in opts.duplicate_targets | ||||
|                         if dedupe: | ||||
|                             avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), | ||||
|                                 ()) | ||||
|                             dupe_set.add((opts, lhs_col)) | ||||
|                         int_opts = int_model._meta | ||||
|                         alias = self.query.join( | ||||
|                             (alias, int_opts.db_table, lhs_col, int_opts.pk.column), | ||||
|                             exclusions=used, promote=True, reuse=used | ||||
|                         ) | ||||
|                         alias_chain.append(alias) | ||||
|                         for dupe_opts, dupe_col in dupe_set: | ||||
|                             self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||
|                     dedupe = f.column in opts.duplicate_targets | ||||
|                     if dupe_set or dedupe: | ||||
|                         avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) | ||||
|                         if dedupe: | ||||
|                             dupe_set.add((opts, f.column)) | ||||
|                 alias = self.query.join( | ||||
|                     (alias, table, f.rel.get_related_field().column, f.column), | ||||
|                     exclusions=used.union(avoid), | ||||
|                     promote=True | ||||
|                 ) | ||||
|                 used.add(alias) | ||||
|                 columns, aliases = self.get_default_columns(start_alias=alias, | ||||
|                     opts=model._meta, as_pairs=True) | ||||
|                 self.query.related_select_cols.extend(columns) | ||||
|                 self.query.related_select_fields.extend(model._meta.fields) | ||||
|  | ||||
|                 next = requested.get(f.related_query_name(), {}) | ||||
|                 new_nullable = f.null or None | ||||
|  | ||||
|                 self.fill_related_selections(model._meta, table, cur_depth+1, | ||||
|                     used, next, restricted, new_nullable) | ||||
|  | ||||
|     def deferred_to_columns(self): | ||||
|         """ | ||||
|         Converts the self.deferred_loading data structure to mapping of table | ||||
|   | ||||
| @@ -619,17 +619,29 @@ This is also valid:: | ||||
|  | ||||
| ...and would also pull in the ``building`` relation. | ||||
|  | ||||
| You can only refer to ``ForeignKey`` relations in the list of fields passed to | ||||
| ``select_related``. You *can* refer to foreign keys that have ``null=True`` | ||||
| (unlike the default ``select_related()`` call). It's an error to use both a | ||||
| list of fields and the ``depth`` parameter in the same ``select_related()`` | ||||
| call, since they are conflicting options. | ||||
| You can refer to any ``ForeignKey`` or ``OneToOneField`` relation in | ||||
| the list of fields passed to ``select_related``. Ths includes foreign | ||||
| keys that have ``null=True`` (unlike the default ``select_related()`` | ||||
| call). It's an error to use both a list of fields and the ``depth`` | ||||
| parameter in the same ``select_related()`` call, since they are | ||||
| conflicting options. | ||||
|  | ||||
| .. versionadded:: 1.0 | ||||
|  | ||||
| Both the ``depth`` argument and the ability to specify field names in the call | ||||
| to ``select_related()`` are new in Django version 1.0. | ||||
|  | ||||
| .. versionchanged:: 1.2 | ||||
|  | ||||
| You can also refer to the reverse direction of a ``OneToOneFields`` in | ||||
| the list of fields passed to ``select_related`` -- that is, you can traverse | ||||
| a ``OneToOneField`` back to the object on which the field is defined. Instead | ||||
| of specifying the field name, use the ``related_name`` for the field on the | ||||
| related object. | ||||
|  | ||||
| ``OneToOneFields`` will not be traversed in the reverse direction if you | ||||
| are performing a depth-based ``select_related``. | ||||
|  | ||||
| .. _queryset-extra: | ||||
|  | ||||
| ``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)`` | ||||
|   | ||||
							
								
								
									
										46
									
								
								tests/regressiontests/select_related_onetoone/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/regressiontests/select_related_onetoone/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| from django.db import models | ||||
|  | ||||
|  | ||||
| class User(models.Model): | ||||
|     username = models.CharField(max_length=100) | ||||
|     email = models.EmailField() | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return self.username | ||||
|  | ||||
|  | ||||
| class UserProfile(models.Model): | ||||
|     user = models.OneToOneField(User) | ||||
|     city = models.CharField(max_length=100) | ||||
|     state = models.CharField(max_length=2) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return "%s, %s" % (self.city, self.state) | ||||
|  | ||||
|  | ||||
| class UserStatResult(models.Model): | ||||
|     results = models.CharField(max_length=50) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return 'UserStatResults, results = %s' % (self.results,) | ||||
|  | ||||
|  | ||||
| class UserStat(models.Model): | ||||
|     user = models.OneToOneField(User, primary_key=True) | ||||
|     posts = models.IntegerField() | ||||
|     results = models.ForeignKey(UserStatResult) | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return 'UserStat, posts = %s' % (self.posts,) | ||||
|  | ||||
|  | ||||
| class StatDetails(models.Model): | ||||
|     base_stats = models.OneToOneField(UserStat) | ||||
|     comments = models.IntegerField() | ||||
|  | ||||
|     def __unicode__(self): | ||||
|         return 'StatDetails, comments = %s' % (self.comments,) | ||||
|  | ||||
|  | ||||
| class AdvancedUserStat(UserStat): | ||||
|     pass | ||||
							
								
								
									
										83
									
								
								tests/regressiontests/select_related_onetoone/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								tests/regressiontests/select_related_onetoone/tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| from django import db | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
|  | ||||
| from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat | ||||
|  | ||||
| class ReverseSelectRelatedTestCase(TestCase): | ||||
|     def setUp(self): | ||||
|         # Explicitly enable debug for these tests - we need to count | ||||
|         # the queries that have been issued. | ||||
|         self.old_debug = settings.DEBUG | ||||
|         settings.DEBUG = True | ||||
|  | ||||
|         user = User.objects.create(username="test") | ||||
|         userprofile = UserProfile.objects.create(user=user, state="KS", | ||||
|                                                  city="Lawrence") | ||||
|         results = UserStatResult.objects.create(results='first results') | ||||
|         userstat = UserStat.objects.create(user=user, posts=150, | ||||
|                                            results=results) | ||||
|         details = StatDetails.objects.create(base_stats=userstat, comments=259) | ||||
|  | ||||
|         user2 = User.objects.create(username="bob") | ||||
|         results2 = UserStatResult.objects.create(results='moar results') | ||||
|         advstat = AdvancedUserStat.objects.create(user=user2, posts=200, | ||||
|                                                   results=results2) | ||||
|         StatDetails.objects.create(base_stats=advstat, comments=250) | ||||
|  | ||||
|         db.reset_queries() | ||||
|  | ||||
|     def assertQueries(self, queries): | ||||
|         self.assertEqual(len(db.connection.queries), queries) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         settings.DEBUG = self.old_debug | ||||
|  | ||||
|     def test_basic(self): | ||||
|         u = User.objects.select_related("userprofile").get(username="test") | ||||
|         self.assertEqual(u.userprofile.state, "KS") | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_follow_next_level(self): | ||||
|         u = User.objects.select_related("userstat__results").get(username="test") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertEqual(u.userstat.results.results, 'first results') | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_follow_two(self): | ||||
|         u = User.objects.select_related("userprofile", "userstat").get(username="test") | ||||
|         self.assertEqual(u.userprofile.state, "KS") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_follow_two_next_level(self): | ||||
|         u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") | ||||
|         self.assertEqual(u.userstat.results.results, 'first results') | ||||
|         self.assertEqual(u.userstat.statdetails.comments, 259) | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_forward_and_back(self): | ||||
|         stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") | ||||
|         self.assertEqual(stat.user.userprofile.state, 'KS') | ||||
|         self.assertEqual(stat.user.userstat.posts, 150) | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_back_and_forward(self): | ||||
|         u = User.objects.select_related("userstat").get(username="test") | ||||
|         self.assertEqual(u.userstat.user.username, 'test') | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_not_followed_by_default(self): | ||||
|         u = User.objects.select_related().get(username="test") | ||||
|         self.assertEqual(u.userstat.posts, 150) | ||||
|         self.assertQueries(2) | ||||
|  | ||||
|     def test_follow_from_child_class(self): | ||||
|         stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) | ||||
|         self.assertEqual(stat.statdetails.comments, 250) | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_follow_inheritance(self): | ||||
|         stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) | ||||
|         self.assertEqual(stat.advanceduserstat.posts, 200) | ||||
|         self.assertQueries(1) | ||||
		Reference in New Issue
	
	Block a user