diff --git a/django/db/__init__.py b/django/db/__init__.py index d4ea1403bd..e16de07416 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -11,16 +11,18 @@ if not settings.DATABASE_ENGINE: settings.DATABASE_ENGINE = 'dummy' try: - # Most of the time, the database backend will be one of the official + # Most of the time, the database backend will be one of the official # backends that ships with Django, so look there first. _import_path = 'django.db.backends.' backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) + creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) except ImportError, e: - # If the import failed, we might be looking for a database backend + # If the import failed, we might be looking for a database backend # distributed external to Django. So we'll try that next. try: _import_path = '' backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, ['']) + creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, ['']) except ImportError, e_user: # The database backend wasn't found. Display a helpful error message # listing all possible (built-in) database backends. @@ -37,10 +39,12 @@ def _import_database_module(import_path='', module_name=''): """Lazyily import a database module when requested.""" return __import__('%s%s.%s' % (_import_path, settings.DATABASE_ENGINE, module_name), {}, {}, ['']) -# We don't want to import the introspect/creation modules unless -# someone asks for 'em, so lazily load them on demmand. +# We don't want to import the introspect module unless someone asks for it, so +# lazily load it on demmand. get_introspection_module = curry(_import_database_module, _import_path, 'introspection') -get_creation_module = curry(_import_database_module, _import_path, 'creation') + +def get_creation_module(): + return creation # We want runshell() to work the same way, but we have to treat it a # little differently (since it just runs instead of returning a module like diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 7d28ba1d73..be2580134e 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -790,8 +790,11 @@ class ManyToOneRel(object): self.multiple = True def get_related_field(self): - "Returns the Field in the 'to' object to which this relationship is tied." - return self.to._meta.get_field(self.field_name) + """ + Returns the Field in the 'to' object to which this relationship is + tied. + """ + return self.to._meta.get_field_by_name(self.field_name, True)[0] class OneToOneRel(ManyToOneRel): def __init__(self, to, field_name, num_in_admin=0, edit_inline=False, diff --git a/django/db/models/options.py b/django/db/models/options.py index d49025daaf..624a7d6803 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -93,7 +93,8 @@ class Options(object): def add_field(self, field): # Insert the given field in the order in which it was created, using # the "creation_counter" attribute of the field. - # Move many-to-many related fields from self.fields into self.many_to_many. + # Move many-to-many related fields from self.fields into + # self.many_to_many. if field.rel and isinstance(field.rel, ManyToManyRel): self.many_to_many.insert(bisect(self.many_to_many, field), field) else: @@ -129,6 +130,58 @@ class Options(object): return f raise FieldDoesNotExist, '%s has no field named %r' % (self.object_name, name) + def get_field_by_name(self, name, only_direct=False): + """ + Returns the (field_object, direct, m2m), where field_object is the + Field instance for the given name, direct is True if the field exists + on this model, and m2m is True for many-to-many relations. When + 'direct' is False, 'field_object' is the corresponding RelatedObject + for this field (since the field doesn't have an instance associated + with it). + + If 'only_direct' is True, only forwards relations (and non-relations) + are considered in the result. + + Uses a cache internally, so after the first access, this is very fast. + """ + try: + result = self._name_map.get(name) + except AttributeError: + cache = self.init_name_map() + result = cache.get(name) + + if not result or (not result[1] and only_direct): + raise FieldDoesNotExist('%s has no field named %r' + % (self.object_name, name)) + return result + + def get_all_field_names(self): + """ + Returns a list of all field names that are possible for this model + (including reverse relation names). + """ + try: + cache = self._name_map + except AttributeError: + cache = self.init_name_map() + names = cache.keys() + names.sort() + return names + + def init_name_map(self): + """ + Initialises the field name -> field object mapping. + """ + cache = dict([(f.name, (f, True, False)) for f in self.fields]) + cache.update([(f.name, (f, True, True)) for f in self.many_to_many]) + cache.update([(f.field.related_query_name(), (f, False, True)) + for f in self.get_all_related_many_to_many_objects()]) + cache.update([(f.field.related_query_name(), (f, False, False)) + for f in self.get_all_related_objects()]) + if app_cache_ready(): + self._name_map = cache + return cache + def get_add_permission(self): return 'add_%s' % self.object_name.lower() diff --git a/django/db/models/query.py b/django/db/models/query.py index ffea11eeb7..880562960c 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -24,9 +24,9 @@ CHUNK_SIZE = 100 class _QuerySet(object): "Represents a lazy database lookup for a set of objects" - def __init__(self, model=None): + def __init__(self, model=None, query=None): self.model = model - self.query = sql.Query(self.model, connection) + self.query = query or sql.Query(self.model, connection) self._result_cache = None ######################## @@ -338,7 +338,7 @@ class _QuerySet(object): if tables: clone.query.extra_tables.extend(tables) if order_by: - clone.query.extra_order_by.extend(order_by) + clone.query.extra_order_by = order_by return clone ################### @@ -348,9 +348,7 @@ class _QuerySet(object): def _clone(self, klass=None, setup=False, **kwargs): if klass is None: klass = self.__class__ - c = klass() - c.model = self.model - c.query = self.query.clone() + c = klass(model=self.model, query=self.query.clone()) c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -460,8 +458,8 @@ class DateQuerySet(QuerySet): return c class EmptyQuerySet(QuerySet): - def __init__(self, model=None): - super(EmptyQuerySet, self).__init__(model) + def __init__(self, model=None, query=None): + super(EmptyQuerySet, self).__init__(model, query) self._result_cache = [] def count(self): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 412fe50e21..9f303fa015 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -12,9 +12,11 @@ import re from django.utils.tree import Node from django.utils.datastructures import SortedDict +from django.dispatch import dispatcher +from django.db.models import signals from django.db.models.sql.where import WhereNode, AND, OR from django.db.models.sql.datastructures import Count, Date -from django.db.models.fields import FieldDoesNotExist, Field +from django.db.models.fields import FieldDoesNotExist, Field, related from django.contrib.contenttypes import generic from datastructures import EmptyResultSet @@ -49,7 +51,6 @@ RHS_JOIN_COL = 5 ALIAS_TABLE = 0 ALIAS_REFCOUNT = 1 ALIAS_JOIN = 2 -ALIAS_MERGE_SEP = 3 # How many results to expect from a cursor.execute call MULTI = 'multi' @@ -57,6 +58,12 @@ SINGLE = 'single' NONE = None ORDER_PATTERN = re.compile(r'\?|[-+]?\w+$') +ORDER_DIR = { + 'ASC': ('ASC', 'DESC'), + 'DESC': ('DESC', 'ASC')} + +class Empty(object): + pass class Query(object): """ @@ -76,12 +83,13 @@ class Query(object): self.table_map = {} # Maps table names to list of aliases. self.join_map = {} # Maps join_tuple to list of aliases. self.rev_join_map = {} # Reverse of join_map. + self.quote_cache = {} self.default_cols = True # SQL-related attributes self.select = [] self.tables = [] # Aliases in the order they are created. - self.where = WhereNode(self) + self.where = WhereNode() self.group_by = [] self.having = [] self.order_by = [] @@ -118,29 +126,35 @@ class Query(object): for table names. This avoids problems with some SQL dialects that treat quoted strings specially (e.g. PostgreSQL). """ - if name != self.alias_map.get(name, [name])[0]: + if name in self.quote_cache: + return self.quote_cache[name] + if name in self.alias_map and name not in self.table_map: + self.quote_cache[name] = name return name - return self.connection.ops.quote_name(name) + r = self.connection.ops.quote_name(name) + self.quote_cache[name] = r + return r def clone(self, klass=None, **kwargs): """ Creates a copy of the current instance. The 'kwargs' parameter can be used by clients to update attributes after copying has taken place. """ - if not klass: - klass = self.__class__ - obj = klass(self.model, self.connection) - obj.table_map = self.table_map.copy() + obj = Empty() + obj.__class__ = klass or self.__class__ + obj.model = self.model + obj.connection = self.connection obj.alias_map = copy.deepcopy(self.alias_map) + obj.table_map = self.table_map.copy() obj.join_map = copy.deepcopy(self.join_map) obj.rev_join_map = copy.deepcopy(self.rev_join_map) + obj.quote_cache = {} obj.default_cols = self.default_cols obj.select = self.select[:] obj.tables = self.tables[:] obj.where = copy.deepcopy(self.where) - obj.where.query = obj - obj.having = self.having[:] obj.group_by = self.group_by[:] + obj.having = self.having[:] obj.order_by = self.order_by[:] obj.low_mark, obj.high_mark = self.low_mark, self.high_mark obj.distinct = self.distinct @@ -175,7 +189,7 @@ class Query(object): obj.clear_limits() obj.select_related = False if obj.distinct and len(obj.select) > 1: - obj = self.clone(CountQuery, _query=obj, where=WhereNode(self), + obj = self.clone(CountQuery, _query=obj, where=WhereNode(), distinct=False) obj.add_count_column() data = obj.execute_sql(SINGLE) @@ -205,7 +219,7 @@ class Query(object): # This must come after 'select' and 'ordering' -- see docstring of # get_from_clause() for details. from_, f_params = self.get_from_clause() - where, w_params = self.where.as_sql() + where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias) result = ['SELECT'] if self.distinct: @@ -262,28 +276,28 @@ class Query(object): # Work out how to relabel the rhs aliases, if necessary. change_map = {} used = {} - first_new_join = True + conjunction = (connection == AND) + first = True for alias in rhs.tables: if not rhs.alias_map[alias][ALIAS_REFCOUNT]: # An unused alias. continue promote = (rhs.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] == self.LOUTER) - merge_separate = (connection == AND) - new_alias = self.join(rhs.rev_join_map[alias], exclusions=used, - promote=promote, outer_if_first=True, - merge_separate=merge_separate) - if self.alias_map[alias][ALIAS_REFCOUNT] == 1: - first_new_join = False + new_alias = self.join(rhs.rev_join_map[alias], + (conjunction and not first), used, promote, not conjunction) used[new_alias] = None change_map[alias] = new_alias + first = False - # So that we don't exclude valid results, the first join that is - # exclusive to the lhs (self) must be converted to an outer join. - for alias in self.tables[1:]: - if self.alias_map[alias][ALIAS_REFCOUNT] == 1: - self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER - break + # So that we don't exclude valid results in an "or" query combination, + # the first join that is exclusive to the lhs (self) must be converted + # to an outer join. + if not conjunction: + for alias in self.tables[1:]: + if self.alias_map[alias][ALIAS_REFCOUNT] == 1: + self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER + break # Now relabel a copy of the rhs where-clause and add it to the current # one. @@ -297,16 +311,16 @@ class Query(object): # results. alias = self.join((None, self.model._meta.db_table, None, None)) pk = self.model._meta.pk - self.where.add((alias, pk.column, pk, 'isnull', False), AND) + self.where.add([alias, pk.column, pk, 'isnull', False], AND) elif self.where: # rhs has an empty where clause. Make it match everything (see # above for reasoning). - w = WhereNode(self) + w = WhereNode() alias = self.join((None, self.model._meta.db_table, None, None)) pk = self.model._meta.pk - w.add((alias, pk.column, pk, 'isnull', False), AND) + w.add([alias, pk.column, pk, 'isnull', False], AND) else: - w = WhereNode(self) + w = WhereNode() self.where.add(w, connection) # Selection columns and extra extensions are those provided by 'rhs'. @@ -400,7 +414,7 @@ class Query(object): if join_type: result.append('%s %s%s ON (%s.%s = %s.%s)' % (join_type, qn(name), alias_str, qn(lhs), - qn(lhs_col), qn(alias), qn(col))) + qn(lhs_col), qn(alias), qn(col))) else: connector = not first and ', ' or '' result.append('%s%s%s' % (connector, qn(name), alias_str)) @@ -472,7 +486,7 @@ class Query(object): result.append('%s %s' % (elt, order)) elif get_order_dir(field)[0] not in self.extra_select: # 'col' is of the form 'field' or 'field1__field2' or - # 'field1__field2__field', etc. + # '-field1__field2__field', etc. for table, col, order in self.find_ordering_name(field, self.model._meta): elt = '%s.%s' % (qn(table), qn(col)) @@ -495,16 +509,14 @@ class Query(object): pieces = name.split(LOOKUP_SEP) if not alias: alias = self.join((None, opts.db_table, None, None)) - for elt in pieces: - joins, opts, unused1, field, col, unused2 = \ - self.get_next_join(elt, opts, alias, False) - if joins: - alias = joins[-1] - col = col or field.column + field, target, opts, joins, unused2 = self.setup_joins(pieces, opts, + alias, False) + alias = joins[-1][-1] + col = target.column # If we get to this point and the field is a relation to another model, # append the default ordering for that model. - if joins and opts.ordering: + if len(joins) > 1 and opts.ordering: results = [] for item in opts.ordering: results.extend(self.find_ordering_name(item, opts, alias, @@ -559,8 +571,7 @@ class Query(object): self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER def join(self, (lhs, table, lhs_col, col), always_create=False, - exclusions=(), promote=False, outer_if_first=False, - merge_separate=False): + exclusions=(), promote=False, outer_if_first=False): """ Returns an alias for a join between 'table' and 'lhs' on the given columns, either reusing an existing alias for that join or creating a @@ -581,44 +592,44 @@ class Query(object): If 'outer_if_first' is True and a new join is created, it will have the LOUTER join type. This is used when joining certain types of querysets and Q-objects together. - - If the 'merge_separate' parameter is True, we create a new alias if we - would otherwise reuse an alias that also had 'merge_separate' set to - True when it was created. """ - if lhs not in self.alias_map: + if lhs is None: + lhs_table = None + is_table = False + elif lhs not in self.alias_map: lhs_table = lhs - is_table = (lhs is not None) + is_table = True else: lhs_table = self.alias_map[lhs][ALIAS_TABLE] is_table = False t_ident = (lhs_table, table, lhs_col, col) - aliases = self.join_map.get(t_ident) - if aliases and not always_create: - for alias in aliases: - if (alias not in exclusions and - not (merge_separate and - self.alias_map[alias][ALIAS_MERGE_SEP])): - self.ref_alias(alias) - if promote: - self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \ - self.LOUTER - return alias - # If we get to here (no non-excluded alias exists), we'll fall - # through to creating a new alias. + if not always_create: + aliases = self.join_map.get(t_ident) + if aliases: + for alias in aliases: + if alias not in exclusions: + self.ref_alias(alias) + if promote: + self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \ + self.LOUTER + return alias + # If we get to here (no non-excluded alias exists), we'll fall + # through to creating a new alias. # No reuse is possible, so we need a new alias. assert not is_table, \ "Must pass in lhs alias when creating a new join." alias, _ = self.table_alias(table, True) - join_type = (promote or outer_if_first) and self.LOUTER or self.INNER + if promote or outer_if_first: + join_type = self.LOUTER + else: + join_type = self.INNER join = [table, alias, join_type, lhs, lhs_col, col] if not lhs: # Not all tables need to be joined to anything. No join type # means the later columns are ignored. join[JOIN_TYPE] = None self.alias_map[alias][ALIAS_JOIN] = join - self.alias_map[alias][ALIAS_MERGE_SEP] = merge_separate self.join_map.setdefault(t_ident, []).append(alias) self.rev_join_map[alias] = t_ident return alias @@ -677,51 +688,18 @@ class Query(object): opts = self.model._meta alias = self.join((None, opts.db_table, None, None)) - dupe_multis = (connection == AND) - join_list = [] - split = not self.where - null_point = None - # FIXME: Using enumerate() here is expensive. We only need 'i' to - # check we aren't joining against a non-joinable field. Find a - # better way to do this! - for i, name in enumerate(parts): - joins, opts, orig_field, target_field, target_col, nullable = \ - self.get_next_join(name, opts, alias, dupe_multis) - if name == 'pk': - name = target_field.name - if joins is not None: - if null_point is None and nullable: - null_point = len(join_list) - join_list.append(joins) - alias = joins[-1] - if connection == OR and not split: - # FIXME: Document what's going on and why this is needed. - if self.alias_map[joins[0]][ALIAS_REFCOUNT] == 1: - split = True - self.promote_alias(joins[0]) - all_aliases = [] - for a in join_list: - all_aliases.extend(a) - for t in self.tables[1:]: - if t in all_aliases: - continue - self.promote_alias(t) - break - else: - # Normal field lookup must be the last field in the filter. - if i != len(parts) - 1: - raise TypeError("Join on field %r not permitted." - % name) - - col = target_col or target_field.column + field, target, unused, join_list, nullable = self.setup_joins(parts, + opts, alias, (connection == AND)) + col = target.column + alias = join_list[-1][-1] if join_list: # An optimization: if the final join is against the same column as # we are comparing against, we can go back one step in the join # chain and compare against the lhs of the join instead. The result # (potentially) involves one less table join. - join = self.alias_map[join_list[-1][-1]][ALIAS_JOIN] + join = self.alias_map[alias][ALIAS_JOIN] if col == join[RHS_JOIN_COL]: self.unref_alias(alias) alias = join[LHS_ALIAS] @@ -734,17 +712,18 @@ class Query(object): # efficient at the database level. self.promote_alias(join_list[-1][0]) - self.where.add([alias, col, orig_field, lookup_type, value], - connection) + self.where.add([alias, col, field, lookup_type, value], connection) if negate: - if join_list and null_point is not None: - for elt in join_list[null_point:]: - for join in elt: - self.promote_alias(join) - self.where.negate() - self.where.add([alias, col, orig_field, 'isnull', True], OR) - else: - self.where.negate() + flag = False + for pos, null in enumerate(nullable): + if not null: + continue + flag = True + for join in join_list[pos]: + self.promote_alias(join) + self.where.negate() + if flag: + self.where.add([alias, col, field, 'isnull', True], OR) def add_q(self, q_object): """ @@ -765,80 +744,129 @@ class Query(object): else: self.add_filter(child, q_object.connection, q_object.negated) - def get_next_join(self, name, opts, root_alias, dupe_multis): + def setup_joins(self, names, opts, alias, dupe_multis): """ - Compute the necessary table joins for the field called 'name'. 'opts' - is the Options class for the current model (which gives the table we - are joining to), root_alias is the alias for the table we are joining - to. If dupe_multis is True, any many-to-many or many-to-one joins will - always create a new alias (necessary for disjunctive filters). + Compute the necessary table joins for the passage through the fields + given in 'names'. 'opts' is the Options class for the current model + (which gives the table we are joining to), 'alias' is the alias for the + table we are joining to. If dupe_multis is True, any many-to-many or + many-to-one joins will always create a new alias (necessary for + disjunctive filters). - Returns a list of aliases involved in the join, the next value for - 'opts', the field instance that was matched, the new field to include - in the join, the column name on the rhs of the join and whether the - join can include NULL results. + Returns the final field involved in the join, the target database + column (used for any 'where' constraint), the final 'opts' value, the + list of tables joined and a list indicating whether or not each join + can be null. """ - if name == 'pk': - name = opts.pk.name + joins = [[alias]] + nullable = [False] + for pos, name in enumerate(names): + if name == 'pk': + name = opts.pk.name - field = find_field(name, opts.many_to_many, False) - if field: - # Many-to-many field defined on the current model. - remote_opts = field.rel.to._meta - int_alias = self.join((root_alias, field.m2m_db_table(), - opts.pk.column, field.m2m_column_name()), dupe_multis) - far_alias = self.join((int_alias, remote_opts.db_table, - field.m2m_reverse_name(), remote_opts.pk.column), - dupe_multis, merge_separate=True) - return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, - None, field.null) + try: + field, direct, m2m = opts.get_field_by_name(name) + except FieldDoesNotExist: + names = opts.get_all_field_names() + raise TypeError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(names))) + cached_data = opts._join_cache.get(name) + orig_opts = opts - field = find_field(name, opts.get_all_related_many_to_many_objects(), - True) - if field: - # Many-to-many field defined on the target model. - remote_opts = field.opts - field = field.field - int_alias = self.join((root_alias, field.m2m_db_table(), - opts.pk.column, field.m2m_reverse_name()), dupe_multis) - far_alias = self.join((int_alias, remote_opts.db_table, - field.m2m_column_name(), remote_opts.pk.column), - dupe_multis, merge_separate=True) - # XXX: Why is the final component able to be None here? - return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, - None, True) + if direct: + if m2m: + # Many-to-many field defined on the current model. + if cached_data: + (table1, from_col1, to_col1, table2, from_col2, + to_col2, opts, target) = cached_data + else: + table1 = field.m2m_db_table() + from_col1 = opts.pk.column + to_col1 = field.m2m_column_name() + opts = field.rel.to._meta + table2 = opts.db_table + from_col2 = field.m2m_reverse_name() + to_col2 = opts.pk.column + target = opts.pk + orig_opts._join_cache[name] = (table1, from_col1, + to_col1, table2, from_col2, to_col2, opts, + target) - field = find_field(name, opts.get_all_related_objects(), True) - if field: - # One-to-many field (ForeignKey defined on the target model) - remote_opts = field.opts - field = field.field - local_field = opts.get_field(field.rel.field_name) - alias = self.join((root_alias, remote_opts.db_table, - local_field.column, field.column), dupe_multis, - merge_separate=True) - return ([alias], remote_opts, field, field, remote_opts.pk.column, - True) + int_alias = self.join((alias, table1, from_col1, to_col1), + dupe_multis) + alias = self.join((int_alias, table2, from_col2, to_col2), + dupe_multis) + joins.append([int_alias, alias]) + nullable.append(field.null) + elif field.rel: + # One-to-one or many-to-one field + if cached_data: + (table, from_col, to_col, opts, target) = cached_data + else: + opts = field.rel.to._meta + target = field.rel.get_related_field() + table = opts.db_table + from_col = field.column + to_col = target.column + orig_opts._join_cache[name] = (table, from_col, to_col, + opts, target) + alias = self.join((alias, table, from_col, to_col)) + joins.append([alias]) + nullable.append(field.null) + else: + target = field + break + else: + orig_field = field + field = field.field + nullable.append(True) + if m2m: + # Many-to-many field defined on the target model. + if cached_data: + (table1, from_col1, to_col1, table2, from_col2, + to_col2, opts, target) = cached_data + else: + table1 = field.m2m_db_table() + from_col1 = opts.pk.column + to_col1 = field.m2m_reverse_name() + opts = orig_field.opts + table2 = opts.db_table + from_col2 = field.m2m_column_name() + to_col2 = opts.pk.column + target = opts.pk + orig_opts._join_cache[name] = (table1, from_col1, + to_col1, table2, from_col2, to_col2, opts, + target) - field = find_field(name, opts.fields, False) - if not field: - raise TypeError, \ - ("Cannot resolve keyword '%s' into field. Choices are: %s" - % (name, ", ".join(get_legal_fields(opts)))) + int_alias = self.join((alias, table1, from_col1, to_col1), + dupe_multis) + alias = self.join((int_alias, table2, from_col2, to_col2), + dupe_multis) + joins.append([int_alias, alias]) + else: + # One-to-many field (ForeignKey defined on the target model) + if cached_data: + (table, from_col, to_col, opts, target) = cached_data + else: + local_field = opts.get_field_by_name( + field.rel.field_name)[0] + opts = orig_field.opts + table = opts.db_table + from_col = local_field.column + to_col = field.column + target = opts.pk + orig_opts._join_cache[name] = (table, from_col, to_col, + opts, target) - if field.rel: - # One-to-one or many-to-one field - remote_opts = field.rel.to._meta - target = field.rel.get_related_field() - alias = self.join((root_alias, remote_opts.db_table, field.column, - target.column)) - return ([alias], remote_opts, field, target, target.column, - field.null) + alias = self.join((alias, table, from_col, to_col), + dupe_multis) + joins.append([alias]) - # Only remaining possibility is a normal (direct lookup) field. No - # join is required. - return None, opts, field, field, None, False + if pos != len(names) - 1: + raise TypeError("Join on field %r not permitted." % name) + + return field, target, opts, joins, nullable def set_limits(self, low=None, high=None): """ @@ -960,13 +988,7 @@ class Query(object): return cursor.fetchone() # The MULTI case. - def it(): - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - yield rows - return it() + return results_iter(cursor) class DeleteQuery(Query): """ @@ -1003,7 +1025,7 @@ class DeleteQuery(Query): for related in cls._meta.get_all_related_many_to_many_objects(): if not isinstance(related.field, generic.GenericRelation): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode(self) + where = WhereNode() where.add((None, related.field.m2m_reverse_name(), related.field, 'in', pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), @@ -1011,14 +1033,14 @@ class DeleteQuery(Query): self.do_query(related.field.m2m_db_table(), where) for f in cls._meta.many_to_many: - w1 = WhereNode(self) + w1 = WhereNode() if isinstance(f, generic.GenericRelation): from django.contrib.contenttypes.models import ContentType field = f.rel.to._meta.get_field(f.content_type_field_name) w1.add((None, field.column, field, 'exact', ContentType.objects.get_for_model(cls).id), AND) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode(self) + where = WhereNode() where.add((None, f.m2m_column_name(), f, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -1035,7 +1057,7 @@ class DeleteQuery(Query): lot of values in pk_list. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode(self) + where = WhereNode() field = self.model._meta.pk where.add((None, field.column, field, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -1079,7 +1101,7 @@ class UpdateQuery(Query): This is used by the QuerySet.delete_objects() method. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = WhereNode(self) + where = WhereNode() f = self.model._meta.pk where.add((None, f.column, f, 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), @@ -1141,40 +1163,6 @@ class CountQuery(Query): def get_ordering(self): return () -def find_field(name, field_list, related_query): - """ - Finds a field with a specific name in a list of field instances. - Returns None if there are no matches, or several matches. - """ - if related_query: - matches = [f for f in field_list - if f.field.related_query_name() == name] - else: - matches = [f for f in field_list if f.name == name] - if len(matches) != 1: - return None - return matches[0] - -def field_choices(field_list, related_query): - """ - Returns the names of the field objects in field_list. Used to construct - readable error messages. - """ - if related_query: - return [f.field.related_query_name() for f in field_list] - else: - return [f.name for f in field_list] - -def get_legal_fields(opts): - """ - Returns a list of fields that are valid at this point in the query. Used in - error reporting. - """ - return (field_choices(opts.many_to_many, False) - + field_choices( opts.get_all_related_many_to_many_objects(), True) - + field_choices(opts.get_all_related_objects(), True) - + field_choices(opts.fields, False)) - def get_order_dir(field, default='ASC'): """ Returns the field name and direction for an order specification. For @@ -1183,8 +1171,27 @@ def get_order_dir(field, default='ASC'): The 'default' param is used to indicate which way no prefix (or a '+' prefix) should sort. The '-' prefix always sorts the opposite way. """ - dirn = {'ASC': ('ASC', 'DESC'), 'DESC': ('DESC', 'ASC')}[default] + dirn = ORDER_DIR[default] if field[0] == '-': return field[1:], dirn[1] return field, dirn[0] +def results_iter(cursor): + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + yield rows + +def setup_join_cache(sender): + """ + The information needed to join between model fields is something that is + invariant over the life of the model, so we cache it in the model's Options + class, rather than recomputing it all the time. + + This method initialises the (empty) cache when the model is created. + """ + sender._meta._join_cache = {} + +dispatcher.connect(setup_join_cache, signal=signals.class_prepared) + diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index eb50639a3a..5fa1ae8096 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -4,6 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints. import datetime from django.utils import tree +from django.db import connection from datastructures import EmptyResultSet # Connection types @@ -23,23 +24,7 @@ class WhereNode(tree.Node): """ default = AND - def __init__(self, query=None, children=None, connection=None): - super(WhereNode, self).__init__(children, connection) - if query: - # XXX: Would be nice to use a weakref here, but it seems tricky to - # make it work. - self.query = query - - def __deepcopy__(self, memodict): - """ - Used by copy.deepcopy(). - """ - obj = super(WhereNode, self).__deepcopy__(memodict) - obj.query = self.query - memodict[id(obj)] = obj - return obj - - def as_sql(self, node=None): + def as_sql(self, node=None, qn=None): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -50,24 +35,25 @@ class WhereNode(tree.Node): """ if node is None: node = self + if not qn: + qn = connection.ops.quote_name if not node.children: return None, [] result = [] result_params = [] for child in node.children: if hasattr(child, 'as_sql'): - sql, params = child.as_sql() + sql, params = child.as_sql(qn=qn) format = '(%s)' elif isinstance(child, tree.Node): - sql, params = self.as_sql(child) + sql, params = self.as_sql(child, qn) if child.negated: format = 'NOT (%s)' else: format = '(%s)' else: try: - sql = self.make_atom(child) - params = child[2].get_db_prep_lookup(child[3], child[4]) + sql, params = self.make_atom(child, qn) format = '%s' except EmptyResultSet: if self.connection == AND and not node.negated: @@ -80,57 +66,60 @@ class WhereNode(tree.Node): conn = ' %s ' % node.connection return conn.join(result), result_params - def make_atom(self, child): + def make_atom(self, child, qn): """ Turn a tuple (table_alias, field_name, field_class, lookup_type, value) into valid SQL. - Returns the string for the SQL fragment. The caller is responsible for - converting the child's value into an appropriate for for the parameters - list. + Returns the string for the SQL fragment and the parameters to use for + it. """ table_alias, name, field, lookup_type, value = child - conn = self.query.connection - qn = self.query.quote_name_unless_alias if table_alias: lhs = '%s.%s' % (qn(table_alias), qn(name)) else: lhs = qn(name) db_type = field and field.db_type() or None - field_sql = conn.ops.field_cast_sql(db_type) % lhs + field_sql = connection.ops.field_cast_sql(db_type) % lhs if isinstance(value, datetime.datetime): # FIXME datetime_cast_sql() should return '%s' by default. - cast_sql = conn.ops.datetime_cast_sql() or '%s' + cast_sql = connection.ops.datetime_cast_sql() or '%s' else: cast_sql = '%s' # FIXME: This is out of place. Move to a function like # datetime_cast_sql() if (lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith') - and conn.features.needs_upper_for_iops): + and connection.features.needs_upper_for_iops): format = 'UPPER(%s) %s' else: format = '%s %s' - if lookup_type in conn.operators: - return format % (field_sql, conn.operators[lookup_type] % cast_sql) + params = field.get_db_prep_lookup(lookup_type, value) + + if lookup_type in connection.operators: + return (format % (field_sql, + connection.operators[lookup_type] % cast_sql), params) if lookup_type == 'in': if not value: raise EmptyResultSet - return '%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))) + return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), + params) elif lookup_type in ('range', 'year'): - return '%s BETWEEN %%s and %%s' % field_sql + return ('%s BETWEEN %%s and %%s' % field_sql, + params) elif lookup_type in ('month', 'day'): - return '%s = %%s' % conn.ops.date_extract_sql(lookup_type, - field_sql) + return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, + field_sql), params) elif lookup_type == 'isnull': - return '%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')) + return ('%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')), + params) elif lookup_type in 'search': - return conn.op.fulltest_search_sql(field_sql) + return (connection.ops.fulltest_search_sql(field_sql), params) elif lookup_type in ('regex', 'iregex'): - # FIXME: Factor this out in to conn.ops + # FIXME: Factor this out in to connection.ops if settings.DATABASE_ENGINE == 'oracle': if connection.oracle_version and connection.oracle_version <= 9: raise NotImplementedError("Regexes are not supported in Oracle before version 10g.") @@ -138,8 +127,8 @@ class WhereNode(tree.Node): match_option = 'c' else: match_option = 'i' - return "REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql, - match_option) + return ("REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql, + match_option), params) else: raise NotImplementedError diff --git a/tests/modeltests/custom_columns/models.py b/tests/modeltests/custom_columns/models.py index e1d0bc6e94..302a9aee27 100644 --- a/tests/modeltests/custom_columns/models.py +++ b/tests/modeltests/custom_columns/models.py @@ -71,7 +71,7 @@ __test__ = {'API_TESTS':""" >>> Author.objects.filter(firstname__exact='John') Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'firstname' into field. Choices are: article, id, first_name, last_name +TypeError: Cannot resolve keyword 'firstname' into field. Choices are: article, first_name, id, last_name >>> a = Author.objects.get(last_name__exact='Smith') >>> a.first_name diff --git a/tests/modeltests/lookup/models.py b/tests/modeltests/lookup/models.py index 06070b9a01..d581947498 100644 --- a/tests/modeltests/lookup/models.py +++ b/tests/modeltests/lookup/models.py @@ -253,7 +253,7 @@ DoesNotExist: Article matching query does not exist. >>> Article.objects.filter(pub_date_year='2005').count() Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'pub_date_year' into field. Choices are: id, headline, pub_date +TypeError: Cannot resolve keyword 'pub_date_year' into field. Choices are: headline, id, pub_date >>> Article.objects.filter(headline__starts='Article') Traceback (most recent call last): diff --git a/tests/modeltests/many_to_one/models.py b/tests/modeltests/many_to_one/models.py index 85365e2bc3..2231572b45 100644 --- a/tests/modeltests/many_to_one/models.py +++ b/tests/modeltests/many_to_one/models.py @@ -179,13 +179,13 @@ False >>> Article.objects.filter(reporter_id__exact=1) Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: id, headline, pub_date, reporter +TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: headline, id, pub_date, reporter # You need to specify a comparison clause >>> Article.objects.filter(reporter_id=1) Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: id, headline, pub_date, reporter +TypeError: Cannot resolve keyword 'reporter_id' into field. Choices are: headline, id, pub_date, reporter # You can also instantiate an Article by passing # the Reporter's ID instead of a Reporter object. diff --git a/tests/modeltests/reverse_lookup/models.py b/tests/modeltests/reverse_lookup/models.py index 80408ad761..5d722e54bf 100644 --- a/tests/modeltests/reverse_lookup/models.py +++ b/tests/modeltests/reverse_lookup/models.py @@ -55,5 +55,5 @@ __test__ = {'API_TESTS':""" >>> Poll.objects.get(choice__name__exact="This is the answer") Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'choice' into field. Choices are: poll_choice, related_choice, id, question, creator +TypeError: Cannot resolve keyword 'choice' into field. Choices are: creator, id, poll_choice, question, related_choice """} diff --git a/tests/regressiontests/null_queries/models.py b/tests/regressiontests/null_queries/models.py index 2aa36b2c1a..fc8215584e 100644 --- a/tests/regressiontests/null_queries/models.py +++ b/tests/regressiontests/null_queries/models.py @@ -14,7 +14,7 @@ class Choice(models.Model): return u"Choice: %s in poll %s" % (self.choice, self.poll) __test__ = {'API_TESTS':""" -# Regression test for the use of None as a query value. None is interpreted as +# Regression test for the use of None as a query value. None is interpreted as # an SQL NULL, but only in __exact queries. # Set up some initial polls and choices >>> p1 = Poll(question='Why?') @@ -29,10 +29,10 @@ __test__ = {'API_TESTS':""" [] # Valid query, but fails because foo isn't a keyword ->>> Choice.objects.filter(foo__exact=None) +>>> Choice.objects.filter(foo__exact=None) Traceback (most recent call last): ... -TypeError: Cannot resolve keyword 'foo' into field. Choices are: id, poll, choice +TypeError: Cannot resolve keyword 'foo' into field. Choices are: choice, id, poll # Can't use None on anything other than __exact >>> Choice.objects.filter(id__gt=None) diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index f0514dcee9..fe0940152d 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -190,6 +190,10 @@ Bug #4464 [] Bug #2080, #3592 +>>> Author.objects.filter(item__name='one') | Author.objects.filter(name='a3') +[, ] +>>> Author.objects.filter(Q(item__name='one') | Q(name='a3')) +[, ] >>> Author.objects.filter(Q(name='a3') | Q(item__name='one')) [, ] @@ -217,6 +221,12 @@ Bug #2253 >>> (q1 & q2).order_by('name') [] +>>> q1 = Item.objects.filter(tags=t1) +>>> q2 = Item.objects.filter(note=n3, tags=t2) +>>> q3 = Item.objects.filter(creator=a4) +>>> ((q1 & q2) | q3).order_by('name') +[, ] + Bugs #4088, #4306 >>> Report.objects.filter(creator=1001) []