diff --git a/TODO.TXT b/TODO.TXT index 7863fd3577..b8d77fe11a 100644 --- a/TODO.TXT +++ b/TODO.TXT @@ -4,6 +4,11 @@ TODO The follow is a list, more or less in the order I intend to do them of things that need to be done. I'm trying to be as granular as possible. +*** +Immediate TODOs: +Finish refactor of WhereNode so that it just takes connection in as_sql. +*** + 2) Update all old references to ``settings.DATABASE_*`` to reference ``settings.DATABASES``. This includes the following locations @@ -20,29 +25,21 @@ that need to be done. I'm trying to be as granular as possible. * ``dumpdata``: By default dump the ``default`` database. Later add a ``--database`` flag. - * ``loaddata``: Leave as is, will use the standard mechanisms for - determining what DB to save the objects to. Should it get a - ``--database`` flag to overide that? These items will be fixed pending both community consensus, and the API - that will go in that's actually necessary for these to happen. Due to - internal APIs loaddata probably will need an update to load stuff into a - specific DB. + that will go in that's actually necessary for these to happen. -4) Rig up the test harness to work with multiple databases. This includes: + flush, reset, and syncdb need to not prompt the user multiple times. - * The current strategy is to test on N dbs, where N is however many the - user defines and ensuring the data all stays seperate and no exceptions - are raised. Practically speaking this means we're only going to have - good coverage if we write a lot of tests that can break. That's life. +9) Fix transaction support. In practice this means changing all the + dictionaries that currently map thread to a boolean to being a dictionary + mapping thread to a set of the affected DBs, and changing all the functions + that use these dictionaries to handle the changes appropriately. 7) Remove any references to the global ``django.db.connection`` object in the SQL creation process. This includes(but is probably not limited to): * The way we create ``Query`` from ``BaseQuery`` is awkward and hacky. - * ``django.db.models.query.delete_objects`` - * ``django.db.models.query.insert_query`` - * ``django.db.models.base.Model`` -- in ``save_base`` * ``django.db.models.fields.Field`` This uses it, as do it's subclasses. * ``django.db.models.fields.related`` It's used all over the place here, including opening a cursor and executing queries, so that's going to @@ -50,13 +47,9 @@ that need to be done. I'm trying to be as granular as possible. raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes it in before I need to tackle this. + 5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at each stage from here on out). -5) Implement using kwarg on save() method. -6) Add the ``using`` method to ``QuerySet``. This will more or less "just - work" across multiple databases that use the same backend. However, it - will fail gratuitously when trying to use 2 different backends. - 8) Implement some way to create a new ``Query`` for a different backend when we switch. There are several checks against ``self.connection`` prior to SQL construction, so we either need to defer all these(which will be @@ -80,8 +73,4 @@ that need to be done. I'm trying to be as granular as possible. every single one of them, then when it's time excecute the query just pick the right ``Query`` object to use. This *does* not scale, though it could probably be done fairly easily. -9) Fix transaction support. In practice this means changing all the - dictionaries that currently map thread to a boolean to being a dictionary - mapping thread to a set of the affected DBs, and changing all the functions - that use these dictionaries to handle the changes appropriately. 10) Time permitting add support for a ``DatabaseManager``. diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index 5564548133..f2cb989c2c 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -151,7 +151,7 @@ class GenericRelation(RelatedField, Field): def get_internal_type(self): return "ManyToManyField" - def db_type(self): + def db_type(self, connection): # Since we're simulating a ManyToManyField, in effect, best return the # same db_type as well. return None diff --git a/django/contrib/localflavor/us/models.py b/django/contrib/localflavor/us/models.py index 5158da4e87..444035b251 100644 --- a/django/contrib/localflavor/us/models.py +++ b/django/contrib/localflavor/us/models.py @@ -1,28 +1,27 @@ -from django.conf import settings from django.db.models.fields import Field -class USStateField(Field): - def get_internal_type(self): - return "USStateField" - - def db_type(self): - if settings.DATABASE_ENGINE == 'oracle': +class USStateField(Field): + def get_internal_type(self): + return "USStateField" + + def db_type(self, connection): + if connection.settings_dict['DATABASE_ENGINE'] == 'oracle': return 'CHAR(2)' else: return 'varchar(2)' - - def formfield(self, **kwargs): - from django.contrib.localflavor.us.forms import USStateSelect - defaults = {'widget': USStateSelect} - defaults.update(kwargs) + + def formfield(self, **kwargs): + from django.contrib.localflavor.us.forms import USStateSelect + defaults = {'widget': USStateSelect} + defaults.update(kwargs) return super(USStateField, self).formfield(**defaults) class PhoneNumberField(Field): def get_internal_type(self): return "PhoneNumberField" - def db_type(self): - if settings.DATABASE_ENGINE == 'oracle': + def db_type(self, connection): + if connection.settings_dict['DATABASE_ENGINE'] == 'oracle': return 'VARCHAR2(20)' else: return 'varchar(20)' @@ -32,4 +31,3 @@ class PhoneNumberField(Field): defaults = {'form_class': USPhoneNumberField} defaults.update(kwargs) return super(PhoneNumberField, self).formfield(**defaults) - diff --git a/django/contrib/sessions/backends/db.py b/django/contrib/sessions/backends/db.py index 85c4a0dd3c..f31c292d36 100644 --- a/django/contrib/sessions/backends/db.py +++ b/django/contrib/sessions/backends/db.py @@ -2,7 +2,7 @@ import datetime from django.contrib.sessions.models import Session from django.contrib.sessions.backends.base import SessionBase, CreateError from django.core.exceptions import SuspiciousOperation -from django.db import IntegrityError, transaction +from django.db import IntegrityError, transaction, DEFAULT_DB_ALIAS from django.utils.encoding import force_unicode class SessionStore(SessionBase): @@ -53,12 +53,13 @@ class SessionStore(SessionBase): session_data = self.encode(self._get_session(no_load=must_create)), expire_date = self.get_expiry_date() ) - sid = transaction.savepoint() + # TODO update for multidb + sid = transaction.savepoint(using=DEFAULT_DB_ALIAS) try: obj.save(force_insert=must_create) except IntegrityError: if must_create: - transaction.savepoint_rollback(sid) + transaction.savepoint_rollback(sid, using=DEFAULT_DB_ALIAS) raise CreateError raise diff --git a/django/contrib/sites/management.py b/django/contrib/sites/management.py index 25c076099a..19872740ee 100644 --- a/django/contrib/sites/management.py +++ b/django/contrib/sites/management.py @@ -6,12 +6,12 @@ from django.db.models import signals from django.contrib.sites.models import Site from django.contrib.sites import models as site_app -def create_default_site(app, created_models, verbosity, **kwargs): +def create_default_site(app, created_models, verbosity, db, **kwargs): if Site in created_models: if verbosity >= 2: print "Creating example.com Site object" s = Site(domain="example.com", name="example.com") - s.save() + s.save(using=db) Site.objects.clear_cache() signals.post_syncdb.connect(create_default_site, sender=site_app) diff --git a/django/core/management/commands/createcachetable.py b/django/core/management/commands/createcachetable.py index b0ad180ee9..798a1312e0 100644 --- a/django/core/management/commands/createcachetable.py +++ b/django/core/management/commands/createcachetable.py @@ -27,9 +27,9 @@ class Command(LabelCommand): ) table_output = [] index_output = [] - qn = connections.ops.quote_name + qn = connection.ops.quote_name for f in fields: - field_output = [qn(f.name), f.db_type()] + field_output = [qn(f.name), f.db_type(connection)] field_output.append("%sNULL" % (not f.null and "NOT " or "")) if f.primary_key: field_output.append("PRIMARY KEY") diff --git a/django/core/management/commands/flush.py b/django/core/management/commands/flush.py index 0b164d249d..6c6bbb5c11 100644 --- a/django/core/management/commands/flush.py +++ b/django/core/management/commands/flush.py @@ -15,14 +15,14 @@ class Command(NoArgsCommand): make_option('--noinput', action='store_false', dest='interactive', default=True, help='Tells Django to NOT prompt the user for input of any kind.'), make_option('--database', action='store', dest='database', - default='', help='Nominates a database to flush. Defaults to ' + default=None, help='Nominates a database to flush. Defaults to ' 'flushing all databases.'), ) help = "Executes ``sqlflush`` on the current database." def handle_noargs(self, **options): if not options['database']: - dbs = connections.all() + dbs = connections else: dbs = [options['database']] @@ -31,57 +31,14 @@ class Command(NoArgsCommand): self.style = no_style() - # Import the 'management' module within each installed app, to register - # dispatcher events. - for app_name in settings.INSTALLED_APPS: - try: - import_module('.management', app_name) - except ImportError: - pass - - sql_list = sql_flush(self.style, connection, only_django=True) - - if interactive: - confirm = raw_input("""You have requested a flush of the database. - This will IRREVERSIBLY DESTROY all data currently in the %r database, - and return each table to the state it was in after syncdb. - Are you sure you want to do this? - - Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['DATABASE_NAME']) - else: - confirm = 'yes' - - if confirm == 'yes': - try: - cursor = connection.cursor() - for sql in sql_list: - cursor.execute(sql) - except Exception, e: - transaction.rollback_unless_managed() - raise CommandError("""Database %s couldn't be flushed. Possible reasons: - * The database isn't running or isn't configured correctly. - * At least one of the expected database tables doesn't exist. - * The SQL was invalid. - Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run. - The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e)) - transaction.commit_unless_managed() - - # Emit the post sync signal. This allows individual - # applications to respond as if the database had been - # sync'd from scratch. - emit_post_sync_signal(models.get_models(), verbosity, interactive, connection) - - # Reinstall the initial_data fixture. - call_command('loaddata', 'initial_data', **options) - for app_name in settings.INSTALLED_APPS: try: import_module('.management', app_name) except ImportError: pass - for connection in dbs: - + for db in dbs: + connection = connections[db] # Import the 'management' module within each installed app, to register # dispatcher events. sql_list = sql_flush(self.style, connection, only_django=True) @@ -102,22 +59,24 @@ class Command(NoArgsCommand): for sql in sql_list: cursor.execute(sql) except Exception, e: - transaction.rollback_unless_managed() + transaction.rollback_unless_managed(using=db) raise CommandError("""Database %s couldn't be flushed. Possible reasons: * The database isn't running or isn't configured correctly. * At least one of the expected database tables doesn't exist. * The SQL was invalid. Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run. The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e)) - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=db) # Emit the post sync signal. This allows individual # applications to respond as if the database had been # sync'd from scratch. - emit_post_sync_signal(models.get_models(), verbosity, interactive, connection) + emit_post_sync_signal(models.get_models(), verbosity, interactive, db) # Reinstall the initial_data fixture. - call_command('loaddata', 'initial_data', **options) + kwargs = options.copy() + kwargs['database'] = db + call_command('loaddata', 'initial_data', **kwargs) else: print "Flush cancelled." diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index e70f85b9a6..d2819d7fd8 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -4,8 +4,13 @@ import gzip import zipfile from optparse import make_option +from django.conf import settings +from django.core import serializers from django.core.management.base import BaseCommand from django.core.management.color import no_style +from django.db import connections, transaction, DEFAULT_DB_ALIAS +from django.db.models import get_apps + try: set @@ -22,12 +27,15 @@ class Command(BaseCommand): help = 'Installs the named fixture(s) in the database.' args = "fixture [fixture ...]" - def handle(self, *fixture_labels, **options): - from django.db.models import get_apps - from django.core import serializers - from django.db import connection, transaction - from django.conf import settings + option_list = BaseCommand.option_list + ( + make_option('--database', action='store', dest='database', + default=DEFAULT_DB_ALIAS, help='Nominates a specific database to load ' + 'fixtures into. By default uses the "default" database.'), + ) + def handle(self, *fixture_labels, **options): + using = options['database'] + connection = connections[using] self.style = no_style() verbosity = int(options.get('verbosity', 1)) @@ -56,9 +64,9 @@ class Command(BaseCommand): # Start transaction management. All fixtures are installed in a # single transaction to ensure that all references are resolved. if commit: - transaction.commit_unless_managed() - transaction.enter_transaction_management() - transaction.managed(True) + transaction.commit_unless_managed(using=using) + transaction.enter_transaction_management(using=using) + transaction.managed(True, using=using) class SingleZipReader(zipfile.ZipFile): def __init__(self, *args, **kwargs): @@ -103,8 +111,8 @@ class Command(BaseCommand): sys.stderr.write( self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." % (fixture_name, format))) - transaction.rollback() - transaction.leave_transaction_management() + transaction.rollback(using=using) + transaction.leave_transaction_management(using=using) return if os.path.isabs(fixture_name): @@ -119,25 +127,25 @@ class Command(BaseCommand): label_found = False for format in formats: for compression_format in compression_formats: - if compression_format: - file_name = '.'.join([fixture_name, format, + if compression_format: + file_name = '.'.join([fixture_name, format, compression_format]) - else: + else: file_name = '.'.join([fixture_name, format]) - + if verbosity > 1: print "Trying %s for %s fixture '%s'..." % \ (humanize(fixture_dir), file_name, fixture_name) full_path = os.path.join(fixture_dir, file_name) - open_method = compression_types[compression_format] - try: + open_method = compression_types[compression_format] + try: fixture = open_method(full_path, 'r') if label_found: fixture.close() print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." % (fixture_name, humanize(fixture_dir))) - transaction.rollback() - transaction.leave_transaction_management() + transaction.rollback(using=using) + transaction.leave_transaction_management(using=using) return else: fixture_count += 1 @@ -150,7 +158,7 @@ class Command(BaseCommand): for obj in objects: objects_in_fixture += 1 models.add(obj.object.__class__) - obj.save() + obj.save(using=using) object_count += objects_in_fixture label_found = True except (SystemExit, KeyboardInterrupt): @@ -158,15 +166,15 @@ class Command(BaseCommand): except Exception: import traceback fixture.close() - transaction.rollback() - transaction.leave_transaction_management() + transaction.rollback(using=using) + transaction.leave_transaction_management(using=using) if show_traceback: traceback.print_exc() else: sys.stderr.write( self.style.ERROR("Problem installing fixture '%s': %s\n" % - (full_path, ''.join(traceback.format_exception(sys.exc_type, - sys.exc_value, sys.exc_traceback))))) + (full_path, ''.join(traceback.format_exception(sys.exc_type, + sys.exc_value, sys.exc_traceback))))) return fixture.close() @@ -176,8 +184,8 @@ class Command(BaseCommand): sys.stderr.write( self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" % (fixture_name))) - transaction.rollback() - transaction.leave_transaction_management() + transaction.rollback(using=using) + transaction.leave_transaction_management(using=using) return except Exception, e: @@ -196,8 +204,8 @@ class Command(BaseCommand): cursor.execute(line) if commit: - transaction.commit() - transaction.leave_transaction_management() + transaction.commit(using=using) + transaction.leave_transaction_management(using=using) if object_count == 0: if verbosity > 1: diff --git a/django/core/management/commands/sqlsequencereset.py b/django/core/management/commands/sqlsequencereset.py index 3373c607e2..a1bbf48f5b 100644 --- a/django/core/management/commands/sqlsequencereset.py +++ b/django/core/management/commands/sqlsequencereset.py @@ -10,7 +10,7 @@ class Command(AppCommand): make_option('--database', action='store', dest='database', default='default', help='Nominates a database to print the SQL ' 'for. Defaults to the "default" database.'), - ) + ) output_transaction = True diff --git a/django/core/management/commands/syncdb.py b/django/core/management/commands/syncdb.py index fd6b5b5b1a..ff196494bb 100644 --- a/django/core/management/commands/syncdb.py +++ b/django/core/management/commands/syncdb.py @@ -32,9 +32,9 @@ class Command(NoArgsCommand): self.style = no_style() if not options['database']: - dbs = connections.all() + dbs = connections else: - dbs = [connections[options['database']]] + dbs = [options['database']] # Import the 'management' module within each installed app, to register # dispatcher events. @@ -55,7 +55,8 @@ class Command(NoArgsCommand): if not msg.startswith('No module named') or 'management' not in msg: raise - for connection in dbs: + for db in dbs: + connection = connections[db] cursor = connection.cursor() # Get a list of already installed *models* so that references work right. @@ -102,11 +103,11 @@ class Command(NoArgsCommand): for statement in sql: cursor.execute(statement) - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=db) # Send the post_syncdb signal, so individual apps can do whatever they need # to do at this point. - emit_post_sync_signal(created_models, verbosity, interactive, connection) + emit_post_sync_signal(created_models, verbosity, interactive, db) # The connection may have been closed by a syncdb handler. cursor = connection.cursor() @@ -130,9 +131,9 @@ class Command(NoArgsCommand): if show_traceback: import traceback traceback.print_exc() - transaction.rollback_unless_managed() + transaction.rollback_unless_managed(using=db) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=db) else: if verbosity >= 2: print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name) @@ -151,13 +152,9 @@ class Command(NoArgsCommand): except Exception, e: sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \ (app_name, model._meta.object_name, e)) - transaction.rollback_unless_managed() + transaction.rollback_unless_managed(using=db) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=db) - # Install the 'initial_data' fixture, using format discovery - # FIXME we only load the fixture data for one DB right now, since we - # can't control what DB it does into, once we can control this we - # should move it back into the DB loop - from django.core.management import call_command - call_command('loaddata', 'initial_data', verbosity=verbosity) + from django.core.management import call_command + call_command('loaddata', 'initial_data', verbosity=verbosity, database=db) diff --git a/django/core/management/sql.py b/django/core/management/sql.py index 9591ce000a..63122ec96d 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -188,7 +188,7 @@ def custom_sql_for_model(model, style, connection): return output -def emit_post_sync_signal(created_models, verbosity, interactive, connection): +def emit_post_sync_signal(created_models, verbosity, interactive, db): # Emit the post_sync signal for every application. for app in models.get_apps(): app_name = app.__name__.split('.')[-2] @@ -196,4 +196,4 @@ def emit_post_sync_signal(created_models, verbosity, interactive, connection): print "Running post-sync handlers for application", app_name models.signals.post_syncdb.send(sender=app, app=app, created_models=created_models, verbosity=verbosity, - interactive=interactive, connection=connection) + interactive=interactive, db=db) diff --git a/django/core/serializers/base.py b/django/core/serializers/base.py index 22de2d70d0..484b3fc66d 100644 --- a/django/core/serializers/base.py +++ b/django/core/serializers/base.py @@ -154,13 +154,13 @@ class DeserializedObject(object): def __repr__(self): return "" % smart_str(self.object) - def save(self, save_m2m=True): + def save(self, save_m2m=True, using=None): # Call save on the Model baseclass directly. This bypasses any # model-defined save. The save is also forced to be raw. # This ensures that the data that is deserialized is literally # what came from the file, not post-processed by pre_save/save # methods. - models.Model.save_base(self.object, raw=True) + models.Model.save_base(self.object, using=using, raw=True) if self.m2m_data and save_m2m: for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) diff --git a/django/db/__init__.py b/django/db/__init__.py index 9da54541da..d53b60bff4 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -54,12 +54,13 @@ def reset_queries(**kwargs): connection.queries = [] signals.request_started.connect(reset_queries) -# Register an event that rolls back the connection +# Register an event that rolls back the connections # when a Django request has an exception. def _rollback_on_exception(**kwargs): from django.db import transaction - try: - transaction.rollback_unless_managed() - except DatabaseError: - pass + for conn in connections: + try: + transaction.rollback_unless_managed(using=conn) + except DatabaseError: + pass signals.got_request_exception.connect(_rollback_on_exception) diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 84429dc45d..282606a06d 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -90,7 +90,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): super(DatabaseWrapper, self).__init__(*args, **kwargs) self.features = DatabaseFeatures() - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index ee74be1624..bb0d64515a 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -6,14 +6,14 @@ from django.db.backends import BaseDatabaseOperations # used by both the 'postgresql' and 'postgresql_psycopg2' backends. class DatabaseOperations(BaseDatabaseOperations): - def __init__(self): + def __init__(self, connection): self._postgres_version = None + self.connection = connection def _get_postgres_version(self): if self._postgres_version is None: - from django.db import connection from django.db.backends.postgresql.version import get_version - cursor = connection.cursor() + cursor = self.connection.cursor() self._postgres_version = get_version(cursor) return self._postgres_version postgres_version = property(_get_postgres_version) diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 3d804f09b9..b2a0e943ef 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -66,7 +66,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False) self.features.uses_autocommit = autocommit self._set_isolation_level(int(not autocommit)) - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) diff --git a/django/db/models/base.py b/django/db/models/base.py index bb2dbf0504..12e84625a7 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -15,7 +15,7 @@ from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneF from django.db.models.query import delete_objects, Q from django.db.models.query_utils import CollectedObjects, DeferredAttribute from django.db.models.options import Options -from django.db import connection, transaction, DatabaseError +from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.loading import register_models, get_model from django.utils.functional import curry @@ -395,7 +395,7 @@ class Model(object): return getattr(self, field_name) return getattr(self, field.attname) - def save(self, force_insert=False, force_update=False): + def save(self, force_insert=False, force_update=False, using=None): """ Saves the current instance. Override this in a subclass if you want to control the saving process. @@ -407,18 +407,21 @@ class Model(object): if force_insert and force_update: raise ValueError("Cannot force both insert and updating in " "model saving.") - self.save_base(force_insert=force_insert, force_update=force_update) + self.save_base(using=using, force_insert=force_insert, force_update=force_update) save.alters_data = True def save_base(self, raw=False, cls=None, force_insert=False, - force_update=False): + force_update=False, using=None): """ Does the heavy-lifting involved in saving. Subclasses shouldn't need to override this method. It's separate from save() in order to hide the need for overrides of save() to pass around internal-only parameters ('raw' and 'cls'). """ + if using is None: + using = DEFAULT_DB_ALIAS + connection = connections[using] assert not (force_insert and force_update) if not cls: cls = self.__class__ @@ -441,7 +444,7 @@ class Model(object): if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None: setattr(self, parent._meta.pk.attname, getattr(self, field.attname)) - self.save_base(cls=parent) + self.save_base(cls=parent, using=using) if field: setattr(self, field.attname, self._get_pk_val(parent._meta)) if meta.proxy: @@ -457,12 +460,13 @@ class Model(object): manager = cls._base_manager if pk_set: # Determine whether a record with the primary key already exists. + # FIXME work with the using parameter if (force_update or (not force_insert and - manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())): + manager.using(using).filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())): # It does already exist, so do an UPDATE. if force_update or non_pks: values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks] - rows = manager.filter(pk=pk_val)._update(values) + rows = manager.using(using).filter(pk=pk_val)._update(values) if force_update and not rows: raise DatabaseError("Forced update did not affect any rows.") else: @@ -477,20 +481,20 @@ class Model(object): if meta.order_with_respect_to: field = meta.order_with_respect_to - values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count())) + values.append((meta.get_field_by_name('_order')[0], manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count())) record_exists = False update_pk = bool(meta.has_auto_field and not pk_set) if values: # Create a new record. - result = manager._insert(values, return_id=update_pk) + result = manager._insert(values, return_id=update_pk, using=using) else: # Create a new record with defaults for everything. - result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True) + result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using) if update_pk: setattr(self, meta.pk.attname, result) - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) if signal: signals.post_save.send(sender=self.__class__, instance=self, @@ -549,7 +553,10 @@ class Model(object): # delete it and all its descendents. parent_obj._collect_sub_objects(seen_objs) - def delete(self): + def delete(self, using=None): + if using is None: + using = DEFAULT_DB_ALIAS + connection = connections[using] assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname) # Find all the objects than need to be deleted. @@ -557,7 +564,7 @@ class Model(object): self._collect_sub_objects(seen_objs) # Actually delete the objects. - delete_objects(seen_objs) + delete_objects(seen_objs, using) delete.alters_data = True @@ -610,7 +617,8 @@ def method_set_order(ordered_obj, self, id_list): # for situations like this. for i, j in enumerate(id_list): ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i) - transaction.commit_unless_managed() + # TODO + transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS) def method_get_order(ordered_obj, self): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index e54aabacb5..e782d98c61 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -41,8 +41,8 @@ class ExpressionNode(tree.Node): def prepare(self, evaluator, query, allow_joins): return evaluator.prepare_node(self, query, allow_joins) - def evaluate(self, evaluator, qn): - return evaluator.evaluate_node(self, qn) + def evaluate(self, evaluator, qn, connection): + return evaluator.evaluate_node(self, qn, connection) ############# # OPERATORS # @@ -109,5 +109,5 @@ class F(ExpressionNode): def prepare(self, evaluator, query, allow_joins): return evaluator.prepare_leaf(self, query, allow_joins) - def evaluate(self, evaluator, qn): - return evaluator.evaluate_leaf(self, qn) + def evaluate(self, evaluator, qn, connection): + return evaluator.evaluate_leaf(self, qn, connection) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index dc57e9401a..f8a94e41ec 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -233,6 +233,24 @@ class Field(object): raise TypeError("Field has invalid lookup: %s" % lookup_type) + def validate(self, lookup_type, value): + """ + Validate that the data is valid, as much so as possible without knowing + what connection we are using. Returns True if the value was + successfully validated and false if the value wasn't validated (this + doesn't consider whether the value was actually valid, an exception is + raised in those circumstances). + """ + if hasattr(value, 'validate') or hasattr(value, '_validate'): + if hasattr(value, 'validate'): + value.validate() + else: + value._validate() + return True + if lookup_type == 'isnull': + return True + return False + def has_default(self): "Returns a boolean of whether this field has a default value." return self.default is not NOT_PROVIDED @@ -360,6 +378,17 @@ class AutoField(Field): return None return int(value) + def validate(self, lookup_type, value): + if super(AutoField, self).validate(lookup_type, value): + return + if value is None or hasattr(value, 'as_sql'): + return + if lookup_type in ('range', 'in'): + for val in value: + int(val) + else: + int(value) + def contribute_to_class(self, cls, name): assert not cls._meta.has_auto_field, "A model can't have more than one AutoField." super(AutoField, self).contribute_to_class(cls, name) @@ -396,6 +425,13 @@ class BooleanField(Field): value = bool(int(value)) return super(BooleanField, self).get_db_prep_lookup(lookup_type, value) + def validate(self, lookup_type, value): + if super(BooleanField, self).validate(lookup_type, value): + return + if value in ('1', '0'): + value = int(value) + bool(value) + def get_db_prep_value(self, value): if value is None: return None @@ -510,6 +546,20 @@ class DateField(Field): # Casts dates into the format expected by the backend return connection.ops.value_to_db_date(self.to_python(value)) + def validate(self, lookup_type, value): + if super(DateField, self).validate(lookup_type, value): + return + if value is None: + return + if lookup_type in ('month', 'day', 'year', 'week_day'): + int(value) + return + if lookup_type in ('in', 'range'): + for val in value: + self.to_python(val) + return + self.to_python(value) + def value_to_string(self, obj): val = self._get_val_from_obj(obj) if val is None: @@ -754,6 +804,11 @@ class NullBooleanField(Field): value = bool(int(value)) return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value) + def validate(self, lookup_type, value): + if value in ('1', '0'): + value = int(value) + bool(value) + def get_db_prep_value(self, value): if value is None: return None diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4d7d771b3f..9d9c082112 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1,4 +1,4 @@ -from django.db import connection, transaction +from django.db import connection, transaction, DEFAULT_DB_ALIAS from django.db.backends import util from django.db.models import signals, get_model from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist @@ -478,7 +478,9 @@ def create_many_related_manager(superclass, through=False): cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ (self.join_table, source_col_name, target_col_name), [self._pk_val, obj_id]) - transaction.commit_unless_managed() + # FIXME, once this isn't in related.py it should conditionally + # use the right DB. + transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS) def _remove_items(self, source_col_name, target_col_name, *objs): # source_col_name: the PK colname in join_table for the source object @@ -508,7 +510,8 @@ def create_many_related_manager(superclass, through=False): cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ (self.join_table, source_col_name), [self._pk_val]) - transaction.commit_unless_managed() + # TODO + transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS) return ManyRelatedManager diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 52612d8f64..18ed1c161f 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -173,6 +173,9 @@ class Manager(object): def only(self, *args, **kwargs): return self.get_query_set().only(*args, **kwargs) + def using(self, *args, **kwargs): + return self.get_query_set().using(*args, **kwargs) + def _insert(self, values, **kwargs): return insert_query(self.model, values, **kwargs) diff --git a/django/db/models/query.py b/django/db/models/query.py index 6a8d7d5e64..f9d66c94c5 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -7,7 +7,7 @@ try: except NameError: from sets import Set as set # Python 2.3 fallback -from django.db import connection, transaction, IntegrityError +from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS from django.db.models.aggregates import Aggregate from django.db.models.fields import DateField from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory @@ -31,10 +31,13 @@ class QuerySet(object): """ def __init__(self, model=None, query=None): self.model = model + connection = connections[DEFAULT_DB_ALIAS] self.query = query or sql.Query(self.model, connection) self._result_cache = None self._iter = None self._sticky_filter = False + self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query + # is provided with a non default connection ######################## # PYTHON MAGIC METHODS # @@ -300,12 +303,12 @@ class QuerySet(object): params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) params.update(defaults) obj = self.model(**params) - sid = transaction.savepoint() - obj.save(force_insert=True) - transaction.savepoint_commit(sid) + sid = transaction.savepoint(using=self._using) + obj.save(force_insert=True, using=self._using) + transaction.savepoint_commit(sid, using=self._using) return obj, True except IntegrityError, e: - transaction.savepoint_rollback(sid) + transaction.savepoint_rollback(sid, using=self._using) try: return self.get(**kwargs), False except self.model.DoesNotExist: @@ -364,7 +367,7 @@ class QuerySet(object): if not seen_objs: break - delete_objects(seen_objs) + delete_objects(seen_objs, del_query._using) # Clear the result cache, in case this QuerySet gets reused. self._result_cache = None @@ -379,20 +382,20 @@ class QuerySet(object): "Cannot update a query once a slice has been taken." query = self.query.clone(sql.UpdateQuery) query.add_update_values(kwargs) - if not transaction.is_managed(): - transaction.enter_transaction_management() + if not transaction.is_managed(using=self._using): + transaction.enter_transaction_management(using=self._using) forced_managed = True else: forced_managed = False try: rows = query.execute_sql(None) if forced_managed: - transaction.commit() + transaction.commit(using=self._using) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=self._using) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=self._using) self._result_cache = None return rows update.alters_data = True @@ -616,6 +619,16 @@ class QuerySet(object): clone.query.add_immediate_loading(fields) return clone + def using(self, alias): + """ + Selects which database this QuerySet should excecute it's query against. + """ + clone = self._clone() + clone._using = alias + connection = connections[alias] + clone.query.set_connection(connection) + return clone + ################################### # PUBLIC INTROSPECTION ATTRIBUTES # ################################### @@ -644,6 +657,7 @@ class QuerySet(object): if self._sticky_filter: query.filter_is_sticky = True c = klass(model=self.model, query=query) + c._using = self._using c.__dict__.update(kwargs) if setup and hasattr(c, '_setup_query'): c._setup_query() @@ -700,6 +714,13 @@ class QuerySet(object): obj = self.values("pk") return obj.query.as_nested_sql() + def _validate(self): + """ + A normal QuerySet is always valid when used as the RHS of a filter, + since it automatically gets filtered down to 1 field. + """ + pass + # When used as part of a nested query, a queryset will never be an "always # empty" result. value_annotation = True @@ -818,6 +839,17 @@ class ValuesQuerySet(QuerySet): % self.__class__.__name__) return self._clone().query.as_nested_sql() + def _validate(self): + """ + Validates that we aren't trying to do a query like + value__in=qs.values('value1', 'value2'), which isn't valid. + """ + if ((self._fields and len(self._fields) > 1) or + (not self._fields and len(self.model._meta.fields) > 1)): + raise TypeError('Cannot use a multi-field %s as a filter value.' + % self.__class__.__name__) + + class ValuesListQuerySet(ValuesQuerySet): def iterator(self): if self.flat and len(self._fields) == 1: @@ -970,13 +1002,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def delete_objects(seen_objs): +def delete_objects(seen_objs, using): """ Iterate through a list of seen classes, and remove any instances that are referred to. """ - if not transaction.is_managed(): - transaction.enter_transaction_management() + connection = connections[using] + if not transaction.is_managed(using=using): + transaction.enter_transaction_management(using=using) forced_managed = True else: forced_managed = False @@ -1036,20 +1069,21 @@ def delete_objects(seen_objs): setattr(instance, cls._meta.pk.attname, None) if forced_managed: - transaction.commit() + transaction.commit(using=using) else: - transaction.commit_unless_managed() + transaction.commit_unless_managed(using=using) finally: if forced_managed: - transaction.leave_transaction_management() + transaction.leave_transaction_management(using=using) -def insert_query(model, values, return_id=False, raw_values=False): +def insert_query(model, values, return_id=False, raw_values=False, using=None): """ Inserts a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. It is not part of the public API. """ + connection = connections[using] query = sql.InsertQuery(model, connection) query.insert_values(values, raw_values) return query.execute_sql(return_id) diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 826c4f398c..2c447022d3 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -3,25 +3,23 @@ from django.db.models.fields import FieldDoesNotExist from django.db.models.sql.constants import LOOKUP_SEP class SQLEvaluator(object): + as_sql_takes_connection = True + def __init__(self, expression, query, allow_joins=True): self.expression = expression self.opts = query.get_meta() self.cols = {} self.contains_aggregate = False - self.connection = query.connection self.expression.prepare(self, query, allow_joins) - def as_sql(self, qn=None): - return self.expression.evaluate(self, qn) + def as_sql(self, qn, connection): + return self.expression.evaluate(self, qn, connection) def relabel_aliases(self, change_map): for node, col in self.cols.items(): self.cols[node] = (change_map.get(col[0], col[0]), col[1]) - def update_connection(self, connection): - self.connection = connection - ##################################################### # Vistor methods for initial expression preparation # ##################################################### @@ -57,15 +55,12 @@ class SQLEvaluator(object): # Vistor methods for final expression evaluation # ################################################## - def evaluate_node(self, node, qn): - if not qn: - qn = self.connection.ops.quote_name - + def evaluate_node(self, node, qn, connection): expressions = [] expression_params = [] for child in node.children: if hasattr(child, 'evaluate'): - sql, params = child.evaluate(self, qn) + sql, params = child.evaluate(self, qn, connection) else: sql, params = '%s', (child,) @@ -78,12 +73,9 @@ class SQLEvaluator(object): expressions.append(format % sql) expression_params.extend(params) - return self.connection.ops.combine_expression(node.connector, expressions), expression_params - - def evaluate_leaf(self, node, qn): - if not qn: - qn = self.connection.ops.quote_name + return connection.ops.combine_expression(node.connector, expressions), expression_params + def evaluate_leaf(self, node, qn, connection): col = self.cols[node] if hasattr(col, 'as_sql'): return col.as_sql(qn), () diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 2059139600..02c3c732e5 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -67,10 +67,10 @@ class BaseQuery(object): # SQL-related attributes self.select = [] self.tables = [] # Aliases in the order they are created. - self.where = where(connection=self.connection) + self.where = where() self.where_class = where self.group_by = None - self.having = where(connection=self.connection) + self.having = where() self.order_by = [] self.low_mark, self.high_mark = 0, None # Used for offset/limit self.distinct = False @@ -151,8 +151,6 @@ class BaseQuery(object): # supported. It's the only class-reference to the module-level # connection variable. self.connection = connection - self.where.update_connection(self.connection) - self.having.update_connection(self.connection) def get_meta(self): """ @@ -245,8 +243,6 @@ class BaseQuery(object): obj.used_aliases = set() obj.filter_is_sticky = False obj.__dict__.update(kwargs) - obj.where.update_connection(obj.connection) # where and having track their own connection - obj.having.update_connection(obj.connection)# we need to keep this up to date if hasattr(obj, '_setup_query'): obj._setup_query() return obj @@ -405,8 +401,8 @@ class BaseQuery(object): from_, f_params = self.get_from_clause() qn = self.quote_name_unless_alias - where, w_params = self.where.as_sql(qn=qn) - having, h_params = self.having.as_sql(qn=qn) + where, w_params = self.where.as_sql(qn=qn, connection=self.connection) + having, h_params = self.having.as_sql(qn=qn, connection=self.connection) params = [] for val in self.extra_select.itervalues(): params.extend(val[1]) @@ -534,10 +530,10 @@ class BaseQuery(object): self.where.add(EverythingNode(), AND) elif self.where: # rhs has an empty where clause. - w = self.where_class(connection=self.connection) + w = self.where_class() w.add(EverythingNode(), AND) else: - w = self.where_class(connection=self.connection) + w = self.where_class() self.where.add(w, connector) # Selection columns and extra extensions are those provided by 'rhs'. @@ -1550,7 +1546,7 @@ class BaseQuery(object): for alias, aggregate in self.aggregates.items(): if alias == parts[0]: - entry = self.where_class(connection=self.connection) + entry = self.where_class() entry.add((aggregate, lookup_type, value), AND) if negate: entry.negate() @@ -1618,7 +1614,7 @@ class BaseQuery(object): for alias in join_list: if self.alias_map[alias][JOIN_TYPE] == self.LOUTER: j_col = self.alias_map[alias][RHS_JOIN_COL] - entry = self.where_class(connection=self.connection) + entry = self.where_class() entry.add((Constraint(alias, j_col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) @@ -1627,7 +1623,7 @@ class BaseQuery(object): # Leaky abstraction artifact: We have to specifically # exclude the "foo__in=[]" case from this handling, because # it's short-circuited in the Where class. - entry = self.where_class(connection=self.connection) + entry = self.where_class() entry.add((Constraint(alias, col, None), 'isnull', True), AND) entry.negate() self.where.add(entry, AND) @@ -2337,6 +2333,9 @@ class BaseQuery(object): self.select = [(select_alias, select_col)] self.remove_inherited_models() + def set_connection(self, connection): + self.connection = connection + def execute_sql(self, result_type=MULTI): """ Run the query against the database and returns the result(s). The diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index def1ff8ad8..51f1acafbf 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -24,8 +24,9 @@ class DeleteQuery(Query): """ assert len(self.tables) == 1, \ "Can only delete from one table at a time." - result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] - where, params = self.where.as_sql() + qn = self.quote_name_unless_alias + result = ['DELETE FROM %s' % qn(self.tables[0])] + where, params = self.where.as_sql(qn=qn, connection=self.connection) result.append('WHERE %s' % where) return ' '.join(result), tuple(params) @@ -48,7 +49,7 @@ class DeleteQuery(Query): for related in cls._meta.get_all_related_many_to_many_objects(): if not isinstance(related.field, generic.GenericRelation): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class(connection=self.connection) + where = self.where_class() where.add((Constraint(None, related.field.m2m_reverse_name(), related.field), 'in', @@ -57,14 +58,14 @@ class DeleteQuery(Query): self.do_query(related.field.m2m_db_table(), where) for f in cls._meta.many_to_many: - w1 = self.where_class(connection=self.connection) + w1 = self.where_class() if isinstance(f, generic.GenericRelation): from django.contrib.contenttypes.models import ContentType field = f.rel.to._meta.get_field(f.content_type_field_name) w1.add((Constraint(None, field.column, field), 'exact', ContentType.objects.get_for_model(cls).id), AND) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class(connection=self.connection) + where = self.where_class() where.add((Constraint(None, f.m2m_column_name(), f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -81,7 +82,7 @@ class DeleteQuery(Query): lot of values in pk_list. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class(connection=self.connection) + where = self.where_class() field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) @@ -143,7 +144,10 @@ class UpdateQuery(Query): values, update_params = [], [] for name, val, placeholder in self.values: if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn) + if getattr(val, 'as_sql_takes_connection', False): + sql, params = val.as_sql(qn, self.connection) + else: + sql, params = val.as_sql(qn) values.append('%s = %s' % (qn(name), sql)) update_params.extend(params) elif val is not None: @@ -152,7 +156,7 @@ class UpdateQuery(Query): else: values.append('%s = NULL' % qn(name)) result.append(', '.join(values)) - where, params = self.where.as_sql() + where, params = self.where.as_sql(qn=qn, connection=self.connection) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) @@ -185,7 +189,7 @@ class UpdateQuery(Query): # Now we adjust the current query: reset the where clause and get rid # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class(connection=self.connection) + self.where = self.where_class() if self.related_updates or must_pre_select: # Either we're using the idents in multiple update queries (so # don't want them to change), or the db backend doesn't support @@ -209,7 +213,7 @@ class UpdateQuery(Query): This is used by the QuerySet.delete_objects() method. """ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - self.where = self.where_class(connection=self.connection) + self.where = self.where_class() f = self.model._meta.pk self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 8acbc9eef9..7336d0731c 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -32,18 +32,7 @@ class WhereNode(tree.Node): relabel_aliases() methods. """ default = AND - - def __init__(self, *args, **kwargs): - self.connection = kwargs.pop('connection', None) - super(WhereNode, self).__init__(*args, **kwargs) - - def __getstate__(self): - """ - Don't try to pickle the connection, our Query will restore it for us. - """ - data = self.__dict__.copy() - del data['connection'] - return data + as_sql_takes_connection = True def add(self, data, connector): """ @@ -62,20 +51,6 @@ class WhereNode(tree.Node): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) - if hasattr(obj, "process"): - try: - # FIXME We're calling process too early, the connection could - # change - obj, params = obj.process(lookup_type, value, self.connection) - except (EmptyShortCircuit, EmptyResultSet): - # There are situations where we want to short-circuit any - # comparisons and make sure that nothing is returned. One - # example is when checking for a NULL pk value, or the - # equivalent. - super(WhereNode, self).add(NothingNode(), connector) - return - else: - params = Field().get_db_prep_lookup(lookup_type, value) # The "annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime @@ -88,18 +63,19 @@ class WhereNode(tree.Node): else: annotation = bool(value) + if hasattr(obj, "process"): + obj.validate(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, value), + connector) + return + else: + # TODO: Make this lazy just like the above code for constraints. + params = Field().get_db_prep_lookup(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, params), connector) - def update_connection(self, connection): - self.connection = connection - for child in self.children: - if hasattr(child, 'update_connection'): - child.update_connection(connection) - elif hasattr(child[3], 'update_connection'): - child[3].update_connection(connection) - - def as_sql(self, qn=None): + def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -108,8 +84,6 @@ class WhereNode(tree.Node): (generally not needed except by the internal implementation for recursion). """ - if not qn: - qn = self.connection.ops.quote_name if not self.children: return None, [] result = [] @@ -118,10 +92,13 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn) + if getattr(child, 'as_sql_takes_connection', False): + sql, params = child.as_sql(qn=qn, connection=connection) + else: + sql, params = child.as_sql(qn=qn) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn) + sql, params = self.make_atom(child, qn, connection) except EmptyResultSet: if self.connector == AND and not self.negated: @@ -157,7 +134,7 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def make_atom(self, child, qn): + def make_atom(self, child, qn, connection): """ Turn a tuple (table_alias, column_name, db_type, lookup_type, value_annot, params) into valid SQL. @@ -165,29 +142,39 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params = child + lvalue, lookup_type, value_annot, params_or_value = child + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, connection) + except EmptyShortCircuit: + raise EmptyResultSet + else: + params = params_or_value if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn) + field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. field_sql = lvalue.as_sql(quote_func=qn) if value_annot is datetime.datetime: - cast_sql = self.connection.ops.datetime_cast_sql() + cast_sql = connection.ops.datetime_cast_sql() else: cast_sql = '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn) + if getattr(params, 'as_sql_takes_connection', False): + extra, params = params.as_sql(qn, connection) + else: + extra, params = params.as_sql(qn) cast_sql = '' else: extra = '' - if lookup_type in self.connection.operators: - format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),) + if lookup_type in connection.operators: + format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, - self.connection.operators[lookup_type] % cast_sql, + connection.operators[lookup_type] % cast_sql, extra), params) if lookup_type == 'in': @@ -200,19 +187,19 @@ class WhereNode(tree.Node): elif lookup_type in ('range', 'year'): return ('%s BETWEEN %%s and %%s' % field_sql, params) elif lookup_type in ('month', 'day', 'week_day'): - return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql), + return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql), params) elif lookup_type == 'isnull': return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), ()) elif lookup_type == 'search': - return (self.connection.ops.fulltext_search_sql(field_sql), params) + return (connection.ops.fulltext_search_sql(field_sql), params) elif lookup_type in ('regex', 'iregex'): - return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params + return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params raise TypeError('Invalid lookup_type: %r' % lookup_type) - def sql_for_columns(self, data, qn): + def sql_for_columns(self, data, qn, connection): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause @@ -223,7 +210,7 @@ class WhereNode(tree.Node): lhs = '%s.%s' % (qn(table_alias), qn(name)) else: lhs = qn(name) - return self.connection.ops.field_cast_sql(db_type) % lhs + return connection.ops.field_cast_sql(db_type) % lhs def relabel_aliases(self, change_map, node=None): """ @@ -299,3 +286,11 @@ class Constraint(object): raise EmptyShortCircuit return (self.alias, self.col, db_type), params + + def relabel_aliases(self, change_map): + if self.alias in change_map: + self.alias = change_map[self.alias] + + def validate(self, lookup_type, value): + if hasattr(self.field, 'validate'): + self.field.validate(lookup_type, value) diff --git a/django/db/transaction.py b/django/db/transaction.py index 5d80bf24f0..39aaab4482 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -20,7 +20,7 @@ try: from functools import wraps except ImportError: from django.utils.functional import wraps # Python 2.3, 2.4 fallback. -from django.db import connection +from django.db import connections from django.conf import settings class TransactionManagementError(Exception): @@ -30,17 +30,20 @@ class TransactionManagementError(Exception): """ pass -# The states are dictionaries of lists. The key to the dict is the current -# thread and the list is handled as a stack of values. +# The states are dictionaries of dictionaries of lists. The key to the outer +# dict is the current thread, and the key to the inner dictionary is the +# connection alias and the list is handled as a stack of values. state = {} savepoint_state = {} # The dirty flag is set by *_unless_managed functions to denote that the # code under transaction management has changed things to require a # database commit. +# This is a dictionary mapping thread to a dictionary mapping connection +# alias to a boolean. dirty = {} -def enter_transaction_management(managed=True): +def enter_transaction_management(managed=True, using=None): """ Enters transaction management for a running thread. It must be balanced with the appropriate leave_transaction_management call, since the actual state is @@ -50,166 +53,212 @@ def enter_transaction_management(managed=True): from the settings, if there is no surrounding block (dirty is always false when no current block is running). """ + if using is None: + raise ValueError # TODO use default + connection = connections[using] thread_ident = thread.get_ident() - if thread_ident in state and state[thread_ident]: - state[thread_ident].append(state[thread_ident][-1]) + if thread_ident in state and state[thread_ident].get(using): + state[thread_ident][using].append(state[thread_ident][using][-1]) else: - state[thread_ident] = [] - state[thread_ident].append(settings.TRANSACTIONS_MANAGED) - if thread_ident not in dirty: - dirty[thread_ident] = False + state.setdefault(thread_ident, {}) + state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED] + if thread_ident not in dirty or using not in dirty[thread_ident]: + dirty.setdefault(thread_ident, {}) + dirty[thread_ident][using] = False connection._enter_transaction_management(managed) -def leave_transaction_management(): +def leave_transaction_management(using=None): """ Leaves transaction management for a running thread. A dirty flag is carried over to the surrounding block, as a commit will commit all changes, even those from outside. (Commits are on connection level.) """ - connection._leave_transaction_management(is_managed()) + if using is None: + raise ValueError # TODO use default + connection = connections[using] + connection._leave_transaction_management(is_managed(using=using)) thread_ident = thread.get_ident() - if thread_ident in state and state[thread_ident]: - del state[thread_ident][-1] + if thread_ident in state and state[thread_ident].get(using): + del state[thread_ident][using][-1] else: raise TransactionManagementError("This code isn't under transaction management") - if dirty.get(thread_ident, False): - rollback() + if dirty.get(thread_ident, {}).get(using, False): + rollback(using=using) raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK") - dirty[thread_ident] = False + dirty[thread_ident][using] = False -def is_dirty(): +def is_dirty(using=None): """ Returns True if the current transaction requires a commit for changes to happen. """ - return dirty.get(thread.get_ident(), False) + if using is None: + raise ValueError # TODO use default + return dirty.get(thread.get_ident(), {}).get(using, False) -def set_dirty(): +def set_dirty(using=None): """ Sets a dirty flag for the current thread and code streak. This can be used to decide in a managed block of code to decide whether there are open changes waiting for commit. """ + if using is None: + raise ValueError # TODO use default thread_ident = thread.get_ident() - if thread_ident in dirty: - dirty[thread_ident] = True + if thread_ident in dirty and using in dirty[thread_ident]: + dirty[thread_ident][using] = True else: raise TransactionManagementError("This code isn't under transaction management") -def set_clean(): +def set_clean(using=None): """ Resets a dirty flag for the current thread and code streak. This can be used to decide in a managed block of code to decide whether a commit or rollback should happen. """ + if using is None: + raise ValueError # TODO use default thread_ident = thread.get_ident() - if thread_ident in dirty: - dirty[thread_ident] = False + if thread_ident in dirty and using in dirty[thread_ident]: + dirty[thread_ident][using] = False else: raise TransactionManagementError("This code isn't under transaction management") - clean_savepoints() + clean_savepoints(using=using) -def clean_savepoints(): +def clean_savepoints(using=None): + if using is None: + raise ValueError # TODO use default thread_ident = thread.get_ident() - if thread_ident in savepoint_state: - del savepoint_state[thread_ident] + if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: + del savepoint_state[thread_ident][using] -def is_managed(): +def is_managed(using=None): """ Checks whether the transaction manager is in manual or in auto state. """ + if using is None: + raise ValueError # TODO use default thread_ident = thread.get_ident() - if thread_ident in state: - if state[thread_ident]: - return state[thread_ident][-1] + if thread_ident in state and using in state[thread_ident]: + if state[thread_ident][using]: + return state[thread_ident][using][-1] return settings.TRANSACTIONS_MANAGED -def managed(flag=True): +def managed(flag=True, using=None): """ Puts the transaction manager into a manual state: managed transactions have to be committed explicitly by the user. If you switch off transaction management and there is a pending commit/rollback, the data will be commited. """ + if using is None: + raise ValueError # TODO use default + connection = connections[using] thread_ident = thread.get_ident() - top = state.get(thread_ident, None) + top = state.get(thread_ident, {}).get(using, None) if top: top[-1] = flag - if not flag and is_dirty(): + if not flag and is_dirty(using=using): connection._commit() - set_clean() + set_clean(using=using) else: raise TransactionManagementError("This code isn't under transaction management") -def commit_unless_managed(): +def commit_unless_managed(using=None): """ Commits changes if the system is not in managed transaction mode. """ - if not is_managed(): + if using is None: + raise ValueError # TODO use default + connection = connections[using] + if not is_managed(using=using): connection._commit() - clean_savepoints() + clean_savepoints(using=using) else: - set_dirty() + set_dirty(using=using) -def rollback_unless_managed(): +def rollback_unless_managed(using=None): """ Rolls back changes if the system is not in managed transaction mode. """ - if not is_managed(): + if using is None: + raise ValueError # TODO use default + connection = connections[using] + if not is_managed(using=using): connection._rollback() else: - set_dirty() + set_dirty(using=using) -def commit(): +def commit(using=None): """ Does the commit itself and resets the dirty flag. """ + if using is None: + raise ValueError # TODO use default + connection = connections[using] connection._commit() - set_clean() + set_clean(using=using) -def rollback(): +def rollback(using=None): """ This function does the rollback itself and resets the dirty flag. """ + if using is None: + raise ValueError # TODO use default + connection = connections[using] connection._rollback() - set_clean() + set_clean(using=using) -def savepoint(): +def savepoint(using=None): """ Creates a savepoint (if supported and required by the backend) inside the current transaction. Returns an identifier for the savepoint that will be used for the subsequent rollback or commit. """ + if using is None: + raise ValueError # TODO use default + connection = connections[using] thread_ident = thread.get_ident() - if thread_ident in savepoint_state: - savepoint_state[thread_ident].append(None) + if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: + savepoint_state[thread_ident][using].append(None) else: - savepoint_state[thread_ident] = [None] + savepoint_state.setdefault(thread_ident, {}) + savepoint_state[thread_ident][using] = [None] tid = str(thread_ident).replace('-', '') - sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident])) + sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using])) connection._savepoint(sid) return sid -def savepoint_rollback(sid): +def savepoint_rollback(sid, using=None): """ Rolls back the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ - if thread.get_ident() in savepoint_state: + if using is None: + raise ValueError # TODO use default + connection = connections[using] + thread_ident = thread.get_ident() + if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: connection._savepoint_rollback(sid) -def savepoint_commit(sid): +def savepoint_commit(sid, using=None): """ Commits the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ - if thread.get_ident() in savepoint_state: + if using is None: + raise ValueError # TODO use default + connection = connections[using] + thread_ident = thread.get_ident() + if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: connection._savepoint_commit(sid) ############## # DECORATORS # ############## +# TODO update all of these for multi-db + def autocommit(func): """ Decorator that activates commit on save. This is Django's default behavior; diff --git a/django/test/testcases.py b/django/test/testcases.py index 47e9368e6d..b891813a9e 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -201,7 +201,8 @@ class DocTestRunner(doctest.DocTestRunner): example, exc_info) # Rollback, in case of database errors. Otherwise they'd have # side effects on other tests. - transaction.rollback_unless_managed() + for conn in connections: + transaction.rollback_unless_managed(using=conn) class TransactionTestCase(unittest.TestCase): def _pre_setup(self): @@ -446,8 +447,9 @@ class TestCase(TransactionTestCase): if not connections_support_transactions(): return super(TestCase, self)._fixture_setup() - transaction.enter_transaction_management() - transaction.managed(True) + for conn in connections: + transaction.enter_transaction_management(using=conn) + transaction.managed(True, using=conn) disable_transaction_methods() from django.contrib.sites.models import Site @@ -464,7 +466,8 @@ class TestCase(TransactionTestCase): return super(TestCase, self)._fixture_teardown() restore_transaction_methods() - transaction.rollback() - transaction.leave_transaction_management() + for conn in connections: + transaction.rollback(using=conn) + transaction.leave_transaction_management(using=conn) for connection in connections.all(): connection.close() diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py index 5d22d6cc07..f208a2febc 100644 --- a/tests/regressiontests/extra_regress/models.py +++ b/tests/regressiontests/extra_regress/models.py @@ -15,11 +15,11 @@ class RevisionableModel(models.Model): def __unicode__(self): return u"%s (%s, %s)" % (self.title, self.id, self.base.id) - def save(self, force_insert=False, force_update=False): - super(RevisionableModel, self).save(force_insert, force_update) + def save(self, *args, **kwargs): + super(RevisionableModel, self).save(*args, **kwargs) if not self.base: self.base = self - super(RevisionableModel, self).save() + super(RevisionableModel, self).save(using=kwargs.pop('using', None)) def new_revision(self): new_revision = copy.copy(self)