1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.

Thanks Anssi Kääriäinen for contributing to the patch.
This commit is contained in:
Nicolas Delaby
2017-09-22 17:53:17 +02:00
committed by Tim Graham
parent 3f9d85d95c
commit 01d440fa1e
17 changed files with 916 additions and 83 deletions

View File

@@ -348,7 +348,7 @@ class GenericRelation(ForeignObject):
self.to_fields = [self.model._meta.pk.name]
return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
def _get_path_info_with_parent(self):
def _get_path_info_with_parent(self, filtered_relation):
"""
Return the path that joins the current model through any parent models.
The idea is that if you have a GFK defined on a parent model then we
@@ -365,7 +365,15 @@ class GenericRelation(ForeignObject):
opts = self.remote_field.model._meta.concrete_model._meta
parent_opts = opts.get_field(self.object_id_field_name).model._meta
target = parent_opts.pk
path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False))
path.append(PathInfo(
from_opts=self.model._meta,
to_opts=parent_opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
))
# Collect joins needed for the parent -> child chain. This is easiest
# to do if we collect joins for the child -> parent chain and then
# reverse the direction (call to reverse() and use of
@@ -380,19 +388,35 @@ class GenericRelation(ForeignObject):
path.extend(field.remote_field.get_path_info())
return path
def get_path_info(self):
def get_path_info(self, filtered_relation=None):
opts = self.remote_field.model._meta
object_id_field = opts.get_field(self.object_id_field_name)
if object_id_field.model != opts.model:
return self._get_path_info_with_parent()
return self._get_path_info_with_parent(filtered_relation)
else:
target = opts.pk
return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)]
return [PathInfo(
from_opts=self.model._meta,
to_opts=opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self):
def get_reverse_path_info(self, filtered_relation=None):
opts = self.model._meta
from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)]
return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def value_to_string(self, obj):
qs = getattr(obj, self.name).all()

View File

@@ -20,6 +20,7 @@ from django.db.models.manager import Manager
from django.db.models.query import (
Prefetch, Q, QuerySet, prefetch_related_objects,
)
from django.db.models.query_utils import FilteredRelation
# Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip
@@ -69,6 +70,7 @@ __all__ += [
'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink',
]

View File

@@ -697,18 +697,33 @@ class ForeignObject(RelatedField):
"""
return None
def get_path_info(self):
def get_path_info(self, filtered_relation=None):
"""Get path from this field to the related model."""
opts = self.remote_field.model._meta
from_opts = self.model._meta
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=self.foreign_related_fields,
join_field=self,
m2m=False,
direct=True,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self):
def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model."""
opts = self.model._meta
from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)]
return pathinfos
return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
@classmethod
@functools.lru_cache(maxsize=None)
@@ -861,12 +876,19 @@ class ForeignKey(ForeignObject):
def target_field(self):
return self.foreign_related_fields[0]
def get_reverse_path_info(self):
def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model."""
opts = self.model._meta
from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)]
return pathinfos
return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def validate(self, value, model_instance):
if self.remote_field.parent_link:
@@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField):
)
return name, path, args, kwargs
def _get_path_info(self, direct=False):
def _get_path_info(self, direct=False, filtered_relation=None):
"""Called by both direct and indirect m2m traversal."""
pathinfos = []
int_model = self.remote_field.through
@@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField):
linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name())
if direct:
join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info()
join2infos = linkfield2.get_path_info(filtered_relation)
else:
join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info()
join2infos = linkfield1.get_path_info(filtered_relation)
# Get join infos between the last model of join 1 and the first model
# of join 2. Assume the only reason these may differ is due to model
@@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField):
pathinfos.extend(join2infos)
return pathinfos
def get_path_info(self):
return self._get_path_info(direct=True)
def get_path_info(self, filtered_relation=None):
return self._get_path_info(direct=True, filtered_relation=filtered_relation)
def get_reverse_path_info(self):
return self._get_path_info(direct=False)
def get_reverse_path_info(self, filtered_relation=None):
return self._get_path_info(direct=False, filtered_relation=filtered_relation)
def _get_m2m_db_table(self, opts):
"""

View File

@@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin):
return self.related_name
return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self):
return self.field.get_reverse_path_info()
def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info(filtered_relation)
def get_cache_name(self):
"""

View File

@@ -632,7 +632,15 @@ class Options:
final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
path.append(PathInfo(
from_opts=final_field.model._meta,
to_opts=opts,
target_fields=targets,
join_field=final_field,
m2m=False,
direct=True,
filtered_relation=None,
))
return path
def get_path_from_parent(self, parent):

View File

@@ -22,7 +22,7 @@ from django.db.models.deletion import Collector
from django.db.models.expressions import F
from django.db.models.fields import AutoField
from django.db.models.functions import Trunc
from django.db.models.query_utils import InvalidQuery, Q
from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango30Warning
@@ -953,6 +953,12 @@ class QuerySet:
if lookups == (None,):
clone._prefetch_related_lookups = ()
else:
for lookup in lookups:
if isinstance(lookup, Prefetch):
lookup = lookup.prefetch_to
lookup = lookup.split(LOOKUP_SEP, 1)[0]
if lookup in self.query._filtered_relations:
raise ValueError('prefetch_related() is not supported with FilteredRelation.')
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
return clone
@@ -984,6 +990,9 @@ class QuerySet:
if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % alias)
if isinstance(annotation, FilteredRelation):
clone.query.add_filtered_relation(annotation, alias)
else:
clone.query.add_annotation(annotation, alias, is_summary=False)
for alias, annotation in clone.query.annotations.items():
@@ -1060,6 +1069,10 @@ class QuerySet:
# Can only pass None to defer(), not only(), as the rest option.
# That won't stop people trying to do this, so let's be explicit.
raise TypeError("Cannot pass None as an argument to only().")
for field in fields:
field = field.split(LOOKUP_SEP, 1)[0]
if field in self.query._filtered_relations:
raise ValueError('only() is not supported with FilteredRelation.')
clone = self._chain()
clone.query.add_immediate_loading(fields)
return clone
@@ -1730,9 +1743,9 @@ class RelatedPopulator:
# model's fields.
# - related_populators: a list of RelatedPopulator instances if
# select_related() descends to related models from this model.
# - field, remote_field: the fields to use for populating the
# internal fields cache. If remote_field is set then we also
# set the reverse link.
# - local_setter, remote_setter: Methods to set cached values on
# the object being populated and on the remote object. Usually
# these are Field.set_cached_value() methods.
select_fields = klass_info['select_fields']
from_parent = klass_info['from_parent']
if not from_parent:
@@ -1751,16 +1764,8 @@ class RelatedPopulator:
self.model_cls = klass_info['model']
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
self.related_populators = get_related_populators(klass_info, select, self.db)
reverse = klass_info['reverse']
field = klass_info['field']
self.remote_field = None
if reverse:
self.field = field.remote_field
self.remote_field = field
else:
self.field = field
if field.unique:
self.remote_field = field.remote_field
self.local_setter = klass_info['local_setter']
self.remote_setter = klass_info['remote_setter']
def populate(self, row, from_obj):
if self.reorder_for_init:
@@ -1774,9 +1779,9 @@ class RelatedPopulator:
if self.related_populators:
for rel_iter in self.related_populators:
rel_iter.populate(row, obj)
if self.remote_field:
self.remote_field.set_cached_value(obj, from_obj)
self.field.set_cached_value(from_obj, obj)
self.local_setter(from_obj, obj)
if obj is not None:
self.remote_setter(obj, from_obj)
def get_related_populators(klass_info, select, db):

View File

@@ -16,7 +16,7 @@ from django.utils import tree
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct')
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
class InvalidQuery(Exception):
@@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field):
check(target_opts) or
(getattr(field, 'primary_key', False) and check(field.model._meta))
)
class FilteredRelation:
"""Specify custom filtering in the ON clause of SQL joins."""
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
raise ValueError('relation_name cannot be empty.')
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
raise ValueError('condition argument must be a Q() instance.')
self.condition = condition
self.path = []
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition
)
def clone(self):
clone = FilteredRelation(self.relation_name, condition=self.condition)
clone.alias = self.alias
clone.path = self.path[:]
return clone
def resolve_expression(self, *args, **kwargs):
"""
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
query = compiler.query
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
return compiler.compile(where)

View File

@@ -702,7 +702,7 @@ class SQLCompiler:
"""
result = []
params = []
for alias in self.query.alias_map:
for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias]:
continue
try:
@@ -737,7 +737,7 @@ class SQLCompiler:
f.field.related_query_name()
for f in opts.related_objects if f.field.unique
)
return chain(direct_choices, reverse_choices)
return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth:
@@ -788,7 +788,8 @@ class SQLCompiler:
klass_info = {
'model': f.remote_field.model,
'field': f,
'reverse': False,
'local_setter': f.set_cached_value,
'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
'from_parent': False,
}
related_klass_infos.append(klass_info)
@@ -825,7 +826,8 @@ class SQLCompiler:
klass_info = {
'model': model,
'field': f,
'reverse': True,
'local_setter': f.remote_field.set_cached_value,
'remote_setter': f.set_cached_value,
'from_parent': from_parent,
}
related_klass_infos.append(klass_info)
@@ -842,6 +844,47 @@ class SQLCompiler:
next, restricted)
get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found)
for name in list(requested):
# Filtered relations work only on the topmost level.
if cur_depth > 1:
break
if name in self.query._filtered_relations:
fields_found.add(name)
f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
model = join_opts.model
alias = joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
def local_setter(obj, from_obj):
f.remote_field.set_cached_value(from_obj, obj)
def remote_setter(obj, from_obj):
setattr(from_obj, name, obj)
klass_info = {
'model': model,
'field': f,
'local_setter': local_setter,
'remote_setter': remote_setter,
'from_parent': from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
start_alias=alias, opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
klass_info['select_fields'] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
select, opts=model._meta, root_alias=alias,
cur_depth=cur_depth + 1, requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found)
if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError(

View File

@@ -41,7 +41,7 @@ class Join:
- relabeled_clone()
"""
def __init__(self, table_name, parent_alias, table_alias, join_type,
join_field, nullable):
join_field, nullable, filtered_relation=None):
# Join table
self.table_name = table_name
self.parent_alias = parent_alias
@@ -56,6 +56,7 @@ class Join:
self.join_field = join_field
# Is this join nullabled?
self.nullable = nullable
self.filtered_relation = filtered_relation
def as_sql(self, compiler, connection):
"""
@@ -85,7 +86,11 @@ class Join:
extra_sql, extra_params = compiler.compile(extra_cond)
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if not join_conditions:
# This might be a rel on the other end of an actual declared field.
declared_field = getattr(self.join_field, 'field', self.join_field)
@@ -101,18 +106,27 @@ class Join:
def relabeled_clone(self, change_map):
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
else:
filtered_relation = None
return self.__class__(
self.table_name, new_parent_alias, new_table_alias, self.join_type,
self.join_field, self.nullable)
self.join_field, self.nullable, filtered_relation=filtered_relation,
)
def __eq__(self, other):
if isinstance(other, self.__class__):
def equals(self, other, with_filtered_relation):
return (
isinstance(other, self.__class__) and
self.table_name == other.table_name and
self.parent_alias == other.parent_alias and
self.join_field == other.join_field
self.join_field == other.join_field and
(not with_filtered_relation or self.filtered_relation == other.filtered_relation)
)
return False
def __eq__(self, other):
return self.equals(other, with_filtered_relation=True)
def demote(self):
new = self.relabeled_clone({})
@@ -134,6 +148,7 @@ class BaseTable:
"""
join_type = None
parent_alias = None
filtered_relation = None
def __init__(self, table_name, alias):
self.table_name = table_name
@@ -146,3 +161,10 @@ class BaseTable:
def relabeled_clone(self, change_map):
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
def equals(self, other, with_filtered_relation):
return (
isinstance(self, other.__class__) and
self.table_name == other.table_name and
self.table_alias == other.table_alias
)

View File

@@ -45,6 +45,14 @@ def get_field_names_from_opts(opts):
))
def get_children_from_q(q):
for child in q.children:
if isinstance(child, Node):
yield from get_children_from_q(child)
else:
yield child
JoinInfo = namedtuple(
'JoinInfo',
('final_field', 'targets', 'opts', 'joins', 'path')
@@ -210,6 +218,8 @@ class Query:
# load.
self.deferred_loading = (frozenset(), True)
self._filtered_relations = {}
@property
def extra(self):
if self._extra is None:
@@ -311,6 +321,7 @@ class Query:
if 'subq_aliases' in self.__dict__:
obj.subq_aliases = self.subq_aliases.copy()
obj.used_aliases = self.used_aliases.copy()
obj._filtered_relations = self._filtered_relations.copy()
# Clear the cached_property
try:
del obj.base_table
@@ -624,6 +635,8 @@ class Query:
opts = orig_opts
for name in parts[:-1]:
old_model = cur_model
if name in self._filtered_relations:
name = self._filtered_relations[name].relation_name
source = opts.get_field(name)
if is_reverse_o2o(source):
cur_model = source.related_model
@@ -684,7 +697,7 @@ class Query:
for model, values in seen.items():
callback(target, model, values)
def table_alias(self, table_name, create=False):
def table_alias(self, table_name, create=False, filtered_relation=None):
"""
Return a table alias for the given table_name and whether this is a
new alias or not.
@@ -704,8 +717,8 @@ class Query:
alias_list.append(alias)
else:
# The first occurrence of a table uses the table name directly.
alias = table_name
self.table_map[alias] = [alias]
alias = filtered_relation.alias if filtered_relation is not None else table_name
self.table_map[table_name] = [alias]
self.alias_refcount[alias] = 1
return alias, True
@@ -881,7 +894,7 @@ class Query:
"""
return len([1 for count in self.alias_refcount.values() if count])
def join(self, join, reuse=None):
def join(self, join, reuse=None, reuse_with_filtered_relation=False):
"""
Return an alias for the 'join', either reusing an existing alias for
that join or creating a new one. 'join' is either a
@@ -890,18 +903,29 @@ class Query:
The 'reuse' parameter can be either None which means all joins are
reusable, or it can be a set containing the aliases that can be reused.
The 'reuse_with_filtered_relation' parameter is used when computing
FilteredRelation instances.
A join is always created as LOUTER if the lhs alias is LOUTER to make
sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
joins are created as LOUTER if the join is nullable.
"""
reuse = [a for a, j in self.alias_map.items()
if (reuse is None or a in reuse) and j == join]
if reuse:
self.ref_alias(reuse[0])
return reuse[0]
if reuse_with_filtered_relation and reuse:
reuse_aliases = [
a for a, j in self.alias_map.items()
if a in reuse and j.equals(join, with_filtered_relation=False)
]
else:
reuse_aliases = [
a for a, j in self.alias_map.items()
if (reuse is None or a in reuse) and j == join
]
if reuse_aliases:
self.ref_alias(reuse_aliases[0])
return reuse_aliases[0]
# No reuse is possible, so we need a new alias.
alias, _ = self.table_alias(join.table_name, create=True)
alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation)
if join.join_type:
if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
join_type = LOUTER
@@ -1090,7 +1114,8 @@ class Query:
(name, lhs.output_field.__class__.__name__))
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, allow_joins=True, split_subq=True):
can_reuse=None, allow_joins=True, split_subq=True,
reuse_with_filtered_relation=False):
"""
Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where
@@ -1112,6 +1137,9 @@ class Query:
The 'can_reuse' is a set of reusable joins for multijoins.
If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
will be reused.
The method will create a filter clause that can be added to the current
query. However, if the filter isn't added to the query then the caller
is responsible for unreffing the joins used.
@@ -1147,7 +1175,10 @@ class Query:
allow_many = not branch_negated or not split_subq
try:
join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many)
join_info = self.setup_joins(
parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many,
reuse_with_filtered_relation=reuse_with_filtered_relation,
)
# Prevent iterator from being consumed by check_related_objects()
if isinstance(value, Iterator):
@@ -1250,6 +1281,41 @@ class Query:
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False):
"""Add a FilteredRelation object to the current filter."""
connector = q_object.connector
current_negated ^= q_object.negated
branch_negated = branch_negated or q_object.negated
target_clause = self.where_class(connector=connector, negated=q_object.negated)
for child in q_object.children:
if isinstance(child, Node):
child_clause = self.build_filtered_relation_q(
child, reuse=reuse, branch_negated=branch_negated,
current_negated=current_negated,
)
else:
child_clause, _ = self.build_filter(
child, can_reuse=reuse, branch_negated=branch_negated,
current_negated=current_negated,
allow_joins=True, split_subq=False,
reuse_with_filtered_relation=True,
)
target_clause.add(child_clause, connector)
return target_clause
def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias
lookups = dict(get_children_from_q(filtered_relation.condition))
for lookup in chain((filtered_relation.relation_name,), lookups):
lookup_parts, field_parts, _ = self.solve_lookup_type(lookup)
shift = 2 if not lookup_parts else 1
if len(field_parts) > (shift + len(lookup_parts)):
raise ValueError(
"FilteredRelation's condition doesn't support nested "
"relations (got %r)." % lookup
)
self._filtered_relations[filtered_relation.alias] = filtered_relation
def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
"""
Walk the list of names and turns them into PathInfo tuples. A single
@@ -1272,12 +1338,15 @@ class Query:
name = opts.pk.name
field = None
filtered_relation = None
try:
field = opts.get_field(name)
except FieldDoesNotExist:
if name in self.annotation_select:
field = self.annotation_select[name].output_field
elif name in self._filtered_relations and pos == 0:
filtered_relation = self._filtered_relations[name]
field = opts.get_field(filtered_relation.relation_name)
if field is not None:
# Fields that contain one-to-many relations with a generic
# model (like a GenericForeignKey) cannot generate reverse
@@ -1301,7 +1370,10 @@ class Query:
pos -= 1
if pos == -1 or fail_on_missing:
field_names = list(get_field_names_from_opts(opts))
available = sorted(field_names + list(self.annotation_select))
available = sorted(
field_names + list(self.annotation_select) +
list(self._filtered_relations)
)
raise FieldError("Cannot resolve keyword '%s' into field. "
"Choices are: %s" % (name, ", ".join(available)))
break
@@ -1315,7 +1387,7 @@ class Query:
cur_names_with_path[1].extend(path_to_parent)
opts = path_to_parent[-1].to_opts
if hasattr(field, 'get_path_info'):
pathinfos = field.get_path_info()
pathinfos = field.get_path_info(filtered_relation)
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
@@ -1340,7 +1412,8 @@ class Query:
break
return path, final_field, targets, names[pos + 1:]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
reuse_with_filtered_relation=False):
"""
Compute the necessary table joins for the passage through the fields
given in 'names'. 'opts' is the Options class for the current model
@@ -1352,6 +1425,9 @@ class Query:
that can be reused. Note that non-reverse foreign keys are always
reusable when using setup_joins().
The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
parameter and force the relation on the given connections.
If 'allow_many' is False, then any reverse foreign key seen will
generate a MultiJoin exception.
@@ -1374,15 +1450,29 @@ class Query:
# joins at this stage - we will need the information about join type
# of the trimmed joins.
for join in path:
if join.filtered_relation:
filtered_relation = join.filtered_relation.clone()
table_alias = filtered_relation.alias
else:
filtered_relation = None
table_alias = None
opts = join.to_opts
if join.direct:
nullable = self.is_nullable(join.join_field)
else:
nullable = True
connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable)
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse)
connection = Join(
opts.db_table, alias, table_alias, INNER, join.join_field,
nullable, filtered_relation=filtered_relation,
)
reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
alias = self.join(
connection, reuse=reuse,
reuse_with_filtered_relation=reuse_with_filtered_relation,
)
joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
return JoinInfo(final_field, targets, opts, joins, path)
def trim_joins(self, targets, joins, path):
@@ -1402,6 +1492,8 @@ class Query:
for pos, info in enumerate(reversed(path)):
if len(joins) == 1 or not info.direct:
break
if info.filtered_relation:
break
join_targets = {t.column for t in info.join_field.foreign_related_fields}
cur_targets = {t.column for t in targets}
if not cur_targets.issubset(join_targets):
@@ -1425,7 +1517,7 @@ class Query:
return self.annotation_select[name]
else:
field_list = name.split(LOOKUP_SEP)
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse)
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
if len(targets) > 1:
raise FieldError("Referencing multicolumn fields with F() objects "
@@ -1602,7 +1694,10 @@ class Query:
# from the model on which the lookup failed.
raise
else:
names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select))
names = sorted(
list(get_field_names_from_opts(opts)) + list(self.extra) +
list(self.annotation_select) + list(self._filtered_relations)
)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))

View File

@@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example::
>>> from django.db.models import prefetch_related_objects
>>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants
>>> prefetch_related_objects(restaurants, 'pizzas__toppings')
``FilteredRelation()`` objects
------------------------------
.. versionadded:: 2.0
.. class:: FilteredRelation(relation_name, *, condition=Q())
.. attribute:: FilteredRelation.relation_name
The name of the field on which you'd like to filter the relation.
.. attribute:: FilteredRelation.condition
A :class:`~django.db.models.Q` object to control the filtering.
``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an
``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default
relationship but on the annotation name (``pizzas_vegetarian`` in example
below).
For example, to find restaurants that have vegetarian pizzas with
``'mozzarella'`` in the name::
>>> from django.db.models import FilteredRelation, Q
>>> Restaurant.objects.annotate(
... pizzas_vegetarian=FilteredRelation(
... 'pizzas', condition=Q(pizzas__vegetarian=True),
... ),
... ).filter(pizzas_vegetarian__name__icontains='mozzarella')
If there are a large number of pizzas, this queryset performs better than::
>>> Restaurant.objects.filter(
... pizzas__vegetarian=True,
... pizzas__name__icontains='mozzarella',
... )
because the filtering in the ``WHERE`` clause of the first queryset will only
operate on vegetarian pizzas.
``FilteredRelation`` doesn't support:
* Conditions that span relational fields. For example::
>>> Restaurant.objects.annotate(
... pizzas_with_toppings_startswith_n=FilteredRelation(
... 'pizzas__toppings',
... condition=Q(pizzas__toppings__name__startswith='n'),
... ),
... )
Traceback (most recent call last):
...
ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith').
* :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`.
* A :class:`~django.contrib.contenttypes.fields.GenericForeignKey`
inherited from a parent model.

View File

@@ -354,6 +354,9 @@ Models
* The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching
results as named tuples.
* The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to
querysets.
Pagination
~~~~~~~~~~

View File

View File

@@ -0,0 +1,108 @@
from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation,
)
from django.contrib.contenttypes.models import ContentType
from django.db import models
class Author(models.Model):
name = models.CharField(max_length=50, unique=True)
favorite_books = models.ManyToManyField(
'Book',
related_name='preferred_by_authors',
related_query_name='preferred_by_authors',
)
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
object_id = models.PositiveIntegerField(null=True)
content_object = GenericForeignKey()
def __str__(self):
return self.name
class Editor(models.Model):
name = models.CharField(max_length=255)
def __str__(self):
return self.name
class Book(models.Model):
AVAILABLE = 'available'
RESERVED = 'reserved'
RENTED = 'rented'
STATES = (
(AVAILABLE, 'Available'),
(RESERVED, 'reserved'),
(RENTED, 'Rented'),
)
title = models.CharField(max_length=255)
author = models.ForeignKey(
Author,
models.CASCADE,
related_name='books',
related_query_name='book',
)
editor = models.ForeignKey(Editor, models.CASCADE)
generic_author = GenericRelation(Author)
state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE)
def __str__(self):
return self.title
class Borrower(models.Model):
name = models.CharField(max_length=50, unique=True)
def __str__(self):
return self.name
class Reservation(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))
class RentalSession(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))

View File

@@ -0,0 +1,381 @@
from django.db import connection
from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from .models import Author, Book, Borrower, Editor, RentalSession, Reservation
class FilteredRelationTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.author1 = Author.objects.create(name='Alice')
cls.author2 = Author.objects.create(name='Jane')
cls.editor_a = Editor.objects.create(name='a')
cls.editor_b = Editor.objects.create(name='b')
cls.book1 = Book.objects.create(
title='Poem by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.book1.generic_author.set([cls.author2])
cls.book2 = Book.objects.create(
title='The book by Jane A',
editor=cls.editor_b,
author=cls.author2,
)
cls.book3 = Book.objects.create(
title='The book by Jane B',
editor=cls.editor_b,
author=cls.author2,
)
cls.book4 = Book.objects.create(
title='The book by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.author1.favorite_books.add(cls.book2)
cls.author1.favorite_books.add(cls.book3)
def test_select_related(self):
qs = Author.objects.annotate(
book_join=FilteredRelation('book'),
).select_related('book_join__editor').order_by('pk', 'book_join__pk')
with self.assertNumQueries(1):
self.assertQuerysetEqual(qs, [
(self.author1, self.book1, self.editor_a, self.author1),
(self.author1, self.book4, self.editor_a, self.author1),
(self.author2, self.book2, self.editor_b, self.author2),
(self.author2, self.book3, self.editor_b, self.author2),
], lambda x: (x, x.book_join, x.book_join.editor, x.book_join.author))
def test_select_related_foreign_key(self):
qs = Book.objects.annotate(
author_join=FilteredRelation('author'),
).select_related('author_join').order_by('pk')
with self.assertNumQueries(1):
self.assertQuerysetEqual(qs, [
(self.book1, self.author1),
(self.book2, self.author2),
(self.book3, self.author2),
(self.book4, self.author1),
], lambda x: (x, x.author_join))
def test_without_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
),
[self.author1, self.author2]
)
def test_with_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False),
[self.author1]
)
def test_with_join_and_complex_condition(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation(
'book', condition=Q(
Q(book__title__iexact='poem by alice') |
Q(book__state=Book.RENTED)
),
),
).filter(book_alice__isnull=False),
[self.author1]
)
def test_internal_queryset_alias_mapping(self):
queryset = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
self.assertIn(
'INNER JOIN {} book_alice ON'.format(connection.ops.quote_name('filtered_relation_book')),
str(queryset.query)
)
def test_with_multiple_filter(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_editor_a=FilteredRelation(
'book',
condition=Q(book__title__icontains='book', book__editor_id=self.editor_a.pk),
),
).filter(book_editor_a__isnull=False),
[self.author1]
)
def test_multiple_times(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_title_alice=FilteredRelation('book', condition=Q(book__title__icontains='alice')),
).filter(book_title_alice__isnull=False).filter(book_title_alice__isnull=False).distinct(),
[self.author1]
)
def test_exclude_relation_with_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=~Q(book__title__icontains='alice')),
).filter(book_alice__isnull=False).distinct(),
[self.author2]
)
def test_with_m2m(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__in=[self.book2]),
),
).filter(favorite_books_written_by_jane__isnull=False)
self.assertSequenceEqual(qs, [self.author1])
def test_with_m2m_deep(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__author=self.author2),
),
).filter(favorite_books_written_by_jane__title='The book by Jane B')
self.assertSequenceEqual(qs, [self.author1])
def test_with_m2m_multijoin(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__author=self.author2),
)
).filter(favorite_books_written_by_jane__editor__name='b').distinct()
self.assertSequenceEqual(qs, [self.author1])
def test_values_list(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).values_list('book_alice__title', flat=True),
['Poem by Alice']
)
def test_values(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).values(),
[{'id': self.author1.pk, 'name': 'Alice', 'content_type_id': None, 'object_id': None}]
)
def test_extra(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).extra(where=['1 = 1']),
[self.author1]
)
@skipUnlessDBFeature('supports_select_union')
def test_union(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.union(qs2), [self.author1, self.author2])
@skipUnlessDBFeature('supports_select_intersection')
def test_intersection(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.intersection(qs2), [])
@skipUnlessDBFeature('supports_select_difference')
def test_difference(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.difference(qs2), [self.author1])
def test_select_for_update(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False).select_for_update(),
[self.author2]
)
def test_defer(self):
# One query for the list and one query for the deferred title.
with self.assertNumQueries(2):
self.assertQuerysetEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).select_related('book_alice').defer('book_alice__title'),
['Poem by Alice'], lambda author: author.book_alice.title
)
def test_only_not_supported(self):
msg = 'only() is not supported with FilteredRelation.'
with self.assertRaisesMessage(ValueError, msg):
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).select_related('book_alice').only('book_alice__state')
def test_as_subquery(self):
inner_qs = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs = Author.objects.filter(id__in=inner_qs)
self.assertSequenceEqual(qs, [self.author1])
def test_with_foreign_key_error(self):
msg = (
"FilteredRelation's condition doesn't support nested relations "
"(got 'author__favorite_books__author')."
)
with self.assertRaisesMessage(ValueError, msg):
list(Book.objects.annotate(
alice_favorite_books=FilteredRelation(
'author__favorite_books',
condition=Q(author__favorite_books__author=self.author1),
)
))
def test_with_foreign_key_on_condition_error(self):
msg = (
"FilteredRelation's condition doesn't support nested relations "
"(got 'book__editor__name__icontains')."
)
with self.assertRaisesMessage(ValueError, msg):
list(Author.objects.annotate(
book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')),
))
def test_with_empty_relation_name_error(self):
with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'):
FilteredRelation('', condition=Q(blank=''))
def test_with_condition_as_expression_error(self):
msg = 'condition argument must be a Q() instance.'
expression = Case(
When(book__title__iexact='poem by alice', then=True), default=False,
)
with self.assertRaisesMessage(ValueError, msg):
FilteredRelation('book', condition=expression)
def test_with_prefetch_related(self):
msg = 'prefetch_related() is not supported with FilteredRelation.'
qs = Author.objects.annotate(
book_title_contains_b=FilteredRelation('book', condition=Q(book__title__icontains='b')),
).filter(
book_title_contains_b__isnull=False,
)
with self.assertRaisesMessage(ValueError, msg):
qs.prefetch_related('book_title_contains_b')
with self.assertRaisesMessage(ValueError, msg):
qs.prefetch_related('book_title_contains_b__editor')
def test_with_generic_foreign_key(self):
self.assertSequenceEqual(
Book.objects.annotate(
generic_authored_book=FilteredRelation(
'generic_author',
condition=Q(generic_author__isnull=False)
),
).filter(generic_authored_book__isnull=False),
[self.book1]
)
class FilteredRelationAggregationTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.author1 = Author.objects.create(name='Alice')
cls.editor_a = Editor.objects.create(name='a')
cls.book1 = Book.objects.create(
title='Poem by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.borrower1 = Borrower.objects.create(name='Jenny')
cls.borrower2 = Borrower.objects.create(name='Kevin')
# borrower 1 reserves, rents, and returns book1.
Reservation.objects.create(
borrower=cls.borrower1,
book=cls.book1,
state=Reservation.STOPPED,
)
RentalSession.objects.create(
borrower=cls.borrower1,
book=cls.book1,
state=RentalSession.STOPPED,
)
# borrower2 reserves, rents, and returns book1.
Reservation.objects.create(
borrower=cls.borrower2,
book=cls.book1,
state=Reservation.STOPPED,
)
RentalSession.objects.create(
borrower=cls.borrower2,
book=cls.book1,
state=RentalSession.STOPPED,
)
def test_aggregate(self):
"""
filtered_relation() not only improves performance but also creates
correct results when aggregating with multiple LEFT JOINs.
Books can be reserved then rented by a borrower. Each reservation and
rental session are recorded with Reservation and RentalSession models.
Every time a reservation or a rental session is over, their state is
changed to 'stopped'.
Goal: Count number of books that are either currently reserved or
rented by borrower1 or available.
"""
qs = Book.objects.annotate(
is_reserved_or_rented_by=Case(
When(reservation__state=Reservation.NEW, then=F('reservation__borrower__pk')),
When(rental_session__state=RentalSession.NEW, then=F('rental_session__borrower__pk')),
default=None,
)
).filter(
Q(is_reserved_or_rented_by=self.borrower1.pk) | Q(state=Book.AVAILABLE)
).distinct()
self.assertEqual(qs.count(), 1)
# If count is equal to 1, the same aggregation should return in the
# same result but it returns 4.
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 4}])
# With FilteredRelation, the result is as expected (1).
qs = Book.objects.annotate(
active_reservations=FilteredRelation(
'reservation', condition=Q(
reservation__state=Reservation.NEW,
reservation__borrower=self.borrower1,
)
),
).annotate(
active_rental_sessions=FilteredRelation(
'rental_session', condition=Q(
rental_session__state=RentalSession.NEW,
rental_session__borrower=self.borrower1,
)
),
).filter(
(Q(active_reservations__isnull=False) | Q(active_rental_sessions__isnull=False)) |
Q(state=Book.AVAILABLE)
).distinct()
self.assertEqual(qs.count(), 1)
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 1}])

View File

@@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject):
def get_joining_columns(self, reverse_join=False):
return ()
def get_path_info(self):
def get_path_info(self, filtered_relation=None):
to_opts = self.remote_field.model._meta
from_opts = self.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)]
return [PathInfo(
from_opts=from_opts,
to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self):
def get_reverse_path_info(self, filtered_relation=None):
to_opts = self.model._meta
from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)]
return [PathInfo(
from_opts=from_opts,
to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self.remote_field,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only)

View File

@@ -1,4 +1,5 @@
from django.core.exceptions import FieldError
from django.db.models import FilteredRelation
from django.test import SimpleTestCase, TestCase
from .models import (
@@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase):
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
list(User.objects.select_related('username'))
def test_reverse_related_validation_with_filtered_relation(self):
fields = 'userprofile, userstat, relation'
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar'))