1
0
mirror of https://github.com/django/django.git synced 2025-07-04 17:59:13 +00:00

[soc2009/multidb] Split SQL construction into two seperate classes, the Query class which stores data about a query being constructed, and a Compiler class which generates SQL.

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11759 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-11-21 07:03:40 +00:00
parent e9e73c4b68
commit 4f40925785
34 changed files with 1231 additions and 1219 deletions

43
TODO Normal file
View File

@ -0,0 +1,43 @@
Django Multiple Database TODO List
==================================
Required for v1.2
~~~~~~~~~~~~~~~~~
* Finalize the sql.Query internals
* Clean up the use of db.backend.query_class()
* Verify it still works with GeoDjango
* Resolve internal uses of multidb interface
* Update database backend for session store to use Multidb
* Check default Site creation behavior
* Resolve the public facing UI issues around using multi-db
* Should we take the opportunity to modify DB backends to use fully qualified paths?
* Meta.using? Is is still required/desirable?
* syncdb
* Add --exclude/--include argument? (not sure this approach will work due to flush)
* Flush - which models are flushed?
* Fixture loading over multiple DBs
* Testing infrastructure
* Most tests don't need multidb. Some absolutely require it, but only to prove you
can write to a different db. Second DB could be a SQLite temp file. Need to have
test infrastructure to allow creation of the temp database.
* Cleanup of new API entry points
* validate() on a field
* name/purpose clash with Honza?
* any overlap with existing methods?
* Accessing _using in BaseModelFormSet.
Optional for v1.2
~~~~~~~~~~~~~~~~~
These are the next layer of UI. We can deliver for v1.2 without these if neccessary.
* Technique for determining using() at runtime (by callback?)
* Sticky models
* Related objects
* saving and deleting
* default or an option
* Sample docs for how to do:
* master/slave
* Sharding
* Test protection against cross-database joins.

View File

@ -98,7 +98,7 @@ class GeoQuery(sql.Query):
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(quote_func=qn))
result.append(col.as_sql(qn=qn))
if hasattr(col, 'alias'):
aliases.add(col.alias)
@ -112,7 +112,7 @@ class GeoQuery(sql.Query):
result.extend([
'%s%s' % (
self.get_extra_select_format(alias) % aggregate.as_sql(quote_func=qn),
self.get_extra_select_format(alias) % aggregate.as_sql(qn=qn, connection=self.connection),
alias is not None and ' AS %s' % alias or ''
)
for alias, aggregate in self.aggregate_select.items()

View File

@ -78,7 +78,7 @@ class GeoWhereNode(WhereNode):
annotation = GeoAnnotation(field, value, where)
return super(WhereNode, self).add(((obj.alias, col, field.db_type()), lookup_type, annotation, params), connector)
def make_atom(self, child, qn):
def make_atom(self, child, qn, connection):
obj, lookup_type, value_annot, params = child
if isinstance(value_annot, GeoAnnotation):
@ -94,7 +94,7 @@ class GeoWhereNode(WhereNode):
else:
# If not a GeometryField, call the `make_atom` from the
# base class.
return super(GeoWhereNode, self).make_atom(child, qn)
return super(GeoWhereNode, self).make_atom(child, qn, connection)
@classmethod
def _check_geo_field(cls, opts, lookup):

View File

@ -279,11 +279,11 @@ class RelatedGeoModelTest(unittest.TestCase):
def test14_collect(self):
"Testing the `collect` GeoQuerySet method and `Collect` aggregate."
# Reference query:
# SELECT AsText(ST_Collect("relatedapp_location"."point")) FROM "relatedapp_city" LEFT OUTER JOIN
# "relatedapp_location" ON ("relatedapp_city"."location_id" = "relatedapp_location"."id")
# SELECT AsText(ST_Collect("relatedapp_location"."point")) FROM "relatedapp_city" LEFT OUTER JOIN
# "relatedapp_location" ON ("relatedapp_city"."location_id" = "relatedapp_location"."id")
# WHERE "relatedapp_city"."state" = 'TX';
ref_geom = fromstr('MULTIPOINT(-97.516111 33.058333,-96.801611 32.782057,-95.363151 29.763374,-96.801611 32.782057)')
c1 = City.objects.filter(state='TX').collect(field_name='location__point')
c2 = City.objects.filter(state='TX').aggregate(Collect('location__point'))['location__point__collect']
@ -293,6 +293,7 @@ class RelatedGeoModelTest(unittest.TestCase):
self.assertEqual(4, len(coll))
self.assertEqual(ref_geom, coll)
# TODO: Related tests for KML, GML, and distance lookups.
def suite():

View File

@ -5,21 +5,17 @@ from django.conf import settings
from django.core.management.base import NoArgsCommand
from django.core.management.color import no_style
from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal
from django.db import connections, transaction, models
from django.db import connections, transaction, models, DEFAULT_DB_ALIAS
from django.utils.importlib import import_module
try:
set
except NameError:
from sets import Set as set # Python 2.3 fallback
class Command(NoArgsCommand):
option_list = NoArgsCommand.option_list + (
make_option('--noinput', action='store_false', dest='interactive', default=True,
help='Tells Django to NOT prompt the user for input of any kind.'),
make_option('--database', action='store', dest='database',
default='', help='Nominates a database to sync. Defaults to the '
'"default" database.'),
default=DEFAULT_DB_ALIAS, help='Nominates a database to sync. '
'Defaults to the "default" database.'),
)
help = "Create the database tables for all apps in INSTALLED_APPS whose tables haven't already been created."
@ -30,8 +26,6 @@ class Command(NoArgsCommand):
show_traceback = options.get('traceback', False)
self.style = no_style()
connection = connections[options["database"]]
# Import the 'management' module within each installed app, to register
# dispatcher events.
@ -52,6 +46,8 @@ class Command(NoArgsCommand):
if not msg.startswith('No module named') or 'management' not in msg:
raise
db = options['database']
connection = connections[db]
cursor = connection.cursor()
# Get a list of already installed *models* so that references work right.
@ -88,11 +84,11 @@ class Command(NoArgsCommand):
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
# Send the post_syncdb signal, so individual apps can do whatever they need
# to do at this point.
emit_post_sync_signal(created_models, verbosity, interactive)
emit_post_sync_signal(created_models, verbosity, interactive, db)
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
@ -103,7 +99,7 @@ class Command(NoArgsCommand):
app_name = app.__name__.split('.')[-2]
for model in models.get_models(app):
if model in created_models:
custom_sql = custom_sql_for_model(model, self.style)
custom_sql = custom_sql_for_model(model, self.style, connection)
if custom_sql:
if verbosity >= 1:
print "Installing custom SQL for %s.%s model" % (app_name, model._meta.object_name)
@ -116,9 +112,9 @@ class Command(NoArgsCommand):
if show_traceback:
import traceback
traceback.print_exc()
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=db)
else:
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
else:
if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
@ -137,28 +133,9 @@ class Command(NoArgsCommand):
except Exception, e:
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e))
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=db)
else:
if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
# Install SQL indicies for all newly created models
for app in models.get_apps():
app_name = app.__name__.split('.')[-2]
for model in models.get_models(app):
if model in created_models:
index_sql = connection.creation.sql_indexes_for_model(model, self.style)
if index_sql:
if verbosity >= 1:
print "Installing index for %s.%s model" % (app_name, model._meta.object_name)
try:
for sql in index_sql:
cursor.execute(sql)
except Exception, e:
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e))
transaction.rollback_unless_managed(using=db)
else:
transaction.commit_unless_managed(using=db)
transaction.commit_unless_managed(using=db)
from django.core.management import call_command
call_command('loaddata', 'initial_data', verbosity=verbosity, database=db)
from django.core.management import call_command
call_command('loaddata', 'initial_data', verbosity=verbosity, database=db)

View File

@ -18,12 +18,14 @@ except ImportError:
from django.db.backends import util
from django.utils import datetime_safe
from django.utils.importlib import import_module
class BaseDatabaseWrapper(local):
"""
Represents a database connection.
"""
ops = None
def __init__(self, settings_dict):
# `settings_dict` should be a dictionary containing keys such as
# DATABASE_NAME, DATABASE_USER, etc. It's called `settings_dict`
@ -114,8 +116,9 @@ class BaseDatabaseOperations(object):
a backend performs ordering or calculates the ID of a recently-inserted
row.
"""
compiler_module = "django.db.models.sql.compiler"
def __init__(self):
# this cache is used for backends that provide custom Queyr classes
self._cache = {}
def autoinc_sql(self, table, column):
@ -280,15 +283,17 @@ class BaseDatabaseOperations(object):
"""
pass
def query_class(self, DefaultQueryClass, subclass=None):
def compiler(self, compiler_name):
"""
Given the default Query class, returns a custom Query class
to use for this backend. Returns the Query class unmodified if the
backend doesn't need a custom Query clsas.
"""
if subclass is not None:
return subclass
return DefaultQueryClass
if compiler_name not in self._cache:
self._cache[compiler_name] = getattr(
import_module(self.compiler_module), compiler_name
)
return self._cache[compiler_name]
def quote_name(self, name):
"""

View File

@ -316,7 +316,7 @@ class BaseDatabaseCreation(object):
output.append(ds)
return output
def create_test_db(self, verbosity=1, autoclobber=False, alias=''):
def create_test_db(self, verbosity=1, autoclobber=False, alias=None):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.

View File

@ -26,7 +26,6 @@ except ImportError, e:
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.oracle import query
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
@ -47,13 +46,13 @@ else:
class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = ()
needs_datetime_string_cast = False
uses_custom_query_class = True
interprets_empty_strings_as_nulls = True
uses_savepoints = True
can_return_id_from_insert = True
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.oracle.compiler"
def autoinc_sql(self, table, column):
# To simulate auto-incrementing primary keys in Oracle, we have to
@ -102,6 +101,54 @@ WHEN (new.%(col_name)s IS NULL)
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type)
return sql
def convert_values(self, value, field):
if isinstance(value, Database.LOB):
value = value.read()
if field and field.get_internal_type() == 'TextField':
value = force_unicode(value)
# Oracle stores empty strings as null. We need to undo this in
# order to adhere to the Django convention of using the empty
# string instead of null, but only if the field accepts the
# empty string.
if value is None and field and field.empty_strings_allowed:
value = u''
# Convert 1 or 0 to True or False
elif value in (1, 0) and field and field.get_internal_type() in ('BooleanField', 'NullBooleanField'):
value = bool(value)
# Force floats to the correct type
elif value is not None and field and field.get_internal_type() == 'FloatField':
value = float(value)
# Convert floats to decimals
elif value is not None and field and field.get_internal_type() == 'DecimalField':
value = util.typecast_decimal(field.format_number(value))
# cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime. We use the type
# of the Field to determine which to cast to, but it's not
# always available.
# As a workaround, we cast to date if all the time-related
# values are 0, or to time if the date is 1/1/1900.
# This could be cleaned a bit by adding a method to the Field
# classes to normalize values from the database (the to_python
# method is used for validation and isn't what we want here).
elif isinstance(value, Database.Timestamp):
# In Python 2.3, the cx_Oracle driver returns its own
# Timestamp object that we must convert to a datetime class.
if not isinstance(value, datetime.datetime):
value = datetime.datetime(value.year, value.month,
value.day, value.hour, value.minute, value.second,
value.fsecond)
if field and field.get_internal_type() == 'DateTimeField':
pass
elif field and field.get_internal_type() == 'DateField':
value = value.date()
elif field and field.get_internal_type() == 'TimeField' or (value.year == 1900 and value.month == value.day == 1):
value = value.time()
elif value.hour == value.minute == value.second == value.microsecond == 0:
value = value.date()
return value
def datetime_cast_sql(self):
return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')"
@ -141,15 +188,6 @@ WHEN (new.%(col_name)s IS NULL)
return u''
return force_unicode(value.read())
def query_class(self, DefaultQueryClass, subclass=None):
if (DefaultQueryClass, subclass) in self._cache:
return self._cache[DefaultQueryClass, subclass]
Query = query.query_class(DefaultQueryClass, Database)
if subclass is not None:
Query = type('Query', (subclass, Query), {})
self._cache[DefaultQueryClass, subclass] = Query
return Query
def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When
# not quoted, Oracle has case-insensitive behavior for identifiers, but

View File

@ -0,0 +1,66 @@
from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler):
def resolve_columns(self, row, fields=()):
# If this query has limit/offset information, then we expect the
# first column to be an extra "_RN" column that we need to throw
# away.
if self.query.high_mark is not None or self.query.low_mark:
rn_offset = 1
else:
rn_offset = 0
index_start = rn_offset + len(self.query.extra_select.keys())
values = [self.query.convert_values(v, None, connection=self.connection)
for v in row[rn_offset:index_start]]
for value, field in map(None, row[index_start:], fields):
values.append(self.query.convert_values(value, field, connection=self.connection))
return tuple(values)
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list
of parameters. This is overriden from the original Query class
to handle the additional SQL Oracle requires to emulate LIMIT
and OFFSET.
If 'with_limits' is False, any limit/offset information is not
included in the query.
"""
# The `do_offset` flag indicates whether we need to construct
# the SQL needed to use limit/offset with Oracle.
do_offset = with_limits and (self.query.high_mark is not None
or self.query.low_mark)
if not do_offset:
sql, params = super(SQLCompiler, self).as_sql(with_limits=False,
with_col_aliases=with_col_aliases)
else:
sql, params = super(SQLCompiler, self).as_sql(with_limits=False,
with_col_aliases=True)
# Wrap the base query in an outer SELECT * with boundaries on
# the "_RN" column. This is the canonical way to emulate LIMIT
# and OFFSET on Oracle.
high_where = ''
if self.query.high_mark is not None:
high_where = 'WHERE ROWNUM <= %d' % (self.query.high_mark,)
sql = 'SELECT * FROM (SELECT ROWNUM AS "_RN", "_SUB".* FROM (%s) "_SUB" %s) WHERE "_RN" > %d' % (sql, high_where, self.query.low_mark)
return sql, params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
pass

View File

@ -1,140 +0,0 @@
"""
Custom Query class for Oracle.
Derives from: django.db.models.sql.query.Query
"""
import datetime
from django.db.backends import util
from django.utils.encoding import force_unicode
def query_class(QueryClass, Database):
"""
Returns a custom django.db.models.sql.query.Query subclass that is
appropriate for Oracle.
The 'Database' module (cx_Oracle) is passed in here so that all the setup
required to import it only needs to be done by the calling module.
"""
class OracleQuery(QueryClass):
def __reduce__(self):
"""
Enable pickling for this class (normal pickling handling doesn't
work as Python can only pickle module-level classes by default).
"""
if hasattr(QueryClass, '__getstate__'):
assert hasattr(QueryClass, '__setstate__')
data = self.__getstate__()
else:
data = self.__dict__
return (unpickle_query_class, (QueryClass,), data)
def resolve_columns(self, row, fields=()):
# If this query has limit/offset information, then we expect the
# first column to be an extra "_RN" column that we need to throw
# away.
if self.high_mark is not None or self.low_mark:
rn_offset = 1
else:
rn_offset = 0
index_start = rn_offset + len(self.extra_select.keys())
values = [self.convert_values(v, None)
for v in row[rn_offset:index_start]]
for value, field in map(None, row[index_start:], fields):
values.append(self.convert_values(value, field))
return tuple(values)
def convert_values(self, value, field):
if isinstance(value, Database.LOB):
value = value.read()
if field and field.get_internal_type() == 'TextField':
value = force_unicode(value)
# Oracle stores empty strings as null. We need to undo this in
# order to adhere to the Django convention of using the empty
# string instead of null, but only if the field accepts the
# empty string.
if value is None and field and field.empty_strings_allowed:
value = u''
# Convert 1 or 0 to True or False
elif value in (1, 0) and field and field.get_internal_type() in ('BooleanField', 'NullBooleanField'):
value = bool(value)
# Force floats to the correct type
elif value is not None and field and field.get_internal_type() == 'FloatField':
value = float(value)
# Convert floats to decimals
elif value is not None and field and field.get_internal_type() == 'DecimalField':
value = util.typecast_decimal(field.format_number(value))
# cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime. We use the type
# of the Field to determine which to cast to, but it's not
# always available.
# As a workaround, we cast to date if all the time-related
# values are 0, or to time if the date is 1/1/1900.
# This could be cleaned a bit by adding a method to the Field
# classes to normalize values from the database (the to_python
# method is used for validation and isn't what we want here).
elif isinstance(value, Database.Timestamp):
# In Python 2.3, the cx_Oracle driver returns its own
# Timestamp object that we must convert to a datetime class.
if not isinstance(value, datetime.datetime):
value = datetime.datetime(value.year, value.month,
value.day, value.hour, value.minute, value.second,
value.fsecond)
if field and field.get_internal_type() == 'DateTimeField':
pass
elif field and field.get_internal_type() == 'DateField':
value = value.date()
elif field and field.get_internal_type() == 'TimeField' or (value.year == 1900 and value.month == value.day == 1):
value = value.time()
elif value.hour == value.minute == value.second == value.microsecond == 0:
value = value.date()
return value
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list
of parameters. This is overriden from the original Query class
to handle the additional SQL Oracle requires to emulate LIMIT
and OFFSET.
If 'with_limits' is False, any limit/offset information is not
included in the query.
"""
# The `do_offset` flag indicates whether we need to construct
# the SQL needed to use limit/offset with Oracle.
do_offset = with_limits and (self.high_mark is not None
or self.low_mark)
if not do_offset:
sql, params = super(OracleQuery, self).as_sql(with_limits=False,
with_col_aliases=with_col_aliases)
else:
sql, params = super(OracleQuery, self).as_sql(with_limits=False,
with_col_aliases=True)
# Wrap the base query in an outer SELECT * with boundaries on
# the "_RN" column. This is the canonical way to emulate LIMIT
# and OFFSET on Oracle.
high_where = ''
if self.high_mark is not None:
high_where = 'WHERE ROWNUM <= %d' % (self.high_mark,)
sql = 'SELECT * FROM (SELECT ROWNUM AS "_RN", "_SUB".* FROM (%s) "_SUB" %s) WHERE "_RN" > %d' % (sql, high_where, self.low_mark)
return sql, params
return OracleQuery
def unpickle_query_class(QueryClass):
"""
Utility function, called by Python's unpickling machinery, that handles
unpickling of Oracle Query subclasses.
"""
# XXX: Would be nice to not have any dependency on cx_Oracle here. Since
# modules can't be pickled, we need a way to know to load the right module.
import cx_Oracle
klass = query_class(QueryClass, cx_Oracle)
return klass.__new__(klass)
unpickle_query_class.__safe_for_unpickling__ = True

View File

@ -7,6 +7,7 @@ from django.db.backends import BaseDatabaseOperations
class DatabaseOperations(BaseDatabaseOperations):
def __init__(self, connection):
super(DatabaseOperations, self).__init__()
self._postgres_version = None
self.connection = connection

View File

@ -43,9 +43,6 @@ class Aggregate(object):
"""
klass = getattr(query.aggregates_module, self.name)
aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
# Validate that the backend has a fully supported, correct
# implementation of this aggregate
query.connection.ops.check_aggregate_support(aggregate)
query.aggregates[alias] = aggregate
class Avg(Aggregate):

View File

@ -201,9 +201,9 @@ class Field(object):
if hasattr(value, 'relabel_aliases'):
return value
if hasattr(value, 'as_sql'):
sql, params = value.as_sql(connection)
sql, params = value.as_sql()
else:
sql, params = value._as_sql(connection)
sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)

View File

@ -145,15 +145,18 @@ class RelatedField(object):
v = v[0]
return v
if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabel_aliases method, it will need to
# be invoked before the final SQL is evaluated
if hasattr(value, 'relabel_aliases'):
return value
if hasattr(value, 'as_sql'):
sql, params = value.as_sql(connection)
sql, params = value.as_sql()
else:
sql, params = value._as_sql(connection)
sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
# FIXME: lt and gt are explicitally allowed to make

View File

@ -34,12 +34,11 @@ class QuerySet(object):
using = None
using = using or DEFAULT_DB_ALIAS
connection = connections[using]
self.query = query or connection.ops.query_class(sql.Query)(self.model, connection)
self.query = query or sql.Query(self.model)
self._result_cache = None
self._iter = None
self._sticky_filter = False
self._using = (query and
connections.alias_for_connection(self.query.connection) or using)
self._using = using
########################
# PYTHON MAGIC METHODS #
@ -237,8 +236,9 @@ class QuerySet(object):
else:
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
for row in self.query.results_iter():
compiler = self.query.get_compiler(using=self._using)
for row in compiler.results_iter():
if fill_cache:
obj, _ = get_cached_row(self.model, row,
index_start, max_depth,
@ -279,7 +279,7 @@ class QuerySet(object):
query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=True)
return query.get_aggregation()
return query.get_aggregation(using=self._using)
def count(self):
"""
@ -292,7 +292,7 @@ class QuerySet(object):
if self._result_cache is not None and not self._iter:
return len(self._result_cache)
return self.query.get_count()
return self.query.get_count(using=self._using)
def get(self, *args, **kwargs):
"""
@ -420,7 +420,7 @@ class QuerySet(object):
else:
forced_managed = False
try:
rows = query.execute_sql(None)
rows = query.get_compiler(self._using).execute_sql(None)
if forced_managed:
transaction.commit(using=self._using)
else:
@ -444,12 +444,12 @@ class QuerySet(object):
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
return query.execute_sql(None)
return query.get_compiler(self._using).execute_sql(None)
_update.alters_data = True
def exists(self):
if self._result_cache is None:
return self.query.has_results()
return self.query.has_results(using=self._using)
return bool(self._result_cache)
##################################################
@ -662,16 +662,6 @@ class QuerySet(object):
"""
clone = self._clone()
clone._using = alias
connection = connections[alias]
clone.query.set_connection(connection)
cls = clone.query.get_query_class()
if cls is sql.Query:
subclass = None
else:
subclass = cls
clone.query.__class__ = connection.ops.query_class(
sql.Query, subclass
)
return clone
###################################
@ -757,8 +747,8 @@ class QuerySet(object):
Returns the internal query's SQL and parameters (as a tuple).
"""
obj = self.values("pk")
if connection == obj.query.connection:
return obj.query.as_nested_sql()
if connection == connections[obj._using]:
return obj.query.get_compiler(connection=connection).as_nested_sql()
raise ValueError("Can't do subqueries with queries on different DBs.")
def _validate(self):
@ -789,7 +779,7 @@ class ValuesQuerySet(QuerySet):
names = extra_names + field_names + aggregate_names
for row in self.query.results_iter():
for row in self.query.get_compiler(self._using).results_iter():
yield dict(zip(names, row))
def _setup_query(self):
@ -886,8 +876,8 @@ class ValuesQuerySet(QuerySet):
% self.__class__.__name__)
obj = self._clone()
if connection == obj.query.connection:
return obj.query.as_nested_sql()
if connection == connections[obj._using]:
return obj.query.get_compiler(connection=connection).as_nested_sql()
raise ValueError("Can't do subqueries with queries on different DBs.")
def _validate(self):
@ -904,10 +894,10 @@ class ValuesQuerySet(QuerySet):
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.flat and len(self._fields) == 1:
for row in self.query.results_iter():
for row in self.query.get_compiler(self._using).results_iter():
yield row[0]
elif not self.query.extra_select and not self.query.aggregate_select:
for row in self.query.results_iter():
for row in self.query.get_compiler(self._using).results_iter():
yield tuple(row)
else:
# When extra(select=...) or an annotation is involved, the extra
@ -926,7 +916,7 @@ class ValuesListQuerySet(ValuesQuerySet):
else:
fields = names
for row in self.query.results_iter():
for row in self.query.get_compiler(self._using).results_iter():
data = dict(zip(names, row))
yield tuple([data[f] for f in fields])
@ -938,7 +928,7 @@ class ValuesListQuerySet(ValuesQuerySet):
class DateQuerySet(QuerySet):
def iterator(self):
return self.query.results_iter()
return self.query.get_compiler(self._using).results_iter()
def _setup_query(self):
"""
@ -948,10 +938,7 @@ class DateQuerySet(QuerySet):
instance.
"""
self.query.clear_deferred_loading()
self.query = self.query.clone(
klass=self.query.connection.ops.query_class(sql.Query, sql.DateQuery),
setup=True
)
self.query = self.query.clone(klass=sql.DateQuery, setup=True)
self.query.select = []
field = self.model._meta.get_field(self._field_name, many_to_many=False)
assert isinstance(field, DateField), "%r isn't a DateField." \
@ -1089,19 +1076,18 @@ def delete_objects(seen_objs, using):
signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items]
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch_related(pk_list)
del_query = sql.DeleteQuery(cls)
del_query.delete_batch_related(pk_list, using=using)
update_query = connection.ops.query_class(sql.Query, sql.UpdateQuery)(cls, connection)
update_query = sql.UpdateQuery(cls)
for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)):
if model:
connection.ops.query_class(sql.Query, sql.UpdateQuery)(model, connection).clear_related(field,
pk_list)
sql.UpdateQuery(model).clear_related(field, pk_list, using=using)
else:
update_query.clear_related(field, pk_list)
update_query.clear_related(field, pk_list, using=using)
# Now delete the actual data.
for cls in ordered_classes:
@ -1109,8 +1095,8 @@ def delete_objects(seen_objs, using):
items.reverse()
pk_list = [pk for pk,instance in items]
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
del_query.delete_batch(pk_list)
del_query = sql.DeleteQuery(cls)
del_query.delete_batch(pk_list, using=using)
# Last cleanup; set NULLs where there once was a reference to the
# object, NULL the primary key of the found objects, and perform
@ -1139,7 +1125,7 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
connection = connections[using]
query = connection.ops.query_class(sql.Query, sql.InsertQuery)(model, connection)
query = sql.InsertQuery(model)
query.insert_values(values, raw_values)
return query.execute_sql(return_id)
compiler = query.get_compiler(using=using)
return compiler.execute_sql(return_id)

View File

@ -0,0 +1,902 @@
from django.core.exceptions import FieldError
from django.db import connections
from django.db.backends.util import truncate_name
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_proxied_model, get_order_dir, \
select_related_descend, Query
class SQLCompiler(object):
def __init__(self, query, connection, using):
self.query = query
self.connection = connection
self.using = using
self.quote_cache = {}
# Check that the compiler will be able to execute the query
for alias, aggregate in self.query.aggregate_select.items():
self.connection.ops.check_aggregate_support(aggregate)
def pre_sql_setup(self):
"""
Does any necessary class setup immediately prior to producing SQL. This
is for things that can't necessarily be done in __init__ because we
might not have all the pieces in place at that time.
"""
if not self.query.tables:
self.query.join((None, self.query.model._meta.db_table, None, None))
if (not self.query.select and self.query.default_cols and not
self.query.included_inherited_models):
self.query.setup_inherited_models()
if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections()
def quote_name_unless_alias(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
quoted strings specially (e.g. PostgreSQL).
"""
if name in self.quote_cache:
return self.quote_cache[name]
if ((name in self.query.alias_map and name not in self.query.table_map) or
name in self.query.extra_select):
self.quote_cache[name] = name
return name
r = self.connection.ops.quote_name(name)
self.quote_cache[name] = r
return r
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
If 'with_limits' is False, any limit/offset information is not included
in the query.
"""
self.pre_sql_setup()
out_cols = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering()
# This must come after 'select' and 'ordering' -- see docstring of
# get_from_clause() for details.
from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
params = []
for val in self.query.extra_select.itervalues():
params.extend(val[1])
result = ['SELECT']
if self.query.distinct:
result.append('DISTINCT')
result.append(', '.join(out_cols + self.query.ordering_aliases))
result.append('FROM')
result.extend(from_)
params.extend(f_params)
if where:
result.append('WHERE %s' % where)
params.extend(w_params)
if self.query.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND '.join(self.query.extra_where))
grouping, gb_params = self.get_grouping()
if grouping:
if ordering:
# If the backend can't group by PK (i.e., any database
# other than MySQL), then any fields mentioned in the
# ordering clause needs to be in the group by clause.
if not self.connection.features.allows_group_by_pk:
for col, col_params in ordering_group_by:
if col not in grouping:
grouping.append(str(col))
gb_params.extend(col_params)
else:
ordering = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))
params.extend(gb_params)
if having:
result.append('HAVING %s' % having)
params.extend(h_params)
if ordering:
result.append('ORDER BY %s' % ', '.join(ordering))
if with_limits:
if self.query.high_mark is not None:
result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark))
if self.query.low_mark:
if self.query.high_mark is None:
val = self.connection.ops.no_limit_value()
if val:
result.append('LIMIT %d' % val)
result.append('OFFSET %d' % self.query.low_mark)
params.extend(self.query.extra_params)
return ' '.join(result), tuple(params)
def as_nested_sql(self):
"""
Perform the same functionality as the as_sql() method, returning an
SQL string and parameters. However, the alias prefixes are bumped
beforehand (in a copy -- the current query isn't changed) and any
ordering is removed.
Used when nesting this query inside another.
"""
obj = self.query.clone()
obj.clear_ordering(True)
obj.bump_prefix()
return obj.get_compiler(connection=self.connection).as_sql()
def get_columns(self, with_aliases=False):
"""
Returns the list of columns to use in the select statement. If no
columns have been specified, returns all columns relating to fields in
the model.
If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()]
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
else:
col_aliases = set()
if self.query.select:
only_load = self.deferred_to_columns()
for col in self.query.select:
if isinstance(col, (list, tuple)):
alias, column = col
table = self.query.alias_map[alias][TABLE_NAME]
if table in only_load and col not in only_load[table]:
continue
r = '%s.%s' % (qn(alias), qn(column))
if with_aliases:
if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(r)
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(qn, self.connection))
if hasattr(col, 'alias'):
aliases.add(col.alias)
col_aliases.add(col.alias)
elif self.query.default_cols:
cols, new_aliases = self.get_default_columns(with_aliases,
col_aliases)
result.extend(cols)
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
result.extend([
'%s%s' % (
aggregate.as_sql(qn, self.connection),
alias is not None
and ' AS %s' % qn(truncate_name(alias, max_name_length))
or ''
)
for alias, aggregate in self.query.aggregate_select.items()
])
for table, col in self.query.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
if with_aliases and col in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append(r)
aliases.add(r)
col_aliases.add(col)
self._select_aliases = aliases
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
"""
result = []
if opts is None:
opts = self.query.model._meta
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
# Skip all proxy to the root proxied model
proxied_model = get_proxied_model(opts)
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
if start_alias:
try:
alias = seen[model]
except KeyError:
if model is proxied_model:
alias = start_alias
else:
link_field = opts.get_ancestor_link(model)
alias = self.query.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.query.included_inherited_models[model]
table = self.query.alias_map[alias][TABLE_NAME]
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs:
result.append((alias, field.column))
aliases.add(alias)
continue
if with_aliases and field.column in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s.%s AS %s' % (qn(alias),
qn2(field.column), c_alias))
col_aliases.add(c_alias)
aliases.add(c_alias)
else:
r = '%s.%s' % (qn(alias), qn2(field.column))
result.append(r)
aliases.add(r)
if with_aliases:
col_aliases.add(field.column)
return result, aliases
def get_ordering(self):
"""
Returns a tuple containing a list representing the SQL elements in the
"order by" clause, and the list of SQL elements that need to be added
to the GROUP BY clause as a result of the ordering.
Also sets the ordering_aliases attribute on this instance to a list of
extra aliases needed in the select.
Determining the ordering SQL can change the tables we need to include,
so this should be run *before* get_from_clause().
"""
if self.query.extra_order_by:
ordering = self.query.extra_order_by
elif not self.query.default_ordering:
ordering = self.query.order_by
else:
ordering = self.query.order_by or self.query.model._meta.ordering
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
distinct = self.query.distinct
select_aliases = self._select_aliases
result = []
group_by = []
ordering_aliases = []
if self.query.standard_ordering:
asc, desc = ORDER_DIR['ASC']
else:
asc, desc = ORDER_DIR['DESC']
# It's possible, due to model inheritance, that normal usage might try
# to include the same field more than once in the ordering. We track
# the table/column pairs we use and discard any after the first use.
processed_pairs = set()
for field in ordering:
if field == '?':
result.append(self.connection.ops.random_function_sql())
continue
if isinstance(field, int):
if field < 0:
order = desc
field = -field
else:
order = asc
result.append('%s %s' % (field, order))
group_by.append((field, []))
continue
col, order = get_order_dir(field, asc)
if col in self.query.aggregate_select:
result.append('%s %s' % (col, order))
continue
if '.' in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
table, col = col.split('.', 1)
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), col)
processed_pairs.add((table, col))
if not distinct or elt in select_aliases:
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
elif get_order_dir(field)[0] not in self.query.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
for table, col, order in self.find_ordering_name(field,
self.query.model._meta, default_order=asc):
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), qn2(col))
processed_pairs.add((table, col))
if distinct and elt not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
else:
elt = qn2(col)
if distinct and col not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append(self.query.extra_select[col])
self.query.ordering_aliases = ordering_aliases
return result, group_by
def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
already_seen=None):
"""
Returns the table alias (the name might be ambiguous, the alias will
not be) and column name for ordering by the given 'name' parameter.
The 'name' is of the form 'field1__field2__...__fieldN'.
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
if not alias:
alias = self.query.get_initial_alias()
field, target, opts, joins, last, extra = self.query.setup_joins(pieces,
opts, alias, False)
alias = joins[-1]
col = target.column
if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on
# the model.
self.query.ref_alias(alias)
# Must use left outer joins for nullable fields and their relations.
self.query.promote_alias_chain(joins,
self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
if field.rel and len(joins) > 1 and opts.ordering:
# Firstly, avoid infinite loops.
if not already_seen:
already_seen = set()
join_tuple = tuple([self.query.alias_map[j][TABLE_NAME] for j in joins])
if join_tuple in already_seen:
raise FieldError('Infinite loop caused by ordering.')
already_seen.add(join_tuple)
results = []
for item in opts.ordering:
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
if alias:
# We have to do the same "final join" optimisation as in
# add_filter, since the final column might not otherwise be part of
# the select set (so we can't order on it).
while 1:
join = self.query.alias_map[alias]
if col != join[RHS_JOIN_COL]:
break
self.query.unref_alias(alias)
alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL]
return [(alias, col, order)]
def get_from_clause(self):
"""
Returns a list of strings that are joined together to go after the
"FROM" part of the query, as well as a list any extra parameters that
need to be included. Sub-classes, can override this to create a
from-clause via a "select".
This should only be called after any SQL construction methods that
might change the tables we need. This means the select columns and
ordering must be done first.
"""
result = []
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
first = True
for alias in self.query.tables:
if not self.query.alias_refcount[alias]:
continue
try:
name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
result.append('%s %s%s ON (%s.%s = %s.%s)'
% (join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col)))
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
first = False
for t in self.query.extra_tables:
alias, unused = self.query.table_alias(t)
# Only add the alias if it's not already present (the table_alias()
# calls increments the refcount, so an alias refcount of one means
# this is the only reference.
if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
connector = not first and ', ' or ''
result.append('%s%s' % (connector, qn(alias)))
first = False
return result, []
def get_grouping(self):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
qn = self.quote_name_unless_alias
result, params = [], []
if self.query.group_by is not None:
if len(self.query.model._meta.fields) == len(self.query.select) and \
self.connection.features.allows_group_by_pk:
self.query.group_by = [(self.query.model._meta.db_table, self.query.model._meta.pk.column)]
group_by = self.query.group_by or []
extra_selects = []
for extra_select, extra_params in self.query.extra_select.itervalues():
extra_selects.append(extra_select)
params.extend(extra_params)
for col in group_by + self.query.related_select_cols + extra_selects:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
elif hasattr(col, 'as_sql'):
result.append(col.as_sql(qn))
else:
result.append(str(col))
return result, params
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None,
dupe_set=None, avoid_set=None):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
(for example, cur_depth=1 means we are looking at models with direct
connections to the root model).
"""
if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
# We've recursed far enough; bail out.
return
if not opts:
opts = self.query.get_meta()
root_alias = self.query.get_initial_alias()
self.query.related_select_cols = []
self.query.related_select_fields = []
if not used:
used = set()
if dupe_set is None:
dupe_set = set()
if avoid_set is None:
avoid_set = set()
orig_dupe_set = dupe_set
# Setup for the case when only particular related fields should be
# included in the related selection.
if requested is None and restricted is not False:
if isinstance(self.query.select_related, dict):
requested = self.query.select_related
restricted = True
else:
restricted = False
for f, model in opts.get_fields_with_model():
if not select_related_descend(f, restricted, requested):
continue
# The "avoid" set is aliases we want to avoid just for this
# particular branch of the recursion. They aren't permanently
# forbidden from reuse in the related selection tables (which is
# what "used" specifies).
avoid = avoid_set.copy()
dupe_set = orig_dupe_set.copy()
table = f.rel.to._meta.db_table
if nullable or f.null:
promote = True
else:
promote = False
if model:
int_opts = opts
alias = root_alias
alias_chain = []
for int_model in opts.get_base_chain(model):
# Proxy model have elements in base chain
# with no parents, assign the new options
# object and skip to the next base in that
# case
if not int_opts.parents[int_model]:
int_opts = int_model._meta
continue
lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col),
())
dupe_set.add((opts, lhs_col))
int_opts = int_model._meta
alias = self.query.join((alias, int_opts.db_table, lhs_col,
int_opts.pk.column), exclusions=used,
promote=promote)
alias_chain.append(alias)
for (dupe_opts, dupe_col) in dupe_set:
self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
if self.query.alias_map[root_alias][JOIN_TYPE] == self.query.LOUTER:
self.query.promote_alias_chain(alias_chain, True)
else:
alias = root_alias
dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe:
avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe:
dupe_set.add((opts, f.column))
alias = self.query.join((alias, table, f.column,
f.rel.get_related_field().column),
exclusions=used.union(avoid), promote=promote)
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(columns)
if self.query.alias_map[alias][JOIN_TYPE] == self.query.LOUTER:
self.query.promote_alias_chain(aliases, True)
self.query.related_select_fields.extend(f.rel.to._meta.fields)
if restricted:
next = requested.get(f.name, {})
else:
next = False
if f.null is not None:
new_nullable = f.null
else:
new_nullable = None
for dupe_opts, dupe_col in dupe_set:
self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid)
def deferred_to_columns(self):
"""
Converts the self.deferred_loading data structure to mapping of table
names to sets of column names which are to be loaded. Returns the
dictionary.
"""
columns = {}
self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb)
return columns
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
fields = None
for rows in self.execute_sql(MULTI):
for row in rows:
if resolve_columns:
if fields is None:
# We only set this up here because
# related_select_fields isn't populated until
# execute_sql() has been called.
if self.query.select_fields:
fields = self.query.select_fields + self.query.related_select_fields
else:
fields = self.query.model._meta.fields
row = self.resolve_columns(row, fields)
if self.query.aggregate_select:
aggregate_start = len(self.query.extra_select.keys()) + len(self.query.select)
aggregate_end = aggregate_start + len(self.query.aggregate_select)
row = tuple(row[:aggregate_start]) + tuple([
self.query.resolve_aggregate(value, aggregate, self.connection)
for (alias, aggregate), value
in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end])
]) + tuple(row[aggregate_end:])
yield row
def execute_sql(self, result_type=MULTI):
"""
Run the query against the database and returns the result(s). The
return value is a single data item if result_type is SINGLE, or an
iterator over the results if the result_type is MULTI.
result_type is either MULTI (use fetchmany() to retrieve all rows),
SINGLE (only retrieve a single row), or None. In this last case, the
cursor is returned if any query is executed, since it's used by
subclasses such as InsertQuery). It's possible, however, that no query
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
try:
sql, params = self.as_sql()
if not sql:
raise EmptyResultSet
except EmptyResultSet:
if result_type == MULTI:
return empty_iter()
else:
return
cursor = self.connection.cursor()
cursor.execute(sql, params)
if not result_type:
return cursor
if result_type == SINGLE:
if self.query.ordering_aliases:
return cursor.fetchone()[:-len(self.query.ordering_aliases)]
return cursor.fetchone()
# The MULTI case.
if self.query.ordering_aliases:
result = order_modified_iter(cursor, len(self.query.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
# If we are using non-chunked reads, we return the same data
# structure as normally, but ensure it is all read into memory
# before going any further.
return list(result)
return result
class SQLInsertCompiler(SQLCompiler):
def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
# going to be column names (so we can avoid the extra overhead).
qn = self.connection.ops.quote_name
opts = self.query.model._meta
result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
result.append('VALUES (%s)' % ', '.join(self.query.values))
params = self.query.params
if self.query.return_id and self.connection.features.can_return_id_from_insert:
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
def execute_sql(self, return_id=False):
self.query.return_id = return_id
cursor = super(SQLInsertCompiler, self).execute_sql(None)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.query.model._meta.db_table, self.query.model._meta.pk.column)
class SQLDeleteCompiler(SQLCompiler):
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
assert len(self.query.tables) == 1, \
"Can only delete from one table at a time."
qn = self.quote_name_unless_alias
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
class SQLUpdateCompiler(SQLCompiler):
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
from django.db.models.base import Model
self.pre_sql_setup()
if not self.query.values:
return '', ()
table = self.query.tables[0]
qn = self.quote_name_unless_alias
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
for field, model, val in self.query.values:
if hasattr(val, 'prepare_database_save'):
val = val.prepare_database_save(field)
else:
val = field.get_db_prep_save(val, connection=self.connection)
# Getting the placeholder for the field.
if hasattr(field, 'get_placeholder'):
placeholder = field.get_placeholder(val)
else:
placeholder = '%s'
if hasattr(val, 'evaluate'):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
sql, params = val.as_sql(qn, self.connection)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
values.append('%s = %s' % (qn(name), placeholder))
update_params.append(val)
else:
values.append('%s = NULL' % qn(name))
if not values:
return '', ()
result.append(', '.join(values))
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
def execute_sql(self, result_type):
"""
Execute the specified update. Returns the number of rows affected by
the primary update query. The "primary update query" is the first
non-empty query that is executed. Row counts for any subsequent,
related queries are not available.
"""
cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
rows = cursor and cursor.rowcount or 0
is_empty = cursor is None
del cursor
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty:
rows = aux_rows
is_empty = False
return rows
def pre_sql_setup(self):
"""
If the update depends on results from other tables, we need to do some
munging of the "where" conditions to match the format required for
(portable) SQL updates. That is done here.
Further, if we are going to be running multiple updates, we pull out
the id values to update at this point so that they don't change as a
result of the progressive updates.
"""
self.query.select_related = False
self.query.clear_ordering(True)
super(SQLUpdateCompiler, self).pre_sql_setup()
count = self.query.count_active_tables()
if not self.query.related_updates and count == 1:
return
# We need to use a sub-select in the where clause to filter on things
# from other tables.
query = self.query.clone(klass=Query)
query.bump_prefix()
query.extra = {}
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
self.query.where = self.query.where_class()
if self.query.related_updates or must_pre_select:
# Either we're using the idents in multiple update queries (so
# don't want them to change), or the db backend doesn't support
# selecting from the updating table (e.g. MySQL).
idents = []
for rows in query.get_compiler(self.using).execute_sql(MULTI):
idents.extend([r[0] for r in rows])
self.query.add_filter(('pk__in', idents))
self.query.related_ids = idents
else:
# The fast path. Filters and updates in one query.
self.query.add_filter(('pk__in', query.get_compiler(self.using)))
for alias in self.query.tables[1:]:
self.query.alias_refcount[alias] = 0
class SQLAggregateCompiler(SQLCompiler):
def as_sql(self, qn=None):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
if qn is None:
qn = self.quote_name_unless_alias
sql = ('SELECT %s FROM (%s) subquery' % (
', '.join([
aggregate.as_sql(qn, self.connection)
for aggregate in self.query.aggregate_select.values()
]),
self.query.subquery)
)
params = self.query.sub_params
return (sql, params)
class SQLDateCompiler(SQLCompiler):
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
from django.db.backends.util import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
date = row[offset]
if resolve_columns:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_timestamp(str(date))
yield date
def empty_iter():
"""
Returns an iterator containing no results.
"""
yield iter([]).next()
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
case when extra output columns have been added to support ordering
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield [r[:-trim] for r in rows]

View File

@ -11,17 +11,16 @@ from copy import deepcopy
from django.utils.tree import Node
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode
from django.db.backends.util import truncate_name
from django.db import connection, connections
from django.db import connection, connections, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import select_related_descend
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin
from constants import *
__all__ = ['Query']
@ -38,9 +37,10 @@ class Query(object):
query_terms = QUERY_TERMS
aggregates_module = base_aggregates_module
def __init__(self, model, connection, where=WhereNode):
compiler = 'SQLCompiler'
def __init__(self, model, where=WhereNode):
self.model = model
self.connection = connection
self.alias_refcount = {}
self.alias_map = {} # Maps alias to join information
self.table_map = {} # Maps table names to list of aliases.
@ -104,7 +104,7 @@ class Query(object):
Parameter values won't necessarily be quoted correctly, since that is
done by the database interface at execution time.
"""
sql, params = self.as_sql()
sql, params = self.get_compiler(DEFAULT_DB_ALIAS).as_sql()
return sql % params
def __deepcopy__(self, memo):
@ -119,8 +119,6 @@ class Query(object):
obj_dict = self.__dict__.copy()
obj_dict['related_select_fields'] = []
obj_dict['related_select_cols'] = []
del obj_dict['connection']
obj_dict['connection_settings'] = self.connection.settings_dict
# Fields can't be pickled, so if a field list has been
# specified, we pickle the list of field names instead.
@ -142,11 +140,13 @@ class Query(object):
]
self.__dict__.update(obj_dict)
self.connection = connections[connections.alias_for_settings(
obj_dict['connection_settings'])]
def get_query_class(self):
return Query
def get_compiler(self, using=None, connection=None):
if using is None and connection is None:
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
return connection.ops.compiler(self.compiler)(self, connection, using)
def get_meta(self):
"""
@ -156,22 +156,6 @@ class Query(object):
"""
return self.model._meta
def quote_name_unless_alias(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
quoted strings specially (e.g. PostgreSQL).
"""
if name in self.quote_cache:
return self.quote_cache[name]
if ((name in self.alias_map and name not in self.table_map) or
name in self.extra_select):
self.quote_cache[name] = name
return name
r = self.connection.ops.quote_name(name)
self.quote_cache[name] = r
return r
def clone(self, klass=None, **kwargs):
"""
Creates a copy of the current instance. The 'kwargs' parameter can be
@ -180,7 +164,6 @@ class Query(object):
obj = Empty()
obj.__class__ = klass or self.__class__
obj.model = self.model
obj.connection = self.connection
obj.alias_refcount = self.alias_refcount.copy()
obj.alias_map = self.alias_map.copy()
obj.table_map = self.table_map.copy()
@ -243,16 +226,16 @@ class Query(object):
obj._setup_query()
return obj
def convert_values(self, value, field):
def convert_values(self, value, field, connection):
"""Convert the database-returned value into a type that is consistent
across database backends.
By default, this defers to the underlying backend operations, but
it can be overridden by Query classes for specific backends.
"""
return self.connection.ops.convert_values(value, field)
return connection.ops.convert_values(value, field)
def resolve_aggregate(self, value, aggregate):
def resolve_aggregate(self, value, aggregate, connection):
"""Resolve the value of aggregates returned by the database to
consistent (and reasonable) types.
@ -272,39 +255,9 @@ class Query(object):
return float(value)
else:
# Return value depends on the type of the field being processed.
return self.convert_values(value, aggregate.field)
return self.convert_values(value, aggregate.field, connection)
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
fields = None
for rows in self.execute_sql(MULTI):
for row in rows:
if resolve_columns:
if fields is None:
# We only set this up here because
# related_select_fields isn't populated until
# execute_sql() has been called.
if self.select_fields:
fields = self.select_fields + self.related_select_fields
else:
fields = self.model._meta.fields
row = self.resolve_columns(row, fields)
if self.aggregate_select:
aggregate_start = len(self.extra_select.keys()) + len(self.select)
aggregate_end = aggregate_start + len(self.aggregate_select)
row = tuple(row[:aggregate_start]) + tuple([
self.resolve_aggregate(value, aggregate)
for (alias, aggregate), value
in zip(self.aggregate_select.items(), row[aggregate_start:aggregate_end])
]) + tuple(row[aggregate_end:])
yield row
def get_aggregation(self):
def get_aggregation(self, using):
"""
Returns the dictionary with the values of the existing aggregations.
"""
@ -316,7 +269,7 @@ class Query(object):
# over the subquery instead.
if self.group_by is not None:
from subqueries import AggregateQuery
query = self.connection.ops.query_class(Query, AggregateQuery)(self.model, self.connection)
query = AggregateQuery(self.model)
obj = self.clone()
@ -327,7 +280,7 @@ class Query(object):
query.aggregate_select[alias] = aggregate
del obj.aggregate_select[alias]
query.add_subquery(obj)
query.add_subquery(obj, using)
else:
query = self
self.select = []
@ -341,17 +294,17 @@ class Query(object):
query.related_select_cols = []
query.related_select_fields = []
result = query.execute_sql(SINGLE)
result = query.get_compiler(using).execute_sql(SINGLE)
if result is None:
result = [None for q in query.aggregate_select.items()]
return dict([
(alias, self.resolve_aggregate(val, aggregate))
(alias, self.resolve_aggregate(val, aggregate, connection=connections[using]))
for (alias, aggregate), val
in zip(query.aggregate_select.items(), result)
])
def get_count(self):
def get_count(self, using):
"""
Performs a COUNT() query using the current filter constraints.
"""
@ -365,11 +318,11 @@ class Query(object):
subquery.clear_ordering(True)
subquery.clear_limits()
obj = self.connection.ops.query_class(Query, AggregateQuery)(obj.model, obj.connection)
obj.add_subquery(subquery)
obj = AggregateQuery(obj.model)
obj.add_subquery(subquery, using=using)
obj.add_count_column()
number = obj.get_aggregation()[None]
number = obj.get_aggregation(using=using)[None]
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
# in SQL (in variants that provide them) doesn't change the COUNT
@ -380,7 +333,7 @@ class Query(object):
return number
def has_results(self):
def has_results(self, using):
q = self.clone()
q.add_extra({'a': 1}, None, None, None, None, None)
q.add_fields(())
@ -388,99 +341,8 @@ class Query(object):
q.set_aggregate_mask(())
q.clear_ordering()
q.set_limits(high=1)
return bool(q.execute_sql(SINGLE))
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
If 'with_limits' is False, any limit/offset information is not included
in the query.
"""
self.pre_sql_setup()
out_cols = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering()
# This must come after 'select' and 'ordering' -- see docstring of
# get_from_clause() for details.
from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias
where, w_params = self.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.having.as_sql(qn=qn, connection=self.connection)
params = []
for val in self.extra_select.itervalues():
params.extend(val[1])
result = ['SELECT']
if self.distinct:
result.append('DISTINCT')
result.append(', '.join(out_cols + self.ordering_aliases))
result.append('FROM')
result.extend(from_)
params.extend(f_params)
if where:
result.append('WHERE %s' % where)
params.extend(w_params)
if self.extra_where:
if not where:
result.append('WHERE')
else:
result.append('AND')
result.append(' AND '.join(self.extra_where))
grouping, gb_params = self.get_grouping()
if grouping:
if ordering:
# If the backend can't group by PK (i.e., any database
# other than MySQL), then any fields mentioned in the
# ordering clause needs to be in the group by clause.
if not self.connection.features.allows_group_by_pk:
for col, col_params in ordering_group_by:
if col not in grouping:
grouping.append(str(col))
gb_params.extend(col_params)
else:
ordering = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))
params.extend(gb_params)
if having:
result.append('HAVING %s' % having)
params.extend(h_params)
if ordering:
result.append('ORDER BY %s' % ', '.join(ordering))
if with_limits:
if self.high_mark is not None:
result.append('LIMIT %d' % (self.high_mark - self.low_mark))
if self.low_mark:
if self.high_mark is None:
val = self.connection.ops.no_limit_value()
if val:
result.append('LIMIT %d' % val)
result.append('OFFSET %d' % self.low_mark)
params.extend(self.extra_params)
return ' '.join(result), tuple(params)
def as_nested_sql(self):
"""
Perform the same functionality as the as_sql() method, returning an
SQL string and parameters. However, the alias prefixes are bumped
beforehand (in a copy -- the current query isn't changed) and any
ordering is removed.
Used when nesting this query inside another.
"""
obj = self.clone()
obj.clear_ordering(True)
obj.bump_prefix()
return obj.as_sql()
compiler = q.get_compiler(using=using)
return bool(compiler.execute_sql(SINGLE))
def combine(self, rhs, connector):
"""
@ -580,20 +442,6 @@ class Query(object):
self.order_by = rhs.order_by and rhs.order_by[:] or self.order_by
self.extra_order_by = rhs.extra_order_by or self.extra_order_by
def pre_sql_setup(self):
"""
Does any necessary class setup immediately prior to producing SQL. This
is for things that can't necessarily be done in __init__ because we
might not have all the pieces in place at that time.
"""
if not self.tables:
self.join((None, self.model._meta.db_table, None, None))
if (not self.select and self.default_cols and not
self.included_inherited_models):
self.setup_inherited_models()
if self.select_related and not self.related_select_cols:
self.fill_related_selections()
def deferred_to_data(self, target, callback):
"""
Converts the self.deferred_loading data structure to an alternate data
@ -672,15 +520,6 @@ class Query(object):
for model, values in seen.iteritems():
callback(target, model, values)
def deferred_to_columns(self):
"""
Converts the self.deferred_loading data structure to mapping of table
names to sets of column names which are to be loaded. Returns the
dictionary.
"""
columns = {}
self.deferred_to_data(columns, self.deferred_to_columns_cb)
return columns
def deferred_to_columns_cb(self, target, model, fields):
"""
@ -693,352 +532,6 @@ class Query(object):
for field in fields:
target[table].add(field.column)
def get_columns(self, with_aliases=False):
"""
Returns the list of columns to use in the select statement. If no
columns have been specified, returns all columns relating to fields in
the model.
If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.extra_select.iteritems()]
aliases = set(self.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
else:
col_aliases = set()
if self.select:
only_load = self.deferred_to_columns()
for col in self.select:
if isinstance(col, (list, tuple)):
alias, column = col
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and col not in only_load[table]:
continue
r = '%s.%s' % (qn(alias), qn(column))
if with_aliases:
if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(r)
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(qn, self.connection))
if hasattr(col, 'alias'):
aliases.add(col.alias)
col_aliases.add(col.alias)
elif self.default_cols:
cols, new_aliases = self.get_default_columns(with_aliases,
col_aliases)
result.extend(cols)
aliases.update(new_aliases)
result.extend([
'%s%s' % (
aggregate.as_sql(qn, self.connection),
alias is not None and ' AS %s' % qn(alias) or ''
)
for alias, aggregate in self.aggregate_select.items()
])
for table, col in self.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
if with_aliases and col in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append(r)
aliases.add(r)
col_aliases.add(col)
self._select_aliases = aliases
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
"""
result = []
if opts is None:
opts = self.model._meta
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
# Skip all proxy to the root proxied model
proxied_model = get_proxied_model(opts)
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
if start_alias:
try:
alias = seen[model]
except KeyError:
if model is proxied_model:
alias = start_alias
else:
link_field = opts.get_ancestor_link(model)
alias = self.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.included_inherited_models[model]
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs:
result.append((alias, field.column))
aliases.add(alias)
continue
if with_aliases and field.column in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s.%s AS %s' % (qn(alias),
qn2(field.column), c_alias))
col_aliases.add(c_alias)
aliases.add(c_alias)
else:
r = '%s.%s' % (qn(alias), qn2(field.column))
result.append(r)
aliases.add(r)
if with_aliases:
col_aliases.add(field.column)
return result, aliases
def get_from_clause(self):
"""
Returns a list of strings that are joined together to go after the
"FROM" part of the query, as well as a list any extra parameters that
need to be included. Sub-classes, can override this to create a
from-clause via a "select".
This should only be called after any SQL construction methods that
might change the tables we need. This means the select columns and
ordering must be done first.
"""
result = []
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
first = True
for alias in self.tables:
if not self.alias_refcount[alias]:
continue
try:
name, alias, join_type, lhs, lhs_col, col, nullable = self.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
result.append('%s %s%s ON (%s.%s = %s.%s)'
% (join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col)))
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
first = False
for t in self.extra_tables:
alias, unused = self.table_alias(t)
# Only add the alias if it's not already present (the table_alias()
# calls increments the refcount, so an alias refcount of one means
# this is the only reference.
if alias not in self.alias_map or self.alias_refcount[alias] == 1:
connector = not first and ', ' or ''
result.append('%s%s' % (connector, qn(alias)))
first = False
return result, []
def get_grouping(self):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
qn = self.quote_name_unless_alias
result, params = [], []
if self.group_by is not None:
if len(self.model._meta.fields) == len(self.group_by) and \
self.connection.features.allows_group_by_pk:
self.group_by = [(self.model._meta.db_table, self.model._meta.pk.column)]
group_by = self.group_by or []
extra_selects = []
for extra_select, extra_params in self.extra_select.itervalues():
extra_selects.append(extra_select)
params.extend(extra_params)
for col in group_by + self.related_select_cols + extra_selects:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
elif hasattr(col, 'as_sql'):
result.append(col.as_sql(qn))
else:
result.append(str(col))
return result, params
def get_ordering(self):
"""
Returns a tuple containing a list representing the SQL elements in the
"order by" clause, and the list of SQL elements that need to be added
to the GROUP BY clause as a result of the ordering.
Also sets the ordering_aliases attribute on this instance to a list of
extra aliases needed in the select.
Determining the ordering SQL can change the tables we need to include,
so this should be run *before* get_from_clause().
"""
if self.extra_order_by:
ordering = self.extra_order_by
elif not self.default_ordering:
ordering = self.order_by
else:
ordering = self.order_by or self.model._meta.ordering
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
distinct = self.distinct
select_aliases = self._select_aliases
result = []
group_by = []
ordering_aliases = []
if self.standard_ordering:
asc, desc = ORDER_DIR['ASC']
else:
asc, desc = ORDER_DIR['DESC']
# It's possible, due to model inheritance, that normal usage might try
# to include the same field more than once in the ordering. We track
# the table/column pairs we use and discard any after the first use.
processed_pairs = set()
for field in ordering:
if field == '?':
result.append(self.connection.ops.random_function_sql())
continue
if isinstance(field, int):
if field < 0:
order = desc
field = -field
else:
order = asc
result.append('%s %s' % (field, order))
group_by.append((field, []))
continue
col, order = get_order_dir(field, asc)
if col in self.aggregate_select:
result.append('%s %s' % (col, order))
continue
if '.' in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
table, col = col.split('.', 1)
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), col)
processed_pairs.add((table, col))
if not distinct or elt in select_aliases:
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
elif get_order_dir(field)[0] not in self.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
for table, col, order in self.find_ordering_name(field,
self.model._meta, default_order=asc):
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), qn2(col))
processed_pairs.add((table, col))
if distinct and elt not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
else:
elt = qn2(col)
if distinct and col not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append(self.extra_select[col])
self.ordering_aliases = ordering_aliases
return result, group_by
def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
already_seen=None):
"""
Returns the table alias (the name might be ambiguous, the alias will
not be) and column name for ordering by the given 'name' parameter.
The 'name' is of the form 'field1__field2__...__fieldN'.
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
if not alias:
alias = self.get_initial_alias()
field, target, opts, joins, last, extra = self.setup_joins(pieces,
opts, alias, False)
alias = joins[-1]
col = target.column
if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on
# the model.
self.ref_alias(alias)
# Must use left outer joins for nullable fields and their relations.
self.promote_alias_chain(joins,
self.alias_map[joins[0]][JOIN_TYPE] == self.LOUTER)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
if field.rel and len(joins) > 1 and opts.ordering:
# Firstly, avoid infinite loops.
if not already_seen:
already_seen = set()
join_tuple = tuple([self.alias_map[j][TABLE_NAME] for j in joins])
if join_tuple in already_seen:
raise FieldError('Infinite loop caused by ordering.')
already_seen.add(join_tuple)
results = []
for item in opts.ordering:
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
if alias:
# We have to do the same "final join" optimisation as in
# add_filter, since the final column might not otherwise be part of
# the select set (so we can't order on it).
while 1:
join = self.alias_map[alias]
if col != join[RHS_JOIN_COL]:
break
self.unref_alias(alias)
alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL]
return [(alias, col, order)]
def table_alias(self, table_name, create=False):
"""
@ -1342,113 +835,6 @@ class Query(object):
self.unref_alias(alias)
self.included_inherited_models = {}
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None,
dupe_set=None, avoid_set=None):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
(for example, cur_depth=1 means we are looking at models with direct
connections to the root model).
"""
if not restricted and self.max_depth and cur_depth > self.max_depth:
# We've recursed far enough; bail out.
return
if not opts:
opts = self.get_meta()
root_alias = self.get_initial_alias()
self.related_select_cols = []
self.related_select_fields = []
if not used:
used = set()
if dupe_set is None:
dupe_set = set()
if avoid_set is None:
avoid_set = set()
orig_dupe_set = dupe_set
# Setup for the case when only particular related fields should be
# included in the related selection.
if requested is None and restricted is not False:
if isinstance(self.select_related, dict):
requested = self.select_related
restricted = True
else:
restricted = False
for f, model in opts.get_fields_with_model():
if not select_related_descend(f, restricted, requested):
continue
# The "avoid" set is aliases we want to avoid just for this
# particular branch of the recursion. They aren't permanently
# forbidden from reuse in the related selection tables (which is
# what "used" specifies).
avoid = avoid_set.copy()
dupe_set = orig_dupe_set.copy()
table = f.rel.to._meta.db_table
if nullable or f.null:
promote = True
else:
promote = False
if model:
int_opts = opts
alias = root_alias
alias_chain = []
for int_model in opts.get_base_chain(model):
# Proxy model have elements in base chain
# with no parents, assign the new options
# object and skip to the next base in that
# case
if not int_opts.parents[int_model]:
int_opts = int_model._meta
continue
lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
avoid.update(self.dupe_avoidance.get(id(opts), lhs_col),
())
dupe_set.add((opts, lhs_col))
int_opts = int_model._meta
alias = self.join((alias, int_opts.db_table, lhs_col,
int_opts.pk.column), exclusions=used,
promote=promote)
alias_chain.append(alias)
for (dupe_opts, dupe_col) in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
if self.alias_map[root_alias][JOIN_TYPE] == self.LOUTER:
self.promote_alias_chain(alias_chain, True)
else:
alias = root_alias
dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe:
avoid.update(self.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe:
dupe_set.add((opts, f.column))
alias = self.join((alias, table, f.column,
f.rel.get_related_field().column),
exclusions=used.union(avoid), promote=promote)
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.related_select_cols.extend(columns)
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
self.promote_alias_chain(aliases, True)
self.related_select_fields.extend(f.rel.to._meta.fields)
if restricted:
next = requested.get(f.name, {})
else:
next = False
if f.null is not None:
new_nullable = f.null
else:
new_nullable = None
for dupe_opts, dupe_col in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid)
def add_aggregate(self, aggregate, model, alias, is_summary):
"""
@ -1497,7 +883,6 @@ class Query(object):
col = field_name
# Add the aggregate to the query
alias = truncate_name(alias, self.connection.ops.max_name_length())
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
@ -1548,10 +933,6 @@ class Query(object):
raise ValueError("Cannot use None as a query value")
lookup_type = 'isnull'
value = True
elif (value == '' and lookup_type == 'exact' and
self.connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull'
value = True
elif callable(value):
value = value()
elif hasattr(value, 'evaluate'):
@ -1969,7 +1350,7 @@ class Query(object):
original exclude filter (filter_expr) and the portion up to the first
N-to-many relation field.
"""
query = self.connection.ops.query_class(Query)(self.model, self.connection)
query = Query(self.model)
query.add_filter(filter_expr, can_reuse=can_reuse)
query.bump_prefix()
query.clear_ordering(True)
@ -2347,54 +1728,6 @@ class Query(object):
self.select = [(select_alias, select_col)]
self.remove_inherited_models()
def set_connection(self, connection):
self.connection = connection
def execute_sql(self, result_type=MULTI):
"""
Run the query against the database and returns the result(s). The
return value is a single data item if result_type is SINGLE, or an
iterator over the results if the result_type is MULTI.
result_type is either MULTI (use fetchmany() to retrieve all rows),
SINGLE (only retrieve a single row), or None. In this last case, the
cursor is returned if any query is executed, since it's used by
subclasses such as InsertQuery). It's possible, however, that no query
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
"""
try:
sql, params = self.as_sql()
if not sql:
raise EmptyResultSet
except EmptyResultSet:
if result_type == MULTI:
return empty_iter()
else:
return
cursor = self.connection.cursor()
cursor.execute(sql, params)
if not result_type:
return cursor
if result_type == SINGLE:
if self.ordering_aliases:
return cursor.fetchone()[:-len(self.ordering_aliases)]
return cursor.fetchone()
# The MULTI case.
if self.ordering_aliases:
result = order_modified_iter(cursor, len(self.ordering_aliases),
self.connection.features.empty_fetchmany_value)
else:
result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
self.connection.features.empty_fetchmany_value)
if not self.connection.features.can_use_chunked_reads:
# If we are using non-chunked reads, we return the same data
# structure as normally, but ensure it is all read into memory
# before going any further.
return list(result)
return result
def get_order_dir(field, default='ASC'):
"""
@ -2409,22 +1742,6 @@ def get_order_dir(field, default='ASC'):
return field[1:], dirn[1]
return field, dirn[0]
def empty_iter():
"""
Returns an iterator containing no results.
"""
yield iter([]).next()
def order_modified_iter(cursor, trim, sentinel):
"""
Yields blocks of rows from a cursor. We use this iterator in the special
case when extra output columns have been added to support ordering
requirements. We must trim those extra columns before anything else can use
the results, since they're only needed to make the SQL valid.
"""
for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
sentinel):
yield [r[:-trim] for r in rows]
def setup_join_cache(sender, **kwargs):
"""

View File

@ -3,6 +3,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.core.exceptions import FieldError
from django.db import connections
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import Date
from django.db.models.sql.expressions import SQLEvaluator
@ -17,28 +18,15 @@ class DeleteQuery(Query):
Delete queries are done through this class, since they are more constrained
than general queries.
"""
def get_query_class(self):
return DeleteQuery
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
assert len(self.tables) == 1, \
"Can only delete from one table at a time."
qn = self.quote_name_unless_alias
result = ['DELETE FROM %s' % qn(self.tables[0])]
where, params = self.where.as_sql(qn=qn, connection=self.connection)
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
compiler = 'SQLDeleteCompiler'
def do_query(self, table, where):
def do_query(self, table, where, using):
self.tables = [table]
self.where = where
self.execute_sql(None)
self.get_compiler(using).execute_sql(None)
def delete_batch_related(self, pk_list):
def delete_batch_related(self, pk_list, using):
"""
Set up and execute delete queries for all the objects related to the
primary key values in pk_list. To delete the objects themselves, use
@ -58,7 +46,7 @@ class DeleteQuery(Query):
'in',
pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
AND)
self.do_query(related.field.m2m_db_table(), where)
self.do_query(related.field.m2m_db_table(), where, using=using)
for f in cls._meta.many_to_many:
w1 = self.where_class()
@ -74,9 +62,9 @@ class DeleteQuery(Query):
AND)
if w1:
where.add(w1, AND)
self.do_query(f.m2m_db_table(), where)
self.do_query(f.m2m_db_table(), where, using=using)
def delete_batch(self, pk_list):
def delete_batch(self, pk_list, using):
"""
Set up and execute delete queries for all the objects in pk_list. This
should be called after delete_batch_related(), if necessary.
@ -89,19 +77,19 @@ class DeleteQuery(Query):
field = self.model._meta.pk
where.add((Constraint(None, field.column, field), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
self.do_query(self.model._meta.db_table, where)
self.do_query(self.model._meta.db_table, where, using=using)
class UpdateQuery(Query):
"""
Represents an "update" SQL query.
"""
compiler = 'SQLUpdateCompiler'
def __init__(self, *args, **kwargs):
super(UpdateQuery, self).__init__(*args, **kwargs)
self._setup_query()
def get_query_class(self):
return UpdateQuery
def _setup_query(self):
"""
Runs on initialization and after cloning. Any attributes that would
@ -117,98 +105,8 @@ class UpdateQuery(Query):
return super(UpdateQuery, self).clone(klass,
related_updates=self.related_updates.copy(), **kwargs)
def execute_sql(self, result_type=None):
"""
Execute the specified update. Returns the number of rows affected by
the primary update query. The "primary update query" is the first
non-empty query that is executed. Row counts for any subsequent,
related queries are not available.
"""
cursor = super(UpdateQuery, self).execute_sql(result_type)
rows = cursor and cursor.rowcount or 0
is_empty = cursor is None
del cursor
for query in self.get_related_updates():
aux_rows = query.execute_sql(result_type)
if is_empty:
rows = aux_rows
is_empty = False
return rows
def as_sql(self):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
self.pre_sql_setup()
if not self.values:
return '', ()
table = self.tables[0]
qn = self.quote_name_unless_alias
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
for name, val, placeholder in self.values:
if hasattr(val, 'as_sql'):
sql, params = val.as_sql(qn, self.connection)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
values.append('%s = %s' % (qn(name), placeholder))
update_params.append(val)
else:
values.append('%s = NULL' % qn(name))
result.append(', '.join(values))
where, params = self.where.as_sql(qn=qn, connection=self.connection)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
def pre_sql_setup(self):
"""
If the update depends on results from other tables, we need to do some
munging of the "where" conditions to match the format required for
(portable) SQL updates. That is done here.
Further, if we are going to be running multiple updates, we pull out
the id values to update at this point so that they don't change as a
result of the progressive updates.
"""
self.select_related = False
self.clear_ordering(True)
super(UpdateQuery, self).pre_sql_setup()
count = self.count_active_tables()
if not self.related_updates and count == 1:
return
# We need to use a sub-select in the where clause to filter on things
# from other tables.
query = self.clone(klass=Query)
query.bump_prefix()
query.extra = {}
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
self.where = self.where_class()
if self.related_updates or must_pre_select:
# Either we're using the idents in multiple update queries (so
# don't want them to change), or the db backend doesn't support
# selecting from the updating table (e.g. MySQL).
idents = []
for rows in query.execute_sql(MULTI):
idents.extend([r[0] for r in rows])
self.add_filter(('pk__in', idents))
self.related_ids = idents
else:
# The fast path. Filters and updates in one query.
self.add_filter(('pk__in', query))
for alias in self.tables[1:]:
self.alias_refcount[alias] = 0
def clear_related(self, related_field, pk_list):
def clear_related(self, related_field, pk_list, using):
"""
Set up and execute an update query that clears related entries for the
keys in pk_list.
@ -221,8 +119,8 @@ class UpdateQuery(Query):
self.where.add((Constraint(None, f.column, f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND)
self.values = [(related_field.column, None, '%s')]
self.execute_sql(None)
self.values = [(related_field, None, None)]
self.get_compiler(using).execute_sql(None)
def add_update_values(self, values):
"""
@ -235,6 +133,9 @@ class UpdateQuery(Query):
field, model, direct, m2m = self.model._meta.get_field_by_name(name)
if not direct or m2m:
raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
if model:
self.add_related_update(model, field, val)
continue
values_seq.append((field, model, val))
return self.add_update_fields(values_seq)
@ -244,36 +145,18 @@ class UpdateQuery(Query):
Used by add_update_values() as well as the "fast" update path when
saving models.
"""
from django.db.models.base import Model
for field, model, val in values_seq:
if hasattr(val, 'prepare_database_save'):
val = val.prepare_database_save(field)
else:
val = field.get_db_prep_save(val, connection=self.connection)
self.values.extend(values_seq)
# Getting the placeholder for the field.
if hasattr(field, 'get_placeholder'):
placeholder = field.get_placeholder(val)
else:
placeholder = '%s'
if hasattr(val, 'evaluate'):
val = SQLEvaluator(val, self, allow_joins=False)
if model:
self.add_related_update(model, field.column, val, placeholder)
else:
self.values.append((field.column, val, placeholder))
def add_related_update(self, model, column, value, placeholder):
def add_related_update(self, model, field, value):
"""
Adds (name, value) to an update query for an ancestor model.
Updates are coalesced so that we only run one update query per ancestor.
"""
try:
self.related_updates[model].append((column, value, placeholder))
self.related_updates[model].append((field, None, value))
except KeyError:
self.related_updates[model] = [(column, value, placeholder)]
self.related_updates[model] = [(field, None, value)]
def get_related_updates(self):
"""
@ -285,7 +168,7 @@ class UpdateQuery(Query):
return []
result = []
for model, values in self.related_updates.iteritems():
query = self.connection.ops.query_class(Query, UpdateQuery)(model, self.connection)
query = UpdateQuery(model)
query.values = values
if self.related_ids:
query.add_filter(('pk__in', self.related_ids))
@ -293,6 +176,8 @@ class UpdateQuery(Query):
return result
class InsertQuery(Query):
compiler = 'SQLInsertCompiler'
def __init__(self, *args, **kwargs):
super(InsertQuery, self).__init__(*args, **kwargs)
self.columns = []
@ -300,41 +185,12 @@ class InsertQuery(Query):
self.params = ()
self.return_id = False
def get_query_class(self):
return InsertQuery
def clone(self, klass=None, **kwargs):
extras = {'columns': self.columns[:], 'values': self.values[:],
'params': self.params, 'return_id': self.return_id}
extras.update(kwargs)
return super(InsertQuery, self).clone(klass, **extras)
def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
# going to be column names (so we can avoid the extra overhead).
qn = self.connection.ops.quote_name
opts = self.model._meta
result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
result.append('VALUES (%s)' % ', '.join(self.values))
params = self.params
if self.return_id and self.connection.features.can_return_id_from_insert:
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
def execute_sql(self, return_id=False):
self.return_id = return_id
cursor = super(InsertQuery, self).execute_sql(None)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.model._meta.db_table, self.model._meta.pk.column)
def insert_values(self, insert_values, raw_values=False):
"""
Set up the insert query from the 'insert_values' dictionary. The
@ -368,47 +224,8 @@ class DateQuery(Query):
date field. This requires some special handling when converting the results
back to Python objects, so we put it in a separate class.
"""
def __getstate__(self):
"""
Special DateQuery-specific pickle handling.
"""
for elt in self.select:
if isinstance(elt, Date):
# Eliminate a method reference that can't be pickled. The
# __setstate__ method restores this.
elt.date_sql_func = None
return super(DateQuery, self).__getstate__()
def __setstate__(self, obj_dict):
super(DateQuery, self).__setstate__(obj_dict)
for elt in self.select:
if isinstance(elt, Date):
self.date_sql_func = self.connection.ops.date_trunc_sql
def get_query_class(self):
return DateQuery
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
from django.db.backends.util import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
date = row[offset]
if resolve_columns:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_timestamp(str(date))
yield date
compiler = 'SQLDateCompiler'
def add_date_select(self, field, lookup_type, order='ASC'):
"""
@ -430,25 +247,8 @@ class AggregateQuery(Query):
An AggregateQuery takes another query as a parameter to the FROM
clause and only selects the elements in the provided list.
"""
def get_query_class(self):
return AggregateQuery
def add_subquery(self, query):
self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
compiler = 'SQLAggregateCompiler'
def as_sql(self, qn=None):
"""
Creates the SQL for this query. Returns the SQL string and list of
parameters.
"""
if qn is None:
qn = self.quote_name_unless_alias
sql = ('SELECT %s FROM (%s) subquery' % (
', '.join([
aggregate.as_sql(qn, self.connection)
for aggregate in self.aggregate_select.values()
]),
self.subquery)
)
params = self.sub_params
return (sql, params)
def add_subquery(self, query, using):
self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)

View File

@ -162,6 +162,11 @@ class WhereNode(tree.Node):
else:
extra = ''
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
and connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull'
value_annot = True
if lookup_type in connection.operators:
format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
return (format % (field_sql,

View File

@ -10,6 +10,7 @@ def load_backend(backend_name):
# backends that ships with Django, so look there first.
return import_module('.base', 'django.db.backends.%s' % backend_name)
except ImportError, e:
raise
# If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next.
try:
@ -17,7 +18,7 @@ def load_backend(backend_name):
except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends.
backend_dir = os.path.join(__path__[0], 'backends')
backend_dir = os.path.join(os.path.dirname(__file__), 'backends')
try:
available_backends = [f for f in os.listdir(backend_dir)
if os.path.isdir(os.path.join(backend_dir, f))

View File

@ -3,6 +3,7 @@ Helper functions for creating Form classes from Django models
and database field objects.
"""
from django.db import connections
from django.utils.encoding import smart_unicode, force_unicode
from django.utils.datastructures import SortedDict
from django.utils.text import get_text_list, capfirst
@ -471,8 +472,7 @@ class BaseModelFormSet(BaseFormSet):
pk = self.data[pk_key]
pk_field = self.model._meta.pk
pk = pk_field.get_db_prep_lookup('exact', pk,
connection=self.get_queryset().query.connection)
pk = pk_field.get_db_prep_lookup('exact', pk)
connection=connections[self.get_queryset()._using])
if isinstance(pk, list):
pk = pk[0]
kwargs['instance'] = self._existing_object(pk)

View File

@ -256,6 +256,7 @@ Here's a sample configuration which uses a MySQL option file::
}
}
# my.cnf
[client]
database = DATABASE_NAME

View File

@ -765,6 +765,14 @@ with an appropriate extension (e.g. ``json`` or ``xml``). See the
documentation for ``loaddata`` for details on the specification of fixture
data files.
--database
~~~~~~~~~~
The alias for the database install the tables for. By default uses the
``'default'`` alias.
--noinput
~~~~~~~~~
The :djadminopt:`--noinput` option may be provided to suppress all user
prompts.

View File

@ -144,8 +144,6 @@ Default: ``600``
The default number of seconds to cache a page when the caching middleware or
``cache_page()`` decorator is used.
.. setting:: DATABASES
.. setting:: CSRF_COOKIE_NAME
CSRF_COOKIE_NAME
@ -192,6 +190,9 @@ end users) indicating the reason the request was rejected. See
:ref:`ref-contrib-csrf`.
.. setting:: DATABASES
DATABASES
---------

View File

@ -171,9 +171,9 @@ True
# temporarily replace the UpdateQuery class to verify that E.f is actually nulled out first
>>> import django.db.models.sql
>>> class LoggingUpdateQuery(django.db.models.sql.UpdateQuery):
... def clear_related(self, related_field, pk_list):
... def clear_related(self, related_field, pk_list, using):
... print "CLEARING FIELD",related_field.name
... return super(LoggingUpdateQuery, self).clear_related(related_field, pk_list)
... return super(LoggingUpdateQuery, self).clear_related(related_field, pk_list, using)
>>> original_class = django.db.models.sql.UpdateQuery
>>> django.db.models.sql.UpdateQuery = LoggingUpdateQuery
>>> e1.delete()

View File

@ -157,7 +157,7 @@ False
# The underlying query only makes one join when a related table is referenced twice.
>>> queryset = Article.objects.filter(reporter__first_name__exact='John', reporter__last_name__exact='Smith')
>>> sql = queryset.query.as_sql()[0]
>>> sql = queryset.query.get_compiler(queryset._using).as_sql()[0]
>>> sql.count('INNER JOIN')
1

View File

@ -166,12 +166,13 @@ class ProxyImprovement(Improvement):
__test__ = {'API_TESTS' : """
# The MyPerson model should be generating the same database queries as the
# Person model (when the same manager is used in each case).
>>> MyPerson.other.all().query.as_sql() == Person.objects.order_by("name").query.as_sql()
>>> from django.db import DEFAULT_DB_ALIAS
>>> MyPerson.other.all().query.get_compiler(DEFAULT_DB_ALIAS).as_sql() == Person.objects.order_by("name").query.get_compiler(DEFAULT_DB_ALIAS).as_sql()
True
# The StatusPerson models should have its own table (it's using ORM-level
# inheritance).
>>> StatusPerson.objects.all().query.as_sql() == Person.objects.all().query.as_sql()
>>> StatusPerson.objects.all().query.get_compiler(DEFAULT_DB_ALIAS).as_sql() == Person.objects.all().query.get_compiler(DEFAULT_DB_ALIAS).as_sql()
False
# Creating a Person makes them accessible through the MyPerson proxy.

View File

@ -250,10 +250,10 @@ FieldError: Cannot resolve keyword 'foo' into field. Choices are: authors, conta
>>> out = pickle.dumps(qs)
# Then check that the round trip works.
>>> query = qs.query.as_sql()[0]
>>> query = qs.query.get_compiler(qs._using).as_sql()[0]
>>> select_fields = qs.query.select_fields
>>> query2 = pickle.loads(pickle.dumps(qs))
>>> query2.query.as_sql()[0] == query
>>> query2.query.get_compiler(query2._using).as_sql()[0] == query
True
>>> query2.query.select_fields = select_fields
@ -380,5 +380,4 @@ if run_stddev_tests():
>>> Book.objects.aggregate(Variance('price', sample=True))
{'price__variance': 700.53...}
"""

View File

@ -18,13 +18,13 @@ class Callproc(unittest.TestCase):
return True
else:
return True
class LongString(unittest.TestCase):
def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly
if settings.DATABASE_ENGINE == 'oracle':
if settings.DATABASES[DEFAULT_DB_ALIAS]['DATABASE_ENGINE'] == 'oracle':
c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join([unicode(x) for x in xrange(4000)])

View File

@ -314,7 +314,8 @@ DoesNotExist: ArticleWithAuthor matching query does not exist.
# likely to ocurr naturally with model inheritance, so we check it here).
# Regression test for #9390. This necessarily pokes at the SQL string for the
# query, since the duplicate problems are only apparent at that late stage.
>>> sql = ArticleWithAuthor.objects.order_by('pub_date', 'pk').query.as_sql()[0]
>>> qs = ArticleWithAuthor.objects.order_by('pub_date', 'pk')
>>> sql = qs.query.get_compiler(qs._using).as_sql()[0]
>>> fragment = sql[sql.find('ORDER BY'):]
>>> pos = fragment.find('pub_date')
>>> fragment.find('pub_date', pos + 1) == -1

View File

@ -71,8 +71,6 @@ class PickleQuerySetTestCase(TestCase):
def test_pickling(self):
for db in connections:
qs = Book.objects.all()
self.assertEqual(qs.query.connection,
pickle.loads(pickle.dumps(qs)).query.connection)
self.assertEqual(qs._using, pickle.loads(pickle.dumps(qs))._using)

View File

@ -822,8 +822,8 @@ We can do slicing beyond what is currently in the result cache, too.
Bug #7045 -- extra tables used to crash SQL construction on the second use.
>>> qs = Ranking.objects.extra(tables=['django_site'])
>>> s = qs.query.as_sql()
>>> s = qs.query.as_sql() # test passes if this doesn't raise an exception.
>>> s = qs.query.get_compiler(qs._using).as_sql()
>>> s = qs.query.get_compiler(qs._using).as_sql() # test passes if this doesn't raise an exception.
Bug #7098 -- Make sure semi-deprecated ordering by related models syntax still
works.
@ -912,9 +912,9 @@ We should also be able to pickle things that use select_related(). The only
tricky thing here is to ensure that we do the related selections properly after
unpickling.
>>> qs = Item.objects.select_related()
>>> query = qs.query.as_sql()[0]
>>> query = qs.query.get_compiler(qs._using).as_sql()[0]
>>> query2 = pickle.loads(pickle.dumps(qs.query))
>>> query2.as_sql()[0] == query
>>> query2.get_compiler(qs._using).as_sql()[0] == query
True
Check pickling of deferred-loading querysets
@ -1051,7 +1051,7 @@ sufficient that this query runs without error.
Calling order_by() with no parameters removes any existing ordering on the
model. But it should still be possible to add new ordering after that.
>>> qs = Author.objects.order_by().order_by('name')
>>> 'ORDER BY' in qs.query.as_sql()[0]
>>> 'ORDER BY' in qs.query.get_compiler(qs._using).as_sql()[0]
True
Incorrect SQL was being generated for certain types of exclude() queries that
@ -1085,7 +1085,8 @@ performance problems on backends like MySQL.
Nested queries should not evaluate the inner query as part of constructing the
SQL (so we should see a nested query here, indicated by two "SELECT" calls).
>>> Annotation.objects.filter(notes__in=Note.objects.filter(note="xyzzy")).query.as_sql()[0].count('SELECT')
>>> qs = Annotation.objects.filter(notes__in=Note.objects.filter(note="xyzzy"))
>>> qs.query.get_compiler(qs._using).as_sql()[0].count('SELECT')
2
Bug #10181 -- Avoid raising an EmptyResultSet if an inner query is provably
@ -1235,7 +1236,7 @@ portion in MySQL to prevent unnecessary sorting.
>>> query = Tag.objects.values_list('parent_id', flat=True).order_by().query
>>> query.group_by = ['parent_id']
>>> sql = query.as_sql()[0]
>>> sql = query.get_compiler(DEFAULT_DB_ALIAS).as_sql()[0]
>>> fragment = "ORDER BY "
>>> pos = sql.find(fragment)
>>> sql.find(fragment, pos + 1) == -1