1
0
mirror of https://github.com/django/django.git synced 2025-07-05 10:19:20 +00:00

[soc2009/multidb] Bring this branch up to date with my external work. This means implementing the using method on querysets as well as a using kwarg on save and delete, plus many internal changes to facilitae this

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@10904 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-06-03 02:37:33 +00:00
parent f4bcbbfa8b
commit 23da5c0ac1
29 changed files with 447 additions and 349 deletions

View File

@ -4,6 +4,11 @@ TODO
The follow is a list, more or less in the order I intend to do them of things The follow is a list, more or less in the order I intend to do them of things
that need to be done. I'm trying to be as granular as possible. that need to be done. I'm trying to be as granular as possible.
***
Immediate TODOs:
Finish refactor of WhereNode so that it just takes connection in as_sql.
***
2) Update all old references to ``settings.DATABASE_*`` to reference 2) Update all old references to ``settings.DATABASE_*`` to reference
``settings.DATABASES``. This includes the following locations ``settings.DATABASES``. This includes the following locations
@ -20,29 +25,21 @@ that need to be done. I'm trying to be as granular as possible.
* ``dumpdata``: By default dump the ``default`` database. Later add a * ``dumpdata``: By default dump the ``default`` database. Later add a
``--database`` flag. ``--database`` flag.
* ``loaddata``: Leave as is, will use the standard mechanisms for
determining what DB to save the objects to. Should it get a
``--database`` flag to overide that?
These items will be fixed pending both community consensus, and the API These items will be fixed pending both community consensus, and the API
that will go in that's actually necessary for these to happen. Due to that will go in that's actually necessary for these to happen.
internal APIs loaddata probably will need an update to load stuff into a
specific DB.
4) Rig up the test harness to work with multiple databases. This includes: flush, reset, and syncdb need to not prompt the user multiple times.
* The current strategy is to test on N dbs, where N is however many the 9) Fix transaction support. In practice this means changing all the
user defines and ensuring the data all stays seperate and no exceptions dictionaries that currently map thread to a boolean to being a dictionary
are raised. Practically speaking this means we're only going to have mapping thread to a set of the affected DBs, and changing all the functions
good coverage if we write a lot of tests that can break. That's life. that use these dictionaries to handle the changes appropriately.
7) Remove any references to the global ``django.db.connection`` object in the 7) Remove any references to the global ``django.db.connection`` object in the
SQL creation process. This includes(but is probably not limited to): SQL creation process. This includes(but is probably not limited to):
* The way we create ``Query`` from ``BaseQuery`` is awkward and hacky. * The way we create ``Query`` from ``BaseQuery`` is awkward and hacky.
* ``django.db.models.query.delete_objects``
* ``django.db.models.query.insert_query``
* ``django.db.models.base.Model`` -- in ``save_base``
* ``django.db.models.fields.Field`` This uses it, as do it's subclasses. * ``django.db.models.fields.Field`` This uses it, as do it's subclasses.
* ``django.db.models.fields.related`` It's used all over the place here, * ``django.db.models.fields.related`` It's used all over the place here,
including opening a cursor and executing queries, so that's going to including opening a cursor and executing queries, so that's going to
@ -50,13 +47,9 @@ that need to be done. I'm trying to be as granular as possible.
raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes
it in before I need to tackle this. it in before I need to tackle this.
5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at 5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at
each stage from here on out). each stage from here on out).
5) Implement using kwarg on save() method.
6) Add the ``using`` method to ``QuerySet``. This will more or less "just
work" across multiple databases that use the same backend. However, it
will fail gratuitously when trying to use 2 different backends.
8) Implement some way to create a new ``Query`` for a different backend when 8) Implement some way to create a new ``Query`` for a different backend when
we switch. There are several checks against ``self.connection`` prior to we switch. There are several checks against ``self.connection`` prior to
SQL construction, so we either need to defer all these(which will be SQL construction, so we either need to defer all these(which will be
@ -80,8 +73,4 @@ that need to be done. I'm trying to be as granular as possible.
every single one of them, then when it's time excecute the query just every single one of them, then when it's time excecute the query just
pick the right ``Query`` object to use. This *does* not scale, though it pick the right ``Query`` object to use. This *does* not scale, though it
could probably be done fairly easily. could probably be done fairly easily.
9) Fix transaction support. In practice this means changing all the
dictionaries that currently map thread to a boolean to being a dictionary
mapping thread to a set of the affected DBs, and changing all the functions
that use these dictionaries to handle the changes appropriately.
10) Time permitting add support for a ``DatabaseManager``. 10) Time permitting add support for a ``DatabaseManager``.

View File

@ -151,7 +151,7 @@ class GenericRelation(RelatedField, Field):
def get_internal_type(self): def get_internal_type(self):
return "ManyToManyField" return "ManyToManyField"
def db_type(self): def db_type(self, connection):
# Since we're simulating a ManyToManyField, in effect, best return the # Since we're simulating a ManyToManyField, in effect, best return the
# same db_type as well. # same db_type as well.
return None return None

View File

@ -1,12 +1,11 @@
from django.conf import settings
from django.db.models.fields import Field from django.db.models.fields import Field
class USStateField(Field): class USStateField(Field):
def get_internal_type(self): def get_internal_type(self):
return "USStateField" return "USStateField"
def db_type(self): def db_type(self, connection):
if settings.DATABASE_ENGINE == 'oracle': if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
return 'CHAR(2)' return 'CHAR(2)'
else: else:
return 'varchar(2)' return 'varchar(2)'
@ -21,8 +20,8 @@ class PhoneNumberField(Field):
def get_internal_type(self): def get_internal_type(self):
return "PhoneNumberField" return "PhoneNumberField"
def db_type(self): def db_type(self, connection):
if settings.DATABASE_ENGINE == 'oracle': if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
return 'VARCHAR2(20)' return 'VARCHAR2(20)'
else: else:
return 'varchar(20)' return 'varchar(20)'
@ -32,4 +31,3 @@ class PhoneNumberField(Field):
defaults = {'form_class': USPhoneNumberField} defaults = {'form_class': USPhoneNumberField}
defaults.update(kwargs) defaults.update(kwargs)
return super(PhoneNumberField, self).formfield(**defaults) return super(PhoneNumberField, self).formfield(**defaults)

View File

@ -2,7 +2,7 @@ import datetime
from django.contrib.sessions.models import Session from django.contrib.sessions.models import Session
from django.contrib.sessions.backends.base import SessionBase, CreateError from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.db import IntegrityError, transaction from django.db import IntegrityError, transaction, DEFAULT_DB_ALIAS
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
class SessionStore(SessionBase): class SessionStore(SessionBase):
@ -53,12 +53,13 @@ class SessionStore(SessionBase):
session_data = self.encode(self._get_session(no_load=must_create)), session_data = self.encode(self._get_session(no_load=must_create)),
expire_date = self.get_expiry_date() expire_date = self.get_expiry_date()
) )
sid = transaction.savepoint() # TODO update for multidb
sid = transaction.savepoint(using=DEFAULT_DB_ALIAS)
try: try:
obj.save(force_insert=must_create) obj.save(force_insert=must_create)
except IntegrityError: except IntegrityError:
if must_create: if must_create:
transaction.savepoint_rollback(sid) transaction.savepoint_rollback(sid, using=DEFAULT_DB_ALIAS)
raise CreateError raise CreateError
raise raise

View File

@ -6,12 +6,12 @@ from django.db.models import signals
from django.contrib.sites.models import Site from django.contrib.sites.models import Site
from django.contrib.sites import models as site_app from django.contrib.sites import models as site_app
def create_default_site(app, created_models, verbosity, **kwargs): def create_default_site(app, created_models, verbosity, db, **kwargs):
if Site in created_models: if Site in created_models:
if verbosity >= 2: if verbosity >= 2:
print "Creating example.com Site object" print "Creating example.com Site object"
s = Site(domain="example.com", name="example.com") s = Site(domain="example.com", name="example.com")
s.save() s.save(using=db)
Site.objects.clear_cache() Site.objects.clear_cache()
signals.post_syncdb.connect(create_default_site, sender=site_app) signals.post_syncdb.connect(create_default_site, sender=site_app)

View File

@ -27,9 +27,9 @@ class Command(LabelCommand):
) )
table_output = [] table_output = []
index_output = [] index_output = []
qn = connections.ops.quote_name qn = connection.ops.quote_name
for f in fields: for f in fields:
field_output = [qn(f.name), f.db_type()] field_output = [qn(f.name), f.db_type(connection)]
field_output.append("%sNULL" % (not f.null and "NOT " or "")) field_output.append("%sNULL" % (not f.null and "NOT " or ""))
if f.primary_key: if f.primary_key:
field_output.append("PRIMARY KEY") field_output.append("PRIMARY KEY")

View File

@ -15,14 +15,14 @@ class Command(NoArgsCommand):
make_option('--noinput', action='store_false', dest='interactive', default=True, make_option('--noinput', action='store_false', dest='interactive', default=True,
help='Tells Django to NOT prompt the user for input of any kind.'), help='Tells Django to NOT prompt the user for input of any kind.'),
make_option('--database', action='store', dest='database', make_option('--database', action='store', dest='database',
default='', help='Nominates a database to flush. Defaults to ' default=None, help='Nominates a database to flush. Defaults to '
'flushing all databases.'), 'flushing all databases.'),
) )
help = "Executes ``sqlflush`` on the current database." help = "Executes ``sqlflush`` on the current database."
def handle_noargs(self, **options): def handle_noargs(self, **options):
if not options['database']: if not options['database']:
dbs = connections.all() dbs = connections
else: else:
dbs = [options['database']] dbs = [options['database']]
@ -31,57 +31,14 @@ class Command(NoArgsCommand):
self.style = no_style() self.style = no_style()
# Import the 'management' module within each installed app, to register
# dispatcher events.
for app_name in settings.INSTALLED_APPS: for app_name in settings.INSTALLED_APPS:
try: try:
import_module('.management', app_name) import_module('.management', app_name)
except ImportError: except ImportError:
pass pass
sql_list = sql_flush(self.style, connection, only_django=True) for db in dbs:
connection = connections[db]
if interactive:
confirm = raw_input("""You have requested a flush of the database.
This will IRREVERSIBLY DESTROY all data currently in the %r database,
and return each table to the state it was in after syncdb.
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['DATABASE_NAME'])
else:
confirm = 'yes'
if confirm == 'yes':
try:
cursor = connection.cursor()
for sql in sql_list:
cursor.execute(sql)
except Exception, e:
transaction.rollback_unless_managed()
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
* The database isn't running or isn't configured correctly.
* At least one of the expected database tables doesn't exist.
* The SQL was invalid.
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
transaction.commit_unless_managed()
# Emit the post sync signal. This allows individual
# applications to respond as if the database had been
# sync'd from scratch.
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection)
# Reinstall the initial_data fixture.
call_command('loaddata', 'initial_data', **options)
for app_name in settings.INSTALLED_APPS:
try:
import_module('.management', app_name)
except ImportError:
pass
for connection in dbs:
# Import the 'management' module within each installed app, to register # Import the 'management' module within each installed app, to register
# dispatcher events. # dispatcher events.
sql_list = sql_flush(self.style, connection, only_django=True) sql_list = sql_flush(self.style, connection, only_django=True)
@ -102,22 +59,24 @@ class Command(NoArgsCommand):
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)
except Exception, e: except Exception, e:
transaction.rollback_unless_managed() transaction.rollback_unless_managed(using=db)
raise CommandError("""Database %s couldn't be flushed. Possible reasons: raise CommandError("""Database %s couldn't be flushed. Possible reasons:
* The database isn't running or isn't configured correctly. * The database isn't running or isn't configured correctly.
* At least one of the expected database tables doesn't exist. * At least one of the expected database tables doesn't exist.
* The SQL was invalid. * The SQL was invalid.
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run. Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e)) The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
transaction.commit_unless_managed() transaction.commit_unless_managed(using=db)
# Emit the post sync signal. This allows individual # Emit the post sync signal. This allows individual
# applications to respond as if the database had been # applications to respond as if the database had been
# sync'd from scratch. # sync'd from scratch.
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection) emit_post_sync_signal(models.get_models(), verbosity, interactive, db)
# Reinstall the initial_data fixture. # Reinstall the initial_data fixture.
call_command('loaddata', 'initial_data', **options) kwargs = options.copy()
kwargs['database'] = db
call_command('loaddata', 'initial_data', **kwargs)
else: else:
print "Flush cancelled." print "Flush cancelled."

View File

@ -4,8 +4,13 @@ import gzip
import zipfile import zipfile
from optparse import make_option from optparse import make_option
from django.conf import settings
from django.core import serializers
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.core.management.color import no_style from django.core.management.color import no_style
from django.db import connections, transaction, DEFAULT_DB_ALIAS
from django.db.models import get_apps
try: try:
set set
@ -22,12 +27,15 @@ class Command(BaseCommand):
help = 'Installs the named fixture(s) in the database.' help = 'Installs the named fixture(s) in the database.'
args = "fixture [fixture ...]" args = "fixture [fixture ...]"
def handle(self, *fixture_labels, **options): option_list = BaseCommand.option_list + (
from django.db.models import get_apps make_option('--database', action='store', dest='database',
from django.core import serializers default=DEFAULT_DB_ALIAS, help='Nominates a specific database to load '
from django.db import connection, transaction 'fixtures into. By default uses the "default" database.'),
from django.conf import settings )
def handle(self, *fixture_labels, **options):
using = options['database']
connection = connections[using]
self.style = no_style() self.style = no_style()
verbosity = int(options.get('verbosity', 1)) verbosity = int(options.get('verbosity', 1))
@ -56,9 +64,9 @@ class Command(BaseCommand):
# Start transaction management. All fixtures are installed in a # Start transaction management. All fixtures are installed in a
# single transaction to ensure that all references are resolved. # single transaction to ensure that all references are resolved.
if commit: if commit:
transaction.commit_unless_managed() transaction.commit_unless_managed(using=using)
transaction.enter_transaction_management() transaction.enter_transaction_management(using=using)
transaction.managed(True) transaction.managed(True, using=using)
class SingleZipReader(zipfile.ZipFile): class SingleZipReader(zipfile.ZipFile):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -103,8 +111,8 @@ class Command(BaseCommand):
sys.stderr.write( sys.stderr.write(
self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." % self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." %
(fixture_name, format))) (fixture_name, format)))
transaction.rollback() transaction.rollback(using=using)
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
return return
if os.path.isabs(fixture_name): if os.path.isabs(fixture_name):
@ -136,8 +144,8 @@ class Command(BaseCommand):
fixture.close() fixture.close()
print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." % print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." %
(fixture_name, humanize(fixture_dir))) (fixture_name, humanize(fixture_dir)))
transaction.rollback() transaction.rollback(using=using)
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
return return
else: else:
fixture_count += 1 fixture_count += 1
@ -150,7 +158,7 @@ class Command(BaseCommand):
for obj in objects: for obj in objects:
objects_in_fixture += 1 objects_in_fixture += 1
models.add(obj.object.__class__) models.add(obj.object.__class__)
obj.save() obj.save(using=using)
object_count += objects_in_fixture object_count += objects_in_fixture
label_found = True label_found = True
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
@ -158,8 +166,8 @@ class Command(BaseCommand):
except Exception: except Exception:
import traceback import traceback
fixture.close() fixture.close()
transaction.rollback() transaction.rollback(using=using)
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
if show_traceback: if show_traceback:
traceback.print_exc() traceback.print_exc()
else: else:
@ -176,8 +184,8 @@ class Command(BaseCommand):
sys.stderr.write( sys.stderr.write(
self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" % self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" %
(fixture_name))) (fixture_name)))
transaction.rollback() transaction.rollback(using=using)
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
return return
except Exception, e: except Exception, e:
@ -196,8 +204,8 @@ class Command(BaseCommand):
cursor.execute(line) cursor.execute(line)
if commit: if commit:
transaction.commit() transaction.commit(using=using)
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
if object_count == 0: if object_count == 0:
if verbosity > 1: if verbosity > 1:

View File

@ -32,9 +32,9 @@ class Command(NoArgsCommand):
self.style = no_style() self.style = no_style()
if not options['database']: if not options['database']:
dbs = connections.all() dbs = connections
else: else:
dbs = [connections[options['database']]] dbs = [options['database']]
# Import the 'management' module within each installed app, to register # Import the 'management' module within each installed app, to register
# dispatcher events. # dispatcher events.
@ -55,7 +55,8 @@ class Command(NoArgsCommand):
if not msg.startswith('No module named') or 'management' not in msg: if not msg.startswith('No module named') or 'management' not in msg:
raise raise
for connection in dbs: for db in dbs:
connection = connections[db]
cursor = connection.cursor() cursor = connection.cursor()
# Get a list of already installed *models* so that references work right. # Get a list of already installed *models* so that references work right.
@ -102,11 +103,11 @@ class Command(NoArgsCommand):
for statement in sql: for statement in sql:
cursor.execute(statement) cursor.execute(statement)
transaction.commit_unless_managed() transaction.commit_unless_managed(using=db)
# Send the post_syncdb signal, so individual apps can do whatever they need # Send the post_syncdb signal, so individual apps can do whatever they need
# to do at this point. # to do at this point.
emit_post_sync_signal(created_models, verbosity, interactive, connection) emit_post_sync_signal(created_models, verbosity, interactive, db)
# The connection may have been closed by a syncdb handler. # The connection may have been closed by a syncdb handler.
cursor = connection.cursor() cursor = connection.cursor()
@ -130,9 +131,9 @@ class Command(NoArgsCommand):
if show_traceback: if show_traceback:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
transaction.rollback_unless_managed() transaction.rollback_unless_managed(using=db)
else: else:
transaction.commit_unless_managed() transaction.commit_unless_managed(using=db)
else: else:
if verbosity >= 2: if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name) print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
@ -151,13 +152,9 @@ class Command(NoArgsCommand):
except Exception, e: except Exception, e:
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \ sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e)) (app_name, model._meta.object_name, e))
transaction.rollback_unless_managed() transaction.rollback_unless_managed(using=db)
else: else:
transaction.commit_unless_managed() transaction.commit_unless_managed(using=db)
# Install the 'initial_data' fixture, using format discovery
# FIXME we only load the fixture data for one DB right now, since we
# can't control what DB it does into, once we can control this we
# should move it back into the DB loop
from django.core.management import call_command from django.core.management import call_command
call_command('loaddata', 'initial_data', verbosity=verbosity) call_command('loaddata', 'initial_data', verbosity=verbosity, database=db)

View File

@ -188,7 +188,7 @@ def custom_sql_for_model(model, style, connection):
return output return output
def emit_post_sync_signal(created_models, verbosity, interactive, connection): def emit_post_sync_signal(created_models, verbosity, interactive, db):
# Emit the post_sync signal for every application. # Emit the post_sync signal for every application.
for app in models.get_apps(): for app in models.get_apps():
app_name = app.__name__.split('.')[-2] app_name = app.__name__.split('.')[-2]
@ -196,4 +196,4 @@ def emit_post_sync_signal(created_models, verbosity, interactive, connection):
print "Running post-sync handlers for application", app_name print "Running post-sync handlers for application", app_name
models.signals.post_syncdb.send(sender=app, app=app, models.signals.post_syncdb.send(sender=app, app=app,
created_models=created_models, verbosity=verbosity, created_models=created_models, verbosity=verbosity,
interactive=interactive, connection=connection) interactive=interactive, db=db)

View File

@ -154,13 +154,13 @@ class DeserializedObject(object):
def __repr__(self): def __repr__(self):
return "<DeserializedObject: %s>" % smart_str(self.object) return "<DeserializedObject: %s>" % smart_str(self.object)
def save(self, save_m2m=True): def save(self, save_m2m=True, using=None):
# Call save on the Model baseclass directly. This bypasses any # Call save on the Model baseclass directly. This bypasses any
# model-defined save. The save is also forced to be raw. # model-defined save. The save is also forced to be raw.
# This ensures that the data that is deserialized is literally # This ensures that the data that is deserialized is literally
# what came from the file, not post-processed by pre_save/save # what came from the file, not post-processed by pre_save/save
# methods. # methods.
models.Model.save_base(self.object, raw=True) models.Model.save_base(self.object, using=using, raw=True)
if self.m2m_data and save_m2m: if self.m2m_data and save_m2m:
for accessor_name, object_list in self.m2m_data.items(): for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list) setattr(self.object, accessor_name, object_list)

View File

@ -54,12 +54,13 @@ def reset_queries(**kwargs):
connection.queries = [] connection.queries = []
signals.request_started.connect(reset_queries) signals.request_started.connect(reset_queries)
# Register an event that rolls back the connection # Register an event that rolls back the connections
# when a Django request has an exception. # when a Django request has an exception.
def _rollback_on_exception(**kwargs): def _rollback_on_exception(**kwargs):
from django.db import transaction from django.db import transaction
for conn in connections:
try: try:
transaction.rollback_unless_managed() transaction.rollback_unless_managed(using=conn)
except DatabaseError: except DatabaseError:
pass pass
signals.got_request_exception.connect(_rollback_on_exception) signals.got_request_exception.connect(_rollback_on_exception)

View File

@ -90,7 +90,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
super(DatabaseWrapper, self).__init__(*args, **kwargs) super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations(self)
self.client = DatabaseClient(self) self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)

View File

@ -6,14 +6,14 @@ from django.db.backends import BaseDatabaseOperations
# used by both the 'postgresql' and 'postgresql_psycopg2' backends. # used by both the 'postgresql' and 'postgresql_psycopg2' backends.
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
def __init__(self): def __init__(self, connection):
self._postgres_version = None self._postgres_version = None
self.connection = connection
def _get_postgres_version(self): def _get_postgres_version(self):
if self._postgres_version is None: if self._postgres_version is None:
from django.db import connection
from django.db.backends.postgresql.version import get_version from django.db.backends.postgresql.version import get_version
cursor = connection.cursor() cursor = self.connection.cursor()
self._postgres_version = get_version(cursor) self._postgres_version = get_version(cursor)
return self._postgres_version return self._postgres_version
postgres_version = property(_get_postgres_version) postgres_version = property(_get_postgres_version)

View File

@ -66,7 +66,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False) autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False)
self.features.uses_autocommit = autocommit self.features.uses_autocommit = autocommit
self._set_isolation_level(int(not autocommit)) self._set_isolation_level(int(not autocommit))
self.ops = DatabaseOperations() self.ops = DatabaseOperations(self)
self.client = DatabaseClient(self) self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)

View File

@ -15,7 +15,7 @@ from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneF
from django.db.models.query import delete_objects, Q from django.db.models.query import delete_objects, Q
from django.db.models.query_utils import CollectedObjects, DeferredAttribute from django.db.models.query_utils import CollectedObjects, DeferredAttribute
from django.db.models.options import Options from django.db.models.options import Options
from django.db import connection, transaction, DatabaseError from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS
from django.db.models import signals from django.db.models import signals
from django.db.models.loading import register_models, get_model from django.db.models.loading import register_models, get_model
from django.utils.functional import curry from django.utils.functional import curry
@ -395,7 +395,7 @@ class Model(object):
return getattr(self, field_name) return getattr(self, field_name)
return getattr(self, field.attname) return getattr(self, field.attname)
def save(self, force_insert=False, force_update=False): def save(self, force_insert=False, force_update=False, using=None):
""" """
Saves the current instance. Override this in a subclass if you want to Saves the current instance. Override this in a subclass if you want to
control the saving process. control the saving process.
@ -407,18 +407,21 @@ class Model(object):
if force_insert and force_update: if force_insert and force_update:
raise ValueError("Cannot force both insert and updating in " raise ValueError("Cannot force both insert and updating in "
"model saving.") "model saving.")
self.save_base(force_insert=force_insert, force_update=force_update) self.save_base(using=using, force_insert=force_insert, force_update=force_update)
save.alters_data = True save.alters_data = True
def save_base(self, raw=False, cls=None, force_insert=False, def save_base(self, raw=False, cls=None, force_insert=False,
force_update=False): force_update=False, using=None):
""" """
Does the heavy-lifting involved in saving. Subclasses shouldn't need to Does the heavy-lifting involved in saving. Subclasses shouldn't need to
override this method. It's separate from save() in order to hide the override this method. It's separate from save() in order to hide the
need for overrides of save() to pass around internal-only parameters need for overrides of save() to pass around internal-only parameters
('raw' and 'cls'). ('raw' and 'cls').
""" """
if using is None:
using = DEFAULT_DB_ALIAS
connection = connections[using]
assert not (force_insert and force_update) assert not (force_insert and force_update)
if not cls: if not cls:
cls = self.__class__ cls = self.__class__
@ -441,7 +444,7 @@ class Model(object):
if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None: if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
setattr(self, parent._meta.pk.attname, getattr(self, field.attname)) setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
self.save_base(cls=parent) self.save_base(cls=parent, using=using)
if field: if field:
setattr(self, field.attname, self._get_pk_val(parent._meta)) setattr(self, field.attname, self._get_pk_val(parent._meta))
if meta.proxy: if meta.proxy:
@ -457,12 +460,13 @@ class Model(object):
manager = cls._base_manager manager = cls._base_manager
if pk_set: if pk_set:
# Determine whether a record with the primary key already exists. # Determine whether a record with the primary key already exists.
# FIXME work with the using parameter
if (force_update or (not force_insert and if (force_update or (not force_insert and
manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())): manager.using(using).filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
# It does already exist, so do an UPDATE. # It does already exist, so do an UPDATE.
if force_update or non_pks: if force_update or non_pks:
values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks] values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
rows = manager.filter(pk=pk_val)._update(values) rows = manager.using(using).filter(pk=pk_val)._update(values)
if force_update and not rows: if force_update and not rows:
raise DatabaseError("Forced update did not affect any rows.") raise DatabaseError("Forced update did not affect any rows.")
else: else:
@ -477,20 +481,20 @@ class Model(object):
if meta.order_with_respect_to: if meta.order_with_respect_to:
field = meta.order_with_respect_to field = meta.order_with_respect_to
values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count())) values.append((meta.get_field_by_name('_order')[0], manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()))
record_exists = False record_exists = False
update_pk = bool(meta.has_auto_field and not pk_set) update_pk = bool(meta.has_auto_field and not pk_set)
if values: if values:
# Create a new record. # Create a new record.
result = manager._insert(values, return_id=update_pk) result = manager._insert(values, return_id=update_pk, using=using)
else: else:
# Create a new record with defaults for everything. # Create a new record with defaults for everything.
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True) result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
if update_pk: if update_pk:
setattr(self, meta.pk.attname, result) setattr(self, meta.pk.attname, result)
transaction.commit_unless_managed() transaction.commit_unless_managed(using=using)
if signal: if signal:
signals.post_save.send(sender=self.__class__, instance=self, signals.post_save.send(sender=self.__class__, instance=self,
@ -549,7 +553,10 @@ class Model(object):
# delete it and all its descendents. # delete it and all its descendents.
parent_obj._collect_sub_objects(seen_objs) parent_obj._collect_sub_objects(seen_objs)
def delete(self): def delete(self, using=None):
if using is None:
using = DEFAULT_DB_ALIAS
connection = connections[using]
assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname) assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
# Find all the objects than need to be deleted. # Find all the objects than need to be deleted.
@ -557,7 +564,7 @@ class Model(object):
self._collect_sub_objects(seen_objs) self._collect_sub_objects(seen_objs)
# Actually delete the objects. # Actually delete the objects.
delete_objects(seen_objs) delete_objects(seen_objs, using)
delete.alters_data = True delete.alters_data = True
@ -610,7 +617,8 @@ def method_set_order(ordered_obj, self, id_list):
# for situations like this. # for situations like this.
for i, j in enumerate(id_list): for i, j in enumerate(id_list):
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i) ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
transaction.commit_unless_managed() # TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def method_get_order(ordered_obj, self): def method_get_order(ordered_obj, self):

View File

@ -41,8 +41,8 @@ class ExpressionNode(tree.Node):
def prepare(self, evaluator, query, allow_joins): def prepare(self, evaluator, query, allow_joins):
return evaluator.prepare_node(self, query, allow_joins) return evaluator.prepare_node(self, query, allow_joins)
def evaluate(self, evaluator, qn): def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_node(self, qn) return evaluator.evaluate_node(self, qn, connection)
############# #############
# OPERATORS # # OPERATORS #
@ -109,5 +109,5 @@ class F(ExpressionNode):
def prepare(self, evaluator, query, allow_joins): def prepare(self, evaluator, query, allow_joins):
return evaluator.prepare_leaf(self, query, allow_joins) return evaluator.prepare_leaf(self, query, allow_joins)
def evaluate(self, evaluator, qn): def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_leaf(self, qn) return evaluator.evaluate_leaf(self, qn, connection)

View File

@ -233,6 +233,24 @@ class Field(object):
raise TypeError("Field has invalid lookup: %s" % lookup_type) raise TypeError("Field has invalid lookup: %s" % lookup_type)
def validate(self, lookup_type, value):
"""
Validate that the data is valid, as much so as possible without knowing
what connection we are using. Returns True if the value was
successfully validated and false if the value wasn't validated (this
doesn't consider whether the value was actually valid, an exception is
raised in those circumstances).
"""
if hasattr(value, 'validate') or hasattr(value, '_validate'):
if hasattr(value, 'validate'):
value.validate()
else:
value._validate()
return True
if lookup_type == 'isnull':
return True
return False
def has_default(self): def has_default(self):
"Returns a boolean of whether this field has a default value." "Returns a boolean of whether this field has a default value."
return self.default is not NOT_PROVIDED return self.default is not NOT_PROVIDED
@ -360,6 +378,17 @@ class AutoField(Field):
return None return None
return int(value) return int(value)
def validate(self, lookup_type, value):
if super(AutoField, self).validate(lookup_type, value):
return
if value is None or hasattr(value, 'as_sql'):
return
if lookup_type in ('range', 'in'):
for val in value:
int(val)
else:
int(value)
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
assert not cls._meta.has_auto_field, "A model can't have more than one AutoField." assert not cls._meta.has_auto_field, "A model can't have more than one AutoField."
super(AutoField, self).contribute_to_class(cls, name) super(AutoField, self).contribute_to_class(cls, name)
@ -396,6 +425,13 @@ class BooleanField(Field):
value = bool(int(value)) value = bool(int(value))
return super(BooleanField, self).get_db_prep_lookup(lookup_type, value) return super(BooleanField, self).get_db_prep_lookup(lookup_type, value)
def validate(self, lookup_type, value):
if super(BooleanField, self).validate(lookup_type, value):
return
if value in ('1', '0'):
value = int(value)
bool(value)
def get_db_prep_value(self, value): def get_db_prep_value(self, value):
if value is None: if value is None:
return None return None
@ -510,6 +546,20 @@ class DateField(Field):
# Casts dates into the format expected by the backend # Casts dates into the format expected by the backend
return connection.ops.value_to_db_date(self.to_python(value)) return connection.ops.value_to_db_date(self.to_python(value))
def validate(self, lookup_type, value):
if super(DateField, self).validate(lookup_type, value):
return
if value is None:
return
if lookup_type in ('month', 'day', 'year', 'week_day'):
int(value)
return
if lookup_type in ('in', 'range'):
for val in value:
self.to_python(val)
return
self.to_python(value)
def value_to_string(self, obj): def value_to_string(self, obj):
val = self._get_val_from_obj(obj) val = self._get_val_from_obj(obj)
if val is None: if val is None:
@ -754,6 +804,11 @@ class NullBooleanField(Field):
value = bool(int(value)) value = bool(int(value))
return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value) return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value)
def validate(self, lookup_type, value):
if value in ('1', '0'):
value = int(value)
bool(value)
def get_db_prep_value(self, value): def get_db_prep_value(self, value):
if value is None: if value is None:
return None return None

View File

@ -1,4 +1,4 @@
from django.db import connection, transaction from django.db import connection, transaction, DEFAULT_DB_ALIAS
from django.db.backends import util from django.db.backends import util
from django.db.models import signals, get_model from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist
@ -478,7 +478,9 @@ def create_many_related_manager(superclass, through=False):
cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
(self.join_table, source_col_name, target_col_name), (self.join_table, source_col_name, target_col_name),
[self._pk_val, obj_id]) [self._pk_val, obj_id])
transaction.commit_unless_managed() # FIXME, once this isn't in related.py it should conditionally
# use the right DB.
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def _remove_items(self, source_col_name, target_col_name, *objs): def _remove_items(self, source_col_name, target_col_name, *objs):
# source_col_name: the PK colname in join_table for the source object # source_col_name: the PK colname in join_table for the source object
@ -508,7 +510,8 @@ def create_many_related_manager(superclass, through=False):
cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ cursor.execute("DELETE FROM %s WHERE %s = %%s" % \
(self.join_table, source_col_name), (self.join_table, source_col_name),
[self._pk_val]) [self._pk_val])
transaction.commit_unless_managed() # TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
return ManyRelatedManager return ManyRelatedManager

View File

@ -173,6 +173,9 @@ class Manager(object):
def only(self, *args, **kwargs): def only(self, *args, **kwargs):
return self.get_query_set().only(*args, **kwargs) return self.get_query_set().only(*args, **kwargs)
def using(self, *args, **kwargs):
return self.get_query_set().using(*args, **kwargs)
def _insert(self, values, **kwargs): def _insert(self, values, **kwargs):
return insert_query(self.model, values, **kwargs) return insert_query(self.model, values, **kwargs)

View File

@ -7,7 +7,7 @@ try:
except NameError: except NameError:
from sets import Set as set # Python 2.3 fallback from sets import Set as set # Python 2.3 fallback
from django.db import connection, transaction, IntegrityError from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS
from django.db.models.aggregates import Aggregate from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
@ -31,10 +31,13 @@ class QuerySet(object):
""" """
def __init__(self, model=None, query=None): def __init__(self, model=None, query=None):
self.model = model self.model = model
connection = connections[DEFAULT_DB_ALIAS]
self.query = query or sql.Query(self.model, connection) self.query = query or sql.Query(self.model, connection)
self._result_cache = None self._result_cache = None
self._iter = None self._iter = None
self._sticky_filter = False self._sticky_filter = False
self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query
# is provided with a non default connection
######################## ########################
# PYTHON MAGIC METHODS # # PYTHON MAGIC METHODS #
@ -300,12 +303,12 @@ class QuerySet(object):
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults) params.update(defaults)
obj = self.model(**params) obj = self.model(**params)
sid = transaction.savepoint() sid = transaction.savepoint(using=self._using)
obj.save(force_insert=True) obj.save(force_insert=True, using=self._using)
transaction.savepoint_commit(sid) transaction.savepoint_commit(sid, using=self._using)
return obj, True return obj, True
except IntegrityError, e: except IntegrityError, e:
transaction.savepoint_rollback(sid) transaction.savepoint_rollback(sid, using=self._using)
try: try:
return self.get(**kwargs), False return self.get(**kwargs), False
except self.model.DoesNotExist: except self.model.DoesNotExist:
@ -364,7 +367,7 @@ class QuerySet(object):
if not seen_objs: if not seen_objs:
break break
delete_objects(seen_objs) delete_objects(seen_objs, del_query._using)
# Clear the result cache, in case this QuerySet gets reused. # Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None self._result_cache = None
@ -379,20 +382,20 @@ class QuerySet(object):
"Cannot update a query once a slice has been taken." "Cannot update a query once a slice has been taken."
query = self.query.clone(sql.UpdateQuery) query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs) query.add_update_values(kwargs)
if not transaction.is_managed(): if not transaction.is_managed(using=self._using):
transaction.enter_transaction_management() transaction.enter_transaction_management(using=self._using)
forced_managed = True forced_managed = True
else: else:
forced_managed = False forced_managed = False
try: try:
rows = query.execute_sql(None) rows = query.execute_sql(None)
if forced_managed: if forced_managed:
transaction.commit() transaction.commit(using=self._using)
else: else:
transaction.commit_unless_managed() transaction.commit_unless_managed(using=self._using)
finally: finally:
if forced_managed: if forced_managed:
transaction.leave_transaction_management() transaction.leave_transaction_management(using=self._using)
self._result_cache = None self._result_cache = None
return rows return rows
update.alters_data = True update.alters_data = True
@ -616,6 +619,16 @@ class QuerySet(object):
clone.query.add_immediate_loading(fields) clone.query.add_immediate_loading(fields)
return clone return clone
def using(self, alias):
"""
Selects which database this QuerySet should excecute it's query against.
"""
clone = self._clone()
clone._using = alias
connection = connections[alias]
clone.query.set_connection(connection)
return clone
################################### ###################################
# PUBLIC INTROSPECTION ATTRIBUTES # # PUBLIC INTROSPECTION ATTRIBUTES #
################################### ###################################
@ -644,6 +657,7 @@ class QuerySet(object):
if self._sticky_filter: if self._sticky_filter:
query.filter_is_sticky = True query.filter_is_sticky = True
c = klass(model=self.model, query=query) c = klass(model=self.model, query=query)
c._using = self._using
c.__dict__.update(kwargs) c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'): if setup and hasattr(c, '_setup_query'):
c._setup_query() c._setup_query()
@ -700,6 +714,13 @@ class QuerySet(object):
obj = self.values("pk") obj = self.values("pk")
return obj.query.as_nested_sql() return obj.query.as_nested_sql()
def _validate(self):
"""
A normal QuerySet is always valid when used as the RHS of a filter,
since it automatically gets filtered down to 1 field.
"""
pass
# When used as part of a nested query, a queryset will never be an "always # When used as part of a nested query, a queryset will never be an "always
# empty" result. # empty" result.
value_annotation = True value_annotation = True
@ -818,6 +839,17 @@ class ValuesQuerySet(QuerySet):
% self.__class__.__name__) % self.__class__.__name__)
return self._clone().query.as_nested_sql() return self._clone().query.as_nested_sql()
def _validate(self):
"""
Validates that we aren't trying to do a query like
value__in=qs.values('value1', 'value2'), which isn't valid.
"""
if ((self._fields and len(self._fields) > 1) or
(not self._fields and len(self.model._meta.fields) > 1)):
raise TypeError('Cannot use a multi-field %s as a filter value.'
% self.__class__.__name__)
class ValuesListQuerySet(ValuesQuerySet): class ValuesListQuerySet(ValuesQuerySet):
def iterator(self): def iterator(self):
if self.flat and len(self._fields) == 1: if self.flat and len(self._fields) == 1:
@ -970,13 +1002,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
setattr(obj, f.get_cache_name(), rel_obj) setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end return obj, index_end
def delete_objects(seen_objs): def delete_objects(seen_objs, using):
""" """
Iterate through a list of seen classes, and remove any instances that are Iterate through a list of seen classes, and remove any instances that are
referred to. referred to.
""" """
if not transaction.is_managed(): connection = connections[using]
transaction.enter_transaction_management() if not transaction.is_managed(using=using):
transaction.enter_transaction_management(using=using)
forced_managed = True forced_managed = True
else: else:
forced_managed = False forced_managed = False
@ -1036,20 +1069,21 @@ def delete_objects(seen_objs):
setattr(instance, cls._meta.pk.attname, None) setattr(instance, cls._meta.pk.attname, None)
if forced_managed: if forced_managed:
transaction.commit() transaction.commit(using=using)
else: else:
transaction.commit_unless_managed() transaction.commit_unless_managed(using=using)
finally: finally:
if forced_managed: if forced_managed:
transaction.leave_transaction_management() transaction.leave_transaction_management(using=using)
def insert_query(model, values, return_id=False, raw_values=False): def insert_query(model, values, return_id=False, raw_values=False, using=None):
""" """
Inserts a new record for the given model. This provides an interface to Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API. part of the public API.
""" """
connection = connections[using]
query = sql.InsertQuery(model, connection) query = sql.InsertQuery(model, connection)
query.insert_values(values, raw_values) query.insert_values(values, raw_values)
return query.execute_sql(return_id) return query.execute_sql(return_id)

View File

@ -3,25 +3,23 @@ from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql.constants import LOOKUP_SEP from django.db.models.sql.constants import LOOKUP_SEP
class SQLEvaluator(object): class SQLEvaluator(object):
as_sql_takes_connection = True
def __init__(self, expression, query, allow_joins=True): def __init__(self, expression, query, allow_joins=True):
self.expression = expression self.expression = expression
self.opts = query.get_meta() self.opts = query.get_meta()
self.cols = {} self.cols = {}
self.contains_aggregate = False self.contains_aggregate = False
self.connection = query.connection
self.expression.prepare(self, query, allow_joins) self.expression.prepare(self, query, allow_joins)
def as_sql(self, qn=None): def as_sql(self, qn, connection):
return self.expression.evaluate(self, qn) return self.expression.evaluate(self, qn, connection)
def relabel_aliases(self, change_map): def relabel_aliases(self, change_map):
for node, col in self.cols.items(): for node, col in self.cols.items():
self.cols[node] = (change_map.get(col[0], col[0]), col[1]) self.cols[node] = (change_map.get(col[0], col[0]), col[1])
def update_connection(self, connection):
self.connection = connection
##################################################### #####################################################
# Vistor methods for initial expression preparation # # Vistor methods for initial expression preparation #
##################################################### #####################################################
@ -57,15 +55,12 @@ class SQLEvaluator(object):
# Vistor methods for final expression evaluation # # Vistor methods for final expression evaluation #
################################################## ##################################################
def evaluate_node(self, node, qn): def evaluate_node(self, node, qn, connection):
if not qn:
qn = self.connection.ops.quote_name
expressions = [] expressions = []
expression_params = [] expression_params = []
for child in node.children: for child in node.children:
if hasattr(child, 'evaluate'): if hasattr(child, 'evaluate'):
sql, params = child.evaluate(self, qn) sql, params = child.evaluate(self, qn, connection)
else: else:
sql, params = '%s', (child,) sql, params = '%s', (child,)
@ -78,12 +73,9 @@ class SQLEvaluator(object):
expressions.append(format % sql) expressions.append(format % sql)
expression_params.extend(params) expression_params.extend(params)
return self.connection.ops.combine_expression(node.connector, expressions), expression_params return connection.ops.combine_expression(node.connector, expressions), expression_params
def evaluate_leaf(self, node, qn):
if not qn:
qn = self.connection.ops.quote_name
def evaluate_leaf(self, node, qn, connection):
col = self.cols[node] col = self.cols[node]
if hasattr(col, 'as_sql'): if hasattr(col, 'as_sql'):
return col.as_sql(qn), () return col.as_sql(qn), ()

View File

@ -67,10 +67,10 @@ class BaseQuery(object):
# SQL-related attributes # SQL-related attributes
self.select = [] self.select = []
self.tables = [] # Aliases in the order they are created. self.tables = [] # Aliases in the order they are created.
self.where = where(connection=self.connection) self.where = where()
self.where_class = where self.where_class = where
self.group_by = None self.group_by = None
self.having = where(connection=self.connection) self.having = where()
self.order_by = [] self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False self.distinct = False
@ -151,8 +151,6 @@ class BaseQuery(object):
# supported. It's the only class-reference to the module-level # supported. It's the only class-reference to the module-level
# connection variable. # connection variable.
self.connection = connection self.connection = connection
self.where.update_connection(self.connection)
self.having.update_connection(self.connection)
def get_meta(self): def get_meta(self):
""" """
@ -245,8 +243,6 @@ class BaseQuery(object):
obj.used_aliases = set() obj.used_aliases = set()
obj.filter_is_sticky = False obj.filter_is_sticky = False
obj.__dict__.update(kwargs) obj.__dict__.update(kwargs)
obj.where.update_connection(obj.connection) # where and having track their own connection
obj.having.update_connection(obj.connection)# we need to keep this up to date
if hasattr(obj, '_setup_query'): if hasattr(obj, '_setup_query'):
obj._setup_query() obj._setup_query()
return obj return obj
@ -405,8 +401,8 @@ class BaseQuery(object):
from_, f_params = self.get_from_clause() from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias qn = self.quote_name_unless_alias
where, w_params = self.where.as_sql(qn=qn) where, w_params = self.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.having.as_sql(qn=qn) having, h_params = self.having.as_sql(qn=qn, connection=self.connection)
params = [] params = []
for val in self.extra_select.itervalues(): for val in self.extra_select.itervalues():
params.extend(val[1]) params.extend(val[1])
@ -534,10 +530,10 @@ class BaseQuery(object):
self.where.add(EverythingNode(), AND) self.where.add(EverythingNode(), AND)
elif self.where: elif self.where:
# rhs has an empty where clause. # rhs has an empty where clause.
w = self.where_class(connection=self.connection) w = self.where_class()
w.add(EverythingNode(), AND) w.add(EverythingNode(), AND)
else: else:
w = self.where_class(connection=self.connection) w = self.where_class()
self.where.add(w, connector) self.where.add(w, connector)
# Selection columns and extra extensions are those provided by 'rhs'. # Selection columns and extra extensions are those provided by 'rhs'.
@ -1550,7 +1546,7 @@ class BaseQuery(object):
for alias, aggregate in self.aggregates.items(): for alias, aggregate in self.aggregates.items():
if alias == parts[0]: if alias == parts[0]:
entry = self.where_class(connection=self.connection) entry = self.where_class()
entry.add((aggregate, lookup_type, value), AND) entry.add((aggregate, lookup_type, value), AND)
if negate: if negate:
entry.negate() entry.negate()
@ -1618,7 +1614,7 @@ class BaseQuery(object):
for alias in join_list: for alias in join_list:
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
j_col = self.alias_map[alias][RHS_JOIN_COL] j_col = self.alias_map[alias][RHS_JOIN_COL]
entry = self.where_class(connection=self.connection) entry = self.where_class()
entry.add((Constraint(alias, j_col, None), 'isnull', True), AND) entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
entry.negate() entry.negate()
self.where.add(entry, AND) self.where.add(entry, AND)
@ -1627,7 +1623,7 @@ class BaseQuery(object):
# Leaky abstraction artifact: We have to specifically # Leaky abstraction artifact: We have to specifically
# exclude the "foo__in=[]" case from this handling, because # exclude the "foo__in=[]" case from this handling, because
# it's short-circuited in the Where class. # it's short-circuited in the Where class.
entry = self.where_class(connection=self.connection) entry = self.where_class()
entry.add((Constraint(alias, col, None), 'isnull', True), AND) entry.add((Constraint(alias, col, None), 'isnull', True), AND)
entry.negate() entry.negate()
self.where.add(entry, AND) self.where.add(entry, AND)
@ -2337,6 +2333,9 @@ class BaseQuery(object):
self.select = [(select_alias, select_col)] self.select = [(select_alias, select_col)]
self.remove_inherited_models() self.remove_inherited_models()
def set_connection(self, connection):
self.connection = connection
def execute_sql(self, result_type=MULTI): def execute_sql(self, result_type=MULTI):
""" """
Run the query against the database and returns the result(s). The Run the query against the database and returns the result(s). The

View File

@ -24,8 +24,9 @@ class DeleteQuery(Query):
""" """
assert len(self.tables) == 1, \ assert len(self.tables) == 1, \
"Can only delete from one table at a time." "Can only delete from one table at a time."
result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] qn = self.quote_name_unless_alias
where, params = self.where.as_sql() result = ['DELETE FROM %s' % qn(self.tables[0])]
where, params = self.where.as_sql(qn=qn, connection=self.connection)
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
return ' '.join(result), tuple(params) return ' '.join(result), tuple(params)
@ -48,7 +49,7 @@ class DeleteQuery(Query):
for related in cls._meta.get_all_related_many_to_many_objects(): for related in cls._meta.get_all_related_many_to_many_objects():
if not isinstance(related.field, generic.GenericRelation): if not isinstance(related.field, generic.GenericRelation):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection) where = self.where_class()
where.add((Constraint(None, where.add((Constraint(None,
related.field.m2m_reverse_name(), related.field), related.field.m2m_reverse_name(), related.field),
'in', 'in',
@ -57,14 +58,14 @@ class DeleteQuery(Query):
self.do_query(related.field.m2m_db_table(), where) self.do_query(related.field.m2m_db_table(), where)
for f in cls._meta.many_to_many: for f in cls._meta.many_to_many:
w1 = self.where_class(connection=self.connection) w1 = self.where_class()
if isinstance(f, generic.GenericRelation): if isinstance(f, generic.GenericRelation):
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
field = f.rel.to._meta.get_field(f.content_type_field_name) field = f.rel.to._meta.get_field(f.content_type_field_name)
w1.add((Constraint(None, field.column, field), 'exact', w1.add((Constraint(None, field.column, field), 'exact',
ContentType.objects.get_for_model(cls).id), AND) ContentType.objects.get_for_model(cls).id), AND)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection) where = self.where_class()
where.add((Constraint(None, f.m2m_column_name(), f), 'in', where.add((Constraint(None, f.m2m_column_name(), f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND) AND)
@ -81,7 +82,7 @@ class DeleteQuery(Query):
lot of values in pk_list. lot of values in pk_list.
""" """
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection) where = self.where_class()
field = self.model._meta.pk field = self.model._meta.pk
where.add((Constraint(None, field.column, field), 'in', where.add((Constraint(None, field.column, field), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
@ -143,6 +144,9 @@ class UpdateQuery(Query):
values, update_params = [], [] values, update_params = [], []
for name, val, placeholder in self.values: for name, val, placeholder in self.values:
if hasattr(val, 'as_sql'): if hasattr(val, 'as_sql'):
if getattr(val, 'as_sql_takes_connection', False):
sql, params = val.as_sql(qn, self.connection)
else:
sql, params = val.as_sql(qn) sql, params = val.as_sql(qn)
values.append('%s = %s' % (qn(name), sql)) values.append('%s = %s' % (qn(name), sql))
update_params.extend(params) update_params.extend(params)
@ -152,7 +156,7 @@ class UpdateQuery(Query):
else: else:
values.append('%s = NULL' % qn(name)) values.append('%s = NULL' % qn(name))
result.append(', '.join(values)) result.append(', '.join(values))
where, params = self.where.as_sql() where, params = self.where.as_sql(qn=qn, connection=self.connection)
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params) return ' '.join(result), tuple(update_params + params)
@ -185,7 +189,7 @@ class UpdateQuery(Query):
# Now we adjust the current query: reset the where clause and get rid # Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select). # of all the tables we don't need (since they're in the sub-select).
self.where = self.where_class(connection=self.connection) self.where = self.where_class()
if self.related_updates or must_pre_select: if self.related_updates or must_pre_select:
# Either we're using the idents in multiple update queries (so # Either we're using the idents in multiple update queries (so
# don't want them to change), or the db backend doesn't support # don't want them to change), or the db backend doesn't support
@ -209,7 +213,7 @@ class UpdateQuery(Query):
This is used by the QuerySet.delete_objects() method. This is used by the QuerySet.delete_objects() method.
""" """
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class(connection=self.connection) self.where = self.where_class()
f = self.model._meta.pk f = self.model._meta.pk
self.where.add((Constraint(None, f.column, f), 'in', self.where.add((Constraint(None, f.column, f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),

View File

@ -32,18 +32,7 @@ class WhereNode(tree.Node):
relabel_aliases() methods. relabel_aliases() methods.
""" """
default = AND default = AND
as_sql_takes_connection = True
def __init__(self, *args, **kwargs):
self.connection = kwargs.pop('connection', None)
super(WhereNode, self).__init__(*args, **kwargs)
def __getstate__(self):
"""
Don't try to pickle the connection, our Query will restore it for us.
"""
data = self.__dict__.copy()
del data['connection']
return data
def add(self, data, connector): def add(self, data, connector):
""" """
@ -62,20 +51,6 @@ class WhereNode(tree.Node):
# Consume any generators immediately, so that we can determine # Consume any generators immediately, so that we can determine
# emptiness and transform any non-empty values correctly. # emptiness and transform any non-empty values correctly.
value = list(value) value = list(value)
if hasattr(obj, "process"):
try:
# FIXME We're calling process too early, the connection could
# change
obj, params = obj.process(lookup_type, value, self.connection)
except (EmptyShortCircuit, EmptyResultSet):
# There are situations where we want to short-circuit any
# comparisons and make sure that nothing is returned. One
# example is when checking for a NULL pk value, or the
# equivalent.
super(WhereNode, self).add(NothingNode(), connector)
return
else:
params = Field().get_db_prep_lookup(lookup_type, value)
# The "annotation" parameter is used to pass auxilliary information # The "annotation" parameter is used to pass auxilliary information
# about the value(s) to the query construction. Specifically, datetime # about the value(s) to the query construction. Specifically, datetime
@ -88,18 +63,19 @@ class WhereNode(tree.Node):
else: else:
annotation = bool(value) annotation = bool(value)
if hasattr(obj, "process"):
obj.validate(lookup_type, value)
super(WhereNode, self).add((obj, lookup_type, annotation, value),
connector)
return
else:
# TODO: Make this lazy just like the above code for constraints.
params = Field().get_db_prep_lookup(lookup_type, value)
super(WhereNode, self).add((obj, lookup_type, annotation, params), super(WhereNode, self).add((obj, lookup_type, annotation, params),
connector) connector)
def update_connection(self, connection): def as_sql(self, qn, connection):
self.connection = connection
for child in self.children:
if hasattr(child, 'update_connection'):
child.update_connection(connection)
elif hasattr(child[3], 'update_connection'):
child[3].update_connection(connection)
def as_sql(self, qn=None):
""" """
Returns the SQL version of the where clause and the value to be Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty. substituted in. Returns None, None if this node is empty.
@ -108,8 +84,6 @@ class WhereNode(tree.Node):
(generally not needed except by the internal implementation for (generally not needed except by the internal implementation for
recursion). recursion).
""" """
if not qn:
qn = self.connection.ops.quote_name
if not self.children: if not self.children:
return None, [] return None, []
result = [] result = []
@ -118,10 +92,13 @@ class WhereNode(tree.Node):
for child in self.children: for child in self.children:
try: try:
if hasattr(child, 'as_sql'): if hasattr(child, 'as_sql'):
if getattr(child, 'as_sql_takes_connection', False):
sql, params = child.as_sql(qn=qn, connection=connection)
else:
sql, params = child.as_sql(qn=qn) sql, params = child.as_sql(qn=qn)
else: else:
# A leaf node in the tree. # A leaf node in the tree.
sql, params = self.make_atom(child, qn) sql, params = self.make_atom(child, qn, connection)
except EmptyResultSet: except EmptyResultSet:
if self.connector == AND and not self.negated: if self.connector == AND and not self.negated:
@ -157,7 +134,7 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string sql_string = '(%s)' % sql_string
return sql_string, result_params return sql_string, result_params
def make_atom(self, child, qn): def make_atom(self, child, qn, connection):
""" """
Turn a tuple (table_alias, column_name, db_type, lookup_type, Turn a tuple (table_alias, column_name, db_type, lookup_type,
value_annot, params) into valid SQL. value_annot, params) into valid SQL.
@ -165,29 +142,39 @@ class WhereNode(tree.Node):
Returns the string for the SQL fragment and the parameters to use for Returns the string for the SQL fragment and the parameters to use for
it. it.
""" """
lvalue, lookup_type, value_annot, params = child lvalue, lookup_type, value_annot, params_or_value = child
if hasattr(lvalue, 'process'):
try:
lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
except EmptyShortCircuit:
raise EmptyResultSet
else:
params = params_or_value
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.
field_sql = self.sql_for_columns(lvalue, qn) field_sql = self.sql_for_columns(lvalue, qn, connection)
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql = lvalue.as_sql(quote_func=qn) field_sql = lvalue.as_sql(quote_func=qn)
if value_annot is datetime.datetime: if value_annot is datetime.datetime:
cast_sql = self.connection.ops.datetime_cast_sql() cast_sql = connection.ops.datetime_cast_sql()
else: else:
cast_sql = '%s' cast_sql = '%s'
if hasattr(params, 'as_sql'): if hasattr(params, 'as_sql'):
if getattr(params, 'as_sql_takes_connection', False):
extra, params = params.as_sql(qn, connection)
else:
extra, params = params.as_sql(qn) extra, params = params.as_sql(qn)
cast_sql = '' cast_sql = ''
else: else:
extra = '' extra = ''
if lookup_type in self.connection.operators: if lookup_type in connection.operators:
format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),) format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
return (format % (field_sql, return (format % (field_sql,
self.connection.operators[lookup_type] % cast_sql, connection.operators[lookup_type] % cast_sql,
extra), params) extra), params)
if lookup_type == 'in': if lookup_type == 'in':
@ -200,19 +187,19 @@ class WhereNode(tree.Node):
elif lookup_type in ('range', 'year'): elif lookup_type in ('range', 'year'):
return ('%s BETWEEN %%s and %%s' % field_sql, params) return ('%s BETWEEN %%s and %%s' % field_sql, params)
elif lookup_type in ('month', 'day', 'week_day'): elif lookup_type in ('month', 'day', 'week_day'):
return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql), return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
params) params)
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
return ('%s IS %sNULL' % (field_sql, return ('%s IS %sNULL' % (field_sql,
(not value_annot and 'NOT ' or '')), ()) (not value_annot and 'NOT ' or '')), ())
elif lookup_type == 'search': elif lookup_type == 'search':
return (self.connection.ops.fulltext_search_sql(field_sql), params) return (connection.ops.fulltext_search_sql(field_sql), params)
elif lookup_type in ('regex', 'iregex'): elif lookup_type in ('regex', 'iregex'):
return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
raise TypeError('Invalid lookup_type: %r' % lookup_type) raise TypeError('Invalid lookup_type: %r' % lookup_type)
def sql_for_columns(self, data, qn): def sql_for_columns(self, data, qn, connection):
""" """
Returns the SQL fragment used for the left-hand side of a column Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause constraint (for example, the "T1.foo" portion in the clause
@ -223,7 +210,7 @@ class WhereNode(tree.Node):
lhs = '%s.%s' % (qn(table_alias), qn(name)) lhs = '%s.%s' % (qn(table_alias), qn(name))
else: else:
lhs = qn(name) lhs = qn(name)
return self.connection.ops.field_cast_sql(db_type) % lhs return connection.ops.field_cast_sql(db_type) % lhs
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map, node=None):
""" """
@ -299,3 +286,11 @@ class Constraint(object):
raise EmptyShortCircuit raise EmptyShortCircuit
return (self.alias, self.col, db_type), params return (self.alias, self.col, db_type), params
def relabel_aliases(self, change_map):
if self.alias in change_map:
self.alias = change_map[self.alias]
def validate(self, lookup_type, value):
if hasattr(self.field, 'validate'):
self.field.validate(lookup_type, value)

View File

@ -20,7 +20,7 @@ try:
from functools import wraps from functools import wraps
except ImportError: except ImportError:
from django.utils.functional import wraps # Python 2.3, 2.4 fallback. from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
from django.db import connection from django.db import connections
from django.conf import settings from django.conf import settings
class TransactionManagementError(Exception): class TransactionManagementError(Exception):
@ -30,17 +30,20 @@ class TransactionManagementError(Exception):
""" """
pass pass
# The states are dictionaries of lists. The key to the dict is the current # The states are dictionaries of dictionaries of lists. The key to the outer
# thread and the list is handled as a stack of values. # dict is the current thread, and the key to the inner dictionary is the
# connection alias and the list is handled as a stack of values.
state = {} state = {}
savepoint_state = {} savepoint_state = {}
# The dirty flag is set by *_unless_managed functions to denote that the # The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a # code under transaction management has changed things to require a
# database commit. # database commit.
# This is a dictionary mapping thread to a dictionary mapping connection
# alias to a boolean.
dirty = {} dirty = {}
def enter_transaction_management(managed=True): def enter_transaction_management(managed=True, using=None):
""" """
Enters transaction management for a running thread. It must be balanced with Enters transaction management for a running thread. It must be balanced with
the appropriate leave_transaction_management call, since the actual state is the appropriate leave_transaction_management call, since the actual state is
@ -50,166 +53,212 @@ def enter_transaction_management(managed=True):
from the settings, if there is no surrounding block (dirty is always false from the settings, if there is no surrounding block (dirty is always false
when no current block is running). when no current block is running).
""" """
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident]: if thread_ident in state and state[thread_ident].get(using):
state[thread_ident].append(state[thread_ident][-1]) state[thread_ident][using].append(state[thread_ident][using][-1])
else: else:
state[thread_ident] = [] state.setdefault(thread_ident, {})
state[thread_ident].append(settings.TRANSACTIONS_MANAGED) state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED]
if thread_ident not in dirty: if thread_ident not in dirty or using not in dirty[thread_ident]:
dirty[thread_ident] = False dirty.setdefault(thread_ident, {})
dirty[thread_ident][using] = False
connection._enter_transaction_management(managed) connection._enter_transaction_management(managed)
def leave_transaction_management(): def leave_transaction_management(using=None):
""" """
Leaves transaction management for a running thread. A dirty flag is carried Leaves transaction management for a running thread. A dirty flag is carried
over to the surrounding block, as a commit will commit all changes, even over to the surrounding block, as a commit will commit all changes, even
those from outside. (Commits are on connection level.) those from outside. (Commits are on connection level.)
""" """
connection._leave_transaction_management(is_managed()) if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._leave_transaction_management(is_managed(using=using))
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident]: if thread_ident in state and state[thread_ident].get(using):
del state[thread_ident][-1] del state[thread_ident][using][-1]
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
if dirty.get(thread_ident, False): if dirty.get(thread_ident, {}).get(using, False):
rollback() rollback(using=using)
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK") raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK")
dirty[thread_ident] = False dirty[thread_ident][using] = False
def is_dirty(): def is_dirty(using=None):
""" """
Returns True if the current transaction requires a commit for changes to Returns True if the current transaction requires a commit for changes to
happen. happen.
""" """
return dirty.get(thread.get_ident(), False) if using is None:
raise ValueError # TODO use default
return dirty.get(thread.get_ident(), {}).get(using, False)
def set_dirty(): def set_dirty(using=None):
""" """
Sets a dirty flag for the current thread and code streak. This can be used Sets a dirty flag for the current thread and code streak. This can be used
to decide in a managed block of code to decide whether there are open to decide in a managed block of code to decide whether there are open
changes waiting for commit. changes waiting for commit.
""" """
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in dirty: if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident] = True dirty[thread_ident][using] = True
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
def set_clean(): def set_clean(using=None):
""" """
Resets a dirty flag for the current thread and code streak. This can be used Resets a dirty flag for the current thread and code streak. This can be used
to decide in a managed block of code to decide whether a commit or rollback to decide in a managed block of code to decide whether a commit or rollback
should happen. should happen.
""" """
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in dirty: if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident] = False dirty[thread_ident][using] = False
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
clean_savepoints() clean_savepoints(using=using)
def clean_savepoints(): def clean_savepoints(using=None):
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in savepoint_state: if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
del savepoint_state[thread_ident] del savepoint_state[thread_ident][using]
def is_managed(): def is_managed(using=None):
""" """
Checks whether the transaction manager is in manual or in auto state. Checks whether the transaction manager is in manual or in auto state.
""" """
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in state: if thread_ident in state and using in state[thread_ident]:
if state[thread_ident]: if state[thread_ident][using]:
return state[thread_ident][-1] return state[thread_ident][using][-1]
return settings.TRANSACTIONS_MANAGED return settings.TRANSACTIONS_MANAGED
def managed(flag=True): def managed(flag=True, using=None):
""" """
Puts the transaction manager into a manual state: managed transactions have Puts the transaction manager into a manual state: managed transactions have
to be committed explicitly by the user. If you switch off transaction to be committed explicitly by the user. If you switch off transaction
management and there is a pending commit/rollback, the data will be management and there is a pending commit/rollback, the data will be
commited. commited.
""" """
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
top = state.get(thread_ident, None) top = state.get(thread_ident, {}).get(using, None)
if top: if top:
top[-1] = flag top[-1] = flag
if not flag and is_dirty(): if not flag and is_dirty(using=using):
connection._commit() connection._commit()
set_clean() set_clean(using=using)
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
def commit_unless_managed(): def commit_unless_managed(using=None):
""" """
Commits changes if the system is not in managed transaction mode. Commits changes if the system is not in managed transaction mode.
""" """
if not is_managed(): if using is None:
raise ValueError # TODO use default
connection = connections[using]
if not is_managed(using=using):
connection._commit() connection._commit()
clean_savepoints() clean_savepoints(using=using)
else: else:
set_dirty() set_dirty(using=using)
def rollback_unless_managed(): def rollback_unless_managed(using=None):
""" """
Rolls back changes if the system is not in managed transaction mode. Rolls back changes if the system is not in managed transaction mode.
""" """
if not is_managed(): if using is None:
raise ValueError # TODO use default
connection = connections[using]
if not is_managed(using=using):
connection._rollback() connection._rollback()
else: else:
set_dirty() set_dirty(using=using)
def commit(): def commit(using=None):
""" """
Does the commit itself and resets the dirty flag. Does the commit itself and resets the dirty flag.
""" """
if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._commit() connection._commit()
set_clean() set_clean(using=using)
def rollback(): def rollback(using=None):
""" """
This function does the rollback itself and resets the dirty flag. This function does the rollback itself and resets the dirty flag.
""" """
if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._rollback() connection._rollback()
set_clean() set_clean(using=using)
def savepoint(): def savepoint(using=None):
""" """
Creates a savepoint (if supported and required by the backend) inside the Creates a savepoint (if supported and required by the backend) inside the
current transaction. Returns an identifier for the savepoint that will be current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit. used for the subsequent rollback or commit.
""" """
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in savepoint_state: if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
savepoint_state[thread_ident].append(None) savepoint_state[thread_ident][using].append(None)
else: else:
savepoint_state[thread_ident] = [None] savepoint_state.setdefault(thread_ident, {})
savepoint_state[thread_ident][using] = [None]
tid = str(thread_ident).replace('-', '') tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident])) sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using]))
connection._savepoint(sid) connection._savepoint(sid)
return sid return sid
def savepoint_rollback(sid): def savepoint_rollback(sid, using=None):
""" """
Rolls back the most recent savepoint (if one exists). Does nothing if Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
if thread.get_ident() in savepoint_state: if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
connection._savepoint_rollback(sid) connection._savepoint_rollback(sid)
def savepoint_commit(sid): def savepoint_commit(sid, using=None):
""" """
Commits the most recent savepoint (if one exists). Does nothing if Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
if thread.get_ident() in savepoint_state: if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
connection._savepoint_commit(sid) connection._savepoint_commit(sid)
############## ##############
# DECORATORS # # DECORATORS #
############## ##############
# TODO update all of these for multi-db
def autocommit(func): def autocommit(func):
""" """
Decorator that activates commit on save. This is Django's default behavior; Decorator that activates commit on save. This is Django's default behavior;

View File

@ -201,7 +201,8 @@ class DocTestRunner(doctest.DocTestRunner):
example, exc_info) example, exc_info)
# Rollback, in case of database errors. Otherwise they'd have # Rollback, in case of database errors. Otherwise they'd have
# side effects on other tests. # side effects on other tests.
transaction.rollback_unless_managed() for conn in connections:
transaction.rollback_unless_managed(using=conn)
class TransactionTestCase(unittest.TestCase): class TransactionTestCase(unittest.TestCase):
def _pre_setup(self): def _pre_setup(self):
@ -446,8 +447,9 @@ class TestCase(TransactionTestCase):
if not connections_support_transactions(): if not connections_support_transactions():
return super(TestCase, self)._fixture_setup() return super(TestCase, self)._fixture_setup()
transaction.enter_transaction_management() for conn in connections:
transaction.managed(True) transaction.enter_transaction_management(using=conn)
transaction.managed(True, using=conn)
disable_transaction_methods() disable_transaction_methods()
from django.contrib.sites.models import Site from django.contrib.sites.models import Site
@ -464,7 +466,8 @@ class TestCase(TransactionTestCase):
return super(TestCase, self)._fixture_teardown() return super(TestCase, self)._fixture_teardown()
restore_transaction_methods() restore_transaction_methods()
transaction.rollback() for conn in connections:
transaction.leave_transaction_management() transaction.rollback(using=conn)
transaction.leave_transaction_management(using=conn)
for connection in connections.all(): for connection in connections.all():
connection.close() connection.close()

View File

@ -15,11 +15,11 @@ class RevisionableModel(models.Model):
def __unicode__(self): def __unicode__(self):
return u"%s (%s, %s)" % (self.title, self.id, self.base.id) return u"%s (%s, %s)" % (self.title, self.id, self.base.id)
def save(self, force_insert=False, force_update=False): def save(self, *args, **kwargs):
super(RevisionableModel, self).save(force_insert, force_update) super(RevisionableModel, self).save(*args, **kwargs)
if not self.base: if not self.base:
self.base = self self.base = self
super(RevisionableModel, self).save() super(RevisionableModel, self).save(using=kwargs.pop('using', None))
def new_revision(self): def new_revision(self):
new_revision = copy.copy(self) new_revision = copy.copy(self)