mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #17485 -- Made defer work with select_related
This commit tackles a couple of issues. First, in certain cases there were some mixups if field.attname or field.name should be deferred. Field.attname is now always used. Another issue tackled is a case where field is both deferred by .only(), and selected by select_related. This case is now an error. A lot of thanks to koniiiik (Michal Petrucha) for the patch, and to Andrei Antoukh for review.
This commit is contained in:
		| @@ -1296,7 +1296,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, | |||||||
|         # Build the list of fields that *haven't* been requested |         # Build the list of fields that *haven't* been requested | ||||||
|         for field, model in klass._meta.get_fields_with_model(): |         for field, model in klass._meta.get_fields_with_model(): | ||||||
|             if field.name not in load_fields: |             if field.name not in load_fields: | ||||||
|                 skip.add(field.name) |                 skip.add(field.attname) | ||||||
|             elif local_only and model is not None: |             elif local_only and model is not None: | ||||||
|                 continue |                 continue | ||||||
|             else: |             else: | ||||||
| @@ -1327,7 +1327,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, | |||||||
|  |  | ||||||
|     related_fields = [] |     related_fields = [] | ||||||
|     for f in klass._meta.fields: |     for f in klass._meta.fields: | ||||||
|         if select_related_descend(f, restricted, requested): |         if select_related_descend(f, restricted, requested, load_fields): | ||||||
|             if restricted: |             if restricted: | ||||||
|                 next = requested[f.name] |                 next = requested[f.name] | ||||||
|             else: |             else: | ||||||
| @@ -1339,7 +1339,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, | |||||||
|     reverse_related_fields = [] |     reverse_related_fields = [] | ||||||
|     if restricted: |     if restricted: | ||||||
|         for o in klass._meta.get_all_related_objects(): |         for o in klass._meta.get_all_related_objects(): | ||||||
|             if o.field.unique and select_related_descend(o.field, restricted, requested, reverse=True): |             if o.field.unique and select_related_descend(o.field, restricted, requested, | ||||||
|  |                                                          only_load.get(o.model), reverse=True): | ||||||
|                 next = requested[o.field.related_query_name()] |                 next = requested[o.field.related_query_name()] | ||||||
|                 klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, |                 klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, | ||||||
|                                             requested=next, only_load=only_load, local_only=True) |                                             requested=next, only_load=only_load, local_only=True) | ||||||
|   | |||||||
| @@ -126,18 +126,19 @@ class DeferredAttribute(object): | |||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def select_related_descend(field, restricted, requested, reverse=False): | def select_related_descend(field, restricted, requested, load_fields, reverse=False): | ||||||
|     """ |     """ | ||||||
|     Returns True if this field should be used to descend deeper for |     Returns True if this field should be used to descend deeper for | ||||||
|     select_related() purposes. Used by both the query construction code |     select_related() purposes. Used by both the query construction code | ||||||
|     (sql.query.fill_related_selections()) and the model instance creation code |     (sql.query.fill_related_selections()) and the model instance creation code | ||||||
|     (query.get_cached_row()). |     (query.get_klass_info()). | ||||||
|  |  | ||||||
|     Arguments: |     Arguments: | ||||||
|      * field - the field to be checked |      * field - the field to be checked | ||||||
|      * restricted - a boolean field, indicating if the field list has been |      * restricted - a boolean field, indicating if the field list has been | ||||||
|        manually restricted using a requested clause) |        manually restricted using a requested clause) | ||||||
|      * requested - The select_related() dictionary. |      * requested - The select_related() dictionary. | ||||||
|  |      * load_fields - the set of fields to be loaded on this model | ||||||
|      * reverse - boolean, True if we are checking a reverse select related |      * reverse - boolean, True if we are checking a reverse select related | ||||||
|     """ |     """ | ||||||
|     if not field.rel: |     if not field.rel: | ||||||
| @@ -151,6 +152,14 @@ def select_related_descend(field, restricted, requested, reverse=False): | |||||||
|             return False |             return False | ||||||
|     if not restricted and field.null: |     if not restricted and field.null: | ||||||
|         return False |         return False | ||||||
|  |     if load_fields: | ||||||
|  |         if field.name not in load_fields: | ||||||
|  |             if restricted and field.name in requested: | ||||||
|  |                 raise InvalidQuery("Field %s.%s cannot be both deferred" | ||||||
|  |                                    " and traversed using select_related" | ||||||
|  |                                    " at the same time." % | ||||||
|  |                                    (field.model._meta.object_name, field.name)) | ||||||
|  |             return False | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
| # This function is needed because data descriptors must be defined on a class | # This function is needed because data descriptors must be defined on a class | ||||||
|   | |||||||
| @@ -596,6 +596,7 @@ class SQLCompiler(object): | |||||||
|         if avoid_set is None: |         if avoid_set is None: | ||||||
|             avoid_set = set() |             avoid_set = set() | ||||||
|         orig_dupe_set = dupe_set |         orig_dupe_set = dupe_set | ||||||
|  |         only_load = self.query.get_loaded_field_names() | ||||||
|  |  | ||||||
|         # Setup for the case when only particular related fields should be |         # Setup for the case when only particular related fields should be | ||||||
|         # included in the related selection. |         # included in the related selection. | ||||||
| @@ -607,7 +608,8 @@ class SQLCompiler(object): | |||||||
|                 restricted = False |                 restricted = False | ||||||
|  |  | ||||||
|         for f, model in opts.get_fields_with_model(): |         for f, model in opts.get_fields_with_model(): | ||||||
|             if not select_related_descend(f, restricted, requested): |             if not select_related_descend(f, restricted, requested, | ||||||
|  |                                           only_load.get(model or self.query.model)): | ||||||
|                 continue |                 continue | ||||||
|             # The "avoid" set is aliases we want to avoid just for this |             # The "avoid" set is aliases we want to avoid just for this | ||||||
|             # particular branch of the recursion. They aren't permanently |             # particular branch of the recursion. They aren't permanently | ||||||
| @@ -680,7 +682,8 @@ class SQLCompiler(object): | |||||||
|                 if o.field.unique |                 if o.field.unique | ||||||
|             ] |             ] | ||||||
|             for f, model in related_fields: |             for f, model in related_fields: | ||||||
|                 if not select_related_descend(f, restricted, requested, reverse=True): |                 if not select_related_descend(f, restricted, requested, | ||||||
|  |                                               only_load.get(model), reverse=True): | ||||||
|                     continue |                     continue | ||||||
|                 # The "avoid" set is aliases we want to avoid just for this |                 # The "avoid" set is aliases we want to avoid just for this | ||||||
|                 # particular branch of the recursion. They aren't permanently |                 # particular branch of the recursion. They aren't permanently | ||||||
|   | |||||||
| @@ -1845,8 +1845,14 @@ class Query(object): | |||||||
|  |  | ||||||
|         If no fields are marked for deferral, returns an empty dictionary. |         If no fields are marked for deferral, returns an empty dictionary. | ||||||
|         """ |         """ | ||||||
|  |         # We cache this because we call this function multiple times | ||||||
|  |         # (compiler.fill_related_selections, query.iterator) | ||||||
|  |         try: | ||||||
|  |             return self._loaded_field_names_cache | ||||||
|  |         except AttributeError: | ||||||
|             collection = {} |             collection = {} | ||||||
|             self.deferred_to_data(collection, self.get_loaded_field_names_cb) |             self.deferred_to_data(collection, self.get_loaded_field_names_cb) | ||||||
|  |             self._loaded_field_names_cache = collection | ||||||
|             return collection |             return collection | ||||||
|  |  | ||||||
|     def get_loaded_field_names_cb(self, target, model, fields): |     def get_loaded_field_names_cb(self, target, model, fields): | ||||||
|   | |||||||
| @@ -1081,11 +1081,13 @@ to ``defer()``:: | |||||||
|     # Load all fields immediately. |     # Load all fields immediately. | ||||||
|     my_queryset.defer(None) |     my_queryset.defer(None) | ||||||
|  |  | ||||||
|  | .. versionchanged:: 1.5 | ||||||
|  |  | ||||||
| Some fields in a model won't be deferred, even if you ask for them. You can | Some fields in a model won't be deferred, even if you ask for them. You can | ||||||
| never defer the loading of the primary key. If you are using | never defer the loading of the primary key. If you are using | ||||||
| :meth:`select_related()` to retrieve related models, you shouldn't defer the | :meth:`select_related()` to retrieve related models, you shouldn't defer the | ||||||
| loading of the field that connects from the primary model to the related one | loading of the field that connects from the primary model to the related | ||||||
| (at the moment, that doesn't raise an error, but it will eventually). | one, doing so will result in an error. | ||||||
|  |  | ||||||
| .. note:: | .. note:: | ||||||
|  |  | ||||||
| @@ -1145,9 +1147,12 @@ logically:: | |||||||
|     # existing set of fields). |     # existing set of fields). | ||||||
|     Entry.objects.defer("body").only("headline", "body") |     Entry.objects.defer("body").only("headline", "body") | ||||||
|  |  | ||||||
|  | .. versionchanged:: 1.5 | ||||||
|  |  | ||||||
| All of the cautions in the note for the :meth:`defer` documentation apply to | All of the cautions in the note for the :meth:`defer` documentation apply to | ||||||
| ``only()`` as well. Use it cautiously and only after exhausting your other | ``only()`` as well. Use it cautiously and only after exhausting your other | ||||||
| options. | options. Also note that using :meth:`only` and omitting a field requested | ||||||
|  | using :meth:`select_related` is an error as well. | ||||||
|  |  | ||||||
| using | using | ||||||
| ~~~~~ | ~~~~~ | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| from __future__ import absolute_import | from __future__ import absolute_import | ||||||
|  |  | ||||||
| from django.db.models.query_utils import DeferredAttribute | from django.db.models.query_utils import DeferredAttribute, InvalidQuery | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from .models import Secondary, Primary, Child, BigChild, ChildProxy | from .models import Secondary, Primary, Child, BigChild, ChildProxy | ||||||
| @@ -73,9 +73,19 @@ class DeferTests(TestCase): | |||||||
|         self.assert_delayed(qs.defer("name").get(pk=p1.pk), 1) |         self.assert_delayed(qs.defer("name").get(pk=p1.pk), 1) | ||||||
|         self.assert_delayed(qs.only("name").get(pk=p1.pk), 2) |         self.assert_delayed(qs.only("name").get(pk=p1.pk), 2) | ||||||
|  |  | ||||||
|         # DOES THIS WORK? |         # When we defer a field and also select_related it, the query is | ||||||
|         self.assert_delayed(qs.only("name").select_related("related")[0], 1) |         # invalid and raises an exception. | ||||||
|         self.assert_delayed(qs.defer("related").select_related("related")[0], 0) |         with self.assertRaises(InvalidQuery): | ||||||
|  |             qs.only("name").select_related("related")[0] | ||||||
|  |         with self.assertRaises(InvalidQuery): | ||||||
|  |             qs.defer("related").select_related("related")[0] | ||||||
|  |  | ||||||
|  |         # With a depth-based select_related, all deferred ForeignKeys are | ||||||
|  |         # deferred instead of traversed. | ||||||
|  |         with self.assertNumQueries(3): | ||||||
|  |             obj = qs.defer("related").select_related()[0] | ||||||
|  |             self.assert_delayed(obj, 1) | ||||||
|  |             self.assertEqual(obj.related.id, s1.pk) | ||||||
|  |  | ||||||
|         # Saving models with deferred fields is possible (but inefficient, |         # Saving models with deferred fields is possible (but inefficient, | ||||||
|         # since every field has to be retrieved first). |         # since every field has to be retrieved first). | ||||||
| @@ -155,7 +165,7 @@ class DeferTests(TestCase): | |||||||
|         children = ChildProxy.objects.all().select_related().only('id', 'name') |         children = ChildProxy.objects.all().select_related().only('id', 'name') | ||||||
|         self.assertEqual(len(children), 1) |         self.assertEqual(len(children), 1) | ||||||
|         child = children[0] |         child = children[0] | ||||||
|         self.assert_delayed(child, 1) |         self.assert_delayed(child, 2) | ||||||
|         self.assertEqual(child.name, 'p1') |         self.assertEqual(child.name, 'p1') | ||||||
|         self.assertEqual(child.value, 'xx') |         self.assertEqual(child.value, 'xx') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -47,3 +47,7 @@ class SimpleItem(models.Model): | |||||||
|  |  | ||||||
| class Feature(models.Model): | class Feature(models.Model): | ||||||
|     item = models.ForeignKey(SimpleItem) |     item = models.ForeignKey(SimpleItem) | ||||||
|  |  | ||||||
|  | class ItemAndSimpleItem(models.Model): | ||||||
|  |     item = models.ForeignKey(Item) | ||||||
|  |     simple = models.ForeignKey(SimpleItem) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from django.db.models.loading import cache | |||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from .models import (ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, | from .models import (ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, | ||||||
|     SimpleItem, Feature) |     SimpleItem, Feature, ItemAndSimpleItem) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeferRegressionTest(TestCase): | class DeferRegressionTest(TestCase): | ||||||
| @@ -109,6 +109,7 @@ class DeferRegressionTest(TestCase): | |||||||
|                 Child, |                 Child, | ||||||
|                 Feature, |                 Feature, | ||||||
|                 Item, |                 Item, | ||||||
|  |                 ItemAndSimpleItem, | ||||||
|                 Leaf, |                 Leaf, | ||||||
|                 Proxy, |                 Proxy, | ||||||
|                 RelatedItem, |                 RelatedItem, | ||||||
| @@ -125,12 +126,16 @@ class DeferRegressionTest(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |         # FIXME: This is dependent on the order in which tests are run -- | ||||||
|  |         # this test case has to be the first, otherwise a LOT more classes | ||||||
|  |         # appear. | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             klasses, [ |             klasses, [ | ||||||
|                 "Child", |                 "Child", | ||||||
|                 "Child_Deferred_value", |                 "Child_Deferred_value", | ||||||
|                 "Feature", |                 "Feature", | ||||||
|                 "Item", |                 "Item", | ||||||
|  |                 "ItemAndSimpleItem", | ||||||
|                 "Item_Deferred_name", |                 "Item_Deferred_name", | ||||||
|                 "Item_Deferred_name_other_value_text", |                 "Item_Deferred_name_other_value_text", | ||||||
|                 "Item_Deferred_name_other_value_value", |                 "Item_Deferred_name_other_value_value", | ||||||
| @@ -139,7 +144,7 @@ class DeferRegressionTest(TestCase): | |||||||
|                 "Leaf", |                 "Leaf", | ||||||
|                 "Leaf_Deferred_child_id_second_child_id_value", |                 "Leaf_Deferred_child_id_second_child_id_value", | ||||||
|                 "Leaf_Deferred_name_value", |                 "Leaf_Deferred_name_value", | ||||||
|                 "Leaf_Deferred_second_child_value", |                 "Leaf_Deferred_second_child_id_value", | ||||||
|                 "Leaf_Deferred_value", |                 "Leaf_Deferred_value", | ||||||
|                 "Proxy", |                 "Proxy", | ||||||
|                 "RelatedItem", |                 "RelatedItem", | ||||||
| @@ -175,6 +180,23 @@ class DeferRegressionTest(TestCase): | |||||||
|         self.assertEqual(1, qs.count()) |         self.assertEqual(1, qs.count()) | ||||||
|         self.assertEqual('Foobar', qs[0].name) |         self.assertEqual('Foobar', qs[0].name) | ||||||
|  |  | ||||||
|  |     def test_defer_with_select_related(self): | ||||||
|  |         item1 = Item.objects.create(name="first", value=47) | ||||||
|  |         item2 = Item.objects.create(name="second", value=42) | ||||||
|  |         simple = SimpleItem.objects.create(name="simple", value="23") | ||||||
|  |         related = ItemAndSimpleItem.objects.create(item=item1, simple=simple) | ||||||
|  |  | ||||||
|  |         obj = ItemAndSimpleItem.objects.defer('item').select_related('simple').get() | ||||||
|  |         self.assertEqual(obj.item, item1) | ||||||
|  |         self.assertEqual(obj.item_id, item1.id) | ||||||
|  |  | ||||||
|  |         obj.item = item2 | ||||||
|  |         obj.save() | ||||||
|  |  | ||||||
|  |         obj = ItemAndSimpleItem.objects.defer('item').select_related('simple').get() | ||||||
|  |         self.assertEqual(obj.item, item2) | ||||||
|  |         self.assertEqual(obj.item_id, item2.id) | ||||||
|  |  | ||||||
|     def test_deferred_class_factory(self): |     def test_deferred_class_factory(self): | ||||||
|         from django.db.models.query_utils import deferred_class_factory |         from django.db.models.query_utils import deferred_class_factory | ||||||
|         new_class = deferred_class_factory(Item, |         new_class = deferred_class_factory(Item, | ||||||
|   | |||||||
| @@ -133,7 +133,7 @@ class SelectRelatedRegressTests(TestCase): | |||||||
|         self.assertEqual(troy.state.name, 'Western Australia') |         self.assertEqual(troy.state.name, 'Western Australia') | ||||||
|  |  | ||||||
|         # Also works if you use only, rather than defer |         # Also works if you use only, rather than defer | ||||||
|         troy = SpecialClient.objects.select_related('state').only('name').get(name='Troy Buswell') |         troy = SpecialClient.objects.select_related('state').only('name', 'state').get(name='Troy Buswell') | ||||||
|  |  | ||||||
|         self.assertEqual(troy.name, 'Troy Buswell') |         self.assertEqual(troy.name, 'Troy Buswell') | ||||||
|         self.assertEqual(troy.value, 42) |         self.assertEqual(troy.value, 42) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user