1
0
mirror of https://github.com/django/django.git synced 2025-07-06 10:49:17 +00:00

queryset-refactor: Optimisation pass. The test suite is now within 2% of trunk and it's a fairly pathological case. Introduces a couple of test failures due to some simplification in the code. They'll be fixed later.

git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6730 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2007-11-29 04:56:09 +00:00
parent a97abcffc2
commit b43a018032
12 changed files with 350 additions and 286 deletions

View File

@ -11,16 +11,18 @@ if not settings.DATABASE_ENGINE:
settings.DATABASE_ENGINE = 'dummy'
try:
# Most of the time, the database backend will be one of the official
# Most of the time, the database backend will be one of the official
# backends that ships with Django, so look there first.
_import_path = 'django.db.backends.'
backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
except ImportError, e:
# If the import failed, we might be looking for a database backend
# If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next.
try:
_import_path = ''
backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, [''])
creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, [''])
except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends.
@ -37,10 +39,12 @@ def _import_database_module(import_path='', module_name=''):
"""Lazyily import a database module when requested."""
return __import__('%s%s.%s' % (_import_path, settings.DATABASE_ENGINE, module_name), {}, {}, [''])
# We don't want to import the introspect/creation modules unless
# someone asks for 'em, so lazily load them on demmand.
# We don't want to import the introspect module unless someone asks for it, so
# lazily load it on demmand.
get_introspection_module = curry(_import_database_module, _import_path, 'introspection')
get_creation_module = curry(_import_database_module, _import_path, 'creation')
def get_creation_module():
return creation
# We want runshell() to work the same way, but we have to treat it a
# little differently (since it just runs instead of returning a module like

View File

@ -790,8 +790,11 @@ class ManyToOneRel(object):
self.multiple = True
def get_related_field(self):
"Returns the Field in the 'to' object to which this relationship is tied."
return self.to._meta.get_field(self.field_name)
"""
Returns the Field in the 'to' object to which this relationship is
tied.
"""
return self.to._meta.get_field_by_name(self.field_name, True)[0]
class OneToOneRel(ManyToOneRel):
def __init__(self, to, field_name, num_in_admin=0, edit_inline=False,

View File

@ -93,7 +93,8 @@ class Options(object):
def add_field(self, field):
# Insert the given field in the order in which it was created, using
# the "creation_counter" attribute of the field.
# Move many-to-many related fields from self.fields into self.many_to_many.
# Move many-to-many related fields from self.fields into
# self.many_to_many.
if field.rel and isinstance(field.rel, ManyToManyRel):
self.many_to_many.insert(bisect(self.many_to_many, field), field)
else:
@ -129,6 +130,58 @@ class Options(object):
return f
raise FieldDoesNotExist, '%s has no field named %r' % (self.object_name, name)
def get_field_by_name(self, name, only_direct=False):
"""
Returns the (field_object, direct, m2m), where field_object is the
Field instance for the given name, direct is True if the field exists
on this model, and m2m is True for many-to-many relations. When
'direct' is False, 'field_object' is the corresponding RelatedObject
for this field (since the field doesn't have an instance associated
with it).
If 'only_direct' is True, only forwards relations (and non-relations)
are considered in the result.
Uses a cache internally, so after the first access, this is very fast.
"""
try:
result = self._name_map.get(name)
except AttributeError:
cache = self.init_name_map()
result = cache.get(name)
if not result or (not result[1] and only_direct):
raise FieldDoesNotExist('%s has no field named %r'
% (self.object_name, name))
return result
def get_all_field_names(self):
"""
Returns a list of all field names that are possible for this model
(including reverse relation names).
"""
try:
cache = self._name_map
except AttributeError:
cache = self.init_name_map()
names = cache.keys()
names.sort()
return names
def init_name_map(self):
"""
Initialises the field name -> field object mapping.
"""
cache = dict([(f.name, (f, True, False)) for f in self.fields])
cache.update([(f.name, (f, True, True)) for f in self.many_to_many])
cache.update([(f.field.related_query_name(), (f, False, True))
for f in self.get_all_related_many_to_many_objects()])
cache.update([(f.field.related_query_name(), (f, False, False))
for f in self.get_all_related_objects()])
if app_cache_ready():
self._name_map = cache
return cache
def get_add_permission(self):
return 'add_%s' % self.object_name.lower()

View File

@ -24,9 +24,9 @@ CHUNK_SIZE = 100
class _QuerySet(object):
"Represents a lazy database lookup for a set of objects"
def __init__(self, model=None):
def __init__(self, model=None, query=None):
self.model = model
self.query = sql.Query(self.model, connection)
self.query = query or sql.Query(self.model, connection)
self._result_cache = None
########################
@ -338,7 +338,7 @@ class _QuerySet(object):
if tables:
clone.query.extra_tables.extend(tables)
if order_by:
clone.query.extra_order_by.extend(order_by)
clone.query.extra_order_by = order_by
return clone
###################
@ -348,9 +348,7 @@ class _QuerySet(object):
def _clone(self, klass=None, setup=False, **kwargs):
if klass is None:
klass = self.__class__
c = klass()
c.model = self.model
c.query = self.query.clone()
c = klass(model=self.model, query=self.query.clone())
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@ -460,8 +458,8 @@ class DateQuerySet(QuerySet):
return c
class EmptyQuerySet(QuerySet):
def __init__(self, model=None):
super(EmptyQuerySet, self).__init__(model)
def __init__(self, model=None, query=None):
super(EmptyQuerySet, self).__init__(model, query)
self._result_cache = []
def count(self):

View File

@ -12,9 +12,11 @@ import re
from django.utils.tree import Node
from django.utils.datastructures import SortedDict
from django.dispatch import dispatcher
from django.db.models import signals
from django.db.models.sql.where import WhereNode, AND, OR
from django.db.models.sql.datastructures import Count, Date
from django.db.models.fields import FieldDoesNotExist, Field
from django.db.models.fields import FieldDoesNotExist, Field, related
from django.contrib.contenttypes import generic
from datastructures import EmptyResultSet
@ -49,7 +51,6 @@ RHS_JOIN_COL = 5
ALIAS_TABLE = 0
ALIAS_REFCOUNT = 1
ALIAS_JOIN = 2
ALIAS_MERGE_SEP = 3
# How many results to expect from a cursor.execute call
MULTI = 'multi'
@ -57,6 +58,12 @@ SINGLE = 'single'
NONE = None
ORDER_PATTERN = re.compile(r'\?|[-+]?\w+$')
ORDER_DIR = {
'ASC': ('ASC', 'DESC'),
'DESC': ('DESC', 'ASC')}
class Empty(object):
pass
class Query(object):
"""
@ -76,12 +83,13 @@ class Query(object):
self.table_map = {} # Maps table names to list of aliases.
self.join_map = {} # Maps join_tuple to list of aliases.
self.rev_join_map = {} # Reverse of join_map.
self.quote_cache = {}
self.default_cols = True
# SQL-related attributes
self.select = []
self.tables = [] # Aliases in the order they are created.
self.where = WhereNode(self)
self.where = WhereNode()
self.group_by = []
self.having = []
self.order_by = []
@ -118,29 +126,35 @@ class Query(object):
for table names. This avoids problems with some SQL dialects that treat
quoted strings specially (e.g. PostgreSQL).
"""
if name != self.alias_map.get(name, [name])[0]:
if name in self.quote_cache:
return self.quote_cache[name]
if name in self.alias_map and name not in self.table_map:
self.quote_cache[name] = name
return name
return self.connection.ops.quote_name(name)
r = self.connection.ops.quote_name(name)
self.quote_cache[name] = r
return r
def clone(self, klass=None, **kwargs):
"""
Creates a copy of the current instance. The 'kwargs' parameter can be
used by clients to update attributes after copying has taken place.
"""
if not klass:
klass = self.__class__
obj = klass(self.model, self.connection)
obj.table_map = self.table_map.copy()
obj = Empty()
obj.__class__ = klass or self.__class__
obj.model = self.model
obj.connection = self.connection
obj.alias_map = copy.deepcopy(self.alias_map)
obj.table_map = self.table_map.copy()
obj.join_map = copy.deepcopy(self.join_map)
obj.rev_join_map = copy.deepcopy(self.rev_join_map)
obj.quote_cache = {}
obj.default_cols = self.default_cols
obj.select = self.select[:]
obj.tables = self.tables[:]
obj.where = copy.deepcopy(self.where)
obj.where.query = obj
obj.having = self.having[:]
obj.group_by = self.group_by[:]
obj.having = self.having[:]
obj.order_by = self.order_by[:]
obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
obj.distinct = self.distinct
@ -175,7 +189,7 @@ class Query(object):
obj.clear_limits()
obj.select_related = False
if obj.distinct and len(obj.select) > 1:
obj = self.clone(CountQuery, _query=obj, where=WhereNode(self),
obj = self.clone(CountQuery, _query=obj, where=WhereNode(),
distinct=False)
obj.add_count_column()
data = obj.execute_sql(SINGLE)
@ -205,7 +219,7 @@ class Query(object):
# This must come after 'select' and 'ordering' -- see docstring of
# get_from_clause() for details.
from_, f_params = self.get_from_clause()
where, w_params = self.where.as_sql()
where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias)
result = ['SELECT']
if self.distinct:
@ -262,28 +276,28 @@ class Query(object):
# Work out how to relabel the rhs aliases, if necessary.
change_map = {}
used = {}
first_new_join = True
conjunction = (connection == AND)
first = True
for alias in rhs.tables:
if not rhs.alias_map[alias][ALIAS_REFCOUNT]:
# An unused alias.
continue
promote = (rhs.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] ==
self.LOUTER)
merge_separate = (connection == AND)
new_alias = self.join(rhs.rev_join_map[alias], exclusions=used,
promote=promote, outer_if_first=True,
merge_separate=merge_separate)
if self.alias_map[alias][ALIAS_REFCOUNT] == 1:
first_new_join = False
new_alias = self.join(rhs.rev_join_map[alias],
(conjunction and not first), used, promote, not conjunction)
used[new_alias] = None
change_map[alias] = new_alias
first = False
# So that we don't exclude valid results, the first join that is
# exclusive to the lhs (self) must be converted to an outer join.
for alias in self.tables[1:]:
if self.alias_map[alias][ALIAS_REFCOUNT] == 1:
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER
break
# So that we don't exclude valid results in an "or" query combination,
# the first join that is exclusive to the lhs (self) must be converted
# to an outer join.
if not conjunction:
for alias in self.tables[1:]:
if self.alias_map[alias][ALIAS_REFCOUNT] == 1:
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER
break
# Now relabel a copy of the rhs where-clause and add it to the current
# one.
@ -297,16 +311,16 @@ class Query(object):
# results.
alias = self.join((None, self.model._meta.db_table, None, None))
pk = self.model._meta.pk
self.where.add((alias, pk.column, pk, 'isnull', False), AND)
self.where.add([alias, pk.column, pk, 'isnull', False], AND)
elif self.where:
# rhs has an empty where clause. Make it match everything (see
# above for reasoning).
w = WhereNode(self)
w = WhereNode()
alias = self.join((None, self.model._meta.db_table, None, None))
pk = self.model._meta.pk
w.add((alias, pk.column, pk, 'isnull', False), AND)
w.add([alias, pk.column, pk, 'isnull', False], AND)
else:
w = WhereNode(self)
w = WhereNode()
self.where.add(w, connection)
# Selection columns and extra extensions are those provided by 'rhs'.
@ -400,7 +414,7 @@ class Query(object):
if join_type:
result.append('%s %s%s ON (%s.%s = %s.%s)'
% (join_type, qn(name), alias_str, qn(lhs),
qn(lhs_col), qn(alias), qn(col)))
qn(lhs_col), qn(alias), qn(col)))
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@ -472,7 +486,7 @@ class Query(object):
result.append('%s %s' % (elt, order))
elif get_order_dir(field)[0] not in self.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# 'field1__field2__field', etc.
# '-field1__field2__field', etc.
for table, col, order in self.find_ordering_name(field,
self.model._meta):
elt = '%s.%s' % (qn(table), qn(col))
@ -495,16 +509,14 @@ class Query(object):
pieces = name.split(LOOKUP_SEP)
if not alias:
alias = self.join((None, opts.db_table, None, None))
for elt in pieces:
joins, opts, unused1, field, col, unused2 = \
self.get_next_join(elt, opts, alias, False)
if joins:
alias = joins[-1]
col = col or field.column
field, target, opts, joins, unused2 = self.setup_joins(pieces, opts,
alias, False)
alias = joins[-1][-1]
col = target.column
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
if joins and opts.ordering:
if len(joins) > 1 and opts.ordering:
results = []
for item in opts.ordering:
results.extend(self.find_ordering_name(item, opts, alias,
@ -559,8 +571,7 @@ class Query(object):
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER
def join(self, (lhs, table, lhs_col, col), always_create=False,
exclusions=(), promote=False, outer_if_first=False,
merge_separate=False):
exclusions=(), promote=False, outer_if_first=False):
"""
Returns an alias for a join between 'table' and 'lhs' on the given
columns, either reusing an existing alias for that join or creating a
@ -581,44 +592,44 @@ class Query(object):
If 'outer_if_first' is True and a new join is created, it will have the
LOUTER join type. This is used when joining certain types of querysets
and Q-objects together.
If the 'merge_separate' parameter is True, we create a new alias if we
would otherwise reuse an alias that also had 'merge_separate' set to
True when it was created.
"""
if lhs not in self.alias_map:
if lhs is None:
lhs_table = None
is_table = False
elif lhs not in self.alias_map:
lhs_table = lhs
is_table = (lhs is not None)
is_table = True
else:
lhs_table = self.alias_map[lhs][ALIAS_TABLE]
is_table = False
t_ident = (lhs_table, table, lhs_col, col)
aliases = self.join_map.get(t_ident)
if aliases and not always_create:
for alias in aliases:
if (alias not in exclusions and
not (merge_separate and
self.alias_map[alias][ALIAS_MERGE_SEP])):
self.ref_alias(alias)
if promote:
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \
self.LOUTER
return alias
# If we get to here (no non-excluded alias exists), we'll fall
# through to creating a new alias.
if not always_create:
aliases = self.join_map.get(t_ident)
if aliases:
for alias in aliases:
if alias not in exclusions:
self.ref_alias(alias)
if promote:
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \
self.LOUTER
return alias
# If we get to here (no non-excluded alias exists), we'll fall
# through to creating a new alias.
# No reuse is possible, so we need a new alias.
assert not is_table, \
"Must pass in lhs alias when creating a new join."
alias, _ = self.table_alias(table, True)
join_type = (promote or outer_if_first) and self.LOUTER or self.INNER
if promote or outer_if_first:
join_type = self.LOUTER
else:
join_type = self.INNER
join = [table, alias, join_type, lhs, lhs_col, col]
if not lhs:
# Not all tables need to be joined to anything. No join type
# means the later columns are ignored.
join[JOIN_TYPE] = None
self.alias_map[alias][ALIAS_JOIN] = join
self.alias_map[alias][ALIAS_MERGE_SEP] = merge_separate
self.join_map.setdefault(t_ident, []).append(alias)
self.rev_join_map[alias] = t_ident
return alias
@ -677,51 +688,18 @@ class Query(object):
opts = self.model._meta
alias = self.join((None, opts.db_table, None, None))
dupe_multis = (connection == AND)
join_list = []
split = not self.where
null_point = None
# FIXME: Using enumerate() here is expensive. We only need 'i' to
# check we aren't joining against a non-joinable field. Find a
# better way to do this!
for i, name in enumerate(parts):
joins, opts, orig_field, target_field, target_col, nullable = \
self.get_next_join(name, opts, alias, dupe_multis)
if name == 'pk':
name = target_field.name
if joins is not None:
if null_point is None and nullable:
null_point = len(join_list)
join_list.append(joins)
alias = joins[-1]
if connection == OR and not split:
# FIXME: Document what's going on and why this is needed.
if self.alias_map[joins[0]][ALIAS_REFCOUNT] == 1:
split = True
self.promote_alias(joins[0])
all_aliases = []
for a in join_list:
all_aliases.extend(a)
for t in self.tables[1:]:
if t in all_aliases:
continue
self.promote_alias(t)
break
else:
# Normal field lookup must be the last field in the filter.
if i != len(parts) - 1:
raise TypeError("Join on field %r not permitted."
% name)
col = target_col or target_field.column
field, target, unused, join_list, nullable = self.setup_joins(parts,
opts, alias, (connection == AND))
col = target.column
alias = join_list[-1][-1]
if join_list:
# An optimization: if the final join is against the same column as
# we are comparing against, we can go back one step in the join
# chain and compare against the lhs of the join instead. The result
# (potentially) involves one less table join.
join = self.alias_map[join_list[-1][-1]][ALIAS_JOIN]
join = self.alias_map[alias][ALIAS_JOIN]
if col == join[RHS_JOIN_COL]:
self.unref_alias(alias)
alias = join[LHS_ALIAS]
@ -734,17 +712,18 @@ class Query(object):
# efficient at the database level.
self.promote_alias(join_list[-1][0])
self.where.add([alias, col, orig_field, lookup_type, value],
connection)
self.where.add([alias, col, field, lookup_type, value], connection)
if negate:
if join_list and null_point is not None:
for elt in join_list[null_point:]:
for join in elt:
self.promote_alias(join)
self.where.negate()
self.where.add([alias, col, orig_field, 'isnull', True], OR)
else:
self.where.negate()
flag = False
for pos, null in enumerate(nullable):
if not null:
continue
flag = True
for join in join_list[pos]:
self.promote_alias(join)
self.where.negate()
if flag:
self.where.add([alias, col, field, 'isnull', True], OR)
def add_q(self, q_object):
"""
@ -765,80 +744,129 @@ class Query(object):
else:
self.add_filter(child, q_object.connection, q_object.negated)
def get_next_join(self, name, opts, root_alias, dupe_multis):
def setup_joins(self, names, opts, alias, dupe_multis):
"""
Compute the necessary table joins for the field called 'name'. 'opts'
is the Options class for the current model (which gives the table we
are joining to), root_alias is the alias for the table we are joining
to. If dupe_multis is True, any many-to-many or many-to-one joins will
always create a new alias (necessary for disjunctive filters).
Compute the necessary table joins for the passage through the fields
given in 'names'. 'opts' is the Options class for the current model
(which gives the table we are joining to), 'alias' is the alias for the
table we are joining to. If dupe_multis is True, any many-to-many or
many-to-one joins will always create a new alias (necessary for
disjunctive filters).
Returns a list of aliases involved in the join, the next value for
'opts', the field instance that was matched, the new field to include
in the join, the column name on the rhs of the join and whether the
join can include NULL results.
Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value, the
list of tables joined and a list indicating whether or not each join
can be null.
"""
if name == 'pk':
name = opts.pk.name
joins = [[alias]]
nullable = [False]
for pos, name in enumerate(names):
if name == 'pk':
name = opts.pk.name
field = find_field(name, opts.many_to_many, False)
if field:
# Many-to-many field defined on the current model.
remote_opts = field.rel.to._meta
int_alias = self.join((root_alias, field.m2m_db_table(),
opts.pk.column, field.m2m_column_name()), dupe_multis)
far_alias = self.join((int_alias, remote_opts.db_table,
field.m2m_reverse_name(), remote_opts.pk.column),
dupe_multis, merge_separate=True)
return ([int_alias, far_alias], remote_opts, field, remote_opts.pk,
None, field.null)
try:
field, direct, m2m = opts.get_field_by_name(name)
except FieldDoesNotExist:
names = opts.get_all_field_names()
raise TypeError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
cached_data = opts._join_cache.get(name)
orig_opts = opts
field = find_field(name, opts.get_all_related_many_to_many_objects(),
True)
if field:
# Many-to-many field defined on the target model.
remote_opts = field.opts
field = field.field
int_alias = self.join((root_alias, field.m2m_db_table(),
opts.pk.column, field.m2m_reverse_name()), dupe_multis)
far_alias = self.join((int_alias, remote_opts.db_table,
field.m2m_column_name(), remote_opts.pk.column),
dupe_multis, merge_separate=True)
# XXX: Why is the final component able to be None here?
return ([int_alias, far_alias], remote_opts, field, remote_opts.pk,
None, True)
if direct:
if m2m:
# Many-to-many field defined on the current model.
if cached_data:
(table1, from_col1, to_col1, table2, from_col2,
to_col2, opts, target) = cached_data
else:
table1 = field.m2m_db_table()
from_col1 = opts.pk.column
to_col1 = field.m2m_column_name()
opts = field.rel.to._meta
table2 = opts.db_table
from_col2 = field.m2m_reverse_name()
to_col2 = opts.pk.column
target = opts.pk
orig_opts._join_cache[name] = (table1, from_col1,
to_col1, table2, from_col2, to_col2, opts,
target)
field = find_field(name, opts.get_all_related_objects(), True)
if field:
# One-to-many field (ForeignKey defined on the target model)
remote_opts = field.opts
field = field.field
local_field = opts.get_field(field.rel.field_name)
alias = self.join((root_alias, remote_opts.db_table,
local_field.column, field.column), dupe_multis,
merge_separate=True)
return ([alias], remote_opts, field, field, remote_opts.pk.column,
True)
int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis)
alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis)
joins.append([int_alias, alias])
nullable.append(field.null)
elif field.rel:
# One-to-one or many-to-one field
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
opts = field.rel.to._meta
target = field.rel.get_related_field()
table = opts.db_table
from_col = field.column
to_col = target.column
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
alias = self.join((alias, table, from_col, to_col))
joins.append([alias])
nullable.append(field.null)
else:
target = field
break
else:
orig_field = field
field = field.field
nullable.append(True)
if m2m:
# Many-to-many field defined on the target model.
if cached_data:
(table1, from_col1, to_col1, table2, from_col2,
to_col2, opts, target) = cached_data
else:
table1 = field.m2m_db_table()
from_col1 = opts.pk.column
to_col1 = field.m2m_reverse_name()
opts = orig_field.opts
table2 = opts.db_table
from_col2 = field.m2m_column_name()
to_col2 = opts.pk.column
target = opts.pk
orig_opts._join_cache[name] = (table1, from_col1,
to_col1, table2, from_col2, to_col2, opts,
target)
field = find_field(name, opts.fields, False)
if not field:
raise TypeError, \
("Cannot resolve keyword '%s' into field. Choices are: %s"
% (name, ", ".join(get_legal_fields(opts))))
int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis)
alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis)
joins.append([int_alias, alias])
else:
# One-to-many field (ForeignKey defined on the target model)
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
local_field = opts.get_field_by_name(
field.rel.field_name)[0]
opts = orig_field.opts
table = opts.db_table
from_col = local_field.column
to_col = field.column
target = opts.pk
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
if field.rel:
# One-to-one or many-to-one field
remote_opts = field.rel.to._meta
target = field.rel.get_related_field()
alias = self.join((root_alias, remote_opts.db_table, field.column,
target.column))
return ([alias], remote_opts, field, target, target.column,
field.null)
alias = self.join((alias, table, from_col, to_col),
dupe_multis)
joins.append([alias])
# Only remaining possibility is a normal (direct lookup) field. No
# join is required.
return None, opts, field, field, None, False
if pos != len(names) - 1:
raise TypeError("Join on field %r not permitted." % name)
return field, target, opts, joins, nullable
def set_limits(self, low=None, high=None):
"""
@ -960,13 +988,7 @@ class Query(object):
return cursor.fetchone()
# The MULTI case.
def it():
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
yield rows
return it()
return results_iter(cursor)
class DeleteQuery(Query):
"""
@ -1003,7 +1025,7 @@ class DeleteQuery(Query):
for related in cls._meta.get_all_related_many_to_many_objects():
if not isinstance(related.field, generic.GenericRelation):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = WhereNode(self)
where = WhereNode()
where.add((None, related.field.m2m_reverse_name(),
related.field, 'in',
pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
@ -1011,14 +1033,14 @@ class DeleteQuery(Query):
self.do_query(related.field.m2m_db_table(), where)
for f in cls._meta.many_to_many:
w1 = WhereNode(self)
w1 = WhereNode()
if isinstance(f, generic.GenericRelation):
from django.contrib.contenttypes.models import ContentType
field = f.rel.to._meta.get_field(f.content_type_field_name)
w1.add((None, field.column, field, 'exact',
ContentType.objects.get_for_model(cls).id), AND)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = WhereNode(self)
where = WhereNode()
where.add((None, f.m2m_column_name(), f, 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND)
@ -1035,7 +1057,7 @@ class DeleteQuery(Query):
lot of values in pk_list.
"""
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = WhereNode(self)
where = WhereNode()
field = self.model._meta.pk
where.add((None, field.column, field, 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
@ -1079,7 +1101,7 @@ class UpdateQuery(Query):
This is used by the QuerySet.delete_objects() method.
"""
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = WhereNode(self)
where = WhereNode()
f = self.model._meta.pk
where.add((None, f.column, f, 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
@ -1141,40 +1163,6 @@ class CountQuery(Query):
def get_ordering(self):
return ()
def find_field(name, field_list, related_query):
"""
Finds a field with a specific name in a list of field instances.
Returns None if there are no matches, or several matches.
"""
if related_query:
matches = [f for f in field_list
if f.field.related_query_name() == name]
else:
matches = [f for f in field_list if f.name == name]
if len(matches) != 1:
return None
return matches[0]
def field_choices(field_list, related_query):
"""
Returns the names of the field objects in field_list. Used to construct
readable error messages.
"""
if related_query:
return [f.field.related_query_name() for f in field_list]
else:
return [f.name for f in field_list]
def get_legal_fields(opts):
"""
Returns a list of fields that are valid at this point in the query. Used in
error reporting.
"""
return (field_choices(opts.many_to_many, False)
+ field_choices( opts.get_all_related_many_to_many_objects(), True)
+ field_choices(opts.get_all_related_objects(), True)
+ field_choices(opts.fields, False))
def get_order_dir(field, default='ASC'):
"""
Returns the field name and direction for an order specification. For
@ -1183,8 +1171,27 @@ def get_order_dir(field, default='ASC'):
The 'default' param is used to indicate which way no prefix (or a '+'
prefix) should sort. The '-' prefix always sorts the opposite way.
"""
dirn = {'ASC': ('ASC', 'DESC'), 'DESC': ('DESC', 'ASC')}[default]
dirn = ORDER_DIR[default]
if field[0] == '-':
return field[1:], dirn[1]
return field, dirn[0]
def results_iter(cursor):
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
yield rows
def setup_join_cache(sender):
"""
The information needed to join between model fields is something that is
invariant over the life of the model, so we cache it in the model's Options
class, rather than recomputing it all the time.
This method initialises the (empty) cache when the model is created.
"""
sender._meta._join_cache = {}
dispatcher.connect(setup_join_cache, signal=signals.class_prepared)

View File

@ -4,6 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import datetime
from django.utils import tree
from django.db import connection
from datastructures import EmptyResultSet
# Connection types
@ -23,23 +24,7 @@ class WhereNode(tree.Node):
"""
default = AND
def __init__(self, query=None, children=None, connection=None):
super(WhereNode, self).__init__(children, connection)
if query:
# XXX: Would be nice to use a weakref here, but it seems tricky to
# make it work.
self.query = query
def __deepcopy__(self, memodict):
"""
Used by copy.deepcopy().
"""
obj = super(WhereNode, self).__deepcopy__(memodict)
obj.query = self.query
memodict[id(obj)] = obj
return obj
def as_sql(self, node=None):
def as_sql(self, node=None, qn=None):
"""
Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty.
@ -50,24 +35,25 @@ class WhereNode(tree.Node):
"""
if node is None:
node = self
if not qn:
qn = connection.ops.quote_name
if not node.children:
return None, []
result = []
result_params = []
for child in node.children:
if hasattr(child, 'as_sql'):
sql, params = child.as_sql()
sql, params = child.as_sql(qn=qn)
format = '(%s)'
elif isinstance(child, tree.Node):
sql, params = self.as_sql(child)
sql, params = self.as_sql(child, qn)
if child.negated:
format = 'NOT (%s)'
else:
format = '(%s)'
else:
try:
sql = self.make_atom(child)
params = child[2].get_db_prep_lookup(child[3], child[4])
sql, params = self.make_atom(child, qn)
format = '%s'
except EmptyResultSet:
if self.connection == AND and not node.negated:
@ -80,57 +66,60 @@ class WhereNode(tree.Node):
conn = ' %s ' % node.connection
return conn.join(result), result_params
def make_atom(self, child):
def make_atom(self, child, qn):
"""
Turn a tuple (table_alias, field_name, field_class, lookup_type, value)
into valid SQL.
Returns the string for the SQL fragment. The caller is responsible for
converting the child's value into an appropriate for for the parameters
list.
Returns the string for the SQL fragment and the parameters to use for
it.
"""
table_alias, name, field, lookup_type, value = child
conn = self.query.connection
qn = self.query.quote_name_unless_alias
if table_alias:
lhs = '%s.%s' % (qn(table_alias), qn(name))
else:
lhs = qn(name)
db_type = field and field.db_type() or None
field_sql = conn.ops.field_cast_sql(db_type) % lhs
field_sql = connection.ops.field_cast_sql(db_type) % lhs
if isinstance(value, datetime.datetime):
# FIXME datetime_cast_sql() should return '%s' by default.
cast_sql = conn.ops.datetime_cast_sql() or '%s'
cast_sql = connection.ops.datetime_cast_sql() or '%s'
else:
cast_sql = '%s'
# FIXME: This is out of place. Move to a function like
# datetime_cast_sql()
if (lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith')
and conn.features.needs_upper_for_iops):
and connection.features.needs_upper_for_iops):
format = 'UPPER(%s) %s'
else:
format = '%s %s'
if lookup_type in conn.operators:
return format % (field_sql, conn.operators[lookup_type] % cast_sql)
params = field.get_db_prep_lookup(lookup_type, value)
if lookup_type in connection.operators:
return (format % (field_sql,
connection.operators[lookup_type] % cast_sql), params)
if lookup_type == 'in':
if not value:
raise EmptyResultSet
return '%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value)))
return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))),
params)
elif lookup_type in ('range', 'year'):
return '%s BETWEEN %%s and %%s' % field_sql
return ('%s BETWEEN %%s and %%s' % field_sql,
params)
elif lookup_type in ('month', 'day'):
return '%s = %%s' % conn.ops.date_extract_sql(lookup_type,
field_sql)
return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type,
field_sql), params)
elif lookup_type == 'isnull':
return '%s IS %sNULL' % (field_sql, (not value and 'NOT ' or ''))
return ('%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')),
params)
elif lookup_type in 'search':
return conn.op.fulltest_search_sql(field_sql)
return (connection.ops.fulltest_search_sql(field_sql), params)
elif lookup_type in ('regex', 'iregex'):
# FIXME: Factor this out in to conn.ops
# FIXME: Factor this out in to connection.ops
if settings.DATABASE_ENGINE == 'oracle':
if connection.oracle_version and connection.oracle_version <= 9:
raise NotImplementedError("Regexes are not supported in Oracle before version 10g.")
@ -138,8 +127,8 @@ class WhereNode(tree.Node):
match_option = 'c'
else:
match_option = 'i'
return "REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql,
match_option)
return ("REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql,
match_option), params)
else:
raise NotImplementedError

View File

@ -71,7 +71,7 @@ __test__ = {'API_TESTS':"""
>>> Author.objects.filter(firstname__exact='John')
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'firstname' into field. Choices are: article, id, first_name, last_name
TypeError: Cannot resolve keyword 'firstname' into field. Choices are: article, first_name, id, last_name
>>> a = Author.objects.get(last_name__exact='Smith')
>>> a.first_name

View File

@ -253,7 +253,7 @@ DoesNotExist: Article matching query does not exist.
>>> Article.objects.filter(pub_date_year='2005').count()
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'pub_date_year' into field. Choices are: id, headline, pub_date
TypeError: Cannot resolve keyword 'pub_date_year' into field. Choices are: headline, id, pub_date
>>> Article.objects.filter(headline__starts='Article')
Traceback (most recent call last):

View File

@ -179,13 +179,13 @@ False
>>> Article.objects.filter(reporter_id__exact=1)
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: id, headline, pub_date, reporter
TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: headline, id, pub_date, reporter
# You need to specify a comparison clause
>>> Article.objects.filter(reporter_id=1)
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: id, headline, pub_date, reporter
TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: headline, id, pub_date, reporter
# You can also instantiate an Article by passing
# the Reporter's ID instead of a Reporter object.

View File

@ -55,5 +55,5 @@ __test__ = {'API_TESTS':"""
>>> Poll.objects.get(choice__name__exact="This is the answer")
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'choice' into field. Choices are: poll_choice, related_choice, id, question, creator
TypeError: Cannot resolve keyword 'choice' into field. Choices are: creator, id, poll_choice, question, related_choice
"""}

View File

@ -14,7 +14,7 @@ class Choice(models.Model):
return u"Choice: %s in poll %s" % (self.choice, self.poll)
__test__ = {'API_TESTS':"""
# Regression test for the use of None as a query value. None is interpreted as
# Regression test for the use of None as a query value. None is interpreted as
# an SQL NULL, but only in __exact queries.
# Set up some initial polls and choices
>>> p1 = Poll(question='Why?')
@ -29,10 +29,10 @@ __test__ = {'API_TESTS':"""
[]
# Valid query, but fails because foo isn't a keyword
>>> Choice.objects.filter(foo__exact=None)
>>> Choice.objects.filter(foo__exact=None)
Traceback (most recent call last):
...
TypeError: Cannot resolve keyword 'foo' into field. Choices are: id, poll, choice
TypeError: Cannot resolve keyword 'foo' into field. Choices are: choice, id, poll
# Can't use None on anything other than __exact
>>> Choice.objects.filter(id__gt=None)

View File

@ -190,6 +190,10 @@ Bug #4464
[<Item: two>]
Bug #2080, #3592
>>> Author.objects.filter(item__name='one') | Author.objects.filter(name='a3')
[<Author: a1>, <Author: a3>]
>>> Author.objects.filter(Q(item__name='one') | Q(name='a3'))
[<Author: a1>, <Author: a3>]
>>> Author.objects.filter(Q(name='a3') | Q(item__name='one'))
[<Author: a1>, <Author: a3>]
@ -217,6 +221,12 @@ Bug #2253
>>> (q1 & q2).order_by('name')
[<Item: one>]
>>> q1 = Item.objects.filter(tags=t1)
>>> q2 = Item.objects.filter(note=n3, tags=t2)
>>> q3 = Item.objects.filter(creator=a4)
>>> ((q1 & q2) | q3).order_by('name')
[<Item: four>, <Item: one>]
Bugs #4088, #4306
>>> Report.objects.filter(creator=1001)
[<Report: r1>]