mirror of
				https://github.com/django/django.git
				synced 2025-10-30 17:16:10 +00:00 
			
		
		
		
	Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object): | |||||||
|     # SingleRelatedObjectDescriptor instance. |     # SingleRelatedObjectDescriptor instance. | ||||||
|     def __init__(self, related): |     def __init__(self, related): | ||||||
|         self.related = related |         self.related = related | ||||||
|         self.cache_name = '_%s_cache' % related.get_accessor_name() |         self.cache_name = related.get_cache_name() | ||||||
|  |  | ||||||
|     def __get__(self, instance, instance_type=None): |     def __get__(self, instance, instance_type=None): | ||||||
|         if instance is None: |         if instance is None: | ||||||
| @@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object): | |||||||
|             # cache. This cache also might not exist if the related object |             # cache. This cache also might not exist if the related object | ||||||
|             # hasn't been accessed yet. |             # hasn't been accessed yet. | ||||||
|             if related: |             if related: | ||||||
|                 cache_name = '_%s_cache' % self.field.related.get_accessor_name() |                 cache_name = self.field.related.get_cache_name() | ||||||
|                 try: |                 try: | ||||||
|                     delattr(related, cache_name) |                     delattr(related, cache_name) | ||||||
|                 except AttributeError: |                 except AttributeError: | ||||||
|   | |||||||
| @@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | |||||||
|     """ |     """ | ||||||
|     Helper function that recursively returns an object with the specified |     Helper function that recursively returns an object with the specified | ||||||
|     related attributes already populated. |     related attributes already populated. | ||||||
|  |  | ||||||
|  |     This method may be called recursively to populate deep select_related() | ||||||
|  |     clauses. | ||||||
|  |  | ||||||
|  |     Arguments: | ||||||
|  |      * klass - the class to retrieve (and instantiate) | ||||||
|  |      * row - the row of data returned by the database cursor | ||||||
|  |      * index_start - the index of the row at which data for this | ||||||
|  |        object is known to start | ||||||
|  |      * max_depth - the maximum depth to which a select_related() | ||||||
|  |        relationship should be explored. | ||||||
|  |      * cur_depth - the current depth in the select_related() tree. | ||||||
|  |        Used in recursive calls to determin if we should dig deeper. | ||||||
|  |      * requested - A dictionary describing the select_related() tree | ||||||
|  |        that is to be retrieved. keys are field names; values are | ||||||
|  |        dictionaries describing the keys on that related object that | ||||||
|  |        are themselves to be select_related(). | ||||||
|  |      * offset - the number of additional fields that are known to | ||||||
|  |        exist in `row` for `klass`. This usually means the number of | ||||||
|  |        annotated results on `klass`. | ||||||
|  |      * only_load - if the query has had only() or defer() applied, | ||||||
|  |        this is the list of field names that will be returned. If None, | ||||||
|  |        the full field list for `klass` can be assumed. | ||||||
|     """ |     """ | ||||||
|     if max_depth and requested is None and cur_depth > max_depth: |     if max_depth and requested is None and cur_depth > max_depth: | ||||||
|         # We've recursed deeply enough; stop now. |         # We've recursed deeply enough; stop now. | ||||||
| @@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | |||||||
|         # Handle deferred fields. |         # Handle deferred fields. | ||||||
|         skip = set() |         skip = set() | ||||||
|         init_list = [] |         init_list = [] | ||||||
|         pk_val = row[index_start + klass._meta.pk_index()] |         # Build the list of fields that *haven't* been requested | ||||||
|         for field in klass._meta.fields: |         for field in klass._meta.fields: | ||||||
|             if field.name not in load_fields: |             if field.name not in load_fields: | ||||||
|                 skip.add(field.name) |                 skip.add(field.name) | ||||||
|             else: |             else: | ||||||
|                 init_list.append(field.attname) |                 init_list.append(field.attname) | ||||||
|  |         # Retrieve all the requested fields | ||||||
|         field_count = len(init_list) |         field_count = len(init_list) | ||||||
|         fields = row[index_start : index_start + field_count] |         fields = row[index_start : index_start + field_count] | ||||||
|  |         # If all the select_related columns are None, then the related | ||||||
|  |         # object must be non-existent - set the relation to None. | ||||||
|  |         # Otherwise, construct the related object. | ||||||
|         if fields == (None,) * field_count: |         if fields == (None,) * field_count: | ||||||
|             obj = None |             obj = None | ||||||
|         elif skip: |         elif skip: | ||||||
| @@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | |||||||
|         else: |         else: | ||||||
|             obj = klass(*fields) |             obj = klass(*fields) | ||||||
|     else: |     else: | ||||||
|  |         # Load all fields on klass | ||||||
|         field_count = len(klass._meta.fields) |         field_count = len(klass._meta.fields) | ||||||
|         fields = row[index_start : index_start + field_count] |         fields = row[index_start : index_start + field_count] | ||||||
|  |         # If all the select_related columns are None, then the related | ||||||
|  |         # object must be non-existent - set the relation to None. | ||||||
|  |         # Otherwise, construct the related object. | ||||||
|         if fields == (None,) * field_count: |         if fields == (None,) * field_count: | ||||||
|             obj = None |             obj = None | ||||||
|         else: |         else: | ||||||
|             obj = klass(*fields) |             obj = klass(*fields) | ||||||
|  |  | ||||||
|     index_end = index_start + field_count + offset |     index_end = index_start + field_count + offset | ||||||
|  |     # Iterate over each related object, populating any | ||||||
|  |     # select_related() fields | ||||||
|     for f in klass._meta.fields: |     for f in klass._meta.fields: | ||||||
|         if not select_related_descend(f, restricted, requested): |         if not select_related_descend(f, restricted, requested): | ||||||
|             continue |             continue | ||||||
| @@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, | |||||||
|             next = requested[f.name] |             next = requested[f.name] | ||||||
|         else: |         else: | ||||||
|             next = None |             next = None | ||||||
|  |         # Recursively retrieve the data for the related object | ||||||
|         cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, |         cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, | ||||||
|                 cur_depth+1, next) |                 cur_depth+1, next) | ||||||
|  |         # If the recursive descent found an object, populate the | ||||||
|  |         # descriptor caches relevant to the object | ||||||
|         if cached_row: |         if cached_row: | ||||||
|             rel_obj, index_end = cached_row |             rel_obj, index_end = cached_row | ||||||
|             if obj is not None: |             if obj is not None: | ||||||
|  |                 # If the base object exists, populate the | ||||||
|  |                 # descriptor cache | ||||||
|                 setattr(obj, f.get_cache_name(), rel_obj) |                 setattr(obj, f.get_cache_name(), rel_obj) | ||||||
|  |             if f.unique: | ||||||
|  |                 # If the field is unique, populate the | ||||||
|  |                 # reverse descriptor cache on the related object | ||||||
|  |                 setattr(rel_obj, f.related.get_cache_name(), obj) | ||||||
|  |  | ||||||
|  |     # Now do the same, but for reverse related objects. | ||||||
|  |     # Only handle the restricted case - i.e., don't do a depth | ||||||
|  |     # descent into reverse relations unless explicitly requested | ||||||
|  |     if restricted: | ||||||
|  |         related_fields = [ | ||||||
|  |             (o.field, o.model) | ||||||
|  |             for o in klass._meta.get_all_related_objects() | ||||||
|  |             if o.field.unique | ||||||
|  |         ] | ||||||
|  |         for f, model in related_fields: | ||||||
|  |             if not select_related_descend(f, restricted, requested, reverse=True): | ||||||
|  |                 continue | ||||||
|  |             next = requested[f.related_query_name()] | ||||||
|  |             # Recursively retrieve the data for the related object | ||||||
|  |             cached_row = get_cached_row(model, row, index_end, max_depth, | ||||||
|  |                 cur_depth+1, next) | ||||||
|  |             # If the recursive descent found an object, populate the | ||||||
|  |             # descriptor caches relevant to the object | ||||||
|  |             if cached_row: | ||||||
|  |                 rel_obj, index_end = cached_row | ||||||
|  |                 if obj is not None: | ||||||
|  |                     # If the field is unique, populate the | ||||||
|  |                     # reverse descriptor cache | ||||||
|  |                     setattr(obj, f.related.get_cache_name(), rel_obj) | ||||||
|  |                 if rel_obj is not None: | ||||||
|  |                     # If the related object exists, populate | ||||||
|  |                     # the descriptor cache. | ||||||
|  |                     setattr(rel_obj, f.get_cache_name(), obj) | ||||||
|  |  | ||||||
|     return obj, index_end |     return obj, index_end | ||||||
|  |  | ||||||
| def delete_objects(seen_objs, using): | def delete_objects(seen_objs, using): | ||||||
|   | |||||||
| @@ -197,19 +197,29 @@ class DeferredAttribute(object): | |||||||
|         """ |         """ | ||||||
|         instance.__dict__[self.field_name] = value |         instance.__dict__[self.field_name] = value | ||||||
|  |  | ||||||
| def select_related_descend(field, restricted, requested): | def select_related_descend(field, restricted, requested, reverse=False): | ||||||
|     """ |     """ | ||||||
|     Returns True if this field should be used to descend deeper for |     Returns True if this field should be used to descend deeper for | ||||||
|     select_related() purposes. Used by both the query construction code |     select_related() purposes. Used by both the query construction code | ||||||
|     (sql.query.fill_related_selections()) and the model instance creation code |     (sql.query.fill_related_selections()) and the model instance creation code | ||||||
|     (query.get_cached_row()). |     (query.get_cached_row()). | ||||||
|  |  | ||||||
|  |     Arguments: | ||||||
|  |      * field - the field to be checked | ||||||
|  |      * restricted - a boolean field, indicating if the field list has been | ||||||
|  |        manually restricted using a requested clause) | ||||||
|  |      * requested - The select_related() dictionary. | ||||||
|  |      * reverse - boolean, True if we are checking a reverse select related | ||||||
|     """ |     """ | ||||||
|     if not field.rel: |     if not field.rel: | ||||||
|         return False |         return False | ||||||
|     if field.rel.parent_link: |     if field.rel.parent_link and not reverse: | ||||||
|         return False |  | ||||||
|     if restricted and field.name not in requested: |  | ||||||
|         return False |         return False | ||||||
|  |     if restricted: | ||||||
|  |         if reverse and field.related_query_name() not in requested: | ||||||
|  |             return False | ||||||
|  |         if not reverse and field.name not in requested: | ||||||
|  |             return False | ||||||
|     if not restricted and field.null: |     if not restricted and field.null: | ||||||
|         return False |         return False | ||||||
|     return True |     return True | ||||||
|   | |||||||
| @@ -45,3 +45,6 @@ class RelatedObject(object): | |||||||
|             return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') |             return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') | ||||||
|         else: |         else: | ||||||
|             return self.field.rel.related_name or (self.opts.object_name.lower()) |             return self.field.rel.related_name or (self.opts.object_name.lower()) | ||||||
|  |  | ||||||
|  |     def get_cache_name(self): | ||||||
|  |         return "_%s_cache" % self.get_accessor_name() | ||||||
|   | |||||||
| @@ -520,7 +520,7 @@ class SQLCompiler(object): | |||||||
|  |  | ||||||
|         # Setup for the case when only particular related fields should be |         # Setup for the case when only particular related fields should be | ||||||
|         # included in the related selection. |         # included in the related selection. | ||||||
|         if requested is None and restricted is not False: |         if requested is None: | ||||||
|             if isinstance(self.query.select_related, dict): |             if isinstance(self.query.select_related, dict): | ||||||
|                 requested = self.query.select_related |                 requested = self.query.select_related | ||||||
|                 restricted = True |                 restricted = True | ||||||
| @@ -600,6 +600,72 @@ class SQLCompiler(object): | |||||||
|             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, |             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, | ||||||
|                     used, next, restricted, new_nullable, dupe_set, avoid) |                     used, next, restricted, new_nullable, dupe_set, avoid) | ||||||
|  |  | ||||||
|  |         if restricted: | ||||||
|  |             related_fields = [ | ||||||
|  |                 (o.field, o.model) | ||||||
|  |                 for o in opts.get_all_related_objects() | ||||||
|  |                 if o.field.unique | ||||||
|  |             ] | ||||||
|  |             for f, model in related_fields: | ||||||
|  |                 if not select_related_descend(f, restricted, requested, reverse=True): | ||||||
|  |                     continue | ||||||
|  |                 # The "avoid" set is aliases we want to avoid just for this | ||||||
|  |                 # particular branch of the recursion. They aren't permanently | ||||||
|  |                 # forbidden from reuse in the related selection tables (which is | ||||||
|  |                 # what "used" specifies). | ||||||
|  |                 avoid = avoid_set.copy() | ||||||
|  |                 dupe_set = orig_dupe_set.copy() | ||||||
|  |                 table = model._meta.db_table | ||||||
|  |  | ||||||
|  |                 int_opts = opts | ||||||
|  |                 alias = root_alias | ||||||
|  |                 alias_chain = [] | ||||||
|  |                 chain = opts.get_base_chain(f.rel.to) | ||||||
|  |                 if chain is not None: | ||||||
|  |                     for int_model in chain: | ||||||
|  |                         # Proxy model have elements in base chain | ||||||
|  |                         # with no parents, assign the new options | ||||||
|  |                         # object and skip to the next base in that | ||||||
|  |                         # case | ||||||
|  |                         if not int_opts.parents[int_model]: | ||||||
|  |                             int_opts = int_model._meta | ||||||
|  |                             continue | ||||||
|  |                         lhs_col = int_opts.parents[int_model].column | ||||||
|  |                         dedupe = lhs_col in opts.duplicate_targets | ||||||
|  |                         if dedupe: | ||||||
|  |                             avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), | ||||||
|  |                                 ()) | ||||||
|  |                             dupe_set.add((opts, lhs_col)) | ||||||
|  |                         int_opts = int_model._meta | ||||||
|  |                         alias = self.query.join( | ||||||
|  |                             (alias, int_opts.db_table, lhs_col, int_opts.pk.column), | ||||||
|  |                             exclusions=used, promote=True, reuse=used | ||||||
|  |                         ) | ||||||
|  |                         alias_chain.append(alias) | ||||||
|  |                         for dupe_opts, dupe_col in dupe_set: | ||||||
|  |                             self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) | ||||||
|  |                     dedupe = f.column in opts.duplicate_targets | ||||||
|  |                     if dupe_set or dedupe: | ||||||
|  |                         avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) | ||||||
|  |                         if dedupe: | ||||||
|  |                             dupe_set.add((opts, f.column)) | ||||||
|  |                 alias = self.query.join( | ||||||
|  |                     (alias, table, f.rel.get_related_field().column, f.column), | ||||||
|  |                     exclusions=used.union(avoid), | ||||||
|  |                     promote=True | ||||||
|  |                 ) | ||||||
|  |                 used.add(alias) | ||||||
|  |                 columns, aliases = self.get_default_columns(start_alias=alias, | ||||||
|  |                     opts=model._meta, as_pairs=True) | ||||||
|  |                 self.query.related_select_cols.extend(columns) | ||||||
|  |                 self.query.related_select_fields.extend(model._meta.fields) | ||||||
|  |  | ||||||
|  |                 next = requested.get(f.related_query_name(), {}) | ||||||
|  |                 new_nullable = f.null or None | ||||||
|  |  | ||||||
|  |                 self.fill_related_selections(model._meta, table, cur_depth+1, | ||||||
|  |                     used, next, restricted, new_nullable) | ||||||
|  |  | ||||||
|     def deferred_to_columns(self): |     def deferred_to_columns(self): | ||||||
|         """ |         """ | ||||||
|         Converts the self.deferred_loading data structure to mapping of table |         Converts the self.deferred_loading data structure to mapping of table | ||||||
|   | |||||||
| @@ -619,17 +619,29 @@ This is also valid:: | |||||||
|  |  | ||||||
| ...and would also pull in the ``building`` relation. | ...and would also pull in the ``building`` relation. | ||||||
|  |  | ||||||
| You can only refer to ``ForeignKey`` relations in the list of fields passed to | You can refer to any ``ForeignKey`` or ``OneToOneField`` relation in | ||||||
| ``select_related``. You *can* refer to foreign keys that have ``null=True`` | the list of fields passed to ``select_related``. Ths includes foreign | ||||||
| (unlike the default ``select_related()`` call). It's an error to use both a | keys that have ``null=True`` (unlike the default ``select_related()`` | ||||||
| list of fields and the ``depth`` parameter in the same ``select_related()`` | call). It's an error to use both a list of fields and the ``depth`` | ||||||
| call, since they are conflicting options. | parameter in the same ``select_related()`` call, since they are | ||||||
|  | conflicting options. | ||||||
|  |  | ||||||
| .. versionadded:: 1.0 | .. versionadded:: 1.0 | ||||||
|  |  | ||||||
| Both the ``depth`` argument and the ability to specify field names in the call | Both the ``depth`` argument and the ability to specify field names in the call | ||||||
| to ``select_related()`` are new in Django version 1.0. | to ``select_related()`` are new in Django version 1.0. | ||||||
|  |  | ||||||
|  | .. versionchanged:: 1.2 | ||||||
|  |  | ||||||
|  | You can also refer to the reverse direction of a ``OneToOneFields`` in | ||||||
|  | the list of fields passed to ``select_related`` -- that is, you can traverse | ||||||
|  | a ``OneToOneField`` back to the object on which the field is defined. Instead | ||||||
|  | of specifying the field name, use the ``related_name`` for the field on the | ||||||
|  | related object. | ||||||
|  |  | ||||||
|  | ``OneToOneFields`` will not be traversed in the reverse direction if you | ||||||
|  | are performing a depth-based ``select_related``. | ||||||
|  |  | ||||||
| .. _queryset-extra: | .. _queryset-extra: | ||||||
|  |  | ||||||
| ``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)`` | ``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)`` | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								tests/regressiontests/select_related_onetoone/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/regressiontests/select_related_onetoone/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | from django.db import models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class User(models.Model): | ||||||
|  |     username = models.CharField(max_length=100) | ||||||
|  |     email = models.EmailField() | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return self.username | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserProfile(models.Model): | ||||||
|  |     user = models.OneToOneField(User) | ||||||
|  |     city = models.CharField(max_length=100) | ||||||
|  |     state = models.CharField(max_length=2) | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return "%s, %s" % (self.city, self.state) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserStatResult(models.Model): | ||||||
|  |     results = models.CharField(max_length=50) | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return 'UserStatResults, results = %s' % (self.results,) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserStat(models.Model): | ||||||
|  |     user = models.OneToOneField(User, primary_key=True) | ||||||
|  |     posts = models.IntegerField() | ||||||
|  |     results = models.ForeignKey(UserStatResult) | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return 'UserStat, posts = %s' % (self.posts,) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StatDetails(models.Model): | ||||||
|  |     base_stats = models.OneToOneField(UserStat) | ||||||
|  |     comments = models.IntegerField() | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         return 'StatDetails, comments = %s' % (self.comments,) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AdvancedUserStat(UserStat): | ||||||
|  |     pass | ||||||
							
								
								
									
										83
									
								
								tests/regressiontests/select_related_onetoone/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								tests/regressiontests/select_related_onetoone/tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | from django import db | ||||||
|  | from django.conf import settings | ||||||
|  | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat | ||||||
|  |  | ||||||
|  | class ReverseSelectRelatedTestCase(TestCase): | ||||||
|  |     def setUp(self): | ||||||
|  |         # Explicitly enable debug for these tests - we need to count | ||||||
|  |         # the queries that have been issued. | ||||||
|  |         self.old_debug = settings.DEBUG | ||||||
|  |         settings.DEBUG = True | ||||||
|  |  | ||||||
|  |         user = User.objects.create(username="test") | ||||||
|  |         userprofile = UserProfile.objects.create(user=user, state="KS", | ||||||
|  |                                                  city="Lawrence") | ||||||
|  |         results = UserStatResult.objects.create(results='first results') | ||||||
|  |         userstat = UserStat.objects.create(user=user, posts=150, | ||||||
|  |                                            results=results) | ||||||
|  |         details = StatDetails.objects.create(base_stats=userstat, comments=259) | ||||||
|  |  | ||||||
|  |         user2 = User.objects.create(username="bob") | ||||||
|  |         results2 = UserStatResult.objects.create(results='moar results') | ||||||
|  |         advstat = AdvancedUserStat.objects.create(user=user2, posts=200, | ||||||
|  |                                                   results=results2) | ||||||
|  |         StatDetails.objects.create(base_stats=advstat, comments=250) | ||||||
|  |  | ||||||
|  |         db.reset_queries() | ||||||
|  |  | ||||||
|  |     def assertQueries(self, queries): | ||||||
|  |         self.assertEqual(len(db.connection.queries), queries) | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         settings.DEBUG = self.old_debug | ||||||
|  |  | ||||||
|  |     def test_basic(self): | ||||||
|  |         u = User.objects.select_related("userprofile").get(username="test") | ||||||
|  |         self.assertEqual(u.userprofile.state, "KS") | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_follow_next_level(self): | ||||||
|  |         u = User.objects.select_related("userstat__results").get(username="test") | ||||||
|  |         self.assertEqual(u.userstat.posts, 150) | ||||||
|  |         self.assertEqual(u.userstat.results.results, 'first results') | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_follow_two(self): | ||||||
|  |         u = User.objects.select_related("userprofile", "userstat").get(username="test") | ||||||
|  |         self.assertEqual(u.userprofile.state, "KS") | ||||||
|  |         self.assertEqual(u.userstat.posts, 150) | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_follow_two_next_level(self): | ||||||
|  |         u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") | ||||||
|  |         self.assertEqual(u.userstat.results.results, 'first results') | ||||||
|  |         self.assertEqual(u.userstat.statdetails.comments, 259) | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_forward_and_back(self): | ||||||
|  |         stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") | ||||||
|  |         self.assertEqual(stat.user.userprofile.state, 'KS') | ||||||
|  |         self.assertEqual(stat.user.userstat.posts, 150) | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_back_and_forward(self): | ||||||
|  |         u = User.objects.select_related("userstat").get(username="test") | ||||||
|  |         self.assertEqual(u.userstat.user.username, 'test') | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_not_followed_by_default(self): | ||||||
|  |         u = User.objects.select_related().get(username="test") | ||||||
|  |         self.assertEqual(u.userstat.posts, 150) | ||||||
|  |         self.assertQueries(2) | ||||||
|  |  | ||||||
|  |     def test_follow_from_child_class(self): | ||||||
|  |         stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) | ||||||
|  |         self.assertEqual(stat.statdetails.comments, 250) | ||||||
|  |         self.assertQueries(1) | ||||||
|  |  | ||||||
|  |     def test_follow_inheritance(self): | ||||||
|  |         stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) | ||||||
|  |         self.assertEqual(stat.advanceduserstat.posts, 200) | ||||||
|  |         self.assertQueries(1) | ||||||
		Reference in New Issue
	
	Block a user