diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 74b08884da..7c458cb001 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -117,6 +117,13 @@ class Field(object): # This is needed because bisect does not take a comparison function. return cmp(self.creation_counter, other.creation_counter) + def __deepcopy__(self, memodict): + # Slight hack; deepcopy() is difficult to do on classes with + # dynamically created methods. Fortunately, we can get away with doing + # a shallow copy in this particular case. + import copy + return copy.copy(self) + def to_python(self, value): """ Converts the input value into the expected Python data type, raising diff --git a/django/db/models/query.py b/django/db/models/query.py index a750ef5550..967234a6d7 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1,103 +1,59 @@ +import datetime +import operator +import re +import warnings + from django.conf import settings from django.db import connection, transaction from django.db.models.fields import DateField, FieldDoesNotExist -from django.db.models import signals, loading +from django.db.models.query_utils import Q, QNot, EmptyResultSet +from django.db.models import signals, loading, sql from django.dispatch import dispatcher from django.utils.datastructures import SortedDict from django.utils.encoding import smart_unicode from django.contrib.contenttypes import generic -import datetime -import operator -import re try: set except NameError: from sets import Set as set # Python 2.3 fallback -# The string constant used to separate query parts -LOOKUP_SEPARATOR = '__' - -# The list of valid query types -QUERY_TERMS = ( - 'exact', 'iexact', 'contains', 'icontains', - 'gt', 'gte', 'lt', 'lte', 'in', - 'startswith', 'istartswith', 'endswith', 'iendswith', - 'range', 'year', 'month', 'day', 'isnull', 'search', - 'regex', 'iregex', -) - -# Size of each "chunk" for get_iterator calls. -# Larger values are slightly faster at the expense of more storage space. -GET_ITERATOR_CHUNK_SIZE = 100 - -class EmptyResultSet(Exception): - pass +# Used to control how many objects are worked with at once in some cases (e.g. +# when deleting objects). +CHUNK_SIZE = 100 #################### # HELPER FUNCTIONS # #################### -# Django currently supports two forms of ordering. -# Form 1 (deprecated) example: -# order_by=(('pub_date', 'DESC'), ('headline', 'ASC'), (None, 'RANDOM')) -# Form 2 (new-style) example: -# order_by=('-pub_date', 'headline', '?') -# Form 1 is deprecated and will no longer be supported for Django's first -# official release. The following code converts from Form 1 to Form 2. - -LEGACY_ORDERING_MAPPING = {'ASC': '_', 'DESC': '-_', 'RANDOM': '?'} - -def handle_legacy_orderlist(order_list): - if not order_list or isinstance(order_list[0], basestring): - return order_list - else: - import warnings - new_order_list = [LEGACY_ORDERING_MAPPING[j.upper()].replace('_', smart_unicode(i)) for i, j in order_list] - warnings.warn("%r ordering syntax is deprecated. Use %r instead." % (order_list, new_order_list), DeprecationWarning) - return new_order_list - -def orderfield2column(f, opts): - try: - return opts.get_field(f, False).column - except FieldDoesNotExist: - return f - +# FIXME def orderlist2sql(order_list, opts, prefix=''): - qn = connection.ops.quote_name - if prefix.endswith('.'): - prefix = qn(prefix[:-1]) + '.' - output = [] - for f in handle_legacy_orderlist(order_list): - if f.startswith('-'): - output.append('%s%s DESC' % (prefix, qn(orderfield2column(f[1:], opts)))) - elif f == '?': - output.append(connection.ops.random_function_sql()) - else: - output.append('%s%s ASC' % (prefix, qn(orderfield2column(f, opts)))) - return ', '.join(output) + raise NotImplementedError +##def orderlist2sql(order_list, opts, prefix=''): +## qn = connection.ops.quote_name +## if prefix.endswith('.'): +## prefix = qn(prefix[:-1]) + '.' +## output = [] +## for f in handle_legacy_orderlist(order_list): +## if f.startswith('-'): +## output.append('%s%s DESC' % (prefix, qn(orderfield2column(f[1:], opts)))) +## elif f == '?': +## output.append(connection.ops.random_function_sql()) +## else: +## output.append('%s%s ASC' % (prefix, qn(orderfield2column(f, opts)))) +## return ', '.join(output) -def quote_only_if_word(word): - if re.search('\W', word): # Don't quote if there are spaces or non-word chars. - return word - else: - return connection.ops.quote_name(word) +##def quote_only_if_word(word): +## if re.search('\W', word): # Don't quote if there are spaces or non-word chars. +## return word +## else: +## return connection.ops.quote_name(word) class _QuerySet(object): "Represents a lazy database lookup for a set of objects" def __init__(self, model=None): self.model = model - self._filters = Q() - self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering. - self._select_related = False # Whether to fill cache for related objects. - self._max_related_depth = 0 # Maximum "depth" for select_related - self._distinct = False # Whether the query should use SELECT DISTINCT. - self._select = {} # Dictionary of attname -> SQL. - self._where = [] # List of extra WHERE clauses to use. - self._params = [] # List of params to use for extra WHERE clauses. - self._tables = [] # List of extra tables to use. - self._offset = None # OFFSET clause. - self._limit = None # LIMIT clause. + self.query = sql.Query(self.model, connection) self._result_cache = None ######################## @@ -117,57 +73,33 @@ class _QuerySet(object): "Retrieve an item or slice from the set of results." if not isinstance(k, (slice, int, long)): raise TypeError - assert (not isinstance(k, slice) and (k >= 0)) \ - or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \ - "Negative indexing is not supported." - if self._result_cache is None: - if isinstance(k, slice): - # Offset: - if self._offset is None: - offset = k.start - elif k.start is None: - offset = self._offset - else: - offset = self._offset + k.start - # Now adjust offset to the bounds of any existing limit: - if self._limit is not None and k.start is not None: - limit = self._limit - k.start - else: - limit = self._limit + assert ((not isinstance(k, slice) and (k >= 0)) + or (isinstance(k, slice) and (k.start is None or k.start >= 0) + and (k.stop is None or k.stop >= 0))), \ + "Negative indexing is not supported." - # Limit: - if k.stop is not None and k.start is not None: - if limit is None: - limit = k.stop - k.start - else: - limit = min((k.stop - k.start), limit) - else: - if limit is None: - limit = k.stop - else: - if k.stop is not None: - limit = min(k.stop, limit) - - if k.step is None: - return self._clone(_offset=offset, _limit=limit) - else: - return list(self._clone(_offset=offset, _limit=limit))[::k.step] - else: - try: - return list(self._clone(_offset=k, _limit=1))[0] - except self.model.DoesNotExist, e: - raise IndexError, e.args - else: + if self._result_cache is not None: return self._result_cache[k] + if isinstance(k, slice): + qs = self._clone() + qs.query.set_limits(k.start, k.stop) + return k.step and list(qs)[::k.step] or qs + try: + qs = self._clone() + qs.query.set_limits(k, k + 1) + return list(qs)[0] + except self.model.DoesNotExist, e: + raise IndexError, e.args + def __and__(self, other): - combined = self._combine(other) - combined._filters = self._filters & other._filters + combined = self._clone() + combined.query.combine(other.query, sql.AND) return combined def __or__(self, other): - combined = self._combine(other) - combined._filters = self._filters | other._filters + combined = self._clone() + combined.query.combine(other.query, sql.OR) return combined #################################### @@ -175,38 +107,24 @@ class _QuerySet(object): #################################### def iterator(self): - "Performs the SELECT database lookup of this QuerySet." - try: - select, sql, params = self._get_sql_clause() - except EmptyResultSet: - raise StopIteration - - # self._select is a dictionary, and dictionaries' key order is - # undefined, so we convert it to a list of tuples. - extra_select = self._select.items() - - cursor = connection.cursor() - cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) - - fill_cache = self._select_related - fields = self.model._meta.fields - index_end = len(fields) - has_resolve_columns = hasattr(self, 'resolve_columns') - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - for row in rows: - if has_resolve_columns: - row = self.resolve_columns(row, fields) - if fill_cache: - obj, index_end = get_cached_row(klass=self.model, row=row, - index_start=0, max_depth=self._max_related_depth) - else: - obj = self.model(*row[:index_end]) - for i, k in enumerate(extra_select): - setattr(obj, k[0], row[index_end+i]) - yield obj + """ + An iterator over the results from applying this QuerySet to the + database. + """ + fill_cache = self.query.select_related + max_depth = self.query.max_depth + index_end = len(self.model._meta.fields) + extra_select = self.query.extra_select.keys() + extra_select.sort() + for row in self.query.results_iter(): + if fill_cache: + obj, index_end = get_cached_row(klass=self.model, row=row, + index_start=0, max_depth=max_depth) + else: + obj = self.model(*row[:index_end]) + for i, k in enumerate(extra_select): + setattr(obj, k[0], row[index_end + i]) + yield obj def count(self): """ @@ -220,47 +138,19 @@ class _QuerySet(object): if self._result_cache is not None: return len(self._result_cache) - counter = self._clone() - counter._order_by = () - counter._select_related = False - - offset = counter._offset - limit = counter._limit - counter._offset = None - counter._limit = None - - try: - select, sql, params = counter._get_sql_clause() - except EmptyResultSet: - return 0 - - cursor = connection.cursor() - if self._distinct: - id_col = "%s.%s" % (connection.ops.quote_name(self.model._meta.db_table), - connection.ops.quote_name(self.model._meta.pk.column)) - cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params) - else: - cursor.execute("SELECT COUNT(*)" + sql, params) - count = cursor.fetchone()[0] - - # Apply any offset and limit constraints manually, since using LIMIT or - # OFFSET in SQL doesn't change the output of COUNT. - if offset: - count = max(0, count - offset) - if limit: - count = min(limit, count) - - return count + return self.query.get_count() def get(self, *args, **kwargs): - "Performs the SELECT and returns a single object matching the given keyword arguments." + """ + Performs the query and returns a single object matching the given + keyword arguments. + """ clone = self.filter(*args, **kwargs) - # clean up SQL by removing unneeded ORDER BY - if not clone._order_by: - clone._order_by = () + clone.query.clear_ordering() obj_list = list(clone) if len(obj_list) < 1: - raise self.model.DoesNotExist, "%s matching query does not exist." % self.model._meta.object_name + raise self.model.DoesNotExist("%s matching query does not exist." + % self.model._meta.object_name) assert len(obj_list) == 1, "get() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.model._meta.object_name, len(obj_list), kwargs) return obj_list[0] @@ -279,7 +169,8 @@ class _QuerySet(object): Returns a tuple of (object, created), where created is a boolean specifying whether an object was created. """ - assert len(kwargs), 'get_or_create() must be passed at least one keyword argument' + assert kwargs, \ + 'get_or_create() must be passed at least one keyword argument' defaults = kwargs.pop('defaults', {}) try: return self.get(**kwargs), False @@ -297,54 +188,58 @@ class _QuerySet(object): """ latest_by = field_name or self.model._meta.get_latest_by assert bool(latest_by), "latest() requires either a field_name parameter or 'get_latest_by' in the model" - assert self._limit is None and self._offset is None, \ + assert self.query.can_filter(), \ "Cannot change a query once a slice has been taken." - return self._clone(_limit=1, _order_by=('-'+latest_by,)).get() + obj = self._clone() + obj.query.set_limits(high=1) + obj.query.add_ordering('-%s' % latest_by) + return obj.get() def in_bulk(self, id_list): """ Returns a dictionary mapping each of the given IDs to the object with that ID. """ - assert self._limit is None and self._offset is None, \ + assert self.query.can_filter(), \ "Cannot use 'limit' or 'offset' with in_bulk" - assert isinstance(id_list, (tuple, list)), "in_bulk() must be provided with a list of IDs." - qn = connection.ops.quote_name - id_list = list(id_list) - if id_list == []: + assert isinstance(id_list, (tuple, list)), \ + "in_bulk() must be provided with a list of IDs." + if not id_list: return {} qs = self._clone() - qs._where.append("%s.%s IN (%s)" % (qn(self.model._meta.db_table), qn(self.model._meta.pk.column), ",".join(['%s'] * len(id_list)))) - qs._params.extend(id_list) + qs.query.add_filter(('pk__in', id_list)) return dict([(obj._get_pk_val(), obj) for obj in qs.iterator()]) + # XXX Mostly DONE def delete(self): """ Deletes the records in the current QuerySet. """ - assert self._limit is None and self._offset is None, \ - "Cannot use 'limit' or 'offset' with delete." + assert self.query.can_filter(), \ + "Cannot use 'limit' or 'offset' with delete." del_query = self._clone() - # disable non-supported fields - del_query._select_related = False - del_query._order_by = [] + # Disable non-supported fields. + del_query.query.select_related = False + del_query.query.clear_ordering() - # Delete objects in chunks to prevent an the list of - # related objects from becoming too long + # Delete objects in chunks to prevent the list of related objects from + # becoming too long. more_objects = True while more_objects: - # Collect all the objects to be deleted in this chunk, and all the objects - # that are related to the objects that are to be deleted + # Collect all the objects to be deleted in this chunk, and all the + # objects that are related to the objects that are to be deleted. seen_objs = SortedDict() more_objects = False - for object in del_query[0:GET_ITERATOR_CHUNK_SIZE]: + for object in del_query[:CHUNK_SIZE]: more_objects = True object._collect_sub_objects(seen_objs) # If one or more objects were found, delete them. # Otherwise, stop looping. + # FIXME: Does "if seen_objs:.." work here? If so, we can get rid of + # more_objects. if more_objects: delete_objects(seen_objs) @@ -359,13 +254,16 @@ class _QuerySet(object): def values(self, *fields): return self._clone(klass=ValuesQuerySet, _fields=fields) + # FIXME: Not converted yet! def dates(self, field_name, kind, order='ASC'): """ Returns a list of datetime objects representing all available dates for the given field_name, scoped to 'kind'. """ - assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." - assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'." + assert kind in ("month", "year", "day"), \ + "'kind' must be one of 'year', 'month' or 'day'." + assert order in ('ASC', 'DESC'), \ + "'order' must be either 'ASC' or 'DESC'." # Let the FieldDoesNotExist exception propagate. field = self.model._meta.get_field(field_name, many_to_many=False) assert isinstance(field, DateField), "%r isn't a DateField." % field_name @@ -376,62 +274,89 @@ class _QuerySet(object): ################################################################## def filter(self, *args, **kwargs): - "Returns a new QuerySet instance with the args ANDed to the existing set." + """ + Returns a new QuerySet instance with the args ANDed to the existing + set. + """ return self._filter_or_exclude(None, *args, **kwargs) def exclude(self, *args, **kwargs): - "Returns a new QuerySet instance with NOT (args) ANDed to the existing set." + """ + Returns a new QuerySet instance with NOT (args) ANDed to the existing + set. + """ return self._filter_or_exclude(QNot, *args, **kwargs) def _filter_or_exclude(self, mapper, *args, **kwargs): # mapper is a callable used to transform Q objects, - # or None for identity transform + # or None for identity transform. if mapper is None: mapper = lambda x: x - if len(args) > 0 or len(kwargs) > 0: - assert self._limit is None and self._offset is None, \ + if args or kwargs: + assert self.query.can_filter(), \ "Cannot filter a query once a slice has been taken." clone = self._clone() - if len(kwargs) > 0: - clone._filters = clone._filters & mapper(Q(**kwargs)) - if len(args) > 0: - clone._filters = clone._filters & reduce(operator.and_, map(mapper, args)) + if kwargs: + clone.query.add_q(mapper(Q(**kwargs))) + for arg in args: + clone.query.add_q(arg) return clone def complex_filter(self, filter_obj): - """Returns a new QuerySet instance with filter_obj added to the filters. - filter_obj can be a Q object (has 'get_sql' method) or a dictionary of - keyword lookup arguments.""" - # This exists to support framework features such as 'limit_choices_to', - # and usually it will be more natural to use other methods. - if hasattr(filter_obj, 'get_sql'): + """ + Returns a new QuerySet instance with filter_obj added to the filters. + filter_obj can be a Q object (or anything with an add_to_query() + method) or a dictionary of keyword lookup arguments. + + This exists to support framework features such as 'limit_choices_to', + and usually it will be more natural to use other methods. + """ + if isinstance(filter_obj, Q) or hasattr(filter_obj, 'add_to_query'): return self._filter_or_exclude(None, filter_obj) else: return self._filter_or_exclude(None, **filter_obj) def select_related(self, true_or_false=True, depth=0): - "Returns a new QuerySet instance with '_select_related' modified." - return self._clone(_select_related=true_or_false, _max_related_depth=depth) + """Returns a new QuerySet instance that will select related objects.""" + obj = self._clone() + obj.query.select_related = true_or_false + obj.query.max_depth = depth + return obj def order_by(self, *field_names): - "Returns a new QuerySet instance with the ordering changed." - assert self._limit is None and self._offset is None, \ + """Returns a new QuerySet instance with the ordering changed.""" + assert self.query.can_filter(), \ "Cannot reorder a query once a slice has been taken." - return self._clone(_order_by=field_names) + obj = self._clone() + obj.query.add_ordering(*field_names) + return obj def distinct(self, true_or_false=True): - "Returns a new QuerySet instance with '_distinct' modified." - return self._clone(_distinct=true_or_false) + """ + Returns a new QuerySet instance that will select only distinct results. + """ + obj = self._clone() + obj.query.distinct = true_or_false + return obj def extra(self, select=None, where=None, params=None, tables=None): - assert self._limit is None and self._offset is None, \ + """ + Add extra SQL fragments to the query. These are applied more or less + verbatim (no quoting, no alias renaming, etc), so care should be taken + when using extra() with other complex filters and combinations. + """ + assert self.query.can_filter(), \ "Cannot change a query once a slice has been taken" clone = self._clone() - if select: clone._select.update(select) - if where: clone._where.extend(where) - if params: clone._params.extend(params) - if tables: clone._tables.extend(tables) + if select: + clone.query.extra_select.update(select) + if where: + clone.query.extra_where.extend(where) + if params: + clone.query.extra_params.extend(params) + if tables: + clone.query.extra_tables.extend(tables) return clone ################### @@ -443,127 +368,15 @@ class _QuerySet(object): klass = self.__class__ c = klass() c.model = self.model - c._filters = self._filters - c._order_by = self._order_by - c._select_related = self._select_related - c._max_related_depth = self._max_related_depth - c._distinct = self._distinct - c._select = self._select.copy() - c._where = self._where[:] - c._params = self._params[:] - c._tables = self._tables[:] - c._offset = self._offset - c._limit = self._limit + c.query = self.query.clone() c.__dict__.update(kwargs) return c - def _combine(self, other): - assert self._limit is None and self._offset is None \ - and other._limit is None and other._offset is None, \ - "Cannot combine queries once a slice has been taken." - assert self._distinct == other._distinct, \ - "Cannot combine a unique query with a non-unique query" - # use 'other's order by - # (so that A.filter(args1) & A.filter(args2) does the same as - # A.filter(args1).filter(args2) - combined = other._clone() - if self._select: combined._select.update(self._select) - if self._where: combined._where.extend(self._where) - if self._params: combined._params.extend(self._params) - if self._tables: combined._tables.extend(self._tables) - # If 'self' is ordered and 'other' isn't, propagate 'self's ordering - if (self._order_by is not None and len(self._order_by) > 0) and \ - (combined._order_by is None or len(combined._order_by) == 0): - combined._order_by = self._order_by - return combined - def _get_data(self): if self._result_cache is None: self._result_cache = list(self.iterator()) return self._result_cache - def _get_sql_clause(self): - qn = connection.ops.quote_name - opts = self.model._meta - - # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. - select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] - tables = [quote_only_if_word(t) for t in self._tables] - joins = SortedDict() - where = self._where[:] - params = self._params[:] - - # Convert self._filters into SQL. - joins2, where2, params2 = self._filters.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) - - # Add additional tables and WHERE clauses based on select_related. - if self._select_related: - fill_table_cache(opts, select, tables, where, - old_prefix=opts.db_table, - cache_tables_seen=[opts.db_table], - max_depth=self._max_related_depth) - - # Add any additional SELECTs. - if self._select: - select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) - - # Start composing the body of the SQL statement. - sql = [" FROM", qn(opts.db_table)] - - # Compose the join dictionary into SQL describing the joins. - if joins: - sql.append(" ".join(["%s %s AS %s ON %s" % (join_type, table, alias, condition) - for (alias, (table, join_type, condition)) in joins.items()])) - - # Compose the tables clause into SQL. - if tables: - sql.append(", " + ", ".join(tables)) - - # Compose the where clause into SQL. - if where: - sql.append(where and "WHERE " + " AND ".join(where)) - - # ORDER BY clause - order_by = [] - if self._order_by is not None: - ordering_to_use = self._order_by - else: - ordering_to_use = opts.ordering - for f in handle_legacy_orderlist(ordering_to_use): - if f == '?': # Special case. - order_by.append(connection.ops.random_function_sql()) - else: - if f.startswith('-'): - col_name = f[1:] - order = "DESC" - else: - col_name = f - order = "ASC" - if "." in col_name: - table_prefix, col_name = col_name.split('.', 1) - table_prefix = qn(table_prefix) + '.' - else: - # Use the database table as a column prefix if it wasn't given, - # and if the requested column isn't a custom SELECT. - if "." not in col_name and col_name not in (self._select or ()): - table_prefix = qn(opts.db_table) + '.' - else: - table_prefix = '' - order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) - if order_by: - sql.append("ORDER BY " + ", ".join(order_by)) - - # LIMIT and OFFSET clauses - if self._limit is not None: - sql.append("%s " % connection.ops.limit_offset_sql(self._limit, self._offset)) - else: - assert self._offset is None, "'offset' is not allowed without 'limit'" - - return select, " ".join(sql), params - # Use the backend's QuerySet class if it defines one. Otherwise, use _QuerySet. if connection.features.uses_custom_queryset: QuerySet = connection.ops.query_set_class(_QuerySet) @@ -574,7 +387,7 @@ class ValuesQuerySet(QuerySet): def __init__(self, *args, **kwargs): super(ValuesQuerySet, self).__init__(*args, **kwargs) # select_related isn't supported in values(). - self._select_related = False + self.query.select_related = False def iterator(self): try: @@ -707,123 +520,14 @@ class EmptyQuerySet(QuerySet): def _get_sql_clause(self): raise EmptyResultSet -class QOperator(object): - "Base class for QAnd and QOr" - def __init__(self, *args): - self.args = args +# QOperator, QAnd and QOr are temporarily retained for backwards compatibility. +# All the old functionality is now part of the 'Q' class. +class QOperator(Q): + def __init__(self, *args, **kwargs): + warnings.warn('Use Q instead of QOr, QAnd or QOperation.', + DeprecationWarning, stacklevel=2) - def get_sql(self, opts): - joins, where, params = SortedDict(), [], [] - for val in self.args: - try: - joins2, where2, params2 = val.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) - except EmptyResultSet: - if not isinstance(self, QOr): - raise EmptyResultSet - if where: - return joins, ['(%s)' % self.operator.join(where)], params - return joins, [], params - -class QAnd(QOperator): - "Encapsulates a combined query that uses 'AND'." - operator = ' AND ' - def __or__(self, other): - return QOr(self, other) - - def __and__(self, other): - if isinstance(other, QAnd): - return QAnd(*(self.args+other.args)) - elif isinstance(other, (Q, QOr)): - return QAnd(*(self.args+(other,))) - else: - raise TypeError, other - -class QOr(QOperator): - "Encapsulates a combined query that uses 'OR'." - operator = ' OR ' - def __and__(self, other): - return QAnd(self, other) - - def __or__(self, other): - if isinstance(other, QOr): - return QOr(*(self.args+other.args)) - elif isinstance(other, (Q, QAnd)): - return QOr(*(self.args+(other,))) - else: - raise TypeError, other - -class Q(object): - "Encapsulates queries as objects that can be combined logically." - def __init__(self, **kwargs): - self.kwargs = kwargs - - def __and__(self, other): - return QAnd(self, other) - - def __or__(self, other): - return QOr(self, other) - - def get_sql(self, opts): - return parse_lookup(self.kwargs.items(), opts) - -class QNot(Q): - "Encapsulates NOT (...) queries as objects" - def __init__(self, q): - "Creates a negation of the q object passed in." - self.q = q - - def get_sql(self, opts): - try: - joins, where, params = self.q.get_sql(opts) - where2 = ['(NOT (%s))' % " AND ".join(where)] - except EmptyResultSet: - return SortedDict(), [], [] - return joins, where2, params - -def get_where_clause(lookup_type, table_prefix, field_name, value, db_type): - if table_prefix.endswith('.'): - table_prefix = connection.ops.quote_name(table_prefix[:-1])+'.' - field_name = connection.ops.quote_name(field_name) - if type(value) == datetime.datetime and connection.ops.datetime_cast_sql(): - cast_sql = connection.ops.datetime_cast_sql() - else: - cast_sql = '%s' - field_sql = connection.ops.field_cast_sql(db_type) % (table_prefix + field_name) - if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith') and connection.features.needs_upper_for_iops: - format = 'UPPER(%s) %s' - else: - format = '%s %s' - try: - return format % (field_sql, connection.operators[lookup_type] % cast_sql) - except KeyError: - pass - if lookup_type == 'in': - in_string = ','.join(['%s' for id in value]) - if in_string: - return '%s IN (%s)' % (field_sql, in_string) - else: - raise EmptyResultSet - elif lookup_type in ('range', 'year'): - return '%s BETWEEN %%s AND %%s' % field_sql - elif lookup_type in ('month', 'day'): - return "%s = %%s" % connection.ops.date_extract_sql(lookup_type, field_sql) - elif lookup_type == 'isnull': - return "%s IS %sNULL" % (field_sql, (not value and 'NOT ' or '')) - elif lookup_type == 'search': - return connection.ops.fulltext_search_sql(field_sql) - elif lookup_type in ('regex', 'iregex'): - if settings.DATABASE_ENGINE == 'oracle': - if lookup_type == 'regex': - match_option = 'c' - else: - match_option = 'i' - return "REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql, match_option) - else: - raise NotImplementedError - raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) +QOr = QAnd = QOperator def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): """Helper function that recursively returns an object with cache filled""" @@ -842,342 +546,50 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): - """ - Helper function that recursively populates the select, tables and where (in - place) for select_related queries. - """ - - # If we've got a max_depth set and we've exceeded that depth, bail now. - if max_depth and cur_depth > max_depth: - return None - - qn = connection.ops.quote_name - for f in opts.fields: - if f.rel and not f.null: - db_table = f.rel.to._meta.db_table - if db_table not in cache_tables_seen: - tables.append(qn(db_table)) - else: # The table was already seen, so give it a table alias. - new_prefix = '%s%s' % (db_table, len(cache_tables_seen)) - tables.append('%s %s' % (qn(db_table), qn(new_prefix))) - db_table = new_prefix - cache_tables_seen.append(db_table) - where.append('%s.%s = %s.%s' % \ - (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) - select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) - fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) - -def parse_lookup(kwarg_items, opts): - # Helper function that handles converting API kwargs - # (e.g. "name__exact": "tom") to SQL. - # Returns a tuple of (joins, where, params). - - # 'joins' is a sorted dictionary describing the tables that must be joined - # to complete the query. The dictionary is sorted because creation order - # is significant; it is a dictionary to ensure uniqueness of alias names. - # - # Each key-value pair follows the form - # alias: (table, join_type, condition) - # where - # alias is the AS alias for the joined table - # table is the actual table name to be joined - # join_type is the type of join (INNER JOIN, LEFT OUTER JOIN, etc) - # condition is the where-like statement over which narrows the join. - # alias will be derived from the lookup list name. - # - # At present, this method only every returns INNER JOINs; the option is - # there for others to implement custom Q()s, etc that return other join - # types. - joins, where, params = SortedDict(), [], [] - - for kwarg, value in kwarg_items: - path = kwarg.split(LOOKUP_SEPARATOR) - # Extract the last elements of the kwarg. - # The very-last is the lookup_type (equals, like, etc). - # The second-last is the table column on which the lookup_type is - # to be performed. If this name is 'pk', it will be substituted with - # the name of the primary key. - # If there is only one part, or the last part is not a query - # term, assume that the query is an __exact - lookup_type = path.pop() - if lookup_type == 'pk': - lookup_type = 'exact' - path.append(None) - elif len(path) == 0 or lookup_type not in QUERY_TERMS: - path.append(lookup_type) - lookup_type = 'exact' - - if len(path) < 1: - raise TypeError, "Cannot parse keyword query %r" % kwarg - - if value is None: - # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject - # all uses of None as a query value. - if lookup_type != 'exact': - raise ValueError, "Cannot use None as a query value" - elif callable(value): - value = value() - - joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None) - joins.update(joins2) - where.extend(where2) - params.extend(params2) - return joins, where, params - -class FieldFound(Exception): - "Exception used to short circuit field-finding operations." - pass - -def find_field(name, field_list, related_query): - """ - Finds a field with a specific name in a list of field instances. - Returns None if there are no matches, or several matches. - """ - if related_query: - matches = [f for f in field_list if f.field.related_query_name() == name] - else: - matches = [f for f in field_list if f.name == name] - if len(matches) != 1: - return None - return matches[0] - -def field_choices(field_list, related_query): - if related_query: - choices = [f.field.related_query_name() for f in field_list] - else: - choices = [f.name for f in field_list] - return choices - -def lookup_inner(path, lookup_type, value, opts, table, column): - qn = connection.ops.quote_name - joins, where, params = SortedDict(), [], [] - current_opts = opts - current_table = table - current_column = column - intermediate_table = None - join_required = False - - name = path.pop(0) - # Has the primary key been requested? If so, expand it out - # to be the name of the current class' primary key - if name is None or name == 'pk': - name = current_opts.pk.name - - # Try to find the name in the fields associated with the current class - try: - # Does the name belong to a defined many-to-many field? - field = find_field(name, current_opts.many_to_many, False) - if field: - new_table = current_table + '__' + name - new_opts = field.rel.to._meta - new_column = new_opts.pk.column - - # Need to create an intermediate table join over the m2m table - # This process hijacks current_table/column to point to the - # intermediate table. - current_table = "m2m_" + new_table - intermediate_table = field.m2m_db_table() - join_column = field.m2m_reverse_name() - intermediate_column = field.m2m_column_name() - - raise FieldFound - - # Does the name belong to a reverse defined many-to-many field? - field = find_field(name, current_opts.get_all_related_many_to_many_objects(), True) - if field: - new_table = current_table + '__' + name - new_opts = field.opts - new_column = new_opts.pk.column - - # Need to create an intermediate table join over the m2m table. - # This process hijacks current_table/column to point to the - # intermediate table. - current_table = "m2m_" + new_table - intermediate_table = field.field.m2m_db_table() - join_column = field.field.m2m_column_name() - intermediate_column = field.field.m2m_reverse_name() - - raise FieldFound - - # Does the name belong to a one-to-many field? - field = find_field(name, current_opts.get_all_related_objects(), True) - if field: - new_table = table + '__' + name - new_opts = field.opts - new_column = field.field.column - join_column = opts.pk.column - - # 1-N fields MUST be joined, regardless of any other conditions. - join_required = True - - raise FieldFound - - # Does the name belong to a one-to-one, many-to-one, or regular field? - field = find_field(name, current_opts.fields, False) - if field: - if field.rel: # One-to-One/Many-to-one field - new_table = current_table + '__' + name - new_opts = field.rel.to._meta - new_column = new_opts.pk.column - join_column = field.column - raise FieldFound - elif path: - # For regular fields, if there are still items on the path, - # an error has been made. We munge "name" so that the error - # properly identifies the cause of the problem. - name += LOOKUP_SEPARATOR + path[0] - else: - raise FieldFound - - except FieldFound: # Match found, loop has been shortcut. - pass - else: # No match found. - choices = field_choices(current_opts.many_to_many, False) + \ - field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \ - field_choices(current_opts.get_all_related_objects(), True) + \ - field_choices(current_opts.fields, False) - raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices)) - - # Check whether an intermediate join is required between current_table - # and new_table. - if intermediate_table: - joins[qn(current_table)] = ( - qn(intermediate_table), "LEFT OUTER JOIN", - "%s.%s = %s.%s" % (qn(table), qn(current_opts.pk.column), qn(current_table), qn(intermediate_column)) - ) - - if path: - # There are elements left in the path. More joins are required. - if len(path) == 1 and path[0] in (new_opts.pk.name, None) \ - and lookup_type in ('exact', 'isnull') and not join_required: - # If the next and final name query is for a primary key, - # and the search is for isnull/exact, then the current - # (for N-1) or intermediate (for N-N) table can be used - # for the search. No need to join an extra table just - # to check the primary key. - new_table = current_table - else: - # There are 1 or more name queries pending, and we have ruled out - # any shortcuts; therefore, a join is required. - joins[qn(new_table)] = ( - qn(new_opts.db_table), "INNER JOIN", - "%s.%s = %s.%s" % (qn(current_table), qn(join_column), qn(new_table), qn(new_column)) - ) - # If we have made the join, we don't need to tell subsequent - # recursive calls about the column name we joined on. - join_column = None - - # There are name queries remaining. Recurse deeper. - joins2, where2, params2 = lookup_inner(path, lookup_type, value, new_opts, new_table, join_column) - - joins.update(joins2) - where.extend(where2) - params.extend(params2) - else: - # No elements left in path. Current element is the element on which - # the search is being performed. - db_type = None - - if join_required: - # Last query term is a RelatedObject - if field.field.rel.multiple: - # RelatedObject is from a 1-N relation. - # Join is required; query operates on joined table. - column = new_opts.pk.name - joins[qn(new_table)] = ( - qn(new_opts.db_table), "INNER JOIN", - "%s.%s = %s.%s" % (qn(current_table), qn(join_column), qn(new_table), qn(new_column)) - ) - current_table = new_table - else: - # RelatedObject is from a 1-1 relation, - # No need to join; get the pk value from the related object, - # and compare using that. - column = current_opts.pk.name - elif intermediate_table: - # Last query term is a related object from an N-N relation. - # Join from intermediate table is sufficient. - column = join_column - elif name == current_opts.pk.name and lookup_type in ('exact', 'isnull') and current_column: - # Last query term is for a primary key. If previous iterations - # introduced a current/intermediate table that can be used to - # optimize the query, then use that table and column name. - column = current_column - else: - # Last query term was a normal field. - column = field.column - db_type = field.db_type() - - where.append(get_where_clause(lookup_type, current_table + '.', column, value, db_type)) - params.extend(field.get_db_prep_lookup(lookup_type, value)) - - return joins, where, params - def delete_objects(seen_objs): - "Iterate through a list of seen classes, and remove any instances that are referred to" - qn = connection.ops.quote_name + """ + Iterate through a list of seen classes, and remove any instances that are + referred to. + """ ordered_classes = seen_objs.keys() ordered_classes.reverse() - cursor = connection.cursor() - for cls in ordered_classes: seen_objs[cls] = seen_objs[cls].items() seen_objs[cls].sort() # Pre notify all instances to be deleted for pk_val, instance in seen_objs[cls]: - dispatcher.send(signal=signals.pre_delete, sender=cls, instance=instance) + dispatcher.send(signal=signals.pre_delete, sender=cls, + instance=instance) pk_list = [pk for pk,instance in seen_objs[cls]] - for related in cls._meta.get_all_related_many_to_many_objects(): - if not isinstance(related.field, generic.GenericRelation): - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(related.field.m2m_db_table()), - qn(related.field.m2m_reverse_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) - for f in cls._meta.many_to_many: - if isinstance(f, generic.GenericRelation): - from django.contrib.contenttypes.models import ContentType - query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column - args_extra = [ContentType.objects.get_for_model(cls).id] - else: - query_extra = '' - args_extra = [] - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(f.m2m_db_table()), qn(f.m2m_column_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra, - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra) + del_query = sql.DeleteQuery(cls, connection) + del_query.delete_batch_related(pk_list) + + update_query = sql.UpdateQuery(cls, connection) for field in cls._meta.fields: if field.rel and field.null and field.rel.to in seen_objs: - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("UPDATE %s SET %s=NULL WHERE %s IN (%s)" % \ - (qn(cls._meta.db_table), qn(field.column), qn(cls._meta.pk.column), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + update_query.clear_related(field.column, pk_list) # Now delete the actual data for cls in ordered_classes: seen_objs[cls].reverse() pk_list = [pk for pk,instance in seen_objs[cls]] - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(cls._meta.db_table), qn(cls._meta.pk.column), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + del_query = sql.DeleteQuery(cls, connection) + del_query.delete_batch(pk_list) - # Last cleanup; set NULLs where there once was a reference to the object, - # NULL the primary key of the found objects, and perform post-notification. + # Last cleanup; set NULLs where there once was a reference to the + # object, NULL the primary key of the found objects, and perform + # post-notification. for pk_val, instance in seen_objs[cls]: for field in cls._meta.fields: if field.rel and field.null and field.rel.to in seen_objs: setattr(instance, field.attname, None) setattr(instance, cls._meta.pk.attname, None) - dispatcher.send(signal=signals.post_delete, sender=cls, instance=instance) + dispatcher.send(signal=signals.post_delete, sender=cls, + instance=instance) transaction.commit_unless_managed() + diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py new file mode 100644 index 0000000000..274a2b7835 --- /dev/null +++ b/django/db/models/query_utils.py @@ -0,0 +1,52 @@ +""" +Various data structures used in query construction. + +Factored out from django.db.models.query so that they can also be used by other +modules without getting into circular import difficulties. +""" +from django.utils import tree + +class EmptyResultSet(Exception): + """ + Raised when a QuerySet cannot contain any data. + """ + pass + +class Q(tree.Node): + """ + Encapsulates filters as objects that can then be combined logically (using + & and |). + """ + # Connection types + AND = 'AND' + OR = 'OR' + default = AND + + def __init__(self, *args, **kwargs): + if args and kwargs: + raise TypeError('Use positional *or* kwargs; not both!') + nodes = list(args) + kwargs.items() + super(Q, self).__init__(children=nodes) + + def _combine(self, other, conn): + if not isinstance(other, Q): + raise TypeError(other) + self.add(other, conn) + return self + + def __or__(self, other): + return self._combine(other, self.OR) + + def __and__(self, other): + return self._combine(other, self.AND) + +class QNot(Q): + """ + Encapsulates the negation of a Q object. + """ + def __init__(self, q): + """Creates the negation of the Q object passed in.""" + super(QNot, self).__init__() + self.add(q, self.AND) + self.negate() + diff --git a/django/db/models/sql/__init__.py b/django/db/models/sql/__init__.py new file mode 100644 index 0000000000..ec40c03283 --- /dev/null +++ b/django/db/models/sql/__init__.py @@ -0,0 +1,5 @@ +from query import * +from where import AND, OR + +__all__ = ['Query', 'AND', 'OR'] + diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py new file mode 100644 index 0000000000..5b01abbb87 --- /dev/null +++ b/django/db/models/sql/datastructures.py @@ -0,0 +1,59 @@ +""" +Useful auxilliary data structures for query construction. Not useful outside +the SQL domain. +""" + +class EmptyResultSet(Exception): + pass + +class Aggregate(object): + """ + Base class for all aggregate-related classes (min, max, avg, count, sum). + """ + def relabel_aliases(self, change_map): + """ + Relabel the column alias, if necessary. Must be implemented by + subclasses. + """ + raise NotImplementedError + + def as_sql(self, quote_func=None): + """ + Returns the SQL string fragment for this object. + + The quote_func function is used to quote the column components. If + None, it defaults to doing nothing. + + Must be implemented by subclasses. + """ + raise NotImplementedError + +class Count(Aggregate): + """ + Perform a count on the given column. + """ + def __init__(self, col=None, distinct=False): + """ + Set the column to count on (defaults to '*') and set whether the count + should be distinct or not. + """ + self.col = col and col or '*' + self.distinct = distinct + + def relabel_aliases(self, change_map): + c = self.col + if isinstance(c, (list, tuple)): + self.col = (change_map.get(c[0], c[0]), c[1]) + + def as_sql(self, quote_func=None): + if not quote_func: + quote_func = lambda x: x + if isinstance(self.col, (list, tuple)): + col = '%s.%s' % tuple([quote_func(c) for c in self.col]) + else: + col = self.col + if self.distinct: + return 'COUNT(DISTINCT(%s))' % col + else: + return 'COUNT(%s)' % col + diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py new file mode 100644 index 0000000000..089bd75774 --- /dev/null +++ b/django/db/models/sql/query.py @@ -0,0 +1,880 @@ +""" +Create SQL statements for QuerySets. + +The code in here encapsulates all of the SQL construction so that QuerySets +themselves do not have to (and could be backed by things other than SQL +databases). The abstraction barrier only works one way: this module has to know +all about the internals of models in order to get the information it needs. +""" + +import copy + +from django.utils import tree +from django.db.models.sql.where import WhereNode, AND +from django.db.models.sql.datastructures import Count +from django.db.models.fields import FieldDoesNotExist +from django.contrib.contenttypes import generic +from datastructures import EmptyResultSet +from utils import handle_legacy_orderlist + +try: + reversed +except NameError: + from django.utils.itercompat import reversed # For python 2.3. + +# Valid query types (a dictionary is used for speedy lookups). +QUERY_TERMS = dict([(x, None) for x in ( + 'exact', 'iexact', 'contains', 'icontains', 'gt', 'gte', 'lt', 'lte', 'in', + 'startswith', 'istartswith', 'endswith', 'iendswith', 'range', 'year', + 'month', 'day', 'isnull', 'search', 'regex', 'iregex', + )]) + +# Size of each "chunk" for get_iterator calls. +# Larger values are slightly faster at the expense of more storage space. +GET_ITERATOR_CHUNK_SIZE = 100 + +# Separator used to split filter strings apart. +LOOKUP_SEP = '__' + +# Constants to make looking up tuple values clearerer. +# Join lists +TABLE_NAME = 0 +RHS_ALIAS = 1 +JOIN_TYPE = 2 +LHS_ALIAS = 3 +LHS_JOIN_COL = 4 +RHS_JOIN_COL = 5 +# Alias maps lists +ALIAS_TABLE = 0 +ALIAS_REFCOUNT = 1 +ALIAS_JOIN = 2 + +# How many results to expect from a cursor.execute call +MULTI = 'multi' +SINGLE = 'single' +NONE = None + +class Query(object): + """ + A single SQL query. + """ + # SQL join types. These are part of the class because their string forms + # vary from database to database and can be customised by a subclass. + INNER = 'INNER JOIN' + LOUTER = 'LEFT OUTER JOIN' + + alias_prefix = 'T' + + def __init__(self, model, connection): + self.model = model + self.connection = connection + self.alias_map = {} # Maps alias to table name + self.table_map = {} # Maps table names to list of aliases. + self.join_map = {} # Maps join_tuple to list of aliases. + self.rev_join_map = {} # Reverse of join_map. + + # SQL-related attributes + self.select = [] + self.tables = [] # Aliases in the order they are created. + self.where = WhereNode(self) + self.having = [] + self.group_by = [] + self.order_by = [] + self.low_mark, self.high_mark = 0, None # Used for offset/limit + self.distinct = False + self.select_related = False + self.max_depth = 0 + + # These are for extensions. The contents are more or less appended + # verbatim to the appropriate clause. + self.extra_select = {} # Maps col_alias -> col_sql. + self.extra_tables = [] + self.extra_where = [] + self.extra_params = [] + + def __str__(self): + """ + Returns the query as a string of SQL with the parameter values + substituted in. + + Parameter values won't necessarily be quoted correctly, since that is + done by the database interface at execution time. + """ + sql, params = self.as_sql() + return sql % params + + def clone(self, **kwargs): + """ + Creates a copy of the current instance. The 'kwargs' parameter can be + used by clients to update attributes after copying has taken place. + """ + obj = self.__class__(self.model, self.connection) + obj.table_map = self.table_map.copy() + obj.alias_map = copy.deepcopy(self.alias_map) + obj.join_map = copy.deepcopy(self.join_map) + obj.rev_join_map = copy.deepcopy(self.rev_join_map) + obj.select = self.select[:] + obj.tables = self.tables[:] + obj.where = copy.deepcopy(self.where) + obj.having = self.having[:] + obj.group_by = self.group_by[:] + obj.order_by = self.order_by[:] + obj.low_mark, obj.high_mark = self.low_mark, self.high_mark + obj.distinct = self.distinct + obj.select_related = self.select_related + obj.max_depth = self.max_depth + obj.extra_select = self.extra_select.copy() + obj.extra_tables = self.extra_tables[:] + obj.extra_where = self.extra_where[:] + obj.extra_params = self.extra_params[:] + obj.__dict__.update(kwargs) + return obj + + def results_iter(self): + """ + Returns an iterator over the results from executing this query. + """ + fields = self.model._meta.fields + resolve_columns = hasattr(self, 'resolve_columns') + for rows in self.execute_sql(MULTI): + for row in rows: + if resolve_columns: + row = self.resolve_columns(row, fields) + yield row + + def get_count(self): + """ + Performs a COUNT() or COUNT(DISTINCT()) query, as appropriate, using + the current filter constraints. + """ + counter = self.clone() + counter.clear_ordering() + counter.clear_limits() + counter.select_related = False + counter.add_count_column() + data = counter.execute_sql(SINGLE) + if not data: + return 0 + number = data[0] + + # Apply offset and limit constraints manually, since using LIMIT/OFFSET + # in SQL doesn't change the COUNT output. + number = max(0, number - self.low_mark) + if self.high_mark: + number = min(number, self.high_mark - self.low_mark) + + return number + + def as_sql(self, with_limits=True): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + + If 'with_limits' is False, any limit/offset information is not included + in the query. + """ + self.pre_sql_setup() + result = ['SELECT'] + if self.distinct: + result.append('DISTINCT') + out_cols = self.get_columns() + result.append(', '.join(out_cols)) + + result.append('FROM') + for alias in self.tables: + if not self.alias_map[alias][ALIAS_REFCOUNT]: + continue + name, alias, join_type, lhs, lhs_col, col = \ + self.alias_map[alias][ALIAS_JOIN] + alias_str = (alias != name and ' AS %s' % alias or '') + if join_type: + result.append('%s %s%s ON (%s.%s = %s.%s)' + % (join_type, name, alias_str, lhs, lhs_col, alias, + col)) + else: + result.append('%s%s' % (name, alias_str)) + result.extend(self.extra_tables) + + where, params = self.where.as_sql() + if where: + result.append('WHERE %s' % where) + result.append(' AND'.join(self.extra_where)) + + ordering = self.get_ordering() + if ordering: + result.append('ORDER BY %s' % ', '.join(ordering)) + + if with_limits: + if self.high_mark: + result.append('LIMIT %d' % (self.high_mark - self.low_mark)) + if self.low_mark: + assert self.high_mark, "OFFSET not allowed without LIMIT." + result.append('OFFSET %d' % self.low_mark) + + params.extend(self.extra_params) + return ' '.join(result), tuple(params) + + def combine(self, rhs, connection): + """ + Merge the 'rhs' query into the current one (with any 'rhs' effects + being applied *after* (that is, "to the right of") anything in the + current query. 'rhs' is not modified during a call to this function. + + The 'connection' parameter describes how to connect filters from the + 'rhs' query. + """ + assert self.model == rhs.model, \ + "Cannot combine queries on two different base models." + assert self.can_filter(), \ + "Cannot combine queries once a slice has been taken." + assert self.distinct == rhs.distinct, \ + "Cannot combine a unique query with a non-unique query." + + # Work out how to relabel the rhs aliases, if necessary. + change_map = {} + used = {} + for alias in rhs.tables: + promote = (rhs.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] == + self.LOUTER) + new_alias = self.join(rhs.rev_join_map[alias], exclusions=used, + promote=promote) + used[new_alias] = None + change_map[alias] = new_alias + + # Now relabel a copy of the rhs where-clause and add it to the current + # one. + if rhs.where: + w = copy.deepcopy(rhs.where) + w.relabel_aliases(change_map) + if not self.where: + # Since 'self' matches everything, add an explicit "include + # everything" (pk is not NULL) where-constraint so that + # connections between the where clauses won't exclude valid + # results. + alias = self.join((None, self.model._meta.db_table, None, None)) + pk = self.model._meta.pk + self.where.add((alias, pk.column, pk, 'isnull', False), AND) + elif self.where: + # rhs has an empty where clause. Make it match everything (see + # above for reasoning). + w = WhereNode() + alias = self.join((None, self.model._meta.db_table, None, None)) + pk = self.model._meta.pk + w.add((alias, pk.column, pk, 'isnull', False), AND) + else: + w = WhereNode() + self.where.add(w, connection) + + # Selection columns and extra extensions are those provided by 'rhs'. + self.select = [] + for col in rhs.select: + if isinstance(col, (list, tuple)): + self.select.append((change_map.get(col[0], col[0]), col[1])) + else: + item = copy.deepcopy(col) + item.relabel_aliases(change_map) + self.select.append(item) + self.extra_select = rhs.extra_select.copy() + self.extra_tables = rhs.extra_tables[:] + self.extra_where = rhs.extra_where[:] + self.extra_params = rhs.extra_params[:] + + # Ordering uses the 'rhs' ordering, unless it has none, in which case + # the current ordering is used. + self.order_by = rhs.order_by and rhs.order_by[:] or self.order_by + + def pre_sql_setup(self): + """ + Does any necessary class setup prior to producing SQL. This is for + things that can't necessarily be done in __init__. + """ + if not self.tables: + self.join((None, self.model._meta.db_table, None, None)) + + def get_columns(self): + """ + Return the list of columns to use in the select statement. If no + columns have been specified, returns all columns relating to fields in + the model. + """ + qn = self.connection.ops.quote_name + result = [] + if self.select: + for col in self.select: + if isinstance(col, (list, tuple)): + result.append('%s.%s' % (qn(col[0]), qn(col[1]))) + else: + result.append(col.as_sql()) + else: + table_alias = self.tables[0] + result = ['%s.%s' % (table_alias, qn(f.column)) + for f in self.model._meta.fields] + + # We sort extra_select so that the result columns are in a well-defined + # order (and thus QuerySet.iterator can extract them correctly). + extra_select = self.extra_select.items() + extra_select.sort() + result.extend(['(%s) AS %s' % (col, alias) + for alias, col in extra_select]) + return result + + def get_ordering(self): + """ + Returns a tuple representing the SQL elements in the "order by" clause. + """ + ordering = self.order_by or self.model._meta.ordering + qn = self.connection.ops.quote_name + opts = self.model._meta + result = [] + for field in handle_legacy_orderlist(ordering): + if field == '?': + result.append(self.connection.ops.random_function_sql()) + continue + if field[0] == '-': + col = field[1:] + order = 'DESC' + else: + col = field + order = 'ASC' + if '.' in col: + table, col = col.split('.', 1) + table = '%s.' % self.table_alias[table] + elif col not in self.extra_select: + # Use the root model's database table as the referenced table. + table = '%s.' % self.tables[0] + else: + table = '' + result.append('%s%s %s' % (table, + qn(orderfield_to_column(col, opts)), order)) + return result + + def table_alias(self, table_name, create=False): + """ + Returns a table alias for the given table_name and whether this is a + new alias or not. + + If 'create' is true, a new alias is always created. Otherwise, the + most recently created alias for the table (if one exists) is reused. + """ + if not create and table_name in self.table_map: + alias = self.table_map[table_name][-1] + self.alias_map[alias][ALIAS_REFCOUNT] += 1 + return alias, False + + # Create a new alias for this table. + if table_name not in self.table_map: + # The first occurence of a table uses the table name directly. + alias = table_name + else: + alias = '%s%d' % (self.alias_prefix, len(self.alias_map) + 1) + self.alias_map[alias] = [table_name, 1, None] + self.table_map.setdefault(table_name, []).append(alias) + self.tables.append(alias) + return alias, True + + def ref_alias(self, alias): + """ Increases the reference count for this alias. """ + self.alias_map[alias][ALIAS_REFCOUNT] += 1 + + def unref_alias(self, alias): + """ Decreases the reference count for this alias. """ + self.alias_map[alias][ALIAS_REFCOUNT] -= 1 + + def join(self, (lhs, table, lhs_col, col), always_create=False, + exclusions=(), promote=False): + """ + Returns an alias for a join between 'table' and 'lhs' on the given + columns, either reusing an existing alias for that join or creating a + new one. + + 'lhs' is either an existing table alias or a table name. If + 'always_create' is True, a new alias is always created, regardless of + whether one already exists or not. + + If 'exclusions' is specified, it is something satisfying the container + protocol ("foo in exclusions" must work) and specifies a list of + aliases that should not be returned, even if they satisfy the join. + + If 'promote' is True, the join type for the alias will be LOUTER (if + the alias previously existed, the join type will be promoted from INNER + to LOUTER, if necessary). + """ + if lhs not in self.alias_map: + lhs_table = lhs + is_table = (lhs is not None) + else: + lhs_table = self.alias_map[lhs][ALIAS_TABLE] + is_table = False + t_ident = (lhs_table, table, lhs_col, col) + aliases = self.join_map.get(t_ident) + if aliases and not always_create: + for alias in aliases: + if alias not in exclusions: + self.ref_alias(alias) + if promote: + self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = \ + self.LOUTER + return alias + # If we get to here (no non-excluded alias exists), we'll fall + # through to creating a new alias. + + # No reuse is possible, so we need a new alias. + assert not is_table, \ + "Must pass in lhs alias when creating a new join." + alias, _ = self.table_alias(table, True) + join_type = promote and self.LOUTER or self.INNER + join = [table, alias, join_type, lhs, lhs_col, col] + if not lhs: + # Not all tables need to be joined to anything. No join type + # means the later columns are ignored. + join[JOIN_TYPE] = None + self.alias_map[alias][ALIAS_JOIN] = join + self.join_map.setdefault(t_ident, []).append(alias) + self.rev_join_map[alias] = t_ident + return alias + + def fill_table_cache(self, opts=None, root_alias=None, cur_depth=0, + used=None): + """ + Fill in the information needed for a select_related query. + """ + if self.max_depth and cur_depth > self.max_depth: + # We've recursed too deeply; bail out. + return + if not opts: + opts = self.model._meta + root_alias = self.tables[0] + self.select.extend([(root_alias, f) for f in opts.fields]) + if not used: + used = [] + + for f in opts.fields: + if not f.rel or f.null: + continue + table = f.rel.to._meta.db_table + alias = self.join((root_alias, table, f.column, + f.rel.get_related_field().column), exclusion=used) + used.append(alias) + self.select.extend([(table, f2.column) + for f2 in f.rel.to._meta.fields]) + self.fill_table_cache(f.rel.to._meta, alias, cur_depth + 1, used) + + def add_filter(self, filter_expr, connection=AND, negate=False): + """ + Add a single filter to the query. + """ + arg, value = filter_expr + parts = arg.split(LOOKUP_SEP) + if not parts: + raise TypeError("Cannot parse keyword query %r" % arg) + + # Work out the lookup type and remove it from 'parts', if necessary. + if len(parts) == 1 or parts[-1] not in QUERY_TERMS: + lookup_type = 'exact' + else: + lookup_type = parts.pop() + + # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject all + # uses of None as a query value. + # FIXME: Weren't we going to change this so that '__exact=None' was the + # same as '__isnull=True'? Need to check the conclusion of the mailing + # list thread. + if value is None and lookup_type != 'exact': + raise ValueError("Cannot use None as a query value") + elif callable(value): + value = value() + + opts = self.model._meta + alias = self.join((None, opts.db_table, None, None)) + dupe_multis = (connection == AND) + last = None + + # FIXME: Using enumerate() here is expensive. We only need 'i' to + # check we aren't joining against a non-joinable field. Find a + # better way to do this! + for i, name in enumerate(parts): + joins, opts, orig_field, target_field, target_col = \ + self.get_next_join(name, opts, alias, dupe_multis) + if name == 'pk': + name = target_field.name + if joins is not None: + last = joins + alias = joins[-1] + else: + # Normal field lookup must be the last field in the filter. + if i != len(parts) - 1: + raise TypeError("Joins on field %r not permitted." + % name) + + name = target_col or name + + if target_field is opts.pk and last: + # An optimization: if the final join is against a primary key, + # we can go back one step in the join chain and compare against + # the lhs of the join instead. The result (potentially) involves + # one less table join. + self.unref_alias(alias) + join = self.alias_map[last[-1]][ALIAS_JOIN] + alias = join[LHS_ALIAS] + name = join[LHS_JOIN_COL] + + if (lookup_type == 'isnull' and value is True): + # If the comparison is against NULL, we need to use a left outer + # join when connecting to the previous model. We make that + # adjustment here. We don't do this unless needed because it's less + # efficient at the database level. + self.alias_map[joins[0]][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER + + self.where.add([alias, name, orig_field, lookup_type, value], + connection) + if negate: + self.alias_map[last[0]][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER + self.where.negate() + + def add_q(self, q_object): + """ + Adds a Q-object to the current filter. + + Can also be used to add anything that has an 'add_to_query()' method. + """ + if hasattr(q_object, 'add_to_query'): + # Complex custom objects are responsible for adding themselves. + q_object.add_to_query(self) + return + + for child in q_object.children: + if isinstance(child, tree.Node): + self.where.start_subtree(q_object.connection) + self.add_q(child) + self.where.end_subtree() + else: + self.add_filter(child, q_object.connection, q_object.negated) + + def get_next_join(self, name, opts, root_alias, dupe_multis): + """ + Compute the necessary table joins for the field called 'name'. 'opts' + is the Options class for the current model (which gives the table we + are joining to), root_alias is the alias for the table we are joining + to. If dupe_multis is True, any many-to-many or many-to-one joins will + always create a new alias (necessary for disjunctive filters). + + Returns a list of aliases involved in the join, the next value for + 'opts' and the field class that was matched. For a non-joining field, + the first value (join alias) is None. + """ + if name == 'pk': + name = opts.pk.name + + field = find_field(name, opts.many_to_many, False) + if field: + # Many-to-many field defined on the current model. + remote_opts = field.rel.to._meta + int_alias = self.join((root_alias, field.m2m_db_table(), + opts.pk.column, field.m2m_column_name()), dupe_multis) + far_alias = self.join((int_alias, remote_opts.db_table, + field.m2m_reverse_name(), remote_opts.pk.column), + dupe_multis) + return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, + None) + + field = find_field(name, opts.get_all_related_many_to_many_objects(), + True) + if field: + # Many-to-many field defined on the target model. + remote_opts = field.opts + field = field.field + int_alias = self.join((root_alias, field.m2m_db_table(), + opts.pk.column, field.m2m_reverse_name()), dupe_multis) + far_alias = self.join((int_alias, remote_opts.db_table, + field.m2m_column_name(), remote_opts.pk.column), + dupe_multis) + # XXX: Why is the final component able to be None here? + return ([int_alias, far_alias], remote_opts, field, remote_opts.pk, + None) + + field = find_field(name, opts.get_all_related_objects(), True) + if field: + # One-to-many field (ForeignKey defined on the target model) + remote_opts = field.opts + field = field.field + local_field = opts.get_field(field.rel.field_name) + alias = self.join((root_alias, remote_opts.db_table, + local_field.column, field.column), dupe_multis) + return ([alias], remote_opts, field, field, remote_opts.pk.column) + + + field = find_field(name, opts.fields, False) + if not field: + raise TypeError, \ + ("Cannot resolve keyword '%s' into field. Choices are: %s" + % (name, ", ".join(get_legal_fields(opts)))) + + if field.rel: + # One-to-one or many-to-one field + remote_opts = field.rel.to._meta + alias = self.join((root_alias, remote_opts.db_table, field.column, + field.rel.field_name)) + target = remote_opts.get_field(field.rel.field_name) + return [alias], remote_opts, field, target, target.column + + # Only remaining possibility is a normal (direct lookup) field. No + # join is required. + return None, opts, field, field, None + + def set_limits(self, low=None, high=None): + """ + Adjusts the limits on the rows retrieved. We use low/high to set these, + as it makes it more Pythonic to read and write. When the SQL query is + created, they are converted to the appropriate offset and limit values. + + Any limits passed in here are applied relative to the existing + constraints. So low is added to the current low value and both will be + clamped to any existing high value. + """ + if high: + # None (high_mark's default) is less than any number, so this works. + self.high_mark = max(self.high_mark, high) + if low: + self.low_mark = max(self.high_mark, self.low_mark + low) + + def clear_limits(self): + """ + Clears any existing limits. + """ + self.low_mark, self.high_mark = 0, None + + def add_ordering(self, *ordering): + """ + Adds items from the 'ordering' sequence to the query's "order by" + clause. + """ + self.order_by.extend(ordering) + + def clear_ordering(self): + """ + Removes any ordering settings. + """ + self.order_by = [] + + def can_filter(self): + """ + Returns True if adding filters to this instance is still possible. + + Typically, this means no limits or offsets have been put on the results. + """ + return not (self.low_mark or self.high_mark) + + def add_count_column(self): + """ + Converts the query to do count(*) or count(distinct(pk)) in order to + get its size. + """ + # TODO: When group_by support is added, this needs to be adjusted so + # that it doesn't totally overwrite the select list. + if not self.distinct: + select = Count() + # Distinct handling is now done in Count(), so don't do it at this + # level. + self.distinct = False + else: + select = Count((self.table_map[self.model._meta.db_table][0], + self.model._meta.pk.column), True) + self.select = [select] + self.extra_select = {} + + def execute_sql(self, result_type=MULTI): + """ + Run the query against the database and returns the result(s). The + return value is a single data item if result_type is SINGLE, or an + iterator over the results if the result_type is MULTI. + + result_type is either MULTI (use fetchmany() to retrieve all rows), + SINGLE (only retrieve a single row), or NONE (no results expected). + """ + try: + sql, params = self.as_sql() + except EmptyResultSet: + raise StopIteration + + cursor = self.connection.cursor() + cursor.execute(sql, params) + + if result_type == NONE: + return + + if result_type == SINGLE: + return cursor.fetchone() + + # The MULTI case. + def it(): + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + yield rows + return it() + +class DeleteQuery(Query): + """ + Delete queries are done through this class, since they are more constrained + than general queries. + """ + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + assert len(self.tables) == 1, \ + "Can only delete from one table at a time." + result = ['DELETE FROM %s' % self.tables[0]] + where, params = self.where.as_sql() + result.append('WHERE %s' % where) + return ' '.join(result), tuple(params) + + def do_query(self, table, where): + self.tables = [table] + self.where = where + self.execute_sql(NONE) + + def delete_batch_related(self, pk_list): + """ + Set up and execute delete queries for all the objects related to the + primary key values in pk_list. To delete the objects themselves, use + the delete_batch() method. + + More than one physical query may be executed if there are a + lot of values in pk_list. + """ + cls = self.model + for related in cls._meta.get_all_related_many_to_many_objects(): + if not isinstance(related.field, generic.GenericRelation): + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + where = WhereNode(self) + where.add((None, related.field.m2m_reverse_name(), None, + 'in', + pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), + AND) + self.do_query(related.field.m2m_db_table(), where) + + for f in cls._meta.many_to_many: + w1 = WhereNode(self) + if isinstance(f, generic.GenericRelation): + from django.contrib.contenttypes.models import ContentType + field = f.rel.to._meta.get_field(f.content_type_field_name) + w1.add((None, field.column, field, 'exact', + ContentType.objects.get_for_model(cls).id), AND) + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + where = WhereNode(self) + where.add((None, f.m2m_column_name(), None, 'in', + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), + AND) + if w1: + where.add(w1, AND) + self.do_query(f.m2m_db_table(), where) + + def delete_batch(self, pk_list): + """ + Set up and execute delete queries for all the objects in pk_list. This + should be called after delete_batch_related(), if necessary. + + More than one physical query may be executed if there are a + lot of values in pk_list. + """ + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + where = WhereNode(self) + field = self.model._meta.pk + where.add((None, field.column, field, 'in', + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) + self.do_query(self.model._meta.db_table, where) + +class UpdateQuery(Query): + """ + Represents an "update" SQL query. + """ + def __init__(self, *args, **kwargs): + super(UpdateQuery, self).__init__(*args, **kwargs) + self.values = [] + + def as_sql(self): + """ + Creates the SQL for this query. Returns the SQL string and list of + parameters. + """ + assert len(self.tables) == 1, \ + "Can only update one table at a time." + result = ['UPDATE %s' % self.tables[0]] + result.append('SET') + qn = self.connection.ops.quote_name + values = ['%s = %s' % (qn(v[0]), v[1]) for v in self.values] + result.append(', '.join(values)) + where, params = self.where.as_sql() + result.append('WHERE %s' % where) + return ' '.join(result), tuple(params) + + def do_query(self, table, values, where): + self.tables = [table] + self.values = values + self.where = where + self.execute_sql(NONE) + + def clear_related(self, related_field, pk_list): + """ + Set up and execute an update query that clears related entries for the + keys in pk_list. + + This is used by the QuerySet.delete_objects() method. + """ + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + where = WhereNode() + f = self.model._meta.pk + where.add((None, f, f.db_type(), 'in', + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), + AND) + values = [(related_field.column, 'NULL')] + self.do_query(self.model._meta.db_table, values, where) + +def find_field(name, field_list, related_query): + """ + Finds a field with a specific name in a list of field instances. + Returns None if there are no matches, or several matches. + """ + if related_query: + matches = [f for f in field_list + if f.field.related_query_name() == name] + else: + matches = [f for f in field_list if f.name == name] + if len(matches) != 1: + return None + return matches[0] + +def field_choices(field_list, related_query): + """ + Returns the names of the field objects in field_list. Used to construct + readable error messages. + """ + if related_query: + return [f.field.related_query_name() for f in field_list] + else: + return [f.name for f in field_list] + +def get_legal_fields(opts): + """ + Returns a list of fields that are valid at this point in the query. Used in + error reporting. + """ + return (field_choices(opts.many_to_many, False) + + field_choices( opts.get_all_related_many_to_many_objects(), True) + + field_choices(opts.get_all_related_objects(), True) + + field_choices(opts.fields, False)) + +def orderfield_to_column(name, opts): + """ + For a field name specified in an "order by" clause, returns the database + column name. If 'name' is not a field in the current model, it is returned + unchanged. + """ + try: + return opts.get_field(name, False).column + except FieldDoesNotExist: + return name + diff --git a/django/db/models/sql/utils.py b/django/db/models/sql/utils.py new file mode 100644 index 0000000000..6c6c32ab05 --- /dev/null +++ b/django/db/models/sql/utils.py @@ -0,0 +1,27 @@ +""" +Miscellaneous helper functions. +""" + +import warnings + +from django.utils.encoding import smart_unicode + +# Django currently supports two forms of ordering. +# Form 1 (deprecated) example: +# order_by=(('pub_date', 'DESC'), ('headline', 'ASC'), (None, 'RANDOM')) +# Form 2 (new-style) example: +# order_by=('-pub_date', 'headline', '?') +# Form 1 is deprecated and will no longer be supported for Django's first +# official release. The following code converts from Form 1 to Form 2. + +LEGACY_ORDERING_MAPPING = {'ASC': '_', 'DESC': '-_', 'RANDOM': '?'} + +def handle_legacy_orderlist(order_list): + if not order_list or isinstance(order_list[0], basestring): + return order_list + else: + new_order_list = [LEGACY_ORDERING_MAPPING[j.upper()].replace('_', smart_unicode(i)) for i, j in order_list] + warnings.warn("%r ordering syntax is deprecated. Use %r instead." + % (order_list, new_order_list), DeprecationWarning) + return new_order_list + diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py new file mode 100644 index 0000000000..7c27fc92b3 --- /dev/null +++ b/django/db/models/sql/where.py @@ -0,0 +1,151 @@ +""" +Code to manage the creation and SQL rendering of 'where' constraints. +""" +import datetime + +from django.utils import tree +from datastructures import EmptyResultSet + +# Connection types +AND = 'AND' +OR = 'OR' + +class WhereNode(tree.Node): + """ + Used to represent the SQL where-clause. + + The class is tied to the Query class that created it (in order to create + the corret SQL). + + The children in this tree are usually either Q-like objects or lists of + [table_alias, field_name, field_class, lookup_type, value]. However, a + child could also be any class with as_sql() and relabel_aliases() methods. + """ + default = AND + + def __init__(self, query=None, children=None, connection=None): + super(WhereNode, self).__init__(children, connection) + if query: + # XXX: Would be nice to use a weakref here, but it seems tricky to + # make it work. + self.query = query + + def __deepcopy__(self, memodict): + """ + Used by copy.deepcopy(). + """ + obj = super(WhereNode, self).__deepcopy__(memodict) + obj.query = self.query + return obj + + def as_sql(self, node=None): + """ + Returns the SQL version of the where clause and the value to be + substituted in. Returns None, None if this node is empty. + + If 'node' is provided, that is the root of the SQL generation + (generally not needed except by the internal implementation for + recursion). + """ + if node is None: + node = self + if not node.children: + return None, [] + result = [] + result_params = [] + for child in node.children: + if hasattr(child, 'as_sql'): + sql, params = child.as_sql() + format = '(%s)' + elif isinstance(child, tree.Node): + sql, params = self.as_sql(child) + if child.negated: + format = 'NOT (%s)' + else: + format = '(%s)' + else: + sql = self.make_atom(child) + params = child[2].get_db_prep_lookup(child[3], child[4]) + format = '%s' + result.append(format % sql) + result_params.extend(params) + conn = ' %s ' % node.connection + return conn.join(result), result_params + + def make_atom(self, child): + """ + Turn a tuple (table_alias, field_name, field_class, lookup_type, value) + into valid SQL. + + Returns the string for the SQL fragment. The caller is responsible for + converting the child's value into an appropriate for for the parameters + list. + """ + table_alias, name, field, lookup_type, value = child + conn = self.query.connection + if table_alias: + lhs = '%s.%s' % (table_alias, conn.ops.quote_name(name)) + else: + lhs = conn.ops.quote_name(name) + field_sql = conn.ops.field_cast_sql(field.db_type()) % lhs + + if isinstance(value, datetime.datetime): + # FIXME datetime_cast_sql() should return '%s' by default. + cast_sql = conn.ops.datetime_cast_sql() or '%s' + else: + cast_sql = '%s' + + # FIXME: This is out of place. Move to a function like + # datetime_cast_sql() + if (lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith') + and conn.features.needs_upper_for_iops): + format = 'UPPER(%s) %s' + else: + format = '%s %s' + + if lookup_type in conn.operators: + return format % (field_sql, conn.operators[lookup_type] % cast_sql) + + if lookup_type == 'in': + if not value: + raise EmptyResultSet + return '%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))) + elif lookup_type in ('range', 'year'): + return '%s BETWEEN %%s and %%s' % field_sql + elif lookup_type in ('month', 'day'): + return '%s = %%s' % conn.ops.date_extract_sql(lookup_type, + field_sql) + elif lookup_type == 'isnull': + return '%s IS %sNULL' % (field_sql, (not value and 'NOT ' or '')) + elif lookup_type in 'search': + return conn.op.fulltest_search_sql(field_sql) + elif lookup_type in ('regex', 'iregex'): + # FIXME: Factor this out in to conn.ops + if settings.DATABASE_ENGINE == 'oracle': + if lookup_type == 'regex': + match_option = 'c' + else: + match_option = 'i' + return "REGEXP_LIKE(%s, %s, '%s')" % (field_sql, cast_sql, + match_option) + else: + raise NotImplementedError + + raise TypeError('Invalid lookup_type: %r' % lookup_type) + + def relabel_aliases(self, change_map, node=None): + """ + Relabels the alias values of any children. 'change_map' is a dictionary + mapping old (current) alias values to the new values. + """ + if not node: + node = self + for child in node.children: + if hasattr(child, 'relabel_aliases'): + child.relabel_aliases(change_map) + elif isinstance(child, tree.Node): + self.relabel_aliases(change_map, child) + else: + val = child[0] + child[0] = change_map.get(val, val) +