1
0
mirror of https://github.com/django/django.git synced 2025-07-05 10:19:20 +00:00

[soc2009/multidb] Bring this branch up to date with my external work. This means implementing the using method on querysets as well as a using kwarg on save and delete, plus many internal changes to facilitae this

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@10904 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-06-03 02:37:33 +00:00
parent f4bcbbfa8b
commit 23da5c0ac1
29 changed files with 447 additions and 349 deletions

View File

@ -4,6 +4,11 @@ TODO
The follow is a list, more or less in the order I intend to do them of things
that need to be done. I'm trying to be as granular as possible.
***
Immediate TODOs:
Finish refactor of WhereNode so that it just takes connection in as_sql.
***
2) Update all old references to ``settings.DATABASE_*`` to reference
``settings.DATABASES``. This includes the following locations
@ -20,29 +25,21 @@ that need to be done. I'm trying to be as granular as possible.
* ``dumpdata``: By default dump the ``default`` database. Later add a
``--database`` flag.
* ``loaddata``: Leave as is, will use the standard mechanisms for
determining what DB to save the objects to. Should it get a
``--database`` flag to overide that?
These items will be fixed pending both community consensus, and the API
that will go in that's actually necessary for these to happen. Due to
internal APIs loaddata probably will need an update to load stuff into a
specific DB.
that will go in that's actually necessary for these to happen.
4) Rig up the test harness to work with multiple databases. This includes:
flush, reset, and syncdb need to not prompt the user multiple times.
* The current strategy is to test on N dbs, where N is however many the
user defines and ensuring the data all stays seperate and no exceptions
are raised. Practically speaking this means we're only going to have
good coverage if we write a lot of tests that can break. That's life.
9) Fix transaction support. In practice this means changing all the
dictionaries that currently map thread to a boolean to being a dictionary
mapping thread to a set of the affected DBs, and changing all the functions
that use these dictionaries to handle the changes appropriately.
7) Remove any references to the global ``django.db.connection`` object in the
SQL creation process. This includes(but is probably not limited to):
* The way we create ``Query`` from ``BaseQuery`` is awkward and hacky.
* ``django.db.models.query.delete_objects``
* ``django.db.models.query.insert_query``
* ``django.db.models.base.Model`` -- in ``save_base``
* ``django.db.models.fields.Field`` This uses it, as do it's subclasses.
* ``django.db.models.fields.related`` It's used all over the place here,
including opening a cursor and executing queries, so that's going to
@ -50,13 +47,9 @@ that need to be done. I'm trying to be as granular as possible.
raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes
it in before I need to tackle this.
5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at
each stage from here on out).
5) Implement using kwarg on save() method.
6) Add the ``using`` method to ``QuerySet``. This will more or less "just
work" across multiple databases that use the same backend. However, it
will fail gratuitously when trying to use 2 different backends.
8) Implement some way to create a new ``Query`` for a different backend when
we switch. There are several checks against ``self.connection`` prior to
SQL construction, so we either need to defer all these(which will be
@ -80,8 +73,4 @@ that need to be done. I'm trying to be as granular as possible.
every single one of them, then when it's time excecute the query just
pick the right ``Query`` object to use. This *does* not scale, though it
could probably be done fairly easily.
9) Fix transaction support. In practice this means changing all the
dictionaries that currently map thread to a boolean to being a dictionary
mapping thread to a set of the affected DBs, and changing all the functions
that use these dictionaries to handle the changes appropriately.
10) Time permitting add support for a ``DatabaseManager``.

View File

@ -151,7 +151,7 @@ class GenericRelation(RelatedField, Field):
def get_internal_type(self):
return "ManyToManyField"
def db_type(self):
def db_type(self, connection):
# Since we're simulating a ManyToManyField, in effect, best return the
# same db_type as well.
return None

View File

@ -1,12 +1,11 @@
from django.conf import settings
from django.db.models.fields import Field
class USStateField(Field):
def get_internal_type(self):
return "USStateField"
def db_type(self):
if settings.DATABASE_ENGINE == 'oracle':
def db_type(self, connection):
if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
return 'CHAR(2)'
else:
return 'varchar(2)'
@ -21,8 +20,8 @@ class PhoneNumberField(Field):
def get_internal_type(self):
return "PhoneNumberField"
def db_type(self):
if settings.DATABASE_ENGINE == 'oracle':
def db_type(self, connection):
if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
return 'VARCHAR2(20)'
else:
return 'varchar(20)'
@ -32,4 +31,3 @@ class PhoneNumberField(Field):
defaults = {'form_class': USPhoneNumberField}
defaults.update(kwargs)
return super(PhoneNumberField, self).formfield(**defaults)

View File

@ -2,7 +2,7 @@ import datetime
from django.contrib.sessions.models import Session
from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation
from django.db import IntegrityError, transaction
from django.db import IntegrityError, transaction, DEFAULT_DB_ALIAS
from django.utils.encoding import force_unicode
class SessionStore(SessionBase):
@ -53,12 +53,13 @@ class SessionStore(SessionBase):
session_data = self.encode(self._get_session(no_load=must_create)),
expire_date = self.get_expiry_date()
)
sid = transaction.savepoint()
# TODO update for multidb
sid = transaction.savepoint(using=DEFAULT_DB_ALIAS)
try:
obj.save(force_insert=must_create)
except IntegrityError:
if must_create:
transaction.savepoint_rollback(sid)
transaction.savepoint_rollback(sid, using=DEFAULT_DB_ALIAS)
raise CreateError
raise

View File

@ -6,12 +6,12 @@ from django.db.models import signals
from django.contrib.sites.models import Site
from django.contrib.sites import models as site_app
def create_default_site(app, created_models, verbosity, **kwargs):
def create_default_site(app, created_models, verbosity, db, **kwargs):
if Site in created_models:
if verbosity >= 2:
print "Creating example.com Site object"
s = Site(domain="example.com", name="example.com")
s.save()
s.save(using=db)
Site.objects.clear_cache()
signals.post_syncdb.connect(create_default_site, sender=site_app)

View File

@ -27,9 +27,9 @@ class Command(LabelCommand):
)
table_output = []
index_output = []
qn = connections.ops.quote_name
qn = connection.ops.quote_name
for f in fields:
field_output = [qn(f.name), f.db_type()]
field_output = [qn(f.name), f.db_type(connection)]
field_output.append("%sNULL" % (not f.null and "NOT " or ""))
if f.primary_key:
field_output.append("PRIMARY KEY")

View File

@ -15,14 +15,14 @@ class Command(NoArgsCommand):
make_option('--noinput', action='store_false', dest='interactive', default=True,
help='Tells Django to NOT prompt the user for input of any kind.'),
make_option('--database', action='store', dest='database',
default='', help='Nominates a database to flush. Defaults to '
default=None, help='Nominates a database to flush. Defaults to '
'flushing all databases.'),
)
help = "Executes ``sqlflush`` on the current database."
def handle_noargs(self, **options):
if not options['database']:
dbs = connections.all()
dbs = connections
else:
dbs = [options['database']]
@ -31,57 +31,14 @@ class Command(NoArgsCommand):
self.style = no_style()
# Import the 'management' module within each installed app, to register
# dispatcher events.
for app_name in settings.INSTALLED_APPS:
try:
import_module('.management', app_name)
except ImportError:
pass
sql_list = sql_flush(self.style, connection, only_django=True)
if interactive:
confirm = raw_input("""You have requested a flush of the database.
This will IRREVERSIBLY DESTROY all data currently in the %r database,
and return each table to the state it was in after syncdb.
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['DATABASE_NAME'])
else:
confirm = 'yes'
if confirm == 'yes':
try:
cursor = connection.cursor()
for sql in sql_list:
cursor.execute(sql)
except Exception, e:
transaction.rollback_unless_managed()
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
* The database isn't running or isn't configured correctly.
* At least one of the expected database tables doesn't exist.
* The SQL was invalid.
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
transaction.commit_unless_managed()
# Emit the post sync signal. This allows individual
# applications to respond as if the database had been
# sync'd from scratch.
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection)
# Reinstall the initial_data fixture.
call_command('loaddata', 'initial_data', **options)
for app_name in settings.INSTALLED_APPS:
try:
import_module('.management', app_name)
except ImportError:
pass
for connection in dbs:
for db in dbs:
connection = connections[db]
# Import the 'management' module within each installed app, to register
# dispatcher events.
sql_list = sql_flush(self.style, connection, only_django=True)
@ -102,22 +59,24 @@ class Command(NoArgsCommand):
for sql in sql_list:
cursor.execute(sql)
except Exception, e:
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=db)
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
* The database isn't running or isn't configured correctly.
* At least one of the expected database tables doesn't exist.
* The SQL was invalid.
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
# Emit the post sync signal. This allows individual
# applications to respond as if the database had been
# sync'd from scratch.
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection)
emit_post_sync_signal(models.get_models(), verbosity, interactive, db)
# Reinstall the initial_data fixture.
call_command('loaddata', 'initial_data', **options)
kwargs = options.copy()
kwargs['database'] = db
call_command('loaddata', 'initial_data', **kwargs)
else:
print "Flush cancelled."

View File

@ -4,8 +4,13 @@ import gzip
import zipfile
from optparse import make_option
from django.conf import settings
from django.core import serializers
from django.core.management.base import BaseCommand
from django.core.management.color import no_style
from django.db import connections, transaction, DEFAULT_DB_ALIAS
from django.db.models import get_apps
try:
set
@ -22,12 +27,15 @@ class Command(BaseCommand):
help = 'Installs the named fixture(s) in the database.'
args = "fixture [fixture ...]"
def handle(self, *fixture_labels, **options):
from django.db.models import get_apps
from django.core import serializers
from django.db import connection, transaction
from django.conf import settings
option_list = BaseCommand.option_list + (
make_option('--database', action='store', dest='database',
default=DEFAULT_DB_ALIAS, help='Nominates a specific database to load '
'fixtures into. By default uses the "default" database.'),
)
def handle(self, *fixture_labels, **options):
using = options['database']
connection = connections[using]
self.style = no_style()
verbosity = int(options.get('verbosity', 1))
@ -56,9 +64,9 @@ class Command(BaseCommand):
# Start transaction management. All fixtures are installed in a
# single transaction to ensure that all references are resolved.
if commit:
transaction.commit_unless_managed()
transaction.enter_transaction_management()
transaction.managed(True)
transaction.commit_unless_managed(using=using)
transaction.enter_transaction_management(using=using)
transaction.managed(True, using=using)
class SingleZipReader(zipfile.ZipFile):
def __init__(self, *args, **kwargs):
@ -103,8 +111,8 @@ class Command(BaseCommand):
sys.stderr.write(
self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." %
(fixture_name, format)))
transaction.rollback()
transaction.leave_transaction_management()
transaction.rollback(using=using)
transaction.leave_transaction_management(using=using)
return
if os.path.isabs(fixture_name):
@ -136,8 +144,8 @@ class Command(BaseCommand):
fixture.close()
print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." %
(fixture_name, humanize(fixture_dir)))
transaction.rollback()
transaction.leave_transaction_management()
transaction.rollback(using=using)
transaction.leave_transaction_management(using=using)
return
else:
fixture_count += 1
@ -150,7 +158,7 @@ class Command(BaseCommand):
for obj in objects:
objects_in_fixture += 1
models.add(obj.object.__class__)
obj.save()
obj.save(using=using)
object_count += objects_in_fixture
label_found = True
except (SystemExit, KeyboardInterrupt):
@ -158,8 +166,8 @@ class Command(BaseCommand):
except Exception:
import traceback
fixture.close()
transaction.rollback()
transaction.leave_transaction_management()
transaction.rollback(using=using)
transaction.leave_transaction_management(using=using)
if show_traceback:
traceback.print_exc()
else:
@ -176,8 +184,8 @@ class Command(BaseCommand):
sys.stderr.write(
self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" %
(fixture_name)))
transaction.rollback()
transaction.leave_transaction_management()
transaction.rollback(using=using)
transaction.leave_transaction_management(using=using)
return
except Exception, e:
@ -196,8 +204,8 @@ class Command(BaseCommand):
cursor.execute(line)
if commit:
transaction.commit()
transaction.leave_transaction_management()
transaction.commit(using=using)
transaction.leave_transaction_management(using=using)
if object_count == 0:
if verbosity > 1:

View File

@ -32,9 +32,9 @@ class Command(NoArgsCommand):
self.style = no_style()
if not options['database']:
dbs = connections.all()
dbs = connections
else:
dbs = [connections[options['database']]]
dbs = [options['database']]
# Import the 'management' module within each installed app, to register
# dispatcher events.
@ -55,7 +55,8 @@ class Command(NoArgsCommand):
if not msg.startswith('No module named') or 'management' not in msg:
raise
for connection in dbs:
for db in dbs:
connection = connections[db]
cursor = connection.cursor()
# Get a list of already installed *models* so that references work right.
@ -102,11 +103,11 @@ class Command(NoArgsCommand):
for statement in sql:
cursor.execute(statement)
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
# Send the post_syncdb signal, so individual apps can do whatever they need
# to do at this point.
emit_post_sync_signal(created_models, verbosity, interactive, connection)
emit_post_sync_signal(created_models, verbosity, interactive, db)
# The connection may have been closed by a syncdb handler.
cursor = connection.cursor()
@ -130,9 +131,9 @@ class Command(NoArgsCommand):
if show_traceback:
import traceback
traceback.print_exc()
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=db)
else:
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
else:
if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
@ -151,13 +152,9 @@ class Command(NoArgsCommand):
except Exception, e:
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e))
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=db)
else:
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=db)
# Install the 'initial_data' fixture, using format discovery
# FIXME we only load the fixture data for one DB right now, since we
# can't control what DB it does into, once we can control this we
# should move it back into the DB loop
from django.core.management import call_command
call_command('loaddata', 'initial_data', verbosity=verbosity)
call_command('loaddata', 'initial_data', verbosity=verbosity, database=db)

View File

@ -188,7 +188,7 @@ def custom_sql_for_model(model, style, connection):
return output
def emit_post_sync_signal(created_models, verbosity, interactive, connection):
def emit_post_sync_signal(created_models, verbosity, interactive, db):
# Emit the post_sync signal for every application.
for app in models.get_apps():
app_name = app.__name__.split('.')[-2]
@ -196,4 +196,4 @@ def emit_post_sync_signal(created_models, verbosity, interactive, connection):
print "Running post-sync handlers for application", app_name
models.signals.post_syncdb.send(sender=app, app=app,
created_models=created_models, verbosity=verbosity,
interactive=interactive, connection=connection)
interactive=interactive, db=db)

View File

@ -154,13 +154,13 @@ class DeserializedObject(object):
def __repr__(self):
return "<DeserializedObject: %s>" % smart_str(self.object)
def save(self, save_m2m=True):
def save(self, save_m2m=True, using=None):
# Call save on the Model baseclass directly. This bypasses any
# model-defined save. The save is also forced to be raw.
# This ensures that the data that is deserialized is literally
# what came from the file, not post-processed by pre_save/save
# methods.
models.Model.save_base(self.object, raw=True)
models.Model.save_base(self.object, using=using, raw=True)
if self.m2m_data and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)

View File

@ -54,12 +54,13 @@ def reset_queries(**kwargs):
connection.queries = []
signals.request_started.connect(reset_queries)
# Register an event that rolls back the connection
# Register an event that rolls back the connections
# when a Django request has an exception.
def _rollback_on_exception(**kwargs):
from django.db import transaction
for conn in connections:
try:
transaction.rollback_unless_managed()
transaction.rollback_unless_managed(using=conn)
except DatabaseError:
pass
signals.got_request_exception.connect(_rollback_on_exception)

View File

@ -90,7 +90,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.ops = DatabaseOperations(self)
self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)

View File

@ -6,14 +6,14 @@ from django.db.backends import BaseDatabaseOperations
# used by both the 'postgresql' and 'postgresql_psycopg2' backends.
class DatabaseOperations(BaseDatabaseOperations):
def __init__(self):
def __init__(self, connection):
self._postgres_version = None
self.connection = connection
def _get_postgres_version(self):
if self._postgres_version is None:
from django.db import connection
from django.db.backends.postgresql.version import get_version
cursor = connection.cursor()
cursor = self.connection.cursor()
self._postgres_version = get_version(cursor)
return self._postgres_version
postgres_version = property(_get_postgres_version)

View File

@ -66,7 +66,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False)
self.features.uses_autocommit = autocommit
self._set_isolation_level(int(not autocommit))
self.ops = DatabaseOperations()
self.ops = DatabaseOperations(self)
self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)

View File

@ -15,7 +15,7 @@ from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneF
from django.db.models.query import delete_objects, Q
from django.db.models.query_utils import CollectedObjects, DeferredAttribute
from django.db.models.options import Options
from django.db import connection, transaction, DatabaseError
from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.loading import register_models, get_model
from django.utils.functional import curry
@ -395,7 +395,7 @@ class Model(object):
return getattr(self, field_name)
return getattr(self, field.attname)
def save(self, force_insert=False, force_update=False):
def save(self, force_insert=False, force_update=False, using=None):
"""
Saves the current instance. Override this in a subclass if you want to
control the saving process.
@ -407,18 +407,21 @@ class Model(object):
if force_insert and force_update:
raise ValueError("Cannot force both insert and updating in "
"model saving.")
self.save_base(force_insert=force_insert, force_update=force_update)
self.save_base(using=using, force_insert=force_insert, force_update=force_update)
save.alters_data = True
def save_base(self, raw=False, cls=None, force_insert=False,
force_update=False):
force_update=False, using=None):
"""
Does the heavy-lifting involved in saving. Subclasses shouldn't need to
override this method. It's separate from save() in order to hide the
need for overrides of save() to pass around internal-only parameters
('raw' and 'cls').
"""
if using is None:
using = DEFAULT_DB_ALIAS
connection = connections[using]
assert not (force_insert and force_update)
if not cls:
cls = self.__class__
@ -441,7 +444,7 @@ class Model(object):
if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
self.save_base(cls=parent)
self.save_base(cls=parent, using=using)
if field:
setattr(self, field.attname, self._get_pk_val(parent._meta))
if meta.proxy:
@ -457,12 +460,13 @@ class Model(object):
manager = cls._base_manager
if pk_set:
# Determine whether a record with the primary key already exists.
# FIXME work with the using parameter
if (force_update or (not force_insert and
manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
manager.using(using).filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
# It does already exist, so do an UPDATE.
if force_update or non_pks:
values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
rows = manager.filter(pk=pk_val)._update(values)
rows = manager.using(using).filter(pk=pk_val)._update(values)
if force_update and not rows:
raise DatabaseError("Forced update did not affect any rows.")
else:
@ -477,20 +481,20 @@ class Model(object):
if meta.order_with_respect_to:
field = meta.order_with_respect_to
values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count()))
values.append((meta.get_field_by_name('_order')[0], manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()))
record_exists = False
update_pk = bool(meta.has_auto_field and not pk_set)
if values:
# Create a new record.
result = manager._insert(values, return_id=update_pk)
result = manager._insert(values, return_id=update_pk, using=using)
else:
# Create a new record with defaults for everything.
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True)
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
if update_pk:
setattr(self, meta.pk.attname, result)
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=using)
if signal:
signals.post_save.send(sender=self.__class__, instance=self,
@ -549,7 +553,10 @@ class Model(object):
# delete it and all its descendents.
parent_obj._collect_sub_objects(seen_objs)
def delete(self):
def delete(self, using=None):
if using is None:
using = DEFAULT_DB_ALIAS
connection = connections[using]
assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
# Find all the objects than need to be deleted.
@ -557,7 +564,7 @@ class Model(object):
self._collect_sub_objects(seen_objs)
# Actually delete the objects.
delete_objects(seen_objs)
delete_objects(seen_objs, using)
delete.alters_data = True
@ -610,7 +617,8 @@ def method_set_order(ordered_obj, self, id_list):
# for situations like this.
for i, j in enumerate(id_list):
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
transaction.commit_unless_managed()
# TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def method_get_order(ordered_obj, self):

View File

@ -41,8 +41,8 @@ class ExpressionNode(tree.Node):
def prepare(self, evaluator, query, allow_joins):
return evaluator.prepare_node(self, query, allow_joins)
def evaluate(self, evaluator, qn):
return evaluator.evaluate_node(self, qn)
def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_node(self, qn, connection)
#############
# OPERATORS #
@ -109,5 +109,5 @@ class F(ExpressionNode):
def prepare(self, evaluator, query, allow_joins):
return evaluator.prepare_leaf(self, query, allow_joins)
def evaluate(self, evaluator, qn):
return evaluator.evaluate_leaf(self, qn)
def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_leaf(self, qn, connection)

View File

@ -233,6 +233,24 @@ class Field(object):
raise TypeError("Field has invalid lookup: %s" % lookup_type)
def validate(self, lookup_type, value):
"""
Validate that the data is valid, as much so as possible without knowing
what connection we are using. Returns True if the value was
successfully validated and false if the value wasn't validated (this
doesn't consider whether the value was actually valid, an exception is
raised in those circumstances).
"""
if hasattr(value, 'validate') or hasattr(value, '_validate'):
if hasattr(value, 'validate'):
value.validate()
else:
value._validate()
return True
if lookup_type == 'isnull':
return True
return False
def has_default(self):
"Returns a boolean of whether this field has a default value."
return self.default is not NOT_PROVIDED
@ -360,6 +378,17 @@ class AutoField(Field):
return None
return int(value)
def validate(self, lookup_type, value):
if super(AutoField, self).validate(lookup_type, value):
return
if value is None or hasattr(value, 'as_sql'):
return
if lookup_type in ('range', 'in'):
for val in value:
int(val)
else:
int(value)
def contribute_to_class(self, cls, name):
assert not cls._meta.has_auto_field, "A model can't have more than one AutoField."
super(AutoField, self).contribute_to_class(cls, name)
@ -396,6 +425,13 @@ class BooleanField(Field):
value = bool(int(value))
return super(BooleanField, self).get_db_prep_lookup(lookup_type, value)
def validate(self, lookup_type, value):
if super(BooleanField, self).validate(lookup_type, value):
return
if value in ('1', '0'):
value = int(value)
bool(value)
def get_db_prep_value(self, value):
if value is None:
return None
@ -510,6 +546,20 @@ class DateField(Field):
# Casts dates into the format expected by the backend
return connection.ops.value_to_db_date(self.to_python(value))
def validate(self, lookup_type, value):
if super(DateField, self).validate(lookup_type, value):
return
if value is None:
return
if lookup_type in ('month', 'day', 'year', 'week_day'):
int(value)
return
if lookup_type in ('in', 'range'):
for val in value:
self.to_python(val)
return
self.to_python(value)
def value_to_string(self, obj):
val = self._get_val_from_obj(obj)
if val is None:
@ -754,6 +804,11 @@ class NullBooleanField(Field):
value = bool(int(value))
return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value)
def validate(self, lookup_type, value):
if value in ('1', '0'):
value = int(value)
bool(value)
def get_db_prep_value(self, value):
if value is None:
return None

View File

@ -1,4 +1,4 @@
from django.db import connection, transaction
from django.db import connection, transaction, DEFAULT_DB_ALIAS
from django.db.backends import util
from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist
@ -478,7 +478,9 @@ def create_many_related_manager(superclass, through=False):
cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
(self.join_table, source_col_name, target_col_name),
[self._pk_val, obj_id])
transaction.commit_unless_managed()
# FIXME, once this isn't in related.py it should conditionally
# use the right DB.
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def _remove_items(self, source_col_name, target_col_name, *objs):
# source_col_name: the PK colname in join_table for the source object
@ -508,7 +510,8 @@ def create_many_related_manager(superclass, through=False):
cursor.execute("DELETE FROM %s WHERE %s = %%s" % \
(self.join_table, source_col_name),
[self._pk_val])
transaction.commit_unless_managed()
# TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
return ManyRelatedManager

View File

@ -173,6 +173,9 @@ class Manager(object):
def only(self, *args, **kwargs):
return self.get_query_set().only(*args, **kwargs)
def using(self, *args, **kwargs):
return self.get_query_set().using(*args, **kwargs)
def _insert(self, values, **kwargs):
return insert_query(self.model, values, **kwargs)

View File

@ -7,7 +7,7 @@ try:
except NameError:
from sets import Set as set # Python 2.3 fallback
from django.db import connection, transaction, IntegrityError
from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS
from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
@ -31,10 +31,13 @@ class QuerySet(object):
"""
def __init__(self, model=None, query=None):
self.model = model
connection = connections[DEFAULT_DB_ALIAS]
self.query = query or sql.Query(self.model, connection)
self._result_cache = None
self._iter = None
self._sticky_filter = False
self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query
# is provided with a non default connection
########################
# PYTHON MAGIC METHODS #
@ -300,12 +303,12 @@ class QuerySet(object):
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults)
obj = self.model(**params)
sid = transaction.savepoint()
obj.save(force_insert=True)
transaction.savepoint_commit(sid)
sid = transaction.savepoint(using=self._using)
obj.save(force_insert=True, using=self._using)
transaction.savepoint_commit(sid, using=self._using)
return obj, True
except IntegrityError, e:
transaction.savepoint_rollback(sid)
transaction.savepoint_rollback(sid, using=self._using)
try:
return self.get(**kwargs), False
except self.model.DoesNotExist:
@ -364,7 +367,7 @@ class QuerySet(object):
if not seen_objs:
break
delete_objects(seen_objs)
delete_objects(seen_objs, del_query._using)
# Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None
@ -379,20 +382,20 @@ class QuerySet(object):
"Cannot update a query once a slice has been taken."
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
if not transaction.is_managed():
transaction.enter_transaction_management()
if not transaction.is_managed(using=self._using):
transaction.enter_transaction_management(using=self._using)
forced_managed = True
else:
forced_managed = False
try:
rows = query.execute_sql(None)
if forced_managed:
transaction.commit()
transaction.commit(using=self._using)
else:
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=self._using)
finally:
if forced_managed:
transaction.leave_transaction_management()
transaction.leave_transaction_management(using=self._using)
self._result_cache = None
return rows
update.alters_data = True
@ -616,6 +619,16 @@ class QuerySet(object):
clone.query.add_immediate_loading(fields)
return clone
def using(self, alias):
"""
Selects which database this QuerySet should excecute it's query against.
"""
clone = self._clone()
clone._using = alias
connection = connections[alias]
clone.query.set_connection(connection)
return clone
###################################
# PUBLIC INTROSPECTION ATTRIBUTES #
###################################
@ -644,6 +657,7 @@ class QuerySet(object):
if self._sticky_filter:
query.filter_is_sticky = True
c = klass(model=self.model, query=query)
c._using = self._using
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@ -700,6 +714,13 @@ class QuerySet(object):
obj = self.values("pk")
return obj.query.as_nested_sql()
def _validate(self):
"""
A normal QuerySet is always valid when used as the RHS of a filter,
since it automatically gets filtered down to 1 field.
"""
pass
# When used as part of a nested query, a queryset will never be an "always
# empty" result.
value_annotation = True
@ -818,6 +839,17 @@ class ValuesQuerySet(QuerySet):
% self.__class__.__name__)
return self._clone().query.as_nested_sql()
def _validate(self):
"""
Validates that we aren't trying to do a query like
value__in=qs.values('value1', 'value2'), which isn't valid.
"""
if ((self._fields and len(self._fields) > 1) or
(not self._fields and len(self.model._meta.fields) > 1)):
raise TypeError('Cannot use a multi-field %s as a filter value.'
% self.__class__.__name__)
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.flat and len(self._fields) == 1:
@ -970,13 +1002,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
def delete_objects(seen_objs):
def delete_objects(seen_objs, using):
"""
Iterate through a list of seen classes, and remove any instances that are
referred to.
"""
if not transaction.is_managed():
transaction.enter_transaction_management()
connection = connections[using]
if not transaction.is_managed(using=using):
transaction.enter_transaction_management(using=using)
forced_managed = True
else:
forced_managed = False
@ -1036,20 +1069,21 @@ def delete_objects(seen_objs):
setattr(instance, cls._meta.pk.attname, None)
if forced_managed:
transaction.commit()
transaction.commit(using=using)
else:
transaction.commit_unless_managed()
transaction.commit_unless_managed(using=using)
finally:
if forced_managed:
transaction.leave_transaction_management()
transaction.leave_transaction_management(using=using)
def insert_query(model, values, return_id=False, raw_values=False):
def insert_query(model, values, return_id=False, raw_values=False, using=None):
"""
Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
connection = connections[using]
query = sql.InsertQuery(model, connection)
query.insert_values(values, raw_values)
return query.execute_sql(return_id)

View File

@ -3,25 +3,23 @@ from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql.constants import LOOKUP_SEP
class SQLEvaluator(object):
as_sql_takes_connection = True
def __init__(self, expression, query, allow_joins=True):
self.expression = expression
self.opts = query.get_meta()
self.cols = {}
self.contains_aggregate = False
self.connection = query.connection
self.expression.prepare(self, query, allow_joins)
def as_sql(self, qn=None):
return self.expression.evaluate(self, qn)
def as_sql(self, qn, connection):
return self.expression.evaluate(self, qn, connection)
def relabel_aliases(self, change_map):
for node, col in self.cols.items():
self.cols[node] = (change_map.get(col[0], col[0]), col[1])
def update_connection(self, connection):
self.connection = connection
#####################################################
# Vistor methods for initial expression preparation #
#####################################################
@ -57,15 +55,12 @@ class SQLEvaluator(object):
# Vistor methods for final expression evaluation #
##################################################
def evaluate_node(self, node, qn):
if not qn:
qn = self.connection.ops.quote_name
def evaluate_node(self, node, qn, connection):
expressions = []
expression_params = []
for child in node.children:
if hasattr(child, 'evaluate'):
sql, params = child.evaluate(self, qn)
sql, params = child.evaluate(self, qn, connection)
else:
sql, params = '%s', (child,)
@ -78,12 +73,9 @@ class SQLEvaluator(object):
expressions.append(format % sql)
expression_params.extend(params)
return self.connection.ops.combine_expression(node.connector, expressions), expression_params
def evaluate_leaf(self, node, qn):
if not qn:
qn = self.connection.ops.quote_name
return connection.ops.combine_expression(node.connector, expressions), expression_params
def evaluate_leaf(self, node, qn, connection):
col = self.cols[node]
if hasattr(col, 'as_sql'):
return col.as_sql(qn), ()

View File

@ -67,10 +67,10 @@ class BaseQuery(object):
# SQL-related attributes
self.select = []
self.tables = [] # Aliases in the order they are created.
self.where = where(connection=self.connection)
self.where = where()
self.where_class = where
self.group_by = None
self.having = where(connection=self.connection)
self.having = where()
self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False
@ -151,8 +151,6 @@ class BaseQuery(object):
# supported. It's the only class-reference to the module-level
# connection variable.
self.connection = connection
self.where.update_connection(self.connection)
self.having.update_connection(self.connection)
def get_meta(self):
"""
@ -245,8 +243,6 @@ class BaseQuery(object):
obj.used_aliases = set()
obj.filter_is_sticky = False
obj.__dict__.update(kwargs)
obj.where.update_connection(obj.connection) # where and having track their own connection
obj.having.update_connection(obj.connection)# we need to keep this up to date
if hasattr(obj, '_setup_query'):
obj._setup_query()
return obj
@ -405,8 +401,8 @@ class BaseQuery(object):
from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias
where, w_params = self.where.as_sql(qn=qn)
having, h_params = self.having.as_sql(qn=qn)
where, w_params = self.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.having.as_sql(qn=qn, connection=self.connection)
params = []
for val in self.extra_select.itervalues():
params.extend(val[1])
@ -534,10 +530,10 @@ class BaseQuery(object):
self.where.add(EverythingNode(), AND)
elif self.where:
# rhs has an empty where clause.
w = self.where_class(connection=self.connection)
w = self.where_class()
w.add(EverythingNode(), AND)
else:
w = self.where_class(connection=self.connection)
w = self.where_class()
self.where.add(w, connector)
# Selection columns and extra extensions are those provided by 'rhs'.
@ -1550,7 +1546,7 @@ class BaseQuery(object):
for alias, aggregate in self.aggregates.items():
if alias == parts[0]:
entry = self.where_class(connection=self.connection)
entry = self.where_class()
entry.add((aggregate, lookup_type, value), AND)
if negate:
entry.negate()
@ -1618,7 +1614,7 @@ class BaseQuery(object):
for alias in join_list:
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
j_col = self.alias_map[alias][RHS_JOIN_COL]
entry = self.where_class(connection=self.connection)
entry = self.where_class()
entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
entry.negate()
self.where.add(entry, AND)
@ -1627,7 +1623,7 @@ class BaseQuery(object):
# Leaky abstraction artifact: We have to specifically
# exclude the "foo__in=[]" case from this handling, because
# it's short-circuited in the Where class.
entry = self.where_class(connection=self.connection)
entry = self.where_class()
entry.add((Constraint(alias, col, None), 'isnull', True), AND)
entry.negate()
self.where.add(entry, AND)
@ -2337,6 +2333,9 @@ class BaseQuery(object):
self.select = [(select_alias, select_col)]
self.remove_inherited_models()
def set_connection(self, connection):
self.connection = connection
def execute_sql(self, result_type=MULTI):
"""
Run the query against the database and returns the result(s). The

View File

@ -24,8 +24,9 @@ class DeleteQuery(Query):
"""
assert len(self.tables) == 1, \
"Can only delete from one table at a time."
result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])]
where, params = self.where.as_sql()
qn = self.quote_name_unless_alias
result = ['DELETE FROM %s' % qn(self.tables[0])]
where, params = self.where.as_sql(qn=qn, connection=self.connection)
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
@ -48,7 +49,7 @@ class DeleteQuery(Query):
for related in cls._meta.get_all_related_many_to_many_objects():
if not isinstance(related.field, generic.GenericRelation):
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection)
where = self.where_class()
where.add((Constraint(None,
related.field.m2m_reverse_name(), related.field),
'in',
@ -57,14 +58,14 @@ class DeleteQuery(Query):
self.do_query(related.field.m2m_db_table(), where)
for f in cls._meta.many_to_many:
w1 = self.where_class(connection=self.connection)
w1 = self.where_class()
if isinstance(f, generic.GenericRelation):
from django.contrib.contenttypes.models import ContentType
field = f.rel.to._meta.get_field(f.content_type_field_name)
w1.add((Constraint(None, field.column, field), 'exact',
ContentType.objects.get_for_model(cls).id), AND)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection)
where = self.where_class()
where.add((Constraint(None, f.m2m_column_name(), f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
AND)
@ -81,7 +82,7 @@ class DeleteQuery(Query):
lot of values in pk_list.
"""
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
where = self.where_class(connection=self.connection)
where = self.where_class()
field = self.model._meta.pk
where.add((Constraint(None, field.column, field), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
@ -143,6 +144,9 @@ class UpdateQuery(Query):
values, update_params = [], []
for name, val, placeholder in self.values:
if hasattr(val, 'as_sql'):
if getattr(val, 'as_sql_takes_connection', False):
sql, params = val.as_sql(qn, self.connection)
else:
sql, params = val.as_sql(qn)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
@ -152,7 +156,7 @@ class UpdateQuery(Query):
else:
values.append('%s = NULL' % qn(name))
result.append(', '.join(values))
where, params = self.where.as_sql()
where, params = self.where.as_sql(qn=qn, connection=self.connection)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
@ -185,7 +189,7 @@ class UpdateQuery(Query):
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
self.where = self.where_class(connection=self.connection)
self.where = self.where_class()
if self.related_updates or must_pre_select:
# Either we're using the idents in multiple update queries (so
# don't want them to change), or the db backend doesn't support
@ -209,7 +213,7 @@ class UpdateQuery(Query):
This is used by the QuerySet.delete_objects() method.
"""
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.where = self.where_class(connection=self.connection)
self.where = self.where_class()
f = self.model._meta.pk
self.where.add((Constraint(None, f.column, f), 'in',
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),

View File

@ -32,18 +32,7 @@ class WhereNode(tree.Node):
relabel_aliases() methods.
"""
default = AND
def __init__(self, *args, **kwargs):
self.connection = kwargs.pop('connection', None)
super(WhereNode, self).__init__(*args, **kwargs)
def __getstate__(self):
"""
Don't try to pickle the connection, our Query will restore it for us.
"""
data = self.__dict__.copy()
del data['connection']
return data
as_sql_takes_connection = True
def add(self, data, connector):
"""
@ -62,20 +51,6 @@ class WhereNode(tree.Node):
# Consume any generators immediately, so that we can determine
# emptiness and transform any non-empty values correctly.
value = list(value)
if hasattr(obj, "process"):
try:
# FIXME We're calling process too early, the connection could
# change
obj, params = obj.process(lookup_type, value, self.connection)
except (EmptyShortCircuit, EmptyResultSet):
# There are situations where we want to short-circuit any
# comparisons and make sure that nothing is returned. One
# example is when checking for a NULL pk value, or the
# equivalent.
super(WhereNode, self).add(NothingNode(), connector)
return
else:
params = Field().get_db_prep_lookup(lookup_type, value)
# The "annotation" parameter is used to pass auxilliary information
# about the value(s) to the query construction. Specifically, datetime
@ -88,18 +63,19 @@ class WhereNode(tree.Node):
else:
annotation = bool(value)
if hasattr(obj, "process"):
obj.validate(lookup_type, value)
super(WhereNode, self).add((obj, lookup_type, annotation, value),
connector)
return
else:
# TODO: Make this lazy just like the above code for constraints.
params = Field().get_db_prep_lookup(lookup_type, value)
super(WhereNode, self).add((obj, lookup_type, annotation, params),
connector)
def update_connection(self, connection):
self.connection = connection
for child in self.children:
if hasattr(child, 'update_connection'):
child.update_connection(connection)
elif hasattr(child[3], 'update_connection'):
child[3].update_connection(connection)
def as_sql(self, qn=None):
def as_sql(self, qn, connection):
"""
Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty.
@ -108,8 +84,6 @@ class WhereNode(tree.Node):
(generally not needed except by the internal implementation for
recursion).
"""
if not qn:
qn = self.connection.ops.quote_name
if not self.children:
return None, []
result = []
@ -118,10 +92,13 @@ class WhereNode(tree.Node):
for child in self.children:
try:
if hasattr(child, 'as_sql'):
if getattr(child, 'as_sql_takes_connection', False):
sql, params = child.as_sql(qn=qn, connection=connection)
else:
sql, params = child.as_sql(qn=qn)
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn)
sql, params = self.make_atom(child, qn, connection)
except EmptyResultSet:
if self.connector == AND and not self.negated:
@ -157,7 +134,7 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string
return sql_string, result_params
def make_atom(self, child, qn):
def make_atom(self, child, qn, connection):
"""
Turn a tuple (table_alias, column_name, db_type, lookup_type,
value_annot, params) into valid SQL.
@ -165,29 +142,39 @@ class WhereNode(tree.Node):
Returns the string for the SQL fragment and the parameters to use for
it.
"""
lvalue, lookup_type, value_annot, params = child
lvalue, lookup_type, value_annot, params_or_value = child
if hasattr(lvalue, 'process'):
try:
lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
except EmptyShortCircuit:
raise EmptyResultSet
else:
params = params_or_value
if isinstance(lvalue, tuple):
# A direct database column lookup.
field_sql = self.sql_for_columns(lvalue, qn)
field_sql = self.sql_for_columns(lvalue, qn, connection)
else:
# A smart object with an as_sql() method.
field_sql = lvalue.as_sql(quote_func=qn)
if value_annot is datetime.datetime:
cast_sql = self.connection.ops.datetime_cast_sql()
cast_sql = connection.ops.datetime_cast_sql()
else:
cast_sql = '%s'
if hasattr(params, 'as_sql'):
if getattr(params, 'as_sql_takes_connection', False):
extra, params = params.as_sql(qn, connection)
else:
extra, params = params.as_sql(qn)
cast_sql = ''
else:
extra = ''
if lookup_type in self.connection.operators:
format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),)
if lookup_type in connection.operators:
format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
return (format % (field_sql,
self.connection.operators[lookup_type] % cast_sql,
connection.operators[lookup_type] % cast_sql,
extra), params)
if lookup_type == 'in':
@ -200,19 +187,19 @@ class WhereNode(tree.Node):
elif lookup_type in ('range', 'year'):
return ('%s BETWEEN %%s and %%s' % field_sql, params)
elif lookup_type in ('month', 'day', 'week_day'):
return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql),
return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
params)
elif lookup_type == 'isnull':
return ('%s IS %sNULL' % (field_sql,
(not value_annot and 'NOT ' or '')), ())
elif lookup_type == 'search':
return (self.connection.ops.fulltext_search_sql(field_sql), params)
return (connection.ops.fulltext_search_sql(field_sql), params)
elif lookup_type in ('regex', 'iregex'):
return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
raise TypeError('Invalid lookup_type: %r' % lookup_type)
def sql_for_columns(self, data, qn):
def sql_for_columns(self, data, qn, connection):
"""
Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause
@ -223,7 +210,7 @@ class WhereNode(tree.Node):
lhs = '%s.%s' % (qn(table_alias), qn(name))
else:
lhs = qn(name)
return self.connection.ops.field_cast_sql(db_type) % lhs
return connection.ops.field_cast_sql(db_type) % lhs
def relabel_aliases(self, change_map, node=None):
"""
@ -299,3 +286,11 @@ class Constraint(object):
raise EmptyShortCircuit
return (self.alias, self.col, db_type), params
def relabel_aliases(self, change_map):
if self.alias in change_map:
self.alias = change_map[self.alias]
def validate(self, lookup_type, value):
if hasattr(self.field, 'validate'):
self.field.validate(lookup_type, value)

View File

@ -20,7 +20,7 @@ try:
from functools import wraps
except ImportError:
from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
from django.db import connection
from django.db import connections
from django.conf import settings
class TransactionManagementError(Exception):
@ -30,17 +30,20 @@ class TransactionManagementError(Exception):
"""
pass
# The states are dictionaries of lists. The key to the dict is the current
# thread and the list is handled as a stack of values.
# The states are dictionaries of dictionaries of lists. The key to the outer
# dict is the current thread, and the key to the inner dictionary is the
# connection alias and the list is handled as a stack of values.
state = {}
savepoint_state = {}
# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
# database commit.
# This is a dictionary mapping thread to a dictionary mapping connection
# alias to a boolean.
dirty = {}
def enter_transaction_management(managed=True):
def enter_transaction_management(managed=True, using=None):
"""
Enters transaction management for a running thread. It must be balanced with
the appropriate leave_transaction_management call, since the actual state is
@ -50,166 +53,212 @@ def enter_transaction_management(managed=True):
from the settings, if there is no surrounding block (dirty is always false
when no current block is running).
"""
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident]:
state[thread_ident].append(state[thread_ident][-1])
if thread_ident in state and state[thread_ident].get(using):
state[thread_ident][using].append(state[thread_ident][using][-1])
else:
state[thread_ident] = []
state[thread_ident].append(settings.TRANSACTIONS_MANAGED)
if thread_ident not in dirty:
dirty[thread_ident] = False
state.setdefault(thread_ident, {})
state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED]
if thread_ident not in dirty or using not in dirty[thread_ident]:
dirty.setdefault(thread_ident, {})
dirty[thread_ident][using] = False
connection._enter_transaction_management(managed)
def leave_transaction_management():
def leave_transaction_management(using=None):
"""
Leaves transaction management for a running thread. A dirty flag is carried
over to the surrounding block, as a commit will commit all changes, even
those from outside. (Commits are on connection level.)
"""
connection._leave_transaction_management(is_managed())
if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._leave_transaction_management(is_managed(using=using))
thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident]:
del state[thread_ident][-1]
if thread_ident in state and state[thread_ident].get(using):
del state[thread_ident][using][-1]
else:
raise TransactionManagementError("This code isn't under transaction management")
if dirty.get(thread_ident, False):
rollback()
if dirty.get(thread_ident, {}).get(using, False):
rollback(using=using)
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK")
dirty[thread_ident] = False
dirty[thread_ident][using] = False
def is_dirty():
def is_dirty(using=None):
"""
Returns True if the current transaction requires a commit for changes to
happen.
"""
return dirty.get(thread.get_ident(), False)
if using is None:
raise ValueError # TODO use default
return dirty.get(thread.get_ident(), {}).get(using, False)
def set_dirty():
def set_dirty(using=None):
"""
Sets a dirty flag for the current thread and code streak. This can be used
to decide in a managed block of code to decide whether there are open
changes waiting for commit.
"""
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident()
if thread_ident in dirty:
dirty[thread_ident] = True
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = True
else:
raise TransactionManagementError("This code isn't under transaction management")
def set_clean():
def set_clean(using=None):
"""
Resets a dirty flag for the current thread and code streak. This can be used
to decide in a managed block of code to decide whether a commit or rollback
should happen.
"""
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident()
if thread_ident in dirty:
dirty[thread_ident] = False
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = False
else:
raise TransactionManagementError("This code isn't under transaction management")
clean_savepoints()
clean_savepoints(using=using)
def clean_savepoints():
def clean_savepoints(using=None):
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident()
if thread_ident in savepoint_state:
del savepoint_state[thread_ident]
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
del savepoint_state[thread_ident][using]
def is_managed():
def is_managed(using=None):
"""
Checks whether the transaction manager is in manual or in auto state.
"""
if using is None:
raise ValueError # TODO use default
thread_ident = thread.get_ident()
if thread_ident in state:
if state[thread_ident]:
return state[thread_ident][-1]
if thread_ident in state and using in state[thread_ident]:
if state[thread_ident][using]:
return state[thread_ident][using][-1]
return settings.TRANSACTIONS_MANAGED
def managed(flag=True):
def managed(flag=True, using=None):
"""
Puts the transaction manager into a manual state: managed transactions have
to be committed explicitly by the user. If you switch off transaction
management and there is a pending commit/rollback, the data will be
commited.
"""
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
top = state.get(thread_ident, None)
top = state.get(thread_ident, {}).get(using, None)
if top:
top[-1] = flag
if not flag and is_dirty():
if not flag and is_dirty(using=using):
connection._commit()
set_clean()
set_clean(using=using)
else:
raise TransactionManagementError("This code isn't under transaction management")
def commit_unless_managed():
def commit_unless_managed(using=None):
"""
Commits changes if the system is not in managed transaction mode.
"""
if not is_managed():
if using is None:
raise ValueError # TODO use default
connection = connections[using]
if not is_managed(using=using):
connection._commit()
clean_savepoints()
clean_savepoints(using=using)
else:
set_dirty()
set_dirty(using=using)
def rollback_unless_managed():
def rollback_unless_managed(using=None):
"""
Rolls back changes if the system is not in managed transaction mode.
"""
if not is_managed():
if using is None:
raise ValueError # TODO use default
connection = connections[using]
if not is_managed(using=using):
connection._rollback()
else:
set_dirty()
set_dirty(using=using)
def commit():
def commit(using=None):
"""
Does the commit itself and resets the dirty flag.
"""
if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._commit()
set_clean()
set_clean(using=using)
def rollback():
def rollback(using=None):
"""
This function does the rollback itself and resets the dirty flag.
"""
if using is None:
raise ValueError # TODO use default
connection = connections[using]
connection._rollback()
set_clean()
set_clean(using=using)
def savepoint():
def savepoint(using=None):
"""
Creates a savepoint (if supported and required by the backend) inside the
current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit.
"""
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state:
savepoint_state[thread_ident].append(None)
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
savepoint_state[thread_ident][using].append(None)
else:
savepoint_state[thread_ident] = [None]
savepoint_state.setdefault(thread_ident, {})
savepoint_state[thread_ident][using] = [None]
tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using]))
connection._savepoint(sid)
return sid
def savepoint_rollback(sid):
def savepoint_rollback(sid, using=None):
"""
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
if thread.get_ident() in savepoint_state:
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
connection._savepoint_rollback(sid)
def savepoint_commit(sid):
def savepoint_commit(sid, using=None):
"""
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
if thread.get_ident() in savepoint_state:
if using is None:
raise ValueError # TODO use default
connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
connection._savepoint_commit(sid)
##############
# DECORATORS #
##############
# TODO update all of these for multi-db
def autocommit(func):
"""
Decorator that activates commit on save. This is Django's default behavior;

View File

@ -201,7 +201,8 @@ class DocTestRunner(doctest.DocTestRunner):
example, exc_info)
# Rollback, in case of database errors. Otherwise they'd have
# side effects on other tests.
transaction.rollback_unless_managed()
for conn in connections:
transaction.rollback_unless_managed(using=conn)
class TransactionTestCase(unittest.TestCase):
def _pre_setup(self):
@ -446,8 +447,9 @@ class TestCase(TransactionTestCase):
if not connections_support_transactions():
return super(TestCase, self)._fixture_setup()
transaction.enter_transaction_management()
transaction.managed(True)
for conn in connections:
transaction.enter_transaction_management(using=conn)
transaction.managed(True, using=conn)
disable_transaction_methods()
from django.contrib.sites.models import Site
@ -464,7 +466,8 @@ class TestCase(TransactionTestCase):
return super(TestCase, self)._fixture_teardown()
restore_transaction_methods()
transaction.rollback()
transaction.leave_transaction_management()
for conn in connections:
transaction.rollback(using=conn)
transaction.leave_transaction_management(using=conn)
for connection in connections.all():
connection.close()

View File

@ -15,11 +15,11 @@ class RevisionableModel(models.Model):
def __unicode__(self):
return u"%s (%s, %s)" % (self.title, self.id, self.base.id)
def save(self, force_insert=False, force_update=False):
super(RevisionableModel, self).save(force_insert, force_update)
def save(self, *args, **kwargs):
super(RevisionableModel, self).save(*args, **kwargs)
if not self.base:
self.base = self
super(RevisionableModel, self).save()
super(RevisionableModel, self).save(using=kwargs.pop('using', None))
def new_revision(self):
new_revision = copy.copy(self)