mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.
Thanks Anssi Kääriäinen for contributing to the patch.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							3f9d85d95c
						
					
				
				
					commit
					01d440fa1e
				
			| @@ -348,7 +348,7 @@ class GenericRelation(ForeignObject): | |||||||
|         self.to_fields = [self.model._meta.pk.name] |         self.to_fields = [self.model._meta.pk.name] | ||||||
|         return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] |         return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] | ||||||
|  |  | ||||||
|     def _get_path_info_with_parent(self): |     def _get_path_info_with_parent(self, filtered_relation): | ||||||
|         """ |         """ | ||||||
|         Return the path that joins the current model through any parent models. |         Return the path that joins the current model through any parent models. | ||||||
|         The idea is that if you have a GFK defined on a parent model then we |         The idea is that if you have a GFK defined on a parent model then we | ||||||
| @@ -365,7 +365,15 @@ class GenericRelation(ForeignObject): | |||||||
|         opts = self.remote_field.model._meta.concrete_model._meta |         opts = self.remote_field.model._meta.concrete_model._meta | ||||||
|         parent_opts = opts.get_field(self.object_id_field_name).model._meta |         parent_opts = opts.get_field(self.object_id_field_name).model._meta | ||||||
|         target = parent_opts.pk |         target = parent_opts.pk | ||||||
|         path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False)) |         path.append(PathInfo( | ||||||
|  |             from_opts=self.model._meta, | ||||||
|  |             to_opts=parent_opts, | ||||||
|  |             target_fields=(target,), | ||||||
|  |             join_field=self.remote_field, | ||||||
|  |             m2m=True, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )) | ||||||
|         # Collect joins needed for the parent -> child chain. This is easiest |         # Collect joins needed for the parent -> child chain. This is easiest | ||||||
|         # to do if we collect joins for the child -> parent chain and then |         # to do if we collect joins for the child -> parent chain and then | ||||||
|         # reverse the direction (call to reverse() and use of |         # reverse the direction (call to reverse() and use of | ||||||
| @@ -380,19 +388,35 @@ class GenericRelation(ForeignObject): | |||||||
|             path.extend(field.remote_field.get_path_info()) |             path.extend(field.remote_field.get_path_info()) | ||||||
|         return path |         return path | ||||||
|  |  | ||||||
|     def get_path_info(self): |     def get_path_info(self, filtered_relation=None): | ||||||
|         opts = self.remote_field.model._meta |         opts = self.remote_field.model._meta | ||||||
|         object_id_field = opts.get_field(self.object_id_field_name) |         object_id_field = opts.get_field(self.object_id_field_name) | ||||||
|         if object_id_field.model != opts.model: |         if object_id_field.model != opts.model: | ||||||
|             return self._get_path_info_with_parent() |             return self._get_path_info_with_parent(filtered_relation) | ||||||
|         else: |         else: | ||||||
|             target = opts.pk |             target = opts.pk | ||||||
|             return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] |             return [PathInfo( | ||||||
|  |                 from_opts=self.model._meta, | ||||||
|  |                 to_opts=opts, | ||||||
|  |                 target_fields=(target,), | ||||||
|  |                 join_field=self.remote_field, | ||||||
|  |                 m2m=True, | ||||||
|  |                 direct=False, | ||||||
|  |                 filtered_relation=filtered_relation, | ||||||
|  |             )] | ||||||
|  |  | ||||||
|     def get_reverse_path_info(self): |     def get_reverse_path_info(self, filtered_relation=None): | ||||||
|         opts = self.model._meta |         opts = self.model._meta | ||||||
|         from_opts = self.remote_field.model._meta |         from_opts = self.remote_field.model._meta | ||||||
|         return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)] |         return [PathInfo( | ||||||
|  |             from_opts=from_opts, | ||||||
|  |             to_opts=opts, | ||||||
|  |             target_fields=(opts.pk,), | ||||||
|  |             join_field=self, | ||||||
|  |             m2m=not self.unique, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     def value_to_string(self, obj): |     def value_to_string(self, obj): | ||||||
|         qs = getattr(obj, self.name).all() |         qs = getattr(obj, self.name).all() | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ from django.db.models.manager import Manager | |||||||
| from django.db.models.query import ( | from django.db.models.query import ( | ||||||
|     Prefetch, Q, QuerySet, prefetch_related_objects, |     Prefetch, Q, QuerySet, prefetch_related_objects, | ||||||
| ) | ) | ||||||
|  | from django.db.models.query_utils import FilteredRelation | ||||||
|  |  | ||||||
| # Imports that would create circular imports if sorted | # Imports that would create circular imports if sorted | ||||||
| from django.db.models.base import DEFERRED, Model  # isort:skip | from django.db.models.base import DEFERRED, Model  # isort:skip | ||||||
| @@ -69,6 +70,7 @@ __all__ += [ | |||||||
|     'Window', 'WindowFrame', |     'Window', 'WindowFrame', | ||||||
|     'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', |     'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', | ||||||
|     'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', |     'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', | ||||||
|  |     'FilteredRelation', | ||||||
|     'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', |     'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', | ||||||
|     'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', |     'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', | ||||||
| ] | ] | ||||||
|   | |||||||
| @@ -697,18 +697,33 @@ class ForeignObject(RelatedField): | |||||||
|         """ |         """ | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def get_path_info(self): |     def get_path_info(self, filtered_relation=None): | ||||||
|         """Get path from this field to the related model.""" |         """Get path from this field to the related model.""" | ||||||
|         opts = self.remote_field.model._meta |         opts = self.remote_field.model._meta | ||||||
|         from_opts = self.model._meta |         from_opts = self.model._meta | ||||||
|         return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] |         return [PathInfo( | ||||||
|  |             from_opts=from_opts, | ||||||
|  |             to_opts=opts, | ||||||
|  |             target_fields=self.foreign_related_fields, | ||||||
|  |             join_field=self, | ||||||
|  |             m2m=False, | ||||||
|  |             direct=True, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     def get_reverse_path_info(self): |     def get_reverse_path_info(self, filtered_relation=None): | ||||||
|         """Get path from the related model to this field's model.""" |         """Get path from the related model to this field's model.""" | ||||||
|         opts = self.model._meta |         opts = self.model._meta | ||||||
|         from_opts = self.remote_field.model._meta |         from_opts = self.remote_field.model._meta | ||||||
|         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] |         return [PathInfo( | ||||||
|         return pathinfos |             from_opts=from_opts, | ||||||
|  |             to_opts=opts, | ||||||
|  |             target_fields=(opts.pk,), | ||||||
|  |             join_field=self.remote_field, | ||||||
|  |             m2m=not self.unique, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     @functools.lru_cache(maxsize=None) |     @functools.lru_cache(maxsize=None) | ||||||
| @@ -861,12 +876,19 @@ class ForeignKey(ForeignObject): | |||||||
|     def target_field(self): |     def target_field(self): | ||||||
|         return self.foreign_related_fields[0] |         return self.foreign_related_fields[0] | ||||||
|  |  | ||||||
|     def get_reverse_path_info(self): |     def get_reverse_path_info(self, filtered_relation=None): | ||||||
|         """Get path from the related model to this field's model.""" |         """Get path from the related model to this field's model.""" | ||||||
|         opts = self.model._meta |         opts = self.model._meta | ||||||
|         from_opts = self.remote_field.model._meta |         from_opts = self.remote_field.model._meta | ||||||
|         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] |         return [PathInfo( | ||||||
|         return pathinfos |             from_opts=from_opts, | ||||||
|  |             to_opts=opts, | ||||||
|  |             target_fields=(opts.pk,), | ||||||
|  |             join_field=self.remote_field, | ||||||
|  |             m2m=not self.unique, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     def validate(self, value, model_instance): |     def validate(self, value, model_instance): | ||||||
|         if self.remote_field.parent_link: |         if self.remote_field.parent_link: | ||||||
| @@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField): | |||||||
|             ) |             ) | ||||||
|         return name, path, args, kwargs |         return name, path, args, kwargs | ||||||
|  |  | ||||||
|     def _get_path_info(self, direct=False): |     def _get_path_info(self, direct=False, filtered_relation=None): | ||||||
|         """Called by both direct and indirect m2m traversal.""" |         """Called by both direct and indirect m2m traversal.""" | ||||||
|         pathinfos = [] |         pathinfos = [] | ||||||
|         int_model = self.remote_field.through |         int_model = self.remote_field.through | ||||||
| @@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField): | |||||||
|         linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) |         linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) | ||||||
|         if direct: |         if direct: | ||||||
|             join1infos = linkfield1.get_reverse_path_info() |             join1infos = linkfield1.get_reverse_path_info() | ||||||
|             join2infos = linkfield2.get_path_info() |             join2infos = linkfield2.get_path_info(filtered_relation) | ||||||
|         else: |         else: | ||||||
|             join1infos = linkfield2.get_reverse_path_info() |             join1infos = linkfield2.get_reverse_path_info() | ||||||
|             join2infos = linkfield1.get_path_info() |             join2infos = linkfield1.get_path_info(filtered_relation) | ||||||
|  |  | ||||||
|         # Get join infos between the last model of join 1 and the first model |         # Get join infos between the last model of join 1 and the first model | ||||||
|         # of join 2. Assume the only reason these may differ is due to model |         # of join 2. Assume the only reason these may differ is due to model | ||||||
| @@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField): | |||||||
|         pathinfos.extend(join2infos) |         pathinfos.extend(join2infos) | ||||||
|         return pathinfos |         return pathinfos | ||||||
|  |  | ||||||
|     def get_path_info(self): |     def get_path_info(self, filtered_relation=None): | ||||||
|         return self._get_path_info(direct=True) |         return self._get_path_info(direct=True, filtered_relation=filtered_relation) | ||||||
|  |  | ||||||
|     def get_reverse_path_info(self): |     def get_reverse_path_info(self, filtered_relation=None): | ||||||
|         return self._get_path_info(direct=False) |         return self._get_path_info(direct=False, filtered_relation=filtered_relation) | ||||||
|  |  | ||||||
|     def _get_m2m_db_table(self, opts): |     def _get_m2m_db_table(self, opts): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin): | |||||||
|             return self.related_name |             return self.related_name | ||||||
|         return opts.model_name + ('_set' if self.multiple else '') |         return opts.model_name + ('_set' if self.multiple else '') | ||||||
|  |  | ||||||
|     def get_path_info(self): |     def get_path_info(self, filtered_relation=None): | ||||||
|         return self.field.get_reverse_path_info() |         return self.field.get_reverse_path_info(filtered_relation) | ||||||
|  |  | ||||||
|     def get_cache_name(self): |     def get_cache_name(self): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -632,7 +632,15 @@ class Options: | |||||||
|                 final_field = opts.parents[int_model] |                 final_field = opts.parents[int_model] | ||||||
|                 targets = (final_field.remote_field.get_related_field(),) |                 targets = (final_field.remote_field.get_related_field(),) | ||||||
|                 opts = int_model._meta |                 opts = int_model._meta | ||||||
|                 path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) |                 path.append(PathInfo( | ||||||
|  |                     from_opts=final_field.model._meta, | ||||||
|  |                     to_opts=opts, | ||||||
|  |                     target_fields=targets, | ||||||
|  |                     join_field=final_field, | ||||||
|  |                     m2m=False, | ||||||
|  |                     direct=True, | ||||||
|  |                     filtered_relation=None, | ||||||
|  |                 )) | ||||||
|         return path |         return path | ||||||
|  |  | ||||||
|     def get_path_from_parent(self, parent): |     def get_path_from_parent(self, parent): | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ from django.db.models.deletion import Collector | |||||||
| from django.db.models.expressions import F | from django.db.models.expressions import F | ||||||
| from django.db.models.fields import AutoField | from django.db.models.fields import AutoField | ||||||
| from django.db.models.functions import Trunc | from django.db.models.functions import Trunc | ||||||
| from django.db.models.query_utils import InvalidQuery, Q | from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q | ||||||
| from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE | from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| from django.utils.deprecation import RemovedInDjango30Warning | from django.utils.deprecation import RemovedInDjango30Warning | ||||||
| @@ -953,6 +953,12 @@ class QuerySet: | |||||||
|         if lookups == (None,): |         if lookups == (None,): | ||||||
|             clone._prefetch_related_lookups = () |             clone._prefetch_related_lookups = () | ||||||
|         else: |         else: | ||||||
|  |             for lookup in lookups: | ||||||
|  |                 if isinstance(lookup, Prefetch): | ||||||
|  |                     lookup = lookup.prefetch_to | ||||||
|  |                 lookup = lookup.split(LOOKUP_SEP, 1)[0] | ||||||
|  |                 if lookup in self.query._filtered_relations: | ||||||
|  |                     raise ValueError('prefetch_related() is not supported with FilteredRelation.') | ||||||
|             clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups |             clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups | ||||||
|         return clone |         return clone | ||||||
|  |  | ||||||
| @@ -984,6 +990,9 @@ class QuerySet: | |||||||
|             if alias in names: |             if alias in names: | ||||||
|                 raise ValueError("The annotation '%s' conflicts with a field on " |                 raise ValueError("The annotation '%s' conflicts with a field on " | ||||||
|                                  "the model." % alias) |                                  "the model." % alias) | ||||||
|  |             if isinstance(annotation, FilteredRelation): | ||||||
|  |                 clone.query.add_filtered_relation(annotation, alias) | ||||||
|  |             else: | ||||||
|                 clone.query.add_annotation(annotation, alias, is_summary=False) |                 clone.query.add_annotation(annotation, alias, is_summary=False) | ||||||
|  |  | ||||||
|         for alias, annotation in clone.query.annotations.items(): |         for alias, annotation in clone.query.annotations.items(): | ||||||
| @@ -1060,6 +1069,10 @@ class QuerySet: | |||||||
|             # Can only pass None to defer(), not only(), as the rest option. |             # Can only pass None to defer(), not only(), as the rest option. | ||||||
|             # That won't stop people trying to do this, so let's be explicit. |             # That won't stop people trying to do this, so let's be explicit. | ||||||
|             raise TypeError("Cannot pass None as an argument to only().") |             raise TypeError("Cannot pass None as an argument to only().") | ||||||
|  |         for field in fields: | ||||||
|  |             field = field.split(LOOKUP_SEP, 1)[0] | ||||||
|  |             if field in self.query._filtered_relations: | ||||||
|  |                 raise ValueError('only() is not supported with FilteredRelation.') | ||||||
|         clone = self._chain() |         clone = self._chain() | ||||||
|         clone.query.add_immediate_loading(fields) |         clone.query.add_immediate_loading(fields) | ||||||
|         return clone |         return clone | ||||||
| @@ -1730,9 +1743,9 @@ class RelatedPopulator: | |||||||
|         #    model's fields. |         #    model's fields. | ||||||
|         #  - related_populators: a list of RelatedPopulator instances if |         #  - related_populators: a list of RelatedPopulator instances if | ||||||
|         #    select_related() descends to related models from this model. |         #    select_related() descends to related models from this model. | ||||||
|         #  - field, remote_field: the fields to use for populating the |         #  - local_setter, remote_setter: Methods to set cached values on | ||||||
|         #    internal fields cache. If remote_field is set then we also |         #    the object being populated and on the remote object. Usually | ||||||
|         #    set the reverse link. |         #    these are Field.set_cached_value() methods. | ||||||
|         select_fields = klass_info['select_fields'] |         select_fields = klass_info['select_fields'] | ||||||
|         from_parent = klass_info['from_parent'] |         from_parent = klass_info['from_parent'] | ||||||
|         if not from_parent: |         if not from_parent: | ||||||
| @@ -1751,16 +1764,8 @@ class RelatedPopulator: | |||||||
|         self.model_cls = klass_info['model'] |         self.model_cls = klass_info['model'] | ||||||
|         self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) |         self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) | ||||||
|         self.related_populators = get_related_populators(klass_info, select, self.db) |         self.related_populators = get_related_populators(klass_info, select, self.db) | ||||||
|         reverse = klass_info['reverse'] |         self.local_setter = klass_info['local_setter'] | ||||||
|         field = klass_info['field'] |         self.remote_setter = klass_info['remote_setter'] | ||||||
|         self.remote_field = None |  | ||||||
|         if reverse: |  | ||||||
|             self.field = field.remote_field |  | ||||||
|             self.remote_field = field |  | ||||||
|         else: |  | ||||||
|             self.field = field |  | ||||||
|             if field.unique: |  | ||||||
|                 self.remote_field = field.remote_field |  | ||||||
|  |  | ||||||
|     def populate(self, row, from_obj): |     def populate(self, row, from_obj): | ||||||
|         if self.reorder_for_init: |         if self.reorder_for_init: | ||||||
| @@ -1774,9 +1779,9 @@ class RelatedPopulator: | |||||||
|             if self.related_populators: |             if self.related_populators: | ||||||
|                 for rel_iter in self.related_populators: |                 for rel_iter in self.related_populators: | ||||||
|                     rel_iter.populate(row, obj) |                     rel_iter.populate(row, obj) | ||||||
|             if self.remote_field: |         self.local_setter(from_obj, obj) | ||||||
|                 self.remote_field.set_cached_value(obj, from_obj) |         if obj is not None: | ||||||
|         self.field.set_cached_value(from_obj, obj) |             self.remote_setter(obj, from_obj) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_related_populators(klass_info, select, db): | def get_related_populators(klass_info, select, db): | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ from django.utils import tree | |||||||
| # PathInfo is used when converting lookups (fk__somecol). The contents | # PathInfo is used when converting lookups (fk__somecol). The contents | ||||||
| # describe the relation in Model terms (model Options and Fields for both | # describe the relation in Model terms (model Options and Fields for both | ||||||
| # sides of the relation. The join_field is the field backing the relation. | # sides of the relation. The join_field is the field backing the relation. | ||||||
| PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct') | PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation') | ||||||
|  |  | ||||||
|  |  | ||||||
| class InvalidQuery(Exception): | class InvalidQuery(Exception): | ||||||
| @@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field): | |||||||
|         check(target_opts) or |         check(target_opts) or | ||||||
|         (getattr(field, 'primary_key', False) and check(field.model._meta)) |         (getattr(field, 'primary_key', False) and check(field.model._meta)) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FilteredRelation: | ||||||
|  |     """Specify custom filtering in the ON clause of SQL joins.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, relation_name, *, condition=Q()): | ||||||
|  |         if not relation_name: | ||||||
|  |             raise ValueError('relation_name cannot be empty.') | ||||||
|  |         self.relation_name = relation_name | ||||||
|  |         self.alias = None | ||||||
|  |         if not isinstance(condition, Q): | ||||||
|  |             raise ValueError('condition argument must be a Q() instance.') | ||||||
|  |         self.condition = condition | ||||||
|  |         self.path = [] | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return ( | ||||||
|  |             isinstance(other, self.__class__) and | ||||||
|  |             self.relation_name == other.relation_name and | ||||||
|  |             self.alias == other.alias and | ||||||
|  |             self.condition == other.condition | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def clone(self): | ||||||
|  |         clone = FilteredRelation(self.relation_name, condition=self.condition) | ||||||
|  |         clone.alias = self.alias | ||||||
|  |         clone.path = self.path[:] | ||||||
|  |         return clone | ||||||
|  |  | ||||||
|  |     def resolve_expression(self, *args, **kwargs): | ||||||
|  |         """ | ||||||
|  |         QuerySet.annotate() only accepts expression-like arguments | ||||||
|  |         (with a resolve_expression() method). | ||||||
|  |         """ | ||||||
|  |         raise NotImplementedError('FilteredRelation.resolve_expression() is unused.') | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         # Resolve the condition in Join.filtered_relation. | ||||||
|  |         query = compiler.query | ||||||
|  |         where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) | ||||||
|  |         return compiler.compile(where) | ||||||
|   | |||||||
| @@ -702,7 +702,7 @@ class SQLCompiler: | |||||||
|         """ |         """ | ||||||
|         result = [] |         result = [] | ||||||
|         params = [] |         params = [] | ||||||
|         for alias in self.query.alias_map: |         for alias in tuple(self.query.alias_map): | ||||||
|             if not self.query.alias_refcount[alias]: |             if not self.query.alias_refcount[alias]: | ||||||
|                 continue |                 continue | ||||||
|             try: |             try: | ||||||
| @@ -737,7 +737,7 @@ class SQLCompiler: | |||||||
|                 f.field.related_query_name() |                 f.field.related_query_name() | ||||||
|                 for f in opts.related_objects if f.field.unique |                 for f in opts.related_objects if f.field.unique | ||||||
|             ) |             ) | ||||||
|             return chain(direct_choices, reverse_choices) |             return chain(direct_choices, reverse_choices, self.query._filtered_relations) | ||||||
|  |  | ||||||
|         related_klass_infos = [] |         related_klass_infos = [] | ||||||
|         if not restricted and cur_depth > self.query.max_depth: |         if not restricted and cur_depth > self.query.max_depth: | ||||||
| @@ -788,7 +788,8 @@ class SQLCompiler: | |||||||
|             klass_info = { |             klass_info = { | ||||||
|                 'model': f.remote_field.model, |                 'model': f.remote_field.model, | ||||||
|                 'field': f, |                 'field': f, | ||||||
|                 'reverse': False, |                 'local_setter': f.set_cached_value, | ||||||
|  |                 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None, | ||||||
|                 'from_parent': False, |                 'from_parent': False, | ||||||
|             } |             } | ||||||
|             related_klass_infos.append(klass_info) |             related_klass_infos.append(klass_info) | ||||||
| @@ -825,7 +826,8 @@ class SQLCompiler: | |||||||
|                 klass_info = { |                 klass_info = { | ||||||
|                     'model': model, |                     'model': model, | ||||||
|                     'field': f, |                     'field': f, | ||||||
|                     'reverse': True, |                     'local_setter': f.remote_field.set_cached_value, | ||||||
|  |                     'remote_setter': f.set_cached_value, | ||||||
|                     'from_parent': from_parent, |                     'from_parent': from_parent, | ||||||
|                 } |                 } | ||||||
|                 related_klass_infos.append(klass_info) |                 related_klass_infos.append(klass_info) | ||||||
| @@ -842,6 +844,47 @@ class SQLCompiler: | |||||||
|                     next, restricted) |                     next, restricted) | ||||||
|                 get_related_klass_infos(klass_info, next_klass_infos) |                 get_related_klass_infos(klass_info, next_klass_infos) | ||||||
|             fields_not_found = set(requested).difference(fields_found) |             fields_not_found = set(requested).difference(fields_found) | ||||||
|  |             for name in list(requested): | ||||||
|  |                 # Filtered relations work only on the topmost level. | ||||||
|  |                 if cur_depth > 1: | ||||||
|  |                     break | ||||||
|  |                 if name in self.query._filtered_relations: | ||||||
|  |                     fields_found.add(name) | ||||||
|  |                     f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) | ||||||
|  |                     model = join_opts.model | ||||||
|  |                     alias = joins[-1] | ||||||
|  |                     from_parent = issubclass(model, opts.model) and model is not opts.model | ||||||
|  |  | ||||||
|  |                     def local_setter(obj, from_obj): | ||||||
|  |                         f.remote_field.set_cached_value(from_obj, obj) | ||||||
|  |  | ||||||
|  |                     def remote_setter(obj, from_obj): | ||||||
|  |                         setattr(from_obj, name, obj) | ||||||
|  |                     klass_info = { | ||||||
|  |                         'model': model, | ||||||
|  |                         'field': f, | ||||||
|  |                         'local_setter': local_setter, | ||||||
|  |                         'remote_setter': remote_setter, | ||||||
|  |                         'from_parent': from_parent, | ||||||
|  |                     } | ||||||
|  |                     related_klass_infos.append(klass_info) | ||||||
|  |                     select_fields = [] | ||||||
|  |                     columns = self.get_default_columns( | ||||||
|  |                         start_alias=alias, opts=model._meta, | ||||||
|  |                         from_parent=opts.model, | ||||||
|  |                     ) | ||||||
|  |                     for col in columns: | ||||||
|  |                         select_fields.append(len(select)) | ||||||
|  |                         select.append((col, None)) | ||||||
|  |                     klass_info['select_fields'] = select_fields | ||||||
|  |                     next_requested = requested.get(name, {}) | ||||||
|  |                     next_klass_infos = self.get_related_selections( | ||||||
|  |                         select, opts=model._meta, root_alias=alias, | ||||||
|  |                         cur_depth=cur_depth + 1, requested=next_requested, | ||||||
|  |                         restricted=restricted, | ||||||
|  |                     ) | ||||||
|  |                     get_related_klass_infos(klass_info, next_klass_infos) | ||||||
|  |             fields_not_found = set(requested).difference(fields_found) | ||||||
|             if fields_not_found: |             if fields_not_found: | ||||||
|                 invalid_fields = ("'%s'" % s for s in fields_not_found) |                 invalid_fields = ("'%s'" % s for s in fields_not_found) | ||||||
|                 raise FieldError( |                 raise FieldError( | ||||||
|   | |||||||
| @@ -41,7 +41,7 @@ class Join: | |||||||
|         - relabeled_clone() |         - relabeled_clone() | ||||||
|     """ |     """ | ||||||
|     def __init__(self, table_name, parent_alias, table_alias, join_type, |     def __init__(self, table_name, parent_alias, table_alias, join_type, | ||||||
|                  join_field, nullable): |                  join_field, nullable, filtered_relation=None): | ||||||
|         # Join table |         # Join table | ||||||
|         self.table_name = table_name |         self.table_name = table_name | ||||||
|         self.parent_alias = parent_alias |         self.parent_alias = parent_alias | ||||||
| @@ -56,6 +56,7 @@ class Join: | |||||||
|         self.join_field = join_field |         self.join_field = join_field | ||||||
|         # Is this join nullabled? |         # Is this join nullabled? | ||||||
|         self.nullable = nullable |         self.nullable = nullable | ||||||
|  |         self.filtered_relation = filtered_relation | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         """ |         """ | ||||||
| @@ -85,7 +86,11 @@ class Join: | |||||||
|             extra_sql, extra_params = compiler.compile(extra_cond) |             extra_sql, extra_params = compiler.compile(extra_cond) | ||||||
|             join_conditions.append('(%s)' % extra_sql) |             join_conditions.append('(%s)' % extra_sql) | ||||||
|             params.extend(extra_params) |             params.extend(extra_params) | ||||||
|  |         if self.filtered_relation: | ||||||
|  |             extra_sql, extra_params = compiler.compile(self.filtered_relation) | ||||||
|  |             if extra_sql: | ||||||
|  |                 join_conditions.append('(%s)' % extra_sql) | ||||||
|  |                 params.extend(extra_params) | ||||||
|         if not join_conditions: |         if not join_conditions: | ||||||
|             # This might be a rel on the other end of an actual declared field. |             # This might be a rel on the other end of an actual declared field. | ||||||
|             declared_field = getattr(self.join_field, 'field', self.join_field) |             declared_field = getattr(self.join_field, 'field', self.join_field) | ||||||
| @@ -101,18 +106,27 @@ class Join: | |||||||
|     def relabeled_clone(self, change_map): |     def relabeled_clone(self, change_map): | ||||||
|         new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) |         new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) | ||||||
|         new_table_alias = change_map.get(self.table_alias, self.table_alias) |         new_table_alias = change_map.get(self.table_alias, self.table_alias) | ||||||
|  |         if self.filtered_relation is not None: | ||||||
|  |             filtered_relation = self.filtered_relation.clone() | ||||||
|  |             filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path] | ||||||
|  |         else: | ||||||
|  |             filtered_relation = None | ||||||
|         return self.__class__( |         return self.__class__( | ||||||
|             self.table_name, new_parent_alias, new_table_alias, self.join_type, |             self.table_name, new_parent_alias, new_table_alias, self.join_type, | ||||||
|             self.join_field, self.nullable) |             self.join_field, self.nullable, filtered_relation=filtered_relation, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def equals(self, other, with_filtered_relation): | ||||||
|         if isinstance(other, self.__class__): |  | ||||||
|         return ( |         return ( | ||||||
|  |             isinstance(other, self.__class__) and | ||||||
|             self.table_name == other.table_name and |             self.table_name == other.table_name and | ||||||
|             self.parent_alias == other.parent_alias and |             self.parent_alias == other.parent_alias and | ||||||
|                 self.join_field == other.join_field |             self.join_field == other.join_field and | ||||||
|  |             (not with_filtered_relation or self.filtered_relation == other.filtered_relation) | ||||||
|         ) |         ) | ||||||
|         return False |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return self.equals(other, with_filtered_relation=True) | ||||||
|  |  | ||||||
|     def demote(self): |     def demote(self): | ||||||
|         new = self.relabeled_clone({}) |         new = self.relabeled_clone({}) | ||||||
| @@ -134,6 +148,7 @@ class BaseTable: | |||||||
|     """ |     """ | ||||||
|     join_type = None |     join_type = None | ||||||
|     parent_alias = None |     parent_alias = None | ||||||
|  |     filtered_relation = None | ||||||
|  |  | ||||||
|     def __init__(self, table_name, alias): |     def __init__(self, table_name, alias): | ||||||
|         self.table_name = table_name |         self.table_name = table_name | ||||||
| @@ -146,3 +161,10 @@ class BaseTable: | |||||||
|  |  | ||||||
|     def relabeled_clone(self, change_map): |     def relabeled_clone(self, change_map): | ||||||
|         return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) |         return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) | ||||||
|  |  | ||||||
|  |     def equals(self, other, with_filtered_relation): | ||||||
|  |         return ( | ||||||
|  |             isinstance(self, other.__class__) and | ||||||
|  |             self.table_name == other.table_name and | ||||||
|  |             self.table_alias == other.table_alias | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -45,6 +45,14 @@ def get_field_names_from_opts(opts): | |||||||
|     )) |     )) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_children_from_q(q): | ||||||
|  |     for child in q.children: | ||||||
|  |         if isinstance(child, Node): | ||||||
|  |             yield from get_children_from_q(child) | ||||||
|  |         else: | ||||||
|  |             yield child | ||||||
|  |  | ||||||
|  |  | ||||||
| JoinInfo = namedtuple( | JoinInfo = namedtuple( | ||||||
|     'JoinInfo', |     'JoinInfo', | ||||||
|     ('final_field', 'targets', 'opts', 'joins', 'path') |     ('final_field', 'targets', 'opts', 'joins', 'path') | ||||||
| @@ -210,6 +218,8 @@ class Query: | |||||||
|         # load. |         # load. | ||||||
|         self.deferred_loading = (frozenset(), True) |         self.deferred_loading = (frozenset(), True) | ||||||
|  |  | ||||||
|  |         self._filtered_relations = {} | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def extra(self): |     def extra(self): | ||||||
|         if self._extra is None: |         if self._extra is None: | ||||||
| @@ -311,6 +321,7 @@ class Query: | |||||||
|         if 'subq_aliases' in self.__dict__: |         if 'subq_aliases' in self.__dict__: | ||||||
|             obj.subq_aliases = self.subq_aliases.copy() |             obj.subq_aliases = self.subq_aliases.copy() | ||||||
|         obj.used_aliases = self.used_aliases.copy() |         obj.used_aliases = self.used_aliases.copy() | ||||||
|  |         obj._filtered_relations = self._filtered_relations.copy() | ||||||
|         # Clear the cached_property |         # Clear the cached_property | ||||||
|         try: |         try: | ||||||
|             del obj.base_table |             del obj.base_table | ||||||
| @@ -624,6 +635,8 @@ class Query: | |||||||
|             opts = orig_opts |             opts = orig_opts | ||||||
|             for name in parts[:-1]: |             for name in parts[:-1]: | ||||||
|                 old_model = cur_model |                 old_model = cur_model | ||||||
|  |                 if name in self._filtered_relations: | ||||||
|  |                     name = self._filtered_relations[name].relation_name | ||||||
|                 source = opts.get_field(name) |                 source = opts.get_field(name) | ||||||
|                 if is_reverse_o2o(source): |                 if is_reverse_o2o(source): | ||||||
|                     cur_model = source.related_model |                     cur_model = source.related_model | ||||||
| @@ -684,7 +697,7 @@ class Query: | |||||||
|             for model, values in seen.items(): |             for model, values in seen.items(): | ||||||
|                 callback(target, model, values) |                 callback(target, model, values) | ||||||
|  |  | ||||||
|     def table_alias(self, table_name, create=False): |     def table_alias(self, table_name, create=False, filtered_relation=None): | ||||||
|         """ |         """ | ||||||
|         Return a table alias for the given table_name and whether this is a |         Return a table alias for the given table_name and whether this is a | ||||||
|         new alias or not. |         new alias or not. | ||||||
| @@ -704,8 +717,8 @@ class Query: | |||||||
|             alias_list.append(alias) |             alias_list.append(alias) | ||||||
|         else: |         else: | ||||||
|             # The first occurrence of a table uses the table name directly. |             # The first occurrence of a table uses the table name directly. | ||||||
|             alias = table_name |             alias = filtered_relation.alias if filtered_relation is not None else table_name | ||||||
|             self.table_map[alias] = [alias] |             self.table_map[table_name] = [alias] | ||||||
|         self.alias_refcount[alias] = 1 |         self.alias_refcount[alias] = 1 | ||||||
|         return alias, True |         return alias, True | ||||||
|  |  | ||||||
| @@ -881,7 +894,7 @@ class Query: | |||||||
|         """ |         """ | ||||||
|         return len([1 for count in self.alias_refcount.values() if count]) |         return len([1 for count in self.alias_refcount.values() if count]) | ||||||
|  |  | ||||||
|     def join(self, join, reuse=None): |     def join(self, join, reuse=None, reuse_with_filtered_relation=False): | ||||||
|         """ |         """ | ||||||
|         Return an alias for the 'join', either reusing an existing alias for |         Return an alias for the 'join', either reusing an existing alias for | ||||||
|         that join or creating a new one. 'join' is either a |         that join or creating a new one. 'join' is either a | ||||||
| @@ -890,18 +903,29 @@ class Query: | |||||||
|         The 'reuse' parameter can be either None which means all joins are |         The 'reuse' parameter can be either None which means all joins are | ||||||
|         reusable, or it can be a set containing the aliases that can be reused. |         reusable, or it can be a set containing the aliases that can be reused. | ||||||
|  |  | ||||||
|  |         The 'reuse_with_filtered_relation' parameter is used when computing | ||||||
|  |         FilteredRelation instances. | ||||||
|  |  | ||||||
|         A join is always created as LOUTER if the lhs alias is LOUTER to make |         A join is always created as LOUTER if the lhs alias is LOUTER to make | ||||||
|         sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new |         sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new | ||||||
|         joins are created as LOUTER if the join is nullable. |         joins are created as LOUTER if the join is nullable. | ||||||
|         """ |         """ | ||||||
|         reuse = [a for a, j in self.alias_map.items() |         if reuse_with_filtered_relation and reuse: | ||||||
|                  if (reuse is None or a in reuse) and j == join] |             reuse_aliases = [ | ||||||
|         if reuse: |                 a for a, j in self.alias_map.items() | ||||||
|             self.ref_alias(reuse[0]) |                 if a in reuse and j.equals(join, with_filtered_relation=False) | ||||||
|             return reuse[0] |             ] | ||||||
|  |         else: | ||||||
|  |             reuse_aliases = [ | ||||||
|  |                 a for a, j in self.alias_map.items() | ||||||
|  |                 if (reuse is None or a in reuse) and j == join | ||||||
|  |             ] | ||||||
|  |         if reuse_aliases: | ||||||
|  |             self.ref_alias(reuse_aliases[0]) | ||||||
|  |             return reuse_aliases[0] | ||||||
|  |  | ||||||
|         # No reuse is possible, so we need a new alias. |         # No reuse is possible, so we need a new alias. | ||||||
|         alias, _ = self.table_alias(join.table_name, create=True) |         alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation) | ||||||
|         if join.join_type: |         if join.join_type: | ||||||
|             if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: |             if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: | ||||||
|                 join_type = LOUTER |                 join_type = LOUTER | ||||||
| @@ -1090,7 +1114,8 @@ class Query: | |||||||
|                 (name, lhs.output_field.__class__.__name__)) |                 (name, lhs.output_field.__class__.__name__)) | ||||||
|  |  | ||||||
|     def build_filter(self, filter_expr, branch_negated=False, current_negated=False, |     def build_filter(self, filter_expr, branch_negated=False, current_negated=False, | ||||||
|                      can_reuse=None, allow_joins=True, split_subq=True): |                      can_reuse=None, allow_joins=True, split_subq=True, | ||||||
|  |                      reuse_with_filtered_relation=False): | ||||||
|         """ |         """ | ||||||
|         Build a WhereNode for a single filter clause but don't add it |         Build a WhereNode for a single filter clause but don't add it | ||||||
|         to this Query. Query.add_q() will then add this filter to the where |         to this Query. Query.add_q() will then add this filter to the where | ||||||
| @@ -1112,6 +1137,9 @@ class Query: | |||||||
|  |  | ||||||
|         The 'can_reuse' is a set of reusable joins for multijoins. |         The 'can_reuse' is a set of reusable joins for multijoins. | ||||||
|  |  | ||||||
|  |         If 'reuse_with_filtered_relation' is True, then only joins in can_reuse | ||||||
|  |         will be reused. | ||||||
|  |  | ||||||
|         The method will create a filter clause that can be added to the current |         The method will create a filter clause that can be added to the current | ||||||
|         query. However, if the filter isn't added to the query then the caller |         query. However, if the filter isn't added to the query then the caller | ||||||
|         is responsible for unreffing the joins used. |         is responsible for unreffing the joins used. | ||||||
| @@ -1147,7 +1175,10 @@ class Query: | |||||||
|         allow_many = not branch_negated or not split_subq |         allow_many = not branch_negated or not split_subq | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) |             join_info = self.setup_joins( | ||||||
|  |                 parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, | ||||||
|  |                 reuse_with_filtered_relation=reuse_with_filtered_relation, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|             # Prevent iterator from being consumed by check_related_objects() |             # Prevent iterator from being consumed by check_related_objects() | ||||||
|             if isinstance(value, Iterator): |             if isinstance(value, Iterator): | ||||||
| @@ -1250,6 +1281,41 @@ class Query: | |||||||
|         needed_inner = joinpromoter.update_join_types(self) |         needed_inner = joinpromoter.update_join_types(self) | ||||||
|         return target_clause, needed_inner |         return target_clause, needed_inner | ||||||
|  |  | ||||||
|  |     def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False): | ||||||
|  |         """Add a FilteredRelation object to the current filter.""" | ||||||
|  |         connector = q_object.connector | ||||||
|  |         current_negated ^= q_object.negated | ||||||
|  |         branch_negated = branch_negated or q_object.negated | ||||||
|  |         target_clause = self.where_class(connector=connector, negated=q_object.negated) | ||||||
|  |         for child in q_object.children: | ||||||
|  |             if isinstance(child, Node): | ||||||
|  |                 child_clause = self.build_filtered_relation_q( | ||||||
|  |                     child, reuse=reuse, branch_negated=branch_negated, | ||||||
|  |                     current_negated=current_negated, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 child_clause, _ = self.build_filter( | ||||||
|  |                     child, can_reuse=reuse, branch_negated=branch_negated, | ||||||
|  |                     current_negated=current_negated, | ||||||
|  |                     allow_joins=True, split_subq=False, | ||||||
|  |                     reuse_with_filtered_relation=True, | ||||||
|  |                 ) | ||||||
|  |             target_clause.add(child_clause, connector) | ||||||
|  |         return target_clause | ||||||
|  |  | ||||||
|  |     def add_filtered_relation(self, filtered_relation, alias): | ||||||
|  |         filtered_relation.alias = alias | ||||||
|  |         lookups = dict(get_children_from_q(filtered_relation.condition)) | ||||||
|  |         for lookup in chain((filtered_relation.relation_name,), lookups): | ||||||
|  |             lookup_parts, field_parts, _ = self.solve_lookup_type(lookup) | ||||||
|  |             shift = 2 if not lookup_parts else 1 | ||||||
|  |             if len(field_parts) > (shift + len(lookup_parts)): | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "FilteredRelation's condition doesn't support nested " | ||||||
|  |                     "relations (got %r)." % lookup | ||||||
|  |                 ) | ||||||
|  |         self._filtered_relations[filtered_relation.alias] = filtered_relation | ||||||
|  |  | ||||||
|     def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): |     def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): | ||||||
|         """ |         """ | ||||||
|         Walk the list of names and turns them into PathInfo tuples. A single |         Walk the list of names and turns them into PathInfo tuples. A single | ||||||
| @@ -1272,12 +1338,15 @@ class Query: | |||||||
|                 name = opts.pk.name |                 name = opts.pk.name | ||||||
|  |  | ||||||
|             field = None |             field = None | ||||||
|  |             filtered_relation = None | ||||||
|             try: |             try: | ||||||
|                 field = opts.get_field(name) |                 field = opts.get_field(name) | ||||||
|             except FieldDoesNotExist: |             except FieldDoesNotExist: | ||||||
|                 if name in self.annotation_select: |                 if name in self.annotation_select: | ||||||
|                     field = self.annotation_select[name].output_field |                     field = self.annotation_select[name].output_field | ||||||
|  |                 elif name in self._filtered_relations and pos == 0: | ||||||
|  |                     filtered_relation = self._filtered_relations[name] | ||||||
|  |                     field = opts.get_field(filtered_relation.relation_name) | ||||||
|             if field is not None: |             if field is not None: | ||||||
|                 # Fields that contain one-to-many relations with a generic |                 # Fields that contain one-to-many relations with a generic | ||||||
|                 # model (like a GenericForeignKey) cannot generate reverse |                 # model (like a GenericForeignKey) cannot generate reverse | ||||||
| @@ -1301,7 +1370,10 @@ class Query: | |||||||
|                 pos -= 1 |                 pos -= 1 | ||||||
|                 if pos == -1 or fail_on_missing: |                 if pos == -1 or fail_on_missing: | ||||||
|                     field_names = list(get_field_names_from_opts(opts)) |                     field_names = list(get_field_names_from_opts(opts)) | ||||||
|                     available = sorted(field_names + list(self.annotation_select)) |                     available = sorted( | ||||||
|  |                         field_names + list(self.annotation_select) + | ||||||
|  |                         list(self._filtered_relations) | ||||||
|  |                     ) | ||||||
|                     raise FieldError("Cannot resolve keyword '%s' into field. " |                     raise FieldError("Cannot resolve keyword '%s' into field. " | ||||||
|                                      "Choices are: %s" % (name, ", ".join(available))) |                                      "Choices are: %s" % (name, ", ".join(available))) | ||||||
|                 break |                 break | ||||||
| @@ -1315,7 +1387,7 @@ class Query: | |||||||
|                     cur_names_with_path[1].extend(path_to_parent) |                     cur_names_with_path[1].extend(path_to_parent) | ||||||
|                     opts = path_to_parent[-1].to_opts |                     opts = path_to_parent[-1].to_opts | ||||||
|             if hasattr(field, 'get_path_info'): |             if hasattr(field, 'get_path_info'): | ||||||
|                 pathinfos = field.get_path_info() |                 pathinfos = field.get_path_info(filtered_relation) | ||||||
|                 if not allow_many: |                 if not allow_many: | ||||||
|                     for inner_pos, p in enumerate(pathinfos): |                     for inner_pos, p in enumerate(pathinfos): | ||||||
|                         if p.m2m: |                         if p.m2m: | ||||||
| @@ -1340,7 +1412,8 @@ class Query: | |||||||
|                 break |                 break | ||||||
|         return path, final_field, targets, names[pos + 1:] |         return path, final_field, targets, names[pos + 1:] | ||||||
|  |  | ||||||
|     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): |     def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, | ||||||
|  |                     reuse_with_filtered_relation=False): | ||||||
|         """ |         """ | ||||||
|         Compute the necessary table joins for the passage through the fields |         Compute the necessary table joins for the passage through the fields | ||||||
|         given in 'names'. 'opts' is the Options class for the current model |         given in 'names'. 'opts' is the Options class for the current model | ||||||
| @@ -1352,6 +1425,9 @@ class Query: | |||||||
|         that can be reused. Note that non-reverse foreign keys are always |         that can be reused. Note that non-reverse foreign keys are always | ||||||
|         reusable when using setup_joins(). |         reusable when using setup_joins(). | ||||||
|  |  | ||||||
|  |         The 'reuse_with_filtered_relation' can be used to force 'can_reuse' | ||||||
|  |         parameter and force the relation on the given connections. | ||||||
|  |  | ||||||
|         If 'allow_many' is False, then any reverse foreign key seen will |         If 'allow_many' is False, then any reverse foreign key seen will | ||||||
|         generate a MultiJoin exception. |         generate a MultiJoin exception. | ||||||
|  |  | ||||||
| @@ -1374,15 +1450,29 @@ class Query: | |||||||
|         # joins at this stage - we will need the information about join type |         # joins at this stage - we will need the information about join type | ||||||
|         # of the trimmed joins. |         # of the trimmed joins. | ||||||
|         for join in path: |         for join in path: | ||||||
|  |             if join.filtered_relation: | ||||||
|  |                 filtered_relation = join.filtered_relation.clone() | ||||||
|  |                 table_alias = filtered_relation.alias | ||||||
|  |             else: | ||||||
|  |                 filtered_relation = None | ||||||
|  |                 table_alias = None | ||||||
|             opts = join.to_opts |             opts = join.to_opts | ||||||
|             if join.direct: |             if join.direct: | ||||||
|                 nullable = self.is_nullable(join.join_field) |                 nullable = self.is_nullable(join.join_field) | ||||||
|             else: |             else: | ||||||
|                 nullable = True |                 nullable = True | ||||||
|             connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) |             connection = Join( | ||||||
|             reuse = can_reuse if join.m2m else None |                 opts.db_table, alias, table_alias, INNER, join.join_field, | ||||||
|             alias = self.join(connection, reuse=reuse) |                 nullable, filtered_relation=filtered_relation, | ||||||
|  |             ) | ||||||
|  |             reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None | ||||||
|  |             alias = self.join( | ||||||
|  |                 connection, reuse=reuse, | ||||||
|  |                 reuse_with_filtered_relation=reuse_with_filtered_relation, | ||||||
|  |             ) | ||||||
|             joins.append(alias) |             joins.append(alias) | ||||||
|  |             if filtered_relation: | ||||||
|  |                 filtered_relation.path = joins[:] | ||||||
|         return JoinInfo(final_field, targets, opts, joins, path) |         return JoinInfo(final_field, targets, opts, joins, path) | ||||||
|  |  | ||||||
|     def trim_joins(self, targets, joins, path): |     def trim_joins(self, targets, joins, path): | ||||||
| @@ -1402,6 +1492,8 @@ class Query: | |||||||
|         for pos, info in enumerate(reversed(path)): |         for pos, info in enumerate(reversed(path)): | ||||||
|             if len(joins) == 1 or not info.direct: |             if len(joins) == 1 or not info.direct: | ||||||
|                 break |                 break | ||||||
|  |             if info.filtered_relation: | ||||||
|  |                 break | ||||||
|             join_targets = {t.column for t in info.join_field.foreign_related_fields} |             join_targets = {t.column for t in info.join_field.foreign_related_fields} | ||||||
|             cur_targets = {t.column for t in targets} |             cur_targets = {t.column for t in targets} | ||||||
|             if not cur_targets.issubset(join_targets): |             if not cur_targets.issubset(join_targets): | ||||||
| @@ -1425,7 +1517,7 @@ class Query: | |||||||
|                 return self.annotation_select[name] |                 return self.annotation_select[name] | ||||||
|         else: |         else: | ||||||
|             field_list = name.split(LOOKUP_SEP) |             field_list = name.split(LOOKUP_SEP) | ||||||
|             join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) |             join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) | ||||||
|             targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) |             targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) | ||||||
|             if len(targets) > 1: |             if len(targets) > 1: | ||||||
|                 raise FieldError("Referencing multicolumn fields with F() objects " |                 raise FieldError("Referencing multicolumn fields with F() objects " | ||||||
| @@ -1602,7 +1694,10 @@ class Query: | |||||||
|                 # from the model on which the lookup failed. |                 # from the model on which the lookup failed. | ||||||
|                 raise |                 raise | ||||||
|             else: |             else: | ||||||
|                 names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select)) |                 names = sorted( | ||||||
|  |                     list(get_field_names_from_opts(opts)) + list(self.extra) + | ||||||
|  |                     list(self.annotation_select) + list(self._filtered_relations) | ||||||
|  |                 ) | ||||||
|                 raise FieldError("Cannot resolve keyword %r into field. " |                 raise FieldError("Cannot resolve keyword %r into field. " | ||||||
|                                  "Choices are: %s" % (name, ", ".join(names))) |                                  "Choices are: %s" % (name, ", ".join(names))) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example:: | |||||||
|     >>> from django.db.models import prefetch_related_objects |     >>> from django.db.models import prefetch_related_objects | ||||||
|     >>> restaurants = fetch_top_restaurants_from_cache()  # A list of Restaurants |     >>> restaurants = fetch_top_restaurants_from_cache()  # A list of Restaurants | ||||||
|     >>> prefetch_related_objects(restaurants, 'pizzas__toppings') |     >>> prefetch_related_objects(restaurants, 'pizzas__toppings') | ||||||
|  |  | ||||||
|  | ``FilteredRelation()`` objects | ||||||
|  | ------------------------------ | ||||||
|  |  | ||||||
|  | .. versionadded:: 2.0 | ||||||
|  |  | ||||||
|  | .. class:: FilteredRelation(relation_name, *, condition=Q()) | ||||||
|  |  | ||||||
|  |     .. attribute:: FilteredRelation.relation_name | ||||||
|  |  | ||||||
|  |         The name of the field on which you'd like to filter the relation. | ||||||
|  |  | ||||||
|  |     .. attribute:: FilteredRelation.condition | ||||||
|  |  | ||||||
|  |         A :class:`~django.db.models.Q` object to control the filtering. | ||||||
|  |  | ||||||
|  | ``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an | ||||||
|  | ``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default | ||||||
|  | relationship but on the annotation name (``pizzas_vegetarian`` in example | ||||||
|  | below). | ||||||
|  |  | ||||||
|  | For example, to find restaurants that have vegetarian pizzas with | ||||||
|  | ``'mozzarella'`` in the name:: | ||||||
|  |  | ||||||
|  |     >>> from django.db.models import FilteredRelation, Q | ||||||
|  |     >>> Restaurant.objects.annotate( | ||||||
|  |     ...    pizzas_vegetarian=FilteredRelation( | ||||||
|  |     ...        'pizzas', condition=Q(pizzas__vegetarian=True), | ||||||
|  |     ...    ), | ||||||
|  |     ... ).filter(pizzas_vegetarian__name__icontains='mozzarella') | ||||||
|  |  | ||||||
|  | If there are a large number of pizzas, this queryset performs better than:: | ||||||
|  |  | ||||||
|  |     >>> Restaurant.objects.filter( | ||||||
|  |     ...     pizzas__vegetarian=True, | ||||||
|  |     ...     pizzas__name__icontains='mozzarella', | ||||||
|  |     ... ) | ||||||
|  |  | ||||||
|  | because the filtering in the ``WHERE`` clause of the first queryset will only | ||||||
|  | operate on vegetarian pizzas. | ||||||
|  |  | ||||||
|  | ``FilteredRelation`` doesn't support: | ||||||
|  |  | ||||||
|  | * Conditions that span relational fields. For example:: | ||||||
|  |  | ||||||
|  |     >>> Restaurant.objects.annotate( | ||||||
|  |     ...    pizzas_with_toppings_startswith_n=FilteredRelation( | ||||||
|  |     ...        'pizzas__toppings', | ||||||
|  |     ...        condition=Q(pizzas__toppings__name__startswith='n'), | ||||||
|  |     ...    ), | ||||||
|  |     ... ) | ||||||
|  |     Traceback (most recent call last): | ||||||
|  |     ... | ||||||
|  |     ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith'). | ||||||
|  | * :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`. | ||||||
|  | * A :class:`~django.contrib.contenttypes.fields.GenericForeignKey` | ||||||
|  |   inherited from a parent model. | ||||||
|   | |||||||
| @@ -354,6 +354,9 @@ Models | |||||||
| * The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching | * The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching | ||||||
|   results as named tuples. |   results as named tuples. | ||||||
|  |  | ||||||
|  | * The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to | ||||||
|  |   querysets. | ||||||
|  |  | ||||||
| Pagination | Pagination | ||||||
| ~~~~~~~~~~ | ~~~~~~~~~~ | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										0
									
								
								tests/filtered_relation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/filtered_relation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										108
									
								
								tests/filtered_relation/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								tests/filtered_relation/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | |||||||
|  | from django.contrib.contenttypes.fields import ( | ||||||
|  |     GenericForeignKey, GenericRelation, | ||||||
|  | ) | ||||||
|  | from django.contrib.contenttypes.models import ContentType | ||||||
|  | from django.db import models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Author(models.Model): | ||||||
|  |     name = models.CharField(max_length=50, unique=True) | ||||||
|  |     favorite_books = models.ManyToManyField( | ||||||
|  |         'Book', | ||||||
|  |         related_name='preferred_by_authors', | ||||||
|  |         related_query_name='preferred_by_authors', | ||||||
|  |     ) | ||||||
|  |     content_type = models.ForeignKey(ContentType, models.CASCADE, null=True) | ||||||
|  |     object_id = models.PositiveIntegerField(null=True) | ||||||
|  |     content_object = GenericForeignKey() | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.name | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Editor(models.Model): | ||||||
|  |     name = models.CharField(max_length=255) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.name | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Book(models.Model): | ||||||
|  |     AVAILABLE = 'available' | ||||||
|  |     RESERVED = 'reserved' | ||||||
|  |     RENTED = 'rented' | ||||||
|  |     STATES = ( | ||||||
|  |         (AVAILABLE, 'Available'), | ||||||
|  |         (RESERVED, 'reserved'), | ||||||
|  |         (RENTED, 'Rented'), | ||||||
|  |     ) | ||||||
|  |     title = models.CharField(max_length=255) | ||||||
|  |     author = models.ForeignKey( | ||||||
|  |         Author, | ||||||
|  |         models.CASCADE, | ||||||
|  |         related_name='books', | ||||||
|  |         related_query_name='book', | ||||||
|  |     ) | ||||||
|  |     editor = models.ForeignKey(Editor, models.CASCADE) | ||||||
|  |     generic_author = GenericRelation(Author) | ||||||
|  |     state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.title | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Borrower(models.Model): | ||||||
|  |     name = models.CharField(max_length=50, unique=True) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.name | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Reservation(models.Model): | ||||||
|  |     NEW = 'new' | ||||||
|  |     STOPPED = 'stopped' | ||||||
|  |     STATES = ( | ||||||
|  |         (NEW, 'New'), | ||||||
|  |         (STOPPED, 'Stopped'), | ||||||
|  |     ) | ||||||
|  |     borrower = models.ForeignKey( | ||||||
|  |         Borrower, | ||||||
|  |         models.CASCADE, | ||||||
|  |         related_name='reservations', | ||||||
|  |         related_query_name='reservation', | ||||||
|  |     ) | ||||||
|  |     book = models.ForeignKey( | ||||||
|  |         Book, | ||||||
|  |         models.CASCADE, | ||||||
|  |         related_name='reservations', | ||||||
|  |         related_query_name='reservation', | ||||||
|  |     ) | ||||||
|  |     state = models.CharField(max_length=7, choices=STATES, default=NEW) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return '-'.join((self.book.name, self.borrower.name, self.state)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RentalSession(models.Model): | ||||||
|  |     NEW = 'new' | ||||||
|  |     STOPPED = 'stopped' | ||||||
|  |     STATES = ( | ||||||
|  |         (NEW, 'New'), | ||||||
|  |         (STOPPED, 'Stopped'), | ||||||
|  |     ) | ||||||
|  |     borrower = models.ForeignKey( | ||||||
|  |         Borrower, | ||||||
|  |         models.CASCADE, | ||||||
|  |         related_name='rental_sessions', | ||||||
|  |         related_query_name='rental_session', | ||||||
|  |     ) | ||||||
|  |     book = models.ForeignKey( | ||||||
|  |         Book, | ||||||
|  |         models.CASCADE, | ||||||
|  |         related_name='rental_sessions', | ||||||
|  |         related_query_name='rental_session', | ||||||
|  |     ) | ||||||
|  |     state = models.CharField(max_length=7, choices=STATES, default=NEW) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return '-'.join((self.book.name, self.borrower.name, self.state)) | ||||||
							
								
								
									
										381
									
								
								tests/filtered_relation/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										381
									
								
								tests/filtered_relation/tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,381 @@ | |||||||
|  | from django.db import connection | ||||||
|  | from django.db.models import Case, Count, F, FilteredRelation, Q, When | ||||||
|  | from django.test import TestCase | ||||||
|  | from django.test.testcases import skipUnlessDBFeature | ||||||
|  |  | ||||||
|  | from .models import Author, Book, Borrower, Editor, RentalSession, Reservation | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FilteredRelationTests(TestCase): | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         cls.author1 = Author.objects.create(name='Alice') | ||||||
|  |         cls.author2 = Author.objects.create(name='Jane') | ||||||
|  |         cls.editor_a = Editor.objects.create(name='a') | ||||||
|  |         cls.editor_b = Editor.objects.create(name='b') | ||||||
|  |         cls.book1 = Book.objects.create( | ||||||
|  |             title='Poem by Alice', | ||||||
|  |             editor=cls.editor_a, | ||||||
|  |             author=cls.author1, | ||||||
|  |         ) | ||||||
|  |         cls.book1.generic_author.set([cls.author2]) | ||||||
|  |         cls.book2 = Book.objects.create( | ||||||
|  |             title='The book by Jane A', | ||||||
|  |             editor=cls.editor_b, | ||||||
|  |             author=cls.author2, | ||||||
|  |         ) | ||||||
|  |         cls.book3 = Book.objects.create( | ||||||
|  |             title='The book by Jane B', | ||||||
|  |             editor=cls.editor_b, | ||||||
|  |             author=cls.author2, | ||||||
|  |         ) | ||||||
|  |         cls.book4 = Book.objects.create( | ||||||
|  |             title='The book by Alice', | ||||||
|  |             editor=cls.editor_a, | ||||||
|  |             author=cls.author1, | ||||||
|  |         ) | ||||||
|  |         cls.author1.favorite_books.add(cls.book2) | ||||||
|  |         cls.author1.favorite_books.add(cls.book3) | ||||||
|  |  | ||||||
|  |     def test_select_related(self): | ||||||
|  |         qs = Author.objects.annotate( | ||||||
|  |             book_join=FilteredRelation('book'), | ||||||
|  |         ).select_related('book_join__editor').order_by('pk', 'book_join__pk') | ||||||
|  |         with self.assertNumQueries(1): | ||||||
|  |             self.assertQuerysetEqual(qs, [ | ||||||
|  |                 (self.author1, self.book1, self.editor_a, self.author1), | ||||||
|  |                 (self.author1, self.book4, self.editor_a, self.author1), | ||||||
|  |                 (self.author2, self.book2, self.editor_b, self.author2), | ||||||
|  |                 (self.author2, self.book3, self.editor_b, self.author2), | ||||||
|  |             ], lambda x: (x, x.book_join, x.book_join.editor, x.book_join.author)) | ||||||
|  |  | ||||||
|  |     def test_select_related_foreign_key(self): | ||||||
|  |         qs = Book.objects.annotate( | ||||||
|  |             author_join=FilteredRelation('author'), | ||||||
|  |         ).select_related('author_join').order_by('pk') | ||||||
|  |         with self.assertNumQueries(1): | ||||||
|  |             self.assertQuerysetEqual(qs, [ | ||||||
|  |                 (self.book1, self.author1), | ||||||
|  |                 (self.book2, self.author2), | ||||||
|  |                 (self.book3, self.author2), | ||||||
|  |                 (self.book4, self.author1), | ||||||
|  |             ], lambda x: (x, x.author_join)) | ||||||
|  |  | ||||||
|  |     def test_without_join(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ), | ||||||
|  |             [self.author1, self.author2] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_with_join(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ).filter(book_alice__isnull=False), | ||||||
|  |             [self.author1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_with_join_and_complex_condition(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation( | ||||||
|  |                     'book', condition=Q( | ||||||
|  |                         Q(book__title__iexact='poem by alice') | | ||||||
|  |                         Q(book__state=Book.RENTED) | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |             ).filter(book_alice__isnull=False), | ||||||
|  |             [self.author1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_internal_queryset_alias_mapping(self): | ||||||
|  |         queryset = Author.objects.annotate( | ||||||
|  |             book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |         ).filter(book_alice__isnull=False) | ||||||
|  |         self.assertIn( | ||||||
|  |             'INNER JOIN {} book_alice ON'.format(connection.ops.quote_name('filtered_relation_book')), | ||||||
|  |             str(queryset.query) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_with_multiple_filter(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_editor_a=FilteredRelation( | ||||||
|  |                     'book', | ||||||
|  |                     condition=Q(book__title__icontains='book', book__editor_id=self.editor_a.pk), | ||||||
|  |                 ), | ||||||
|  |             ).filter(book_editor_a__isnull=False), | ||||||
|  |             [self.author1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_multiple_times(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_title_alice=FilteredRelation('book', condition=Q(book__title__icontains='alice')), | ||||||
|  |             ).filter(book_title_alice__isnull=False).filter(book_title_alice__isnull=False).distinct(), | ||||||
|  |             [self.author1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_exclude_relation_with_join(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=~Q(book__title__icontains='alice')), | ||||||
|  |             ).filter(book_alice__isnull=False).distinct(), | ||||||
|  |             [self.author2] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_with_m2m(self): | ||||||
|  |         qs = Author.objects.annotate( | ||||||
|  |             favorite_books_written_by_jane=FilteredRelation( | ||||||
|  |                 'favorite_books', condition=Q(favorite_books__in=[self.book2]), | ||||||
|  |             ), | ||||||
|  |         ).filter(favorite_books_written_by_jane__isnull=False) | ||||||
|  |         self.assertSequenceEqual(qs, [self.author1]) | ||||||
|  |  | ||||||
|  |     def test_with_m2m_deep(self): | ||||||
|  |         qs = Author.objects.annotate( | ||||||
|  |             favorite_books_written_by_jane=FilteredRelation( | ||||||
|  |                 'favorite_books', condition=Q(favorite_books__author=self.author2), | ||||||
|  |             ), | ||||||
|  |         ).filter(favorite_books_written_by_jane__title='The book by Jane B') | ||||||
|  |         self.assertSequenceEqual(qs, [self.author1]) | ||||||
|  |  | ||||||
|  |     def test_with_m2m_multijoin(self): | ||||||
|  |         qs = Author.objects.annotate( | ||||||
|  |             favorite_books_written_by_jane=FilteredRelation( | ||||||
|  |                 'favorite_books', condition=Q(favorite_books__author=self.author2), | ||||||
|  |             ) | ||||||
|  |         ).filter(favorite_books_written_by_jane__editor__name='b').distinct() | ||||||
|  |         self.assertSequenceEqual(qs, [self.author1]) | ||||||
|  |  | ||||||
|  |     def test_values_list(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ).filter(book_alice__isnull=False).values_list('book_alice__title', flat=True), | ||||||
|  |             ['Poem by Alice'] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_values(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ).filter(book_alice__isnull=False).values(), | ||||||
|  |             [{'id': self.author1.pk, 'name': 'Alice', 'content_type_id': None, 'object_id': None}] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_extra(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ).filter(book_alice__isnull=False).extra(where=['1 = 1']), | ||||||
|  |             [self.author1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_select_union') | ||||||
|  |     def test_union(self): | ||||||
|  |         qs1 = Author.objects.annotate( | ||||||
|  |             book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |         ).filter(book_alice__isnull=False) | ||||||
|  |         qs2 = Author.objects.annotate( | ||||||
|  |             book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), | ||||||
|  |         ).filter(book_jane__isnull=False) | ||||||
|  |         self.assertSequenceEqual(qs1.union(qs2), [self.author1, self.author2]) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_select_intersection') | ||||||
|  |     def test_intersection(self): | ||||||
|  |         qs1 = Author.objects.annotate( | ||||||
|  |             book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |         ).filter(book_alice__isnull=False) | ||||||
|  |         qs2 = Author.objects.annotate( | ||||||
|  |             book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), | ||||||
|  |         ).filter(book_jane__isnull=False) | ||||||
|  |         self.assertSequenceEqual(qs1.intersection(qs2), []) | ||||||
|  |  | ||||||
|  |     @skipUnlessDBFeature('supports_select_difference') | ||||||
|  |     def test_difference(self): | ||||||
|  |         qs1 = Author.objects.annotate( | ||||||
|  |             book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |         ).filter(book_alice__isnull=False) | ||||||
|  |         qs2 = Author.objects.annotate( | ||||||
|  |             book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), | ||||||
|  |         ).filter(book_jane__isnull=False) | ||||||
|  |         self.assertSequenceEqual(qs1.difference(qs2), [self.author1]) | ||||||
|  |  | ||||||
|  |     def test_select_for_update(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), | ||||||
|  |             ).filter(book_jane__isnull=False).select_for_update(), | ||||||
|  |             [self.author2] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_defer(self): | ||||||
|  |         # One query for the list and one query for the deferred title. | ||||||
|  |         with self.assertNumQueries(2): | ||||||
|  |             self.assertQuerysetEqual( | ||||||
|  |                 Author.objects.annotate( | ||||||
|  |                     book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |                 ).filter(book_alice__isnull=False).select_related('book_alice').defer('book_alice__title'), | ||||||
|  |                 ['Poem by Alice'], lambda author: author.book_alice.title | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_only_not_supported(self): | ||||||
|  |         msg = 'only() is not supported with FilteredRelation.' | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             Author.objects.annotate( | ||||||
|  |                 book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |             ).filter(book_alice__isnull=False).select_related('book_alice').only('book_alice__state') | ||||||
|  |  | ||||||
|  |     def test_as_subquery(self): | ||||||
|  |         inner_qs = Author.objects.annotate( | ||||||
|  |             book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), | ||||||
|  |         ).filter(book_alice__isnull=False) | ||||||
|  |         qs = Author.objects.filter(id__in=inner_qs) | ||||||
|  |         self.assertSequenceEqual(qs, [self.author1]) | ||||||
|  |  | ||||||
|  |     def test_with_foreign_key_error(self): | ||||||
|  |         msg = ( | ||||||
|  |             "FilteredRelation's condition doesn't support nested relations " | ||||||
|  |             "(got 'author__favorite_books__author')." | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             list(Book.objects.annotate( | ||||||
|  |                 alice_favorite_books=FilteredRelation( | ||||||
|  |                     'author__favorite_books', | ||||||
|  |                     condition=Q(author__favorite_books__author=self.author1), | ||||||
|  |                 ) | ||||||
|  |             )) | ||||||
|  |  | ||||||
|  |     def test_with_foreign_key_on_condition_error(self): | ||||||
|  |         msg = ( | ||||||
|  |             "FilteredRelation's condition doesn't support nested relations " | ||||||
|  |             "(got 'book__editor__name__icontains')." | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             list(Author.objects.annotate( | ||||||
|  |                 book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')), | ||||||
|  |             )) | ||||||
|  |  | ||||||
|  |     def test_with_empty_relation_name_error(self): | ||||||
|  |         with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'): | ||||||
|  |             FilteredRelation('', condition=Q(blank='')) | ||||||
|  |  | ||||||
|  |     def test_with_condition_as_expression_error(self): | ||||||
|  |         msg = 'condition argument must be a Q() instance.' | ||||||
|  |         expression = Case( | ||||||
|  |             When(book__title__iexact='poem by alice', then=True), default=False, | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             FilteredRelation('book', condition=expression) | ||||||
|  |  | ||||||
|  |     def test_with_prefetch_related(self): | ||||||
|  |         msg = 'prefetch_related() is not supported with FilteredRelation.' | ||||||
|  |         qs = Author.objects.annotate( | ||||||
|  |             book_title_contains_b=FilteredRelation('book', condition=Q(book__title__icontains='b')), | ||||||
|  |         ).filter( | ||||||
|  |             book_title_contains_b__isnull=False, | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             qs.prefetch_related('book_title_contains_b') | ||||||
|  |         with self.assertRaisesMessage(ValueError, msg): | ||||||
|  |             qs.prefetch_related('book_title_contains_b__editor') | ||||||
|  |  | ||||||
|  |     def test_with_generic_foreign_key(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             Book.objects.annotate( | ||||||
|  |                 generic_authored_book=FilteredRelation( | ||||||
|  |                     'generic_author', | ||||||
|  |                     condition=Q(generic_author__isnull=False) | ||||||
|  |                 ), | ||||||
|  |             ).filter(generic_authored_book__isnull=False), | ||||||
|  |             [self.book1] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FilteredRelationAggregationTests(TestCase): | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         cls.author1 = Author.objects.create(name='Alice') | ||||||
|  |         cls.editor_a = Editor.objects.create(name='a') | ||||||
|  |         cls.book1 = Book.objects.create( | ||||||
|  |             title='Poem by Alice', | ||||||
|  |             editor=cls.editor_a, | ||||||
|  |             author=cls.author1, | ||||||
|  |         ) | ||||||
|  |         cls.borrower1 = Borrower.objects.create(name='Jenny') | ||||||
|  |         cls.borrower2 = Borrower.objects.create(name='Kevin') | ||||||
|  |         # borrower 1 reserves, rents, and returns book1. | ||||||
|  |         Reservation.objects.create( | ||||||
|  |             borrower=cls.borrower1, | ||||||
|  |             book=cls.book1, | ||||||
|  |             state=Reservation.STOPPED, | ||||||
|  |         ) | ||||||
|  |         RentalSession.objects.create( | ||||||
|  |             borrower=cls.borrower1, | ||||||
|  |             book=cls.book1, | ||||||
|  |             state=RentalSession.STOPPED, | ||||||
|  |         ) | ||||||
|  |         # borrower2 reserves, rents, and returns book1. | ||||||
|  |         Reservation.objects.create( | ||||||
|  |             borrower=cls.borrower2, | ||||||
|  |             book=cls.book1, | ||||||
|  |             state=Reservation.STOPPED, | ||||||
|  |         ) | ||||||
|  |         RentalSession.objects.create( | ||||||
|  |             borrower=cls.borrower2, | ||||||
|  |             book=cls.book1, | ||||||
|  |             state=RentalSession.STOPPED, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_aggregate(self): | ||||||
|  |         """ | ||||||
|  |         filtered_relation() not only improves performance but also creates | ||||||
|  |         correct results when aggregating with multiple LEFT JOINs. | ||||||
|  |  | ||||||
|  |         Books can be reserved then rented by a borrower. Each reservation and | ||||||
|  |         rental session are recorded with Reservation and RentalSession models. | ||||||
|  |         Every time a reservation or a rental session is over, their state is | ||||||
|  |         changed to 'stopped'. | ||||||
|  |  | ||||||
|  |         Goal: Count number of books that are either currently reserved or | ||||||
|  |         rented by borrower1 or available. | ||||||
|  |         """ | ||||||
|  |         qs = Book.objects.annotate( | ||||||
|  |             is_reserved_or_rented_by=Case( | ||||||
|  |                 When(reservation__state=Reservation.NEW, then=F('reservation__borrower__pk')), | ||||||
|  |                 When(rental_session__state=RentalSession.NEW, then=F('rental_session__borrower__pk')), | ||||||
|  |                 default=None, | ||||||
|  |             ) | ||||||
|  |         ).filter( | ||||||
|  |             Q(is_reserved_or_rented_by=self.borrower1.pk) | Q(state=Book.AVAILABLE) | ||||||
|  |         ).distinct() | ||||||
|  |         self.assertEqual(qs.count(), 1) | ||||||
|  |         # If count is equal to 1, the same aggregation should return in the | ||||||
|  |         # same result but it returns 4. | ||||||
|  |         self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 4}]) | ||||||
|  |         # With FilteredRelation, the result is as expected (1). | ||||||
|  |         qs = Book.objects.annotate( | ||||||
|  |             active_reservations=FilteredRelation( | ||||||
|  |                 'reservation', condition=Q( | ||||||
|  |                     reservation__state=Reservation.NEW, | ||||||
|  |                     reservation__borrower=self.borrower1, | ||||||
|  |                 ) | ||||||
|  |             ), | ||||||
|  |         ).annotate( | ||||||
|  |             active_rental_sessions=FilteredRelation( | ||||||
|  |                 'rental_session', condition=Q( | ||||||
|  |                     rental_session__state=RentalSession.NEW, | ||||||
|  |                     rental_session__borrower=self.borrower1, | ||||||
|  |                 ) | ||||||
|  |             ), | ||||||
|  |         ).filter( | ||||||
|  |             (Q(active_reservations__isnull=False) | Q(active_rental_sessions__isnull=False)) | | ||||||
|  |             Q(state=Book.AVAILABLE) | ||||||
|  |         ).distinct() | ||||||
|  |         self.assertEqual(qs.count(), 1) | ||||||
|  |         self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 1}]) | ||||||
| @@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject): | |||||||
|     def get_joining_columns(self, reverse_join=False): |     def get_joining_columns(self, reverse_join=False): | ||||||
|         return () |         return () | ||||||
|  |  | ||||||
|     def get_path_info(self): |     def get_path_info(self, filtered_relation=None): | ||||||
|         to_opts = self.remote_field.model._meta |         to_opts = self.remote_field.model._meta | ||||||
|         from_opts = self.model._meta |         from_opts = self.model._meta | ||||||
|         return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)] |         return [PathInfo( | ||||||
|  |             from_opts=from_opts, | ||||||
|  |             to_opts=to_opts, | ||||||
|  |             target_fields=(to_opts.pk,), | ||||||
|  |             join_field=self, | ||||||
|  |             m2m=False, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     def get_reverse_path_info(self): |     def get_reverse_path_info(self, filtered_relation=None): | ||||||
|         to_opts = self.model._meta |         to_opts = self.model._meta | ||||||
|         from_opts = self.remote_field.model._meta |         from_opts = self.remote_field.model._meta | ||||||
|         return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)] |         return [PathInfo( | ||||||
|  |             from_opts=from_opts, | ||||||
|  |             to_opts=to_opts, | ||||||
|  |             target_fields=(to_opts.pk,), | ||||||
|  |             join_field=self.remote_field, | ||||||
|  |             m2m=False, | ||||||
|  |             direct=False, | ||||||
|  |             filtered_relation=filtered_relation, | ||||||
|  |         )] | ||||||
|  |  | ||||||
|     def contribute_to_class(self, cls, name, private_only=False): |     def contribute_to_class(self, cls, name, private_only=False): | ||||||
|         super().contribute_to_class(cls, name, private_only) |         super().contribute_to_class(cls, name, private_only) | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
|  | from django.db.models import FilteredRelation | ||||||
| from django.test import SimpleTestCase, TestCase | from django.test import SimpleTestCase, TestCase | ||||||
|  |  | ||||||
| from .models import ( | from .models import ( | ||||||
| @@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase): | |||||||
|  |  | ||||||
|         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): |         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): | ||||||
|             list(User.objects.select_related('username')) |             list(User.objects.select_related('username')) | ||||||
|  |  | ||||||
|  |     def test_reverse_related_validation_with_filtered_relation(self): | ||||||
|  |         fields = 'userprofile, userstat, relation' | ||||||
|  |         with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)): | ||||||
|  |             list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar')) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user