mirror of
https://github.com/django/django.git
synced 2025-07-05 10:19:20 +00:00
[soc2009/multidb] Bring this branch up to date with my external work. This means implementing the using method on querysets as well as a using kwarg on save and delete, plus many internal changes to facilitae this
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@10904 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
f4bcbbfa8b
commit
23da5c0ac1
35
TODO.TXT
35
TODO.TXT
@ -4,6 +4,11 @@ TODO
|
|||||||
The follow is a list, more or less in the order I intend to do them of things
|
The follow is a list, more or less in the order I intend to do them of things
|
||||||
that need to be done. I'm trying to be as granular as possible.
|
that need to be done. I'm trying to be as granular as possible.
|
||||||
|
|
||||||
|
***
|
||||||
|
Immediate TODOs:
|
||||||
|
Finish refactor of WhereNode so that it just takes connection in as_sql.
|
||||||
|
***
|
||||||
|
|
||||||
2) Update all old references to ``settings.DATABASE_*`` to reference
|
2) Update all old references to ``settings.DATABASE_*`` to reference
|
||||||
``settings.DATABASES``. This includes the following locations
|
``settings.DATABASES``. This includes the following locations
|
||||||
|
|
||||||
@ -20,29 +25,21 @@ that need to be done. I'm trying to be as granular as possible.
|
|||||||
|
|
||||||
* ``dumpdata``: By default dump the ``default`` database. Later add a
|
* ``dumpdata``: By default dump the ``default`` database. Later add a
|
||||||
``--database`` flag.
|
``--database`` flag.
|
||||||
* ``loaddata``: Leave as is, will use the standard mechanisms for
|
|
||||||
determining what DB to save the objects to. Should it get a
|
|
||||||
``--database`` flag to overide that?
|
|
||||||
|
|
||||||
These items will be fixed pending both community consensus, and the API
|
These items will be fixed pending both community consensus, and the API
|
||||||
that will go in that's actually necessary for these to happen. Due to
|
that will go in that's actually necessary for these to happen.
|
||||||
internal APIs loaddata probably will need an update to load stuff into a
|
|
||||||
specific DB.
|
|
||||||
|
|
||||||
4) Rig up the test harness to work with multiple databases. This includes:
|
flush, reset, and syncdb need to not prompt the user multiple times.
|
||||||
|
|
||||||
* The current strategy is to test on N dbs, where N is however many the
|
9) Fix transaction support. In practice this means changing all the
|
||||||
user defines and ensuring the data all stays seperate and no exceptions
|
dictionaries that currently map thread to a boolean to being a dictionary
|
||||||
are raised. Practically speaking this means we're only going to have
|
mapping thread to a set of the affected DBs, and changing all the functions
|
||||||
good coverage if we write a lot of tests that can break. That's life.
|
that use these dictionaries to handle the changes appropriately.
|
||||||
|
|
||||||
7) Remove any references to the global ``django.db.connection`` object in the
|
7) Remove any references to the global ``django.db.connection`` object in the
|
||||||
SQL creation process. This includes(but is probably not limited to):
|
SQL creation process. This includes(but is probably not limited to):
|
||||||
|
|
||||||
* The way we create ``Query`` from ``BaseQuery`` is awkward and hacky.
|
* The way we create ``Query`` from ``BaseQuery`` is awkward and hacky.
|
||||||
* ``django.db.models.query.delete_objects``
|
|
||||||
* ``django.db.models.query.insert_query``
|
|
||||||
* ``django.db.models.base.Model`` -- in ``save_base``
|
|
||||||
* ``django.db.models.fields.Field`` This uses it, as do it's subclasses.
|
* ``django.db.models.fields.Field`` This uses it, as do it's subclasses.
|
||||||
* ``django.db.models.fields.related`` It's used all over the place here,
|
* ``django.db.models.fields.related`` It's used all over the place here,
|
||||||
including opening a cursor and executing queries, so that's going to
|
including opening a cursor and executing queries, so that's going to
|
||||||
@ -50,13 +47,9 @@ that need to be done. I'm trying to be as granular as possible.
|
|||||||
raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes
|
raw SQL and execution to ``Query``/``QuerySet`` so hopefully that makes
|
||||||
it in before I need to tackle this.
|
it in before I need to tackle this.
|
||||||
|
|
||||||
|
|
||||||
5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at
|
5) Add the ``using`` Meta option. Tests and docs(these are to be assumed at
|
||||||
each stage from here on out).
|
each stage from here on out).
|
||||||
5) Implement using kwarg on save() method.
|
|
||||||
6) Add the ``using`` method to ``QuerySet``. This will more or less "just
|
|
||||||
work" across multiple databases that use the same backend. However, it
|
|
||||||
will fail gratuitously when trying to use 2 different backends.
|
|
||||||
|
|
||||||
8) Implement some way to create a new ``Query`` for a different backend when
|
8) Implement some way to create a new ``Query`` for a different backend when
|
||||||
we switch. There are several checks against ``self.connection`` prior to
|
we switch. There are several checks against ``self.connection`` prior to
|
||||||
SQL construction, so we either need to defer all these(which will be
|
SQL construction, so we either need to defer all these(which will be
|
||||||
@ -80,8 +73,4 @@ that need to be done. I'm trying to be as granular as possible.
|
|||||||
every single one of them, then when it's time excecute the query just
|
every single one of them, then when it's time excecute the query just
|
||||||
pick the right ``Query`` object to use. This *does* not scale, though it
|
pick the right ``Query`` object to use. This *does* not scale, though it
|
||||||
could probably be done fairly easily.
|
could probably be done fairly easily.
|
||||||
9) Fix transaction support. In practice this means changing all the
|
|
||||||
dictionaries that currently map thread to a boolean to being a dictionary
|
|
||||||
mapping thread to a set of the affected DBs, and changing all the functions
|
|
||||||
that use these dictionaries to handle the changes appropriately.
|
|
||||||
10) Time permitting add support for a ``DatabaseManager``.
|
10) Time permitting add support for a ``DatabaseManager``.
|
||||||
|
@ -151,7 +151,7 @@ class GenericRelation(RelatedField, Field):
|
|||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return "ManyToManyField"
|
return "ManyToManyField"
|
||||||
|
|
||||||
def db_type(self):
|
def db_type(self, connection):
|
||||||
# Since we're simulating a ManyToManyField, in effect, best return the
|
# Since we're simulating a ManyToManyField, in effect, best return the
|
||||||
# same db_type as well.
|
# same db_type as well.
|
||||||
return None
|
return None
|
||||||
|
@ -1,28 +1,27 @@
|
|||||||
from django.conf import settings
|
|
||||||
from django.db.models.fields import Field
|
from django.db.models.fields import Field
|
||||||
|
|
||||||
class USStateField(Field):
|
class USStateField(Field):
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return "USStateField"
|
return "USStateField"
|
||||||
|
|
||||||
def db_type(self):
|
def db_type(self, connection):
|
||||||
if settings.DATABASE_ENGINE == 'oracle':
|
if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
|
||||||
return 'CHAR(2)'
|
return 'CHAR(2)'
|
||||||
else:
|
else:
|
||||||
return 'varchar(2)'
|
return 'varchar(2)'
|
||||||
|
|
||||||
def formfield(self, **kwargs):
|
def formfield(self, **kwargs):
|
||||||
from django.contrib.localflavor.us.forms import USStateSelect
|
from django.contrib.localflavor.us.forms import USStateSelect
|
||||||
defaults = {'widget': USStateSelect}
|
defaults = {'widget': USStateSelect}
|
||||||
defaults.update(kwargs)
|
defaults.update(kwargs)
|
||||||
return super(USStateField, self).formfield(**defaults)
|
return super(USStateField, self).formfield(**defaults)
|
||||||
|
|
||||||
class PhoneNumberField(Field):
|
class PhoneNumberField(Field):
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return "PhoneNumberField"
|
return "PhoneNumberField"
|
||||||
|
|
||||||
def db_type(self):
|
def db_type(self, connection):
|
||||||
if settings.DATABASE_ENGINE == 'oracle':
|
if connection.settings_dict['DATABASE_ENGINE'] == 'oracle':
|
||||||
return 'VARCHAR2(20)'
|
return 'VARCHAR2(20)'
|
||||||
else:
|
else:
|
||||||
return 'varchar(20)'
|
return 'varchar(20)'
|
||||||
@ -32,4 +31,3 @@ class PhoneNumberField(Field):
|
|||||||
defaults = {'form_class': USPhoneNumberField}
|
defaults = {'form_class': USPhoneNumberField}
|
||||||
defaults.update(kwargs)
|
defaults.update(kwargs)
|
||||||
return super(PhoneNumberField, self).formfield(**defaults)
|
return super(PhoneNumberField, self).formfield(**defaults)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import datetime
|
|||||||
from django.contrib.sessions.models import Session
|
from django.contrib.sessions.models import Session
|
||||||
from django.contrib.sessions.backends.base import SessionBase, CreateError
|
from django.contrib.sessions.backends.base import SessionBase, CreateError
|
||||||
from django.core.exceptions import SuspiciousOperation
|
from django.core.exceptions import SuspiciousOperation
|
||||||
from django.db import IntegrityError, transaction
|
from django.db import IntegrityError, transaction, DEFAULT_DB_ALIAS
|
||||||
from django.utils.encoding import force_unicode
|
from django.utils.encoding import force_unicode
|
||||||
|
|
||||||
class SessionStore(SessionBase):
|
class SessionStore(SessionBase):
|
||||||
@ -53,12 +53,13 @@ class SessionStore(SessionBase):
|
|||||||
session_data = self.encode(self._get_session(no_load=must_create)),
|
session_data = self.encode(self._get_session(no_load=must_create)),
|
||||||
expire_date = self.get_expiry_date()
|
expire_date = self.get_expiry_date()
|
||||||
)
|
)
|
||||||
sid = transaction.savepoint()
|
# TODO update for multidb
|
||||||
|
sid = transaction.savepoint(using=DEFAULT_DB_ALIAS)
|
||||||
try:
|
try:
|
||||||
obj.save(force_insert=must_create)
|
obj.save(force_insert=must_create)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
if must_create:
|
if must_create:
|
||||||
transaction.savepoint_rollback(sid)
|
transaction.savepoint_rollback(sid, using=DEFAULT_DB_ALIAS)
|
||||||
raise CreateError
|
raise CreateError
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -6,12 +6,12 @@ from django.db.models import signals
|
|||||||
from django.contrib.sites.models import Site
|
from django.contrib.sites.models import Site
|
||||||
from django.contrib.sites import models as site_app
|
from django.contrib.sites import models as site_app
|
||||||
|
|
||||||
def create_default_site(app, created_models, verbosity, **kwargs):
|
def create_default_site(app, created_models, verbosity, db, **kwargs):
|
||||||
if Site in created_models:
|
if Site in created_models:
|
||||||
if verbosity >= 2:
|
if verbosity >= 2:
|
||||||
print "Creating example.com Site object"
|
print "Creating example.com Site object"
|
||||||
s = Site(domain="example.com", name="example.com")
|
s = Site(domain="example.com", name="example.com")
|
||||||
s.save()
|
s.save(using=db)
|
||||||
Site.objects.clear_cache()
|
Site.objects.clear_cache()
|
||||||
|
|
||||||
signals.post_syncdb.connect(create_default_site, sender=site_app)
|
signals.post_syncdb.connect(create_default_site, sender=site_app)
|
||||||
|
@ -27,9 +27,9 @@ class Command(LabelCommand):
|
|||||||
)
|
)
|
||||||
table_output = []
|
table_output = []
|
||||||
index_output = []
|
index_output = []
|
||||||
qn = connections.ops.quote_name
|
qn = connection.ops.quote_name
|
||||||
for f in fields:
|
for f in fields:
|
||||||
field_output = [qn(f.name), f.db_type()]
|
field_output = [qn(f.name), f.db_type(connection)]
|
||||||
field_output.append("%sNULL" % (not f.null and "NOT " or ""))
|
field_output.append("%sNULL" % (not f.null and "NOT " or ""))
|
||||||
if f.primary_key:
|
if f.primary_key:
|
||||||
field_output.append("PRIMARY KEY")
|
field_output.append("PRIMARY KEY")
|
||||||
|
@ -15,14 +15,14 @@ class Command(NoArgsCommand):
|
|||||||
make_option('--noinput', action='store_false', dest='interactive', default=True,
|
make_option('--noinput', action='store_false', dest='interactive', default=True,
|
||||||
help='Tells Django to NOT prompt the user for input of any kind.'),
|
help='Tells Django to NOT prompt the user for input of any kind.'),
|
||||||
make_option('--database', action='store', dest='database',
|
make_option('--database', action='store', dest='database',
|
||||||
default='', help='Nominates a database to flush. Defaults to '
|
default=None, help='Nominates a database to flush. Defaults to '
|
||||||
'flushing all databases.'),
|
'flushing all databases.'),
|
||||||
)
|
)
|
||||||
help = "Executes ``sqlflush`` on the current database."
|
help = "Executes ``sqlflush`` on the current database."
|
||||||
|
|
||||||
def handle_noargs(self, **options):
|
def handle_noargs(self, **options):
|
||||||
if not options['database']:
|
if not options['database']:
|
||||||
dbs = connections.all()
|
dbs = connections
|
||||||
else:
|
else:
|
||||||
dbs = [options['database']]
|
dbs = [options['database']]
|
||||||
|
|
||||||
@ -31,57 +31,14 @@ class Command(NoArgsCommand):
|
|||||||
|
|
||||||
self.style = no_style()
|
self.style = no_style()
|
||||||
|
|
||||||
# Import the 'management' module within each installed app, to register
|
|
||||||
# dispatcher events.
|
|
||||||
for app_name in settings.INSTALLED_APPS:
|
|
||||||
try:
|
|
||||||
import_module('.management', app_name)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
sql_list = sql_flush(self.style, connection, only_django=True)
|
|
||||||
|
|
||||||
if interactive:
|
|
||||||
confirm = raw_input("""You have requested a flush of the database.
|
|
||||||
This will IRREVERSIBLY DESTROY all data currently in the %r database,
|
|
||||||
and return each table to the state it was in after syncdb.
|
|
||||||
Are you sure you want to do this?
|
|
||||||
|
|
||||||
Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['DATABASE_NAME'])
|
|
||||||
else:
|
|
||||||
confirm = 'yes'
|
|
||||||
|
|
||||||
if confirm == 'yes':
|
|
||||||
try:
|
|
||||||
cursor = connection.cursor()
|
|
||||||
for sql in sql_list:
|
|
||||||
cursor.execute(sql)
|
|
||||||
except Exception, e:
|
|
||||||
transaction.rollback_unless_managed()
|
|
||||||
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
|
|
||||||
* The database isn't running or isn't configured correctly.
|
|
||||||
* At least one of the expected database tables doesn't exist.
|
|
||||||
* The SQL was invalid.
|
|
||||||
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
|
|
||||||
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
|
|
||||||
transaction.commit_unless_managed()
|
|
||||||
|
|
||||||
# Emit the post sync signal. This allows individual
|
|
||||||
# applications to respond as if the database had been
|
|
||||||
# sync'd from scratch.
|
|
||||||
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection)
|
|
||||||
|
|
||||||
# Reinstall the initial_data fixture.
|
|
||||||
call_command('loaddata', 'initial_data', **options)
|
|
||||||
|
|
||||||
for app_name in settings.INSTALLED_APPS:
|
for app_name in settings.INSTALLED_APPS:
|
||||||
try:
|
try:
|
||||||
import_module('.management', app_name)
|
import_module('.management', app_name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for connection in dbs:
|
for db in dbs:
|
||||||
|
connection = connections[db]
|
||||||
# Import the 'management' module within each installed app, to register
|
# Import the 'management' module within each installed app, to register
|
||||||
# dispatcher events.
|
# dispatcher events.
|
||||||
sql_list = sql_flush(self.style, connection, only_django=True)
|
sql_list = sql_flush(self.style, connection, only_django=True)
|
||||||
@ -102,22 +59,24 @@ class Command(NoArgsCommand):
|
|||||||
for sql in sql_list:
|
for sql in sql_list:
|
||||||
cursor.execute(sql)
|
cursor.execute(sql)
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
transaction.rollback_unless_managed()
|
transaction.rollback_unless_managed(using=db)
|
||||||
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
|
raise CommandError("""Database %s couldn't be flushed. Possible reasons:
|
||||||
* The database isn't running or isn't configured correctly.
|
* The database isn't running or isn't configured correctly.
|
||||||
* At least one of the expected database tables doesn't exist.
|
* At least one of the expected database tables doesn't exist.
|
||||||
* The SQL was invalid.
|
* The SQL was invalid.
|
||||||
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
|
Hint: Look at the output of 'django-admin.py sqlflush'. That's the SQL this command wasn't able to run.
|
||||||
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
|
The full error: %s""" % (connection.settings_dict.DATABASE_NAME, e))
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=db)
|
||||||
|
|
||||||
# Emit the post sync signal. This allows individual
|
# Emit the post sync signal. This allows individual
|
||||||
# applications to respond as if the database had been
|
# applications to respond as if the database had been
|
||||||
# sync'd from scratch.
|
# sync'd from scratch.
|
||||||
emit_post_sync_signal(models.get_models(), verbosity, interactive, connection)
|
emit_post_sync_signal(models.get_models(), verbosity, interactive, db)
|
||||||
|
|
||||||
# Reinstall the initial_data fixture.
|
# Reinstall the initial_data fixture.
|
||||||
call_command('loaddata', 'initial_data', **options)
|
kwargs = options.copy()
|
||||||
|
kwargs['database'] = db
|
||||||
|
call_command('loaddata', 'initial_data', **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print "Flush cancelled."
|
print "Flush cancelled."
|
||||||
|
@ -4,8 +4,13 @@ import gzip
|
|||||||
import zipfile
|
import zipfile
|
||||||
from optparse import make_option
|
from optparse import make_option
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core import serializers
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.core.management.color import no_style
|
from django.core.management.color import no_style
|
||||||
|
from django.db import connections, transaction, DEFAULT_DB_ALIAS
|
||||||
|
from django.db.models import get_apps
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set
|
set
|
||||||
@ -22,12 +27,15 @@ class Command(BaseCommand):
|
|||||||
help = 'Installs the named fixture(s) in the database.'
|
help = 'Installs the named fixture(s) in the database.'
|
||||||
args = "fixture [fixture ...]"
|
args = "fixture [fixture ...]"
|
||||||
|
|
||||||
def handle(self, *fixture_labels, **options):
|
option_list = BaseCommand.option_list + (
|
||||||
from django.db.models import get_apps
|
make_option('--database', action='store', dest='database',
|
||||||
from django.core import serializers
|
default=DEFAULT_DB_ALIAS, help='Nominates a specific database to load '
|
||||||
from django.db import connection, transaction
|
'fixtures into. By default uses the "default" database.'),
|
||||||
from django.conf import settings
|
)
|
||||||
|
|
||||||
|
def handle(self, *fixture_labels, **options):
|
||||||
|
using = options['database']
|
||||||
|
connection = connections[using]
|
||||||
self.style = no_style()
|
self.style = no_style()
|
||||||
|
|
||||||
verbosity = int(options.get('verbosity', 1))
|
verbosity = int(options.get('verbosity', 1))
|
||||||
@ -56,9 +64,9 @@ class Command(BaseCommand):
|
|||||||
# Start transaction management. All fixtures are installed in a
|
# Start transaction management. All fixtures are installed in a
|
||||||
# single transaction to ensure that all references are resolved.
|
# single transaction to ensure that all references are resolved.
|
||||||
if commit:
|
if commit:
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=using)
|
||||||
transaction.enter_transaction_management()
|
transaction.enter_transaction_management(using=using)
|
||||||
transaction.managed(True)
|
transaction.managed(True, using=using)
|
||||||
|
|
||||||
class SingleZipReader(zipfile.ZipFile):
|
class SingleZipReader(zipfile.ZipFile):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -103,8 +111,8 @@ class Command(BaseCommand):
|
|||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." %
|
self.style.ERROR("Problem installing fixture '%s': %s is not a known serialization format." %
|
||||||
(fixture_name, format)))
|
(fixture_name, format)))
|
||||||
transaction.rollback()
|
transaction.rollback(using=using)
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
return
|
return
|
||||||
|
|
||||||
if os.path.isabs(fixture_name):
|
if os.path.isabs(fixture_name):
|
||||||
@ -119,25 +127,25 @@ class Command(BaseCommand):
|
|||||||
label_found = False
|
label_found = False
|
||||||
for format in formats:
|
for format in formats:
|
||||||
for compression_format in compression_formats:
|
for compression_format in compression_formats:
|
||||||
if compression_format:
|
if compression_format:
|
||||||
file_name = '.'.join([fixture_name, format,
|
file_name = '.'.join([fixture_name, format,
|
||||||
compression_format])
|
compression_format])
|
||||||
else:
|
else:
|
||||||
file_name = '.'.join([fixture_name, format])
|
file_name = '.'.join([fixture_name, format])
|
||||||
|
|
||||||
if verbosity > 1:
|
if verbosity > 1:
|
||||||
print "Trying %s for %s fixture '%s'..." % \
|
print "Trying %s for %s fixture '%s'..." % \
|
||||||
(humanize(fixture_dir), file_name, fixture_name)
|
(humanize(fixture_dir), file_name, fixture_name)
|
||||||
full_path = os.path.join(fixture_dir, file_name)
|
full_path = os.path.join(fixture_dir, file_name)
|
||||||
open_method = compression_types[compression_format]
|
open_method = compression_types[compression_format]
|
||||||
try:
|
try:
|
||||||
fixture = open_method(full_path, 'r')
|
fixture = open_method(full_path, 'r')
|
||||||
if label_found:
|
if label_found:
|
||||||
fixture.close()
|
fixture.close()
|
||||||
print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." %
|
print self.style.ERROR("Multiple fixtures named '%s' in %s. Aborting." %
|
||||||
(fixture_name, humanize(fixture_dir)))
|
(fixture_name, humanize(fixture_dir)))
|
||||||
transaction.rollback()
|
transaction.rollback(using=using)
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
fixture_count += 1
|
fixture_count += 1
|
||||||
@ -150,7 +158,7 @@ class Command(BaseCommand):
|
|||||||
for obj in objects:
|
for obj in objects:
|
||||||
objects_in_fixture += 1
|
objects_in_fixture += 1
|
||||||
models.add(obj.object.__class__)
|
models.add(obj.object.__class__)
|
||||||
obj.save()
|
obj.save(using=using)
|
||||||
object_count += objects_in_fixture
|
object_count += objects_in_fixture
|
||||||
label_found = True
|
label_found = True
|
||||||
except (SystemExit, KeyboardInterrupt):
|
except (SystemExit, KeyboardInterrupt):
|
||||||
@ -158,15 +166,15 @@ class Command(BaseCommand):
|
|||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
fixture.close()
|
fixture.close()
|
||||||
transaction.rollback()
|
transaction.rollback(using=using)
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
if show_traceback:
|
if show_traceback:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
self.style.ERROR("Problem installing fixture '%s': %s\n" %
|
self.style.ERROR("Problem installing fixture '%s': %s\n" %
|
||||||
(full_path, ''.join(traceback.format_exception(sys.exc_type,
|
(full_path, ''.join(traceback.format_exception(sys.exc_type,
|
||||||
sys.exc_value, sys.exc_traceback)))))
|
sys.exc_value, sys.exc_traceback)))))
|
||||||
return
|
return
|
||||||
fixture.close()
|
fixture.close()
|
||||||
|
|
||||||
@ -176,8 +184,8 @@ class Command(BaseCommand):
|
|||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" %
|
self.style.ERROR("No fixture data found for '%s'. (File format may be invalid.)" %
|
||||||
(fixture_name)))
|
(fixture_name)))
|
||||||
transaction.rollback()
|
transaction.rollback(using=using)
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
@ -196,8 +204,8 @@ class Command(BaseCommand):
|
|||||||
cursor.execute(line)
|
cursor.execute(line)
|
||||||
|
|
||||||
if commit:
|
if commit:
|
||||||
transaction.commit()
|
transaction.commit(using=using)
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
|
|
||||||
if object_count == 0:
|
if object_count == 0:
|
||||||
if verbosity > 1:
|
if verbosity > 1:
|
||||||
|
@ -10,7 +10,7 @@ class Command(AppCommand):
|
|||||||
make_option('--database', action='store', dest='database',
|
make_option('--database', action='store', dest='database',
|
||||||
default='default', help='Nominates a database to print the SQL '
|
default='default', help='Nominates a database to print the SQL '
|
||||||
'for. Defaults to the "default" database.'),
|
'for. Defaults to the "default" database.'),
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transaction = True
|
output_transaction = True
|
||||||
|
|
||||||
|
@ -32,9 +32,9 @@ class Command(NoArgsCommand):
|
|||||||
self.style = no_style()
|
self.style = no_style()
|
||||||
|
|
||||||
if not options['database']:
|
if not options['database']:
|
||||||
dbs = connections.all()
|
dbs = connections
|
||||||
else:
|
else:
|
||||||
dbs = [connections[options['database']]]
|
dbs = [options['database']]
|
||||||
|
|
||||||
# Import the 'management' module within each installed app, to register
|
# Import the 'management' module within each installed app, to register
|
||||||
# dispatcher events.
|
# dispatcher events.
|
||||||
@ -55,7 +55,8 @@ class Command(NoArgsCommand):
|
|||||||
if not msg.startswith('No module named') or 'management' not in msg:
|
if not msg.startswith('No module named') or 'management' not in msg:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
for connection in dbs:
|
for db in dbs:
|
||||||
|
connection = connections[db]
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
|
|
||||||
# Get a list of already installed *models* so that references work right.
|
# Get a list of already installed *models* so that references work right.
|
||||||
@ -102,11 +103,11 @@ class Command(NoArgsCommand):
|
|||||||
for statement in sql:
|
for statement in sql:
|
||||||
cursor.execute(statement)
|
cursor.execute(statement)
|
||||||
|
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=db)
|
||||||
|
|
||||||
# Send the post_syncdb signal, so individual apps can do whatever they need
|
# Send the post_syncdb signal, so individual apps can do whatever they need
|
||||||
# to do at this point.
|
# to do at this point.
|
||||||
emit_post_sync_signal(created_models, verbosity, interactive, connection)
|
emit_post_sync_signal(created_models, verbosity, interactive, db)
|
||||||
|
|
||||||
# The connection may have been closed by a syncdb handler.
|
# The connection may have been closed by a syncdb handler.
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
@ -130,9 +131,9 @@ class Command(NoArgsCommand):
|
|||||||
if show_traceback:
|
if show_traceback:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
transaction.rollback_unless_managed()
|
transaction.rollback_unless_managed(using=db)
|
||||||
else:
|
else:
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=db)
|
||||||
else:
|
else:
|
||||||
if verbosity >= 2:
|
if verbosity >= 2:
|
||||||
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
|
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
|
||||||
@ -151,13 +152,9 @@ class Command(NoArgsCommand):
|
|||||||
except Exception, e:
|
except Exception, e:
|
||||||
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
|
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
|
||||||
(app_name, model._meta.object_name, e))
|
(app_name, model._meta.object_name, e))
|
||||||
transaction.rollback_unless_managed()
|
transaction.rollback_unless_managed(using=db)
|
||||||
else:
|
else:
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=db)
|
||||||
|
|
||||||
# Install the 'initial_data' fixture, using format discovery
|
from django.core.management import call_command
|
||||||
# FIXME we only load the fixture data for one DB right now, since we
|
call_command('loaddata', 'initial_data', verbosity=verbosity, database=db)
|
||||||
# can't control what DB it does into, once we can control this we
|
|
||||||
# should move it back into the DB loop
|
|
||||||
from django.core.management import call_command
|
|
||||||
call_command('loaddata', 'initial_data', verbosity=verbosity)
|
|
||||||
|
@ -188,7 +188,7 @@ def custom_sql_for_model(model, style, connection):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def emit_post_sync_signal(created_models, verbosity, interactive, connection):
|
def emit_post_sync_signal(created_models, verbosity, interactive, db):
|
||||||
# Emit the post_sync signal for every application.
|
# Emit the post_sync signal for every application.
|
||||||
for app in models.get_apps():
|
for app in models.get_apps():
|
||||||
app_name = app.__name__.split('.')[-2]
|
app_name = app.__name__.split('.')[-2]
|
||||||
@ -196,4 +196,4 @@ def emit_post_sync_signal(created_models, verbosity, interactive, connection):
|
|||||||
print "Running post-sync handlers for application", app_name
|
print "Running post-sync handlers for application", app_name
|
||||||
models.signals.post_syncdb.send(sender=app, app=app,
|
models.signals.post_syncdb.send(sender=app, app=app,
|
||||||
created_models=created_models, verbosity=verbosity,
|
created_models=created_models, verbosity=verbosity,
|
||||||
interactive=interactive, connection=connection)
|
interactive=interactive, db=db)
|
||||||
|
@ -154,13 +154,13 @@ class DeserializedObject(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<DeserializedObject: %s>" % smart_str(self.object)
|
return "<DeserializedObject: %s>" % smart_str(self.object)
|
||||||
|
|
||||||
def save(self, save_m2m=True):
|
def save(self, save_m2m=True, using=None):
|
||||||
# Call save on the Model baseclass directly. This bypasses any
|
# Call save on the Model baseclass directly. This bypasses any
|
||||||
# model-defined save. The save is also forced to be raw.
|
# model-defined save. The save is also forced to be raw.
|
||||||
# This ensures that the data that is deserialized is literally
|
# This ensures that the data that is deserialized is literally
|
||||||
# what came from the file, not post-processed by pre_save/save
|
# what came from the file, not post-processed by pre_save/save
|
||||||
# methods.
|
# methods.
|
||||||
models.Model.save_base(self.object, raw=True)
|
models.Model.save_base(self.object, using=using, raw=True)
|
||||||
if self.m2m_data and save_m2m:
|
if self.m2m_data and save_m2m:
|
||||||
for accessor_name, object_list in self.m2m_data.items():
|
for accessor_name, object_list in self.m2m_data.items():
|
||||||
setattr(self.object, accessor_name, object_list)
|
setattr(self.object, accessor_name, object_list)
|
||||||
|
@ -54,12 +54,13 @@ def reset_queries(**kwargs):
|
|||||||
connection.queries = []
|
connection.queries = []
|
||||||
signals.request_started.connect(reset_queries)
|
signals.request_started.connect(reset_queries)
|
||||||
|
|
||||||
# Register an event that rolls back the connection
|
# Register an event that rolls back the connections
|
||||||
# when a Django request has an exception.
|
# when a Django request has an exception.
|
||||||
def _rollback_on_exception(**kwargs):
|
def _rollback_on_exception(**kwargs):
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
try:
|
for conn in connections:
|
||||||
transaction.rollback_unless_managed()
|
try:
|
||||||
except DatabaseError:
|
transaction.rollback_unless_managed(using=conn)
|
||||||
pass
|
except DatabaseError:
|
||||||
|
pass
|
||||||
signals.got_request_exception.connect(_rollback_on_exception)
|
signals.got_request_exception.connect(_rollback_on_exception)
|
||||||
|
@ -90,7 +90,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
super(DatabaseWrapper, self).__init__(*args, **kwargs)
|
super(DatabaseWrapper, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.features = DatabaseFeatures()
|
self.features = DatabaseFeatures()
|
||||||
self.ops = DatabaseOperations()
|
self.ops = DatabaseOperations(self)
|
||||||
self.client = DatabaseClient(self)
|
self.client = DatabaseClient(self)
|
||||||
self.creation = DatabaseCreation(self)
|
self.creation = DatabaseCreation(self)
|
||||||
self.introspection = DatabaseIntrospection(self)
|
self.introspection = DatabaseIntrospection(self)
|
||||||
|
@ -6,14 +6,14 @@ from django.db.backends import BaseDatabaseOperations
|
|||||||
# used by both the 'postgresql' and 'postgresql_psycopg2' backends.
|
# used by both the 'postgresql' and 'postgresql_psycopg2' backends.
|
||||||
|
|
||||||
class DatabaseOperations(BaseDatabaseOperations):
|
class DatabaseOperations(BaseDatabaseOperations):
|
||||||
def __init__(self):
|
def __init__(self, connection):
|
||||||
self._postgres_version = None
|
self._postgres_version = None
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
def _get_postgres_version(self):
|
def _get_postgres_version(self):
|
||||||
if self._postgres_version is None:
|
if self._postgres_version is None:
|
||||||
from django.db import connection
|
|
||||||
from django.db.backends.postgresql.version import get_version
|
from django.db.backends.postgresql.version import get_version
|
||||||
cursor = connection.cursor()
|
cursor = self.connection.cursor()
|
||||||
self._postgres_version = get_version(cursor)
|
self._postgres_version = get_version(cursor)
|
||||||
return self._postgres_version
|
return self._postgres_version
|
||||||
postgres_version = property(_get_postgres_version)
|
postgres_version = property(_get_postgres_version)
|
||||||
|
@ -66,7 +66,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||||||
autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False)
|
autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False)
|
||||||
self.features.uses_autocommit = autocommit
|
self.features.uses_autocommit = autocommit
|
||||||
self._set_isolation_level(int(not autocommit))
|
self._set_isolation_level(int(not autocommit))
|
||||||
self.ops = DatabaseOperations()
|
self.ops = DatabaseOperations(self)
|
||||||
self.client = DatabaseClient(self)
|
self.client = DatabaseClient(self)
|
||||||
self.creation = DatabaseCreation(self)
|
self.creation = DatabaseCreation(self)
|
||||||
self.introspection = DatabaseIntrospection(self)
|
self.introspection = DatabaseIntrospection(self)
|
||||||
|
@ -15,7 +15,7 @@ from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneF
|
|||||||
from django.db.models.query import delete_objects, Q
|
from django.db.models.query import delete_objects, Q
|
||||||
from django.db.models.query_utils import CollectedObjects, DeferredAttribute
|
from django.db.models.query_utils import CollectedObjects, DeferredAttribute
|
||||||
from django.db.models.options import Options
|
from django.db.models.options import Options
|
||||||
from django.db import connection, transaction, DatabaseError
|
from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS
|
||||||
from django.db.models import signals
|
from django.db.models import signals
|
||||||
from django.db.models.loading import register_models, get_model
|
from django.db.models.loading import register_models, get_model
|
||||||
from django.utils.functional import curry
|
from django.utils.functional import curry
|
||||||
@ -395,7 +395,7 @@ class Model(object):
|
|||||||
return getattr(self, field_name)
|
return getattr(self, field_name)
|
||||||
return getattr(self, field.attname)
|
return getattr(self, field.attname)
|
||||||
|
|
||||||
def save(self, force_insert=False, force_update=False):
|
def save(self, force_insert=False, force_update=False, using=None):
|
||||||
"""
|
"""
|
||||||
Saves the current instance. Override this in a subclass if you want to
|
Saves the current instance. Override this in a subclass if you want to
|
||||||
control the saving process.
|
control the saving process.
|
||||||
@ -407,18 +407,21 @@ class Model(object):
|
|||||||
if force_insert and force_update:
|
if force_insert and force_update:
|
||||||
raise ValueError("Cannot force both insert and updating in "
|
raise ValueError("Cannot force both insert and updating in "
|
||||||
"model saving.")
|
"model saving.")
|
||||||
self.save_base(force_insert=force_insert, force_update=force_update)
|
self.save_base(using=using, force_insert=force_insert, force_update=force_update)
|
||||||
|
|
||||||
save.alters_data = True
|
save.alters_data = True
|
||||||
|
|
||||||
def save_base(self, raw=False, cls=None, force_insert=False,
|
def save_base(self, raw=False, cls=None, force_insert=False,
|
||||||
force_update=False):
|
force_update=False, using=None):
|
||||||
"""
|
"""
|
||||||
Does the heavy-lifting involved in saving. Subclasses shouldn't need to
|
Does the heavy-lifting involved in saving. Subclasses shouldn't need to
|
||||||
override this method. It's separate from save() in order to hide the
|
override this method. It's separate from save() in order to hide the
|
||||||
need for overrides of save() to pass around internal-only parameters
|
need for overrides of save() to pass around internal-only parameters
|
||||||
('raw' and 'cls').
|
('raw' and 'cls').
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
using = DEFAULT_DB_ALIAS
|
||||||
|
connection = connections[using]
|
||||||
assert not (force_insert and force_update)
|
assert not (force_insert and force_update)
|
||||||
if not cls:
|
if not cls:
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
@ -441,7 +444,7 @@ class Model(object):
|
|||||||
if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
|
if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
|
||||||
setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
|
setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
|
||||||
|
|
||||||
self.save_base(cls=parent)
|
self.save_base(cls=parent, using=using)
|
||||||
if field:
|
if field:
|
||||||
setattr(self, field.attname, self._get_pk_val(parent._meta))
|
setattr(self, field.attname, self._get_pk_val(parent._meta))
|
||||||
if meta.proxy:
|
if meta.proxy:
|
||||||
@ -457,12 +460,13 @@ class Model(object):
|
|||||||
manager = cls._base_manager
|
manager = cls._base_manager
|
||||||
if pk_set:
|
if pk_set:
|
||||||
# Determine whether a record with the primary key already exists.
|
# Determine whether a record with the primary key already exists.
|
||||||
|
# FIXME work with the using parameter
|
||||||
if (force_update or (not force_insert and
|
if (force_update or (not force_insert and
|
||||||
manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
|
manager.using(using).filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())):
|
||||||
# It does already exist, so do an UPDATE.
|
# It does already exist, so do an UPDATE.
|
||||||
if force_update or non_pks:
|
if force_update or non_pks:
|
||||||
values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
|
values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
|
||||||
rows = manager.filter(pk=pk_val)._update(values)
|
rows = manager.using(using).filter(pk=pk_val)._update(values)
|
||||||
if force_update and not rows:
|
if force_update and not rows:
|
||||||
raise DatabaseError("Forced update did not affect any rows.")
|
raise DatabaseError("Forced update did not affect any rows.")
|
||||||
else:
|
else:
|
||||||
@ -477,20 +481,20 @@ class Model(object):
|
|||||||
|
|
||||||
if meta.order_with_respect_to:
|
if meta.order_with_respect_to:
|
||||||
field = meta.order_with_respect_to
|
field = meta.order_with_respect_to
|
||||||
values.append((meta.get_field_by_name('_order')[0], manager.filter(**{field.name: getattr(self, field.attname)}).count()))
|
values.append((meta.get_field_by_name('_order')[0], manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()))
|
||||||
record_exists = False
|
record_exists = False
|
||||||
|
|
||||||
update_pk = bool(meta.has_auto_field and not pk_set)
|
update_pk = bool(meta.has_auto_field and not pk_set)
|
||||||
if values:
|
if values:
|
||||||
# Create a new record.
|
# Create a new record.
|
||||||
result = manager._insert(values, return_id=update_pk)
|
result = manager._insert(values, return_id=update_pk, using=using)
|
||||||
else:
|
else:
|
||||||
# Create a new record with defaults for everything.
|
# Create a new record with defaults for everything.
|
||||||
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True)
|
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
|
||||||
|
|
||||||
if update_pk:
|
if update_pk:
|
||||||
setattr(self, meta.pk.attname, result)
|
setattr(self, meta.pk.attname, result)
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=using)
|
||||||
|
|
||||||
if signal:
|
if signal:
|
||||||
signals.post_save.send(sender=self.__class__, instance=self,
|
signals.post_save.send(sender=self.__class__, instance=self,
|
||||||
@ -549,7 +553,10 @@ class Model(object):
|
|||||||
# delete it and all its descendents.
|
# delete it and all its descendents.
|
||||||
parent_obj._collect_sub_objects(seen_objs)
|
parent_obj._collect_sub_objects(seen_objs)
|
||||||
|
|
||||||
def delete(self):
|
def delete(self, using=None):
|
||||||
|
if using is None:
|
||||||
|
using = DEFAULT_DB_ALIAS
|
||||||
|
connection = connections[using]
|
||||||
assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
|
assert self._get_pk_val() is not None, "%s object can't be deleted because its %s attribute is set to None." % (self._meta.object_name, self._meta.pk.attname)
|
||||||
|
|
||||||
# Find all the objects than need to be deleted.
|
# Find all the objects than need to be deleted.
|
||||||
@ -557,7 +564,7 @@ class Model(object):
|
|||||||
self._collect_sub_objects(seen_objs)
|
self._collect_sub_objects(seen_objs)
|
||||||
|
|
||||||
# Actually delete the objects.
|
# Actually delete the objects.
|
||||||
delete_objects(seen_objs)
|
delete_objects(seen_objs, using)
|
||||||
|
|
||||||
delete.alters_data = True
|
delete.alters_data = True
|
||||||
|
|
||||||
@ -610,7 +617,8 @@ def method_set_order(ordered_obj, self, id_list):
|
|||||||
# for situations like this.
|
# for situations like this.
|
||||||
for i, j in enumerate(id_list):
|
for i, j in enumerate(id_list):
|
||||||
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
|
ordered_obj.objects.filter(**{'pk': j, order_name: rel_val}).update(_order=i)
|
||||||
transaction.commit_unless_managed()
|
# TODO
|
||||||
|
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
|
||||||
|
|
||||||
|
|
||||||
def method_get_order(ordered_obj, self):
|
def method_get_order(ordered_obj, self):
|
||||||
|
@ -41,8 +41,8 @@ class ExpressionNode(tree.Node):
|
|||||||
def prepare(self, evaluator, query, allow_joins):
|
def prepare(self, evaluator, query, allow_joins):
|
||||||
return evaluator.prepare_node(self, query, allow_joins)
|
return evaluator.prepare_node(self, query, allow_joins)
|
||||||
|
|
||||||
def evaluate(self, evaluator, qn):
|
def evaluate(self, evaluator, qn, connection):
|
||||||
return evaluator.evaluate_node(self, qn)
|
return evaluator.evaluate_node(self, qn, connection)
|
||||||
|
|
||||||
#############
|
#############
|
||||||
# OPERATORS #
|
# OPERATORS #
|
||||||
@ -109,5 +109,5 @@ class F(ExpressionNode):
|
|||||||
def prepare(self, evaluator, query, allow_joins):
|
def prepare(self, evaluator, query, allow_joins):
|
||||||
return evaluator.prepare_leaf(self, query, allow_joins)
|
return evaluator.prepare_leaf(self, query, allow_joins)
|
||||||
|
|
||||||
def evaluate(self, evaluator, qn):
|
def evaluate(self, evaluator, qn, connection):
|
||||||
return evaluator.evaluate_leaf(self, qn)
|
return evaluator.evaluate_leaf(self, qn, connection)
|
||||||
|
@ -233,6 +233,24 @@ class Field(object):
|
|||||||
|
|
||||||
raise TypeError("Field has invalid lookup: %s" % lookup_type)
|
raise TypeError("Field has invalid lookup: %s" % lookup_type)
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
"""
|
||||||
|
Validate that the data is valid, as much so as possible without knowing
|
||||||
|
what connection we are using. Returns True if the value was
|
||||||
|
successfully validated and false if the value wasn't validated (this
|
||||||
|
doesn't consider whether the value was actually valid, an exception is
|
||||||
|
raised in those circumstances).
|
||||||
|
"""
|
||||||
|
if hasattr(value, 'validate') or hasattr(value, '_validate'):
|
||||||
|
if hasattr(value, 'validate'):
|
||||||
|
value.validate()
|
||||||
|
else:
|
||||||
|
value._validate()
|
||||||
|
return True
|
||||||
|
if lookup_type == 'isnull':
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def has_default(self):
|
def has_default(self):
|
||||||
"Returns a boolean of whether this field has a default value."
|
"Returns a boolean of whether this field has a default value."
|
||||||
return self.default is not NOT_PROVIDED
|
return self.default is not NOT_PROVIDED
|
||||||
@ -360,6 +378,17 @@ class AutoField(Field):
|
|||||||
return None
|
return None
|
||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
if super(AutoField, self).validate(lookup_type, value):
|
||||||
|
return
|
||||||
|
if value is None or hasattr(value, 'as_sql'):
|
||||||
|
return
|
||||||
|
if lookup_type in ('range', 'in'):
|
||||||
|
for val in value:
|
||||||
|
int(val)
|
||||||
|
else:
|
||||||
|
int(value)
|
||||||
|
|
||||||
def contribute_to_class(self, cls, name):
|
def contribute_to_class(self, cls, name):
|
||||||
assert not cls._meta.has_auto_field, "A model can't have more than one AutoField."
|
assert not cls._meta.has_auto_field, "A model can't have more than one AutoField."
|
||||||
super(AutoField, self).contribute_to_class(cls, name)
|
super(AutoField, self).contribute_to_class(cls, name)
|
||||||
@ -396,6 +425,13 @@ class BooleanField(Field):
|
|||||||
value = bool(int(value))
|
value = bool(int(value))
|
||||||
return super(BooleanField, self).get_db_prep_lookup(lookup_type, value)
|
return super(BooleanField, self).get_db_prep_lookup(lookup_type, value)
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
if super(BooleanField, self).validate(lookup_type, value):
|
||||||
|
return
|
||||||
|
if value in ('1', '0'):
|
||||||
|
value = int(value)
|
||||||
|
bool(value)
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def get_db_prep_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
@ -510,6 +546,20 @@ class DateField(Field):
|
|||||||
# Casts dates into the format expected by the backend
|
# Casts dates into the format expected by the backend
|
||||||
return connection.ops.value_to_db_date(self.to_python(value))
|
return connection.ops.value_to_db_date(self.to_python(value))
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
if super(DateField, self).validate(lookup_type, value):
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
return
|
||||||
|
if lookup_type in ('month', 'day', 'year', 'week_day'):
|
||||||
|
int(value)
|
||||||
|
return
|
||||||
|
if lookup_type in ('in', 'range'):
|
||||||
|
for val in value:
|
||||||
|
self.to_python(val)
|
||||||
|
return
|
||||||
|
self.to_python(value)
|
||||||
|
|
||||||
def value_to_string(self, obj):
|
def value_to_string(self, obj):
|
||||||
val = self._get_val_from_obj(obj)
|
val = self._get_val_from_obj(obj)
|
||||||
if val is None:
|
if val is None:
|
||||||
@ -754,6 +804,11 @@ class NullBooleanField(Field):
|
|||||||
value = bool(int(value))
|
value = bool(int(value))
|
||||||
return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value)
|
return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value)
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
if value in ('1', '0'):
|
||||||
|
value = int(value)
|
||||||
|
bool(value)
|
||||||
|
|
||||||
def get_db_prep_value(self, value):
|
def get_db_prep_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from django.db import connection, transaction
|
from django.db import connection, transaction, DEFAULT_DB_ALIAS
|
||||||
from django.db.backends import util
|
from django.db.backends import util
|
||||||
from django.db.models import signals, get_model
|
from django.db.models import signals, get_model
|
||||||
from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist
|
from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist
|
||||||
@ -478,7 +478,9 @@ def create_many_related_manager(superclass, through=False):
|
|||||||
cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
|
cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
|
||||||
(self.join_table, source_col_name, target_col_name),
|
(self.join_table, source_col_name, target_col_name),
|
||||||
[self._pk_val, obj_id])
|
[self._pk_val, obj_id])
|
||||||
transaction.commit_unless_managed()
|
# FIXME, once this isn't in related.py it should conditionally
|
||||||
|
# use the right DB.
|
||||||
|
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
|
||||||
|
|
||||||
def _remove_items(self, source_col_name, target_col_name, *objs):
|
def _remove_items(self, source_col_name, target_col_name, *objs):
|
||||||
# source_col_name: the PK colname in join_table for the source object
|
# source_col_name: the PK colname in join_table for the source object
|
||||||
@ -508,7 +510,8 @@ def create_many_related_manager(superclass, through=False):
|
|||||||
cursor.execute("DELETE FROM %s WHERE %s = %%s" % \
|
cursor.execute("DELETE FROM %s WHERE %s = %%s" % \
|
||||||
(self.join_table, source_col_name),
|
(self.join_table, source_col_name),
|
||||||
[self._pk_val])
|
[self._pk_val])
|
||||||
transaction.commit_unless_managed()
|
# TODO
|
||||||
|
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
|
||||||
|
|
||||||
return ManyRelatedManager
|
return ManyRelatedManager
|
||||||
|
|
||||||
|
@ -173,6 +173,9 @@ class Manager(object):
|
|||||||
def only(self, *args, **kwargs):
|
def only(self, *args, **kwargs):
|
||||||
return self.get_query_set().only(*args, **kwargs)
|
return self.get_query_set().only(*args, **kwargs)
|
||||||
|
|
||||||
|
def using(self, *args, **kwargs):
|
||||||
|
return self.get_query_set().using(*args, **kwargs)
|
||||||
|
|
||||||
def _insert(self, values, **kwargs):
|
def _insert(self, values, **kwargs):
|
||||||
return insert_query(self.model, values, **kwargs)
|
return insert_query(self.model, values, **kwargs)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ try:
|
|||||||
except NameError:
|
except NameError:
|
||||||
from sets import Set as set # Python 2.3 fallback
|
from sets import Set as set # Python 2.3 fallback
|
||||||
|
|
||||||
from django.db import connection, transaction, IntegrityError
|
from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS
|
||||||
from django.db.models.aggregates import Aggregate
|
from django.db.models.aggregates import Aggregate
|
||||||
from django.db.models.fields import DateField
|
from django.db.models.fields import DateField
|
||||||
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
|
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
|
||||||
@ -31,10 +31,13 @@ class QuerySet(object):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, model=None, query=None):
|
def __init__(self, model=None, query=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
connection = connections[DEFAULT_DB_ALIAS]
|
||||||
self.query = query or sql.Query(self.model, connection)
|
self.query = query or sql.Query(self.model, connection)
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
self._iter = None
|
self._iter = None
|
||||||
self._sticky_filter = False
|
self._sticky_filter = False
|
||||||
|
self._using = DEFAULT_DB_ALIAS # this will be wrong if a custom Query
|
||||||
|
# is provided with a non default connection
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# PYTHON MAGIC METHODS #
|
# PYTHON MAGIC METHODS #
|
||||||
@ -300,12 +303,12 @@ class QuerySet(object):
|
|||||||
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
|
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
|
||||||
params.update(defaults)
|
params.update(defaults)
|
||||||
obj = self.model(**params)
|
obj = self.model(**params)
|
||||||
sid = transaction.savepoint()
|
sid = transaction.savepoint(using=self._using)
|
||||||
obj.save(force_insert=True)
|
obj.save(force_insert=True, using=self._using)
|
||||||
transaction.savepoint_commit(sid)
|
transaction.savepoint_commit(sid, using=self._using)
|
||||||
return obj, True
|
return obj, True
|
||||||
except IntegrityError, e:
|
except IntegrityError, e:
|
||||||
transaction.savepoint_rollback(sid)
|
transaction.savepoint_rollback(sid, using=self._using)
|
||||||
try:
|
try:
|
||||||
return self.get(**kwargs), False
|
return self.get(**kwargs), False
|
||||||
except self.model.DoesNotExist:
|
except self.model.DoesNotExist:
|
||||||
@ -364,7 +367,7 @@ class QuerySet(object):
|
|||||||
|
|
||||||
if not seen_objs:
|
if not seen_objs:
|
||||||
break
|
break
|
||||||
delete_objects(seen_objs)
|
delete_objects(seen_objs, del_query._using)
|
||||||
|
|
||||||
# Clear the result cache, in case this QuerySet gets reused.
|
# Clear the result cache, in case this QuerySet gets reused.
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
@ -379,20 +382,20 @@ class QuerySet(object):
|
|||||||
"Cannot update a query once a slice has been taken."
|
"Cannot update a query once a slice has been taken."
|
||||||
query = self.query.clone(sql.UpdateQuery)
|
query = self.query.clone(sql.UpdateQuery)
|
||||||
query.add_update_values(kwargs)
|
query.add_update_values(kwargs)
|
||||||
if not transaction.is_managed():
|
if not transaction.is_managed(using=self._using):
|
||||||
transaction.enter_transaction_management()
|
transaction.enter_transaction_management(using=self._using)
|
||||||
forced_managed = True
|
forced_managed = True
|
||||||
else:
|
else:
|
||||||
forced_managed = False
|
forced_managed = False
|
||||||
try:
|
try:
|
||||||
rows = query.execute_sql(None)
|
rows = query.execute_sql(None)
|
||||||
if forced_managed:
|
if forced_managed:
|
||||||
transaction.commit()
|
transaction.commit(using=self._using)
|
||||||
else:
|
else:
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=self._using)
|
||||||
finally:
|
finally:
|
||||||
if forced_managed:
|
if forced_managed:
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=self._using)
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
return rows
|
return rows
|
||||||
update.alters_data = True
|
update.alters_data = True
|
||||||
@ -616,6 +619,16 @@ class QuerySet(object):
|
|||||||
clone.query.add_immediate_loading(fields)
|
clone.query.add_immediate_loading(fields)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
def using(self, alias):
|
||||||
|
"""
|
||||||
|
Selects which database this QuerySet should excecute it's query against.
|
||||||
|
"""
|
||||||
|
clone = self._clone()
|
||||||
|
clone._using = alias
|
||||||
|
connection = connections[alias]
|
||||||
|
clone.query.set_connection(connection)
|
||||||
|
return clone
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
# PUBLIC INTROSPECTION ATTRIBUTES #
|
# PUBLIC INTROSPECTION ATTRIBUTES #
|
||||||
###################################
|
###################################
|
||||||
@ -644,6 +657,7 @@ class QuerySet(object):
|
|||||||
if self._sticky_filter:
|
if self._sticky_filter:
|
||||||
query.filter_is_sticky = True
|
query.filter_is_sticky = True
|
||||||
c = klass(model=self.model, query=query)
|
c = klass(model=self.model, query=query)
|
||||||
|
c._using = self._using
|
||||||
c.__dict__.update(kwargs)
|
c.__dict__.update(kwargs)
|
||||||
if setup and hasattr(c, '_setup_query'):
|
if setup and hasattr(c, '_setup_query'):
|
||||||
c._setup_query()
|
c._setup_query()
|
||||||
@ -700,6 +714,13 @@ class QuerySet(object):
|
|||||||
obj = self.values("pk")
|
obj = self.values("pk")
|
||||||
return obj.query.as_nested_sql()
|
return obj.query.as_nested_sql()
|
||||||
|
|
||||||
|
def _validate(self):
|
||||||
|
"""
|
||||||
|
A normal QuerySet is always valid when used as the RHS of a filter,
|
||||||
|
since it automatically gets filtered down to 1 field.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
# When used as part of a nested query, a queryset will never be an "always
|
# When used as part of a nested query, a queryset will never be an "always
|
||||||
# empty" result.
|
# empty" result.
|
||||||
value_annotation = True
|
value_annotation = True
|
||||||
@ -818,6 +839,17 @@ class ValuesQuerySet(QuerySet):
|
|||||||
% self.__class__.__name__)
|
% self.__class__.__name__)
|
||||||
return self._clone().query.as_nested_sql()
|
return self._clone().query.as_nested_sql()
|
||||||
|
|
||||||
|
def _validate(self):
|
||||||
|
"""
|
||||||
|
Validates that we aren't trying to do a query like
|
||||||
|
value__in=qs.values('value1', 'value2'), which isn't valid.
|
||||||
|
"""
|
||||||
|
if ((self._fields and len(self._fields) > 1) or
|
||||||
|
(not self._fields and len(self.model._meta.fields) > 1)):
|
||||||
|
raise TypeError('Cannot use a multi-field %s as a filter value.'
|
||||||
|
% self.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
class ValuesListQuerySet(ValuesQuerySet):
|
class ValuesListQuerySet(ValuesQuerySet):
|
||||||
def iterator(self):
|
def iterator(self):
|
||||||
if self.flat and len(self._fields) == 1:
|
if self.flat and len(self._fields) == 1:
|
||||||
@ -970,13 +1002,14 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
|||||||
setattr(obj, f.get_cache_name(), rel_obj)
|
setattr(obj, f.get_cache_name(), rel_obj)
|
||||||
return obj, index_end
|
return obj, index_end
|
||||||
|
|
||||||
def delete_objects(seen_objs):
|
def delete_objects(seen_objs, using):
|
||||||
"""
|
"""
|
||||||
Iterate through a list of seen classes, and remove any instances that are
|
Iterate through a list of seen classes, and remove any instances that are
|
||||||
referred to.
|
referred to.
|
||||||
"""
|
"""
|
||||||
if not transaction.is_managed():
|
connection = connections[using]
|
||||||
transaction.enter_transaction_management()
|
if not transaction.is_managed(using=using):
|
||||||
|
transaction.enter_transaction_management(using=using)
|
||||||
forced_managed = True
|
forced_managed = True
|
||||||
else:
|
else:
|
||||||
forced_managed = False
|
forced_managed = False
|
||||||
@ -1036,20 +1069,21 @@ def delete_objects(seen_objs):
|
|||||||
setattr(instance, cls._meta.pk.attname, None)
|
setattr(instance, cls._meta.pk.attname, None)
|
||||||
|
|
||||||
if forced_managed:
|
if forced_managed:
|
||||||
transaction.commit()
|
transaction.commit(using=using)
|
||||||
else:
|
else:
|
||||||
transaction.commit_unless_managed()
|
transaction.commit_unless_managed(using=using)
|
||||||
finally:
|
finally:
|
||||||
if forced_managed:
|
if forced_managed:
|
||||||
transaction.leave_transaction_management()
|
transaction.leave_transaction_management(using=using)
|
||||||
|
|
||||||
|
|
||||||
def insert_query(model, values, return_id=False, raw_values=False):
|
def insert_query(model, values, return_id=False, raw_values=False, using=None):
|
||||||
"""
|
"""
|
||||||
Inserts a new record for the given model. This provides an interface to
|
Inserts a new record for the given model. This provides an interface to
|
||||||
the InsertQuery class and is how Model.save() is implemented. It is not
|
the InsertQuery class and is how Model.save() is implemented. It is not
|
||||||
part of the public API.
|
part of the public API.
|
||||||
"""
|
"""
|
||||||
|
connection = connections[using]
|
||||||
query = sql.InsertQuery(model, connection)
|
query = sql.InsertQuery(model, connection)
|
||||||
query.insert_values(values, raw_values)
|
query.insert_values(values, raw_values)
|
||||||
return query.execute_sql(return_id)
|
return query.execute_sql(return_id)
|
||||||
|
@ -3,25 +3,23 @@ from django.db.models.fields import FieldDoesNotExist
|
|||||||
from django.db.models.sql.constants import LOOKUP_SEP
|
from django.db.models.sql.constants import LOOKUP_SEP
|
||||||
|
|
||||||
class SQLEvaluator(object):
|
class SQLEvaluator(object):
|
||||||
|
as_sql_takes_connection = True
|
||||||
|
|
||||||
def __init__(self, expression, query, allow_joins=True):
|
def __init__(self, expression, query, allow_joins=True):
|
||||||
self.expression = expression
|
self.expression = expression
|
||||||
self.opts = query.get_meta()
|
self.opts = query.get_meta()
|
||||||
self.cols = {}
|
self.cols = {}
|
||||||
|
|
||||||
self.contains_aggregate = False
|
self.contains_aggregate = False
|
||||||
self.connection = query.connection
|
|
||||||
self.expression.prepare(self, query, allow_joins)
|
self.expression.prepare(self, query, allow_joins)
|
||||||
|
|
||||||
def as_sql(self, qn=None):
|
def as_sql(self, qn, connection):
|
||||||
return self.expression.evaluate(self, qn)
|
return self.expression.evaluate(self, qn, connection)
|
||||||
|
|
||||||
def relabel_aliases(self, change_map):
|
def relabel_aliases(self, change_map):
|
||||||
for node, col in self.cols.items():
|
for node, col in self.cols.items():
|
||||||
self.cols[node] = (change_map.get(col[0], col[0]), col[1])
|
self.cols[node] = (change_map.get(col[0], col[0]), col[1])
|
||||||
|
|
||||||
def update_connection(self, connection):
|
|
||||||
self.connection = connection
|
|
||||||
|
|
||||||
#####################################################
|
#####################################################
|
||||||
# Vistor methods for initial expression preparation #
|
# Vistor methods for initial expression preparation #
|
||||||
#####################################################
|
#####################################################
|
||||||
@ -57,15 +55,12 @@ class SQLEvaluator(object):
|
|||||||
# Vistor methods for final expression evaluation #
|
# Vistor methods for final expression evaluation #
|
||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
def evaluate_node(self, node, qn):
|
def evaluate_node(self, node, qn, connection):
|
||||||
if not qn:
|
|
||||||
qn = self.connection.ops.quote_name
|
|
||||||
|
|
||||||
expressions = []
|
expressions = []
|
||||||
expression_params = []
|
expression_params = []
|
||||||
for child in node.children:
|
for child in node.children:
|
||||||
if hasattr(child, 'evaluate'):
|
if hasattr(child, 'evaluate'):
|
||||||
sql, params = child.evaluate(self, qn)
|
sql, params = child.evaluate(self, qn, connection)
|
||||||
else:
|
else:
|
||||||
sql, params = '%s', (child,)
|
sql, params = '%s', (child,)
|
||||||
|
|
||||||
@ -78,12 +73,9 @@ class SQLEvaluator(object):
|
|||||||
expressions.append(format % sql)
|
expressions.append(format % sql)
|
||||||
expression_params.extend(params)
|
expression_params.extend(params)
|
||||||
|
|
||||||
return self.connection.ops.combine_expression(node.connector, expressions), expression_params
|
return connection.ops.combine_expression(node.connector, expressions), expression_params
|
||||||
|
|
||||||
def evaluate_leaf(self, node, qn):
|
|
||||||
if not qn:
|
|
||||||
qn = self.connection.ops.quote_name
|
|
||||||
|
|
||||||
|
def evaluate_leaf(self, node, qn, connection):
|
||||||
col = self.cols[node]
|
col = self.cols[node]
|
||||||
if hasattr(col, 'as_sql'):
|
if hasattr(col, 'as_sql'):
|
||||||
return col.as_sql(qn), ()
|
return col.as_sql(qn), ()
|
||||||
|
@ -67,10 +67,10 @@ class BaseQuery(object):
|
|||||||
# SQL-related attributes
|
# SQL-related attributes
|
||||||
self.select = []
|
self.select = []
|
||||||
self.tables = [] # Aliases in the order they are created.
|
self.tables = [] # Aliases in the order they are created.
|
||||||
self.where = where(connection=self.connection)
|
self.where = where()
|
||||||
self.where_class = where
|
self.where_class = where
|
||||||
self.group_by = None
|
self.group_by = None
|
||||||
self.having = where(connection=self.connection)
|
self.having = where()
|
||||||
self.order_by = []
|
self.order_by = []
|
||||||
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
||||||
self.distinct = False
|
self.distinct = False
|
||||||
@ -151,8 +151,6 @@ class BaseQuery(object):
|
|||||||
# supported. It's the only class-reference to the module-level
|
# supported. It's the only class-reference to the module-level
|
||||||
# connection variable.
|
# connection variable.
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.where.update_connection(self.connection)
|
|
||||||
self.having.update_connection(self.connection)
|
|
||||||
|
|
||||||
def get_meta(self):
|
def get_meta(self):
|
||||||
"""
|
"""
|
||||||
@ -245,8 +243,6 @@ class BaseQuery(object):
|
|||||||
obj.used_aliases = set()
|
obj.used_aliases = set()
|
||||||
obj.filter_is_sticky = False
|
obj.filter_is_sticky = False
|
||||||
obj.__dict__.update(kwargs)
|
obj.__dict__.update(kwargs)
|
||||||
obj.where.update_connection(obj.connection) # where and having track their own connection
|
|
||||||
obj.having.update_connection(obj.connection)# we need to keep this up to date
|
|
||||||
if hasattr(obj, '_setup_query'):
|
if hasattr(obj, '_setup_query'):
|
||||||
obj._setup_query()
|
obj._setup_query()
|
||||||
return obj
|
return obj
|
||||||
@ -405,8 +401,8 @@ class BaseQuery(object):
|
|||||||
from_, f_params = self.get_from_clause()
|
from_, f_params = self.get_from_clause()
|
||||||
|
|
||||||
qn = self.quote_name_unless_alias
|
qn = self.quote_name_unless_alias
|
||||||
where, w_params = self.where.as_sql(qn=qn)
|
where, w_params = self.where.as_sql(qn=qn, connection=self.connection)
|
||||||
having, h_params = self.having.as_sql(qn=qn)
|
having, h_params = self.having.as_sql(qn=qn, connection=self.connection)
|
||||||
params = []
|
params = []
|
||||||
for val in self.extra_select.itervalues():
|
for val in self.extra_select.itervalues():
|
||||||
params.extend(val[1])
|
params.extend(val[1])
|
||||||
@ -534,10 +530,10 @@ class BaseQuery(object):
|
|||||||
self.where.add(EverythingNode(), AND)
|
self.where.add(EverythingNode(), AND)
|
||||||
elif self.where:
|
elif self.where:
|
||||||
# rhs has an empty where clause.
|
# rhs has an empty where clause.
|
||||||
w = self.where_class(connection=self.connection)
|
w = self.where_class()
|
||||||
w.add(EverythingNode(), AND)
|
w.add(EverythingNode(), AND)
|
||||||
else:
|
else:
|
||||||
w = self.where_class(connection=self.connection)
|
w = self.where_class()
|
||||||
self.where.add(w, connector)
|
self.where.add(w, connector)
|
||||||
|
|
||||||
# Selection columns and extra extensions are those provided by 'rhs'.
|
# Selection columns and extra extensions are those provided by 'rhs'.
|
||||||
@ -1550,7 +1546,7 @@ class BaseQuery(object):
|
|||||||
|
|
||||||
for alias, aggregate in self.aggregates.items():
|
for alias, aggregate in self.aggregates.items():
|
||||||
if alias == parts[0]:
|
if alias == parts[0]:
|
||||||
entry = self.where_class(connection=self.connection)
|
entry = self.where_class()
|
||||||
entry.add((aggregate, lookup_type, value), AND)
|
entry.add((aggregate, lookup_type, value), AND)
|
||||||
if negate:
|
if negate:
|
||||||
entry.negate()
|
entry.negate()
|
||||||
@ -1618,7 +1614,7 @@ class BaseQuery(object):
|
|||||||
for alias in join_list:
|
for alias in join_list:
|
||||||
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
|
if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
|
||||||
j_col = self.alias_map[alias][RHS_JOIN_COL]
|
j_col = self.alias_map[alias][RHS_JOIN_COL]
|
||||||
entry = self.where_class(connection=self.connection)
|
entry = self.where_class()
|
||||||
entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
|
entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
|
||||||
entry.negate()
|
entry.negate()
|
||||||
self.where.add(entry, AND)
|
self.where.add(entry, AND)
|
||||||
@ -1627,7 +1623,7 @@ class BaseQuery(object):
|
|||||||
# Leaky abstraction artifact: We have to specifically
|
# Leaky abstraction artifact: We have to specifically
|
||||||
# exclude the "foo__in=[]" case from this handling, because
|
# exclude the "foo__in=[]" case from this handling, because
|
||||||
# it's short-circuited in the Where class.
|
# it's short-circuited in the Where class.
|
||||||
entry = self.where_class(connection=self.connection)
|
entry = self.where_class()
|
||||||
entry.add((Constraint(alias, col, None), 'isnull', True), AND)
|
entry.add((Constraint(alias, col, None), 'isnull', True), AND)
|
||||||
entry.negate()
|
entry.negate()
|
||||||
self.where.add(entry, AND)
|
self.where.add(entry, AND)
|
||||||
@ -2337,6 +2333,9 @@ class BaseQuery(object):
|
|||||||
self.select = [(select_alias, select_col)]
|
self.select = [(select_alias, select_col)]
|
||||||
self.remove_inherited_models()
|
self.remove_inherited_models()
|
||||||
|
|
||||||
|
def set_connection(self, connection):
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
def execute_sql(self, result_type=MULTI):
|
def execute_sql(self, result_type=MULTI):
|
||||||
"""
|
"""
|
||||||
Run the query against the database and returns the result(s). The
|
Run the query against the database and returns the result(s). The
|
||||||
|
@ -24,8 +24,9 @@ class DeleteQuery(Query):
|
|||||||
"""
|
"""
|
||||||
assert len(self.tables) == 1, \
|
assert len(self.tables) == 1, \
|
||||||
"Can only delete from one table at a time."
|
"Can only delete from one table at a time."
|
||||||
result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])]
|
qn = self.quote_name_unless_alias
|
||||||
where, params = self.where.as_sql()
|
result = ['DELETE FROM %s' % qn(self.tables[0])]
|
||||||
|
where, params = self.where.as_sql(qn=qn, connection=self.connection)
|
||||||
result.append('WHERE %s' % where)
|
result.append('WHERE %s' % where)
|
||||||
return ' '.join(result), tuple(params)
|
return ' '.join(result), tuple(params)
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ class DeleteQuery(Query):
|
|||||||
for related in cls._meta.get_all_related_many_to_many_objects():
|
for related in cls._meta.get_all_related_many_to_many_objects():
|
||||||
if not isinstance(related.field, generic.GenericRelation):
|
if not isinstance(related.field, generic.GenericRelation):
|
||||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||||
where = self.where_class(connection=self.connection)
|
where = self.where_class()
|
||||||
where.add((Constraint(None,
|
where.add((Constraint(None,
|
||||||
related.field.m2m_reverse_name(), related.field),
|
related.field.m2m_reverse_name(), related.field),
|
||||||
'in',
|
'in',
|
||||||
@ -57,14 +58,14 @@ class DeleteQuery(Query):
|
|||||||
self.do_query(related.field.m2m_db_table(), where)
|
self.do_query(related.field.m2m_db_table(), where)
|
||||||
|
|
||||||
for f in cls._meta.many_to_many:
|
for f in cls._meta.many_to_many:
|
||||||
w1 = self.where_class(connection=self.connection)
|
w1 = self.where_class()
|
||||||
if isinstance(f, generic.GenericRelation):
|
if isinstance(f, generic.GenericRelation):
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
field = f.rel.to._meta.get_field(f.content_type_field_name)
|
field = f.rel.to._meta.get_field(f.content_type_field_name)
|
||||||
w1.add((Constraint(None, field.column, field), 'exact',
|
w1.add((Constraint(None, field.column, field), 'exact',
|
||||||
ContentType.objects.get_for_model(cls).id), AND)
|
ContentType.objects.get_for_model(cls).id), AND)
|
||||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||||
where = self.where_class(connection=self.connection)
|
where = self.where_class()
|
||||||
where.add((Constraint(None, f.m2m_column_name(), f), 'in',
|
where.add((Constraint(None, f.m2m_column_name(), f), 'in',
|
||||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
|
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
|
||||||
AND)
|
AND)
|
||||||
@ -81,7 +82,7 @@ class DeleteQuery(Query):
|
|||||||
lot of values in pk_list.
|
lot of values in pk_list.
|
||||||
"""
|
"""
|
||||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||||
where = self.where_class(connection=self.connection)
|
where = self.where_class()
|
||||||
field = self.model._meta.pk
|
field = self.model._meta.pk
|
||||||
where.add((Constraint(None, field.column, field), 'in',
|
where.add((Constraint(None, field.column, field), 'in',
|
||||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
|
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
|
||||||
@ -143,7 +144,10 @@ class UpdateQuery(Query):
|
|||||||
values, update_params = [], []
|
values, update_params = [], []
|
||||||
for name, val, placeholder in self.values:
|
for name, val, placeholder in self.values:
|
||||||
if hasattr(val, 'as_sql'):
|
if hasattr(val, 'as_sql'):
|
||||||
sql, params = val.as_sql(qn)
|
if getattr(val, 'as_sql_takes_connection', False):
|
||||||
|
sql, params = val.as_sql(qn, self.connection)
|
||||||
|
else:
|
||||||
|
sql, params = val.as_sql(qn)
|
||||||
values.append('%s = %s' % (qn(name), sql))
|
values.append('%s = %s' % (qn(name), sql))
|
||||||
update_params.extend(params)
|
update_params.extend(params)
|
||||||
elif val is not None:
|
elif val is not None:
|
||||||
@ -152,7 +156,7 @@ class UpdateQuery(Query):
|
|||||||
else:
|
else:
|
||||||
values.append('%s = NULL' % qn(name))
|
values.append('%s = NULL' % qn(name))
|
||||||
result.append(', '.join(values))
|
result.append(', '.join(values))
|
||||||
where, params = self.where.as_sql()
|
where, params = self.where.as_sql(qn=qn, connection=self.connection)
|
||||||
if where:
|
if where:
|
||||||
result.append('WHERE %s' % where)
|
result.append('WHERE %s' % where)
|
||||||
return ' '.join(result), tuple(update_params + params)
|
return ' '.join(result), tuple(update_params + params)
|
||||||
@ -185,7 +189,7 @@ class UpdateQuery(Query):
|
|||||||
|
|
||||||
# Now we adjust the current query: reset the where clause and get rid
|
# Now we adjust the current query: reset the where clause and get rid
|
||||||
# of all the tables we don't need (since they're in the sub-select).
|
# of all the tables we don't need (since they're in the sub-select).
|
||||||
self.where = self.where_class(connection=self.connection)
|
self.where = self.where_class()
|
||||||
if self.related_updates or must_pre_select:
|
if self.related_updates or must_pre_select:
|
||||||
# Either we're using the idents in multiple update queries (so
|
# Either we're using the idents in multiple update queries (so
|
||||||
# don't want them to change), or the db backend doesn't support
|
# don't want them to change), or the db backend doesn't support
|
||||||
@ -209,7 +213,7 @@ class UpdateQuery(Query):
|
|||||||
This is used by the QuerySet.delete_objects() method.
|
This is used by the QuerySet.delete_objects() method.
|
||||||
"""
|
"""
|
||||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||||
self.where = self.where_class(connection=self.connection)
|
self.where = self.where_class()
|
||||||
f = self.model._meta.pk
|
f = self.model._meta.pk
|
||||||
self.where.add((Constraint(None, f.column, f), 'in',
|
self.where.add((Constraint(None, f.column, f), 'in',
|
||||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
|
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
|
||||||
|
@ -32,18 +32,7 @@ class WhereNode(tree.Node):
|
|||||||
relabel_aliases() methods.
|
relabel_aliases() methods.
|
||||||
"""
|
"""
|
||||||
default = AND
|
default = AND
|
||||||
|
as_sql_takes_connection = True
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.connection = kwargs.pop('connection', None)
|
|
||||||
super(WhereNode, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
"""
|
|
||||||
Don't try to pickle the connection, our Query will restore it for us.
|
|
||||||
"""
|
|
||||||
data = self.__dict__.copy()
|
|
||||||
del data['connection']
|
|
||||||
return data
|
|
||||||
|
|
||||||
def add(self, data, connector):
|
def add(self, data, connector):
|
||||||
"""
|
"""
|
||||||
@ -62,20 +51,6 @@ class WhereNode(tree.Node):
|
|||||||
# Consume any generators immediately, so that we can determine
|
# Consume any generators immediately, so that we can determine
|
||||||
# emptiness and transform any non-empty values correctly.
|
# emptiness and transform any non-empty values correctly.
|
||||||
value = list(value)
|
value = list(value)
|
||||||
if hasattr(obj, "process"):
|
|
||||||
try:
|
|
||||||
# FIXME We're calling process too early, the connection could
|
|
||||||
# change
|
|
||||||
obj, params = obj.process(lookup_type, value, self.connection)
|
|
||||||
except (EmptyShortCircuit, EmptyResultSet):
|
|
||||||
# There are situations where we want to short-circuit any
|
|
||||||
# comparisons and make sure that nothing is returned. One
|
|
||||||
# example is when checking for a NULL pk value, or the
|
|
||||||
# equivalent.
|
|
||||||
super(WhereNode, self).add(NothingNode(), connector)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
params = Field().get_db_prep_lookup(lookup_type, value)
|
|
||||||
|
|
||||||
# The "annotation" parameter is used to pass auxilliary information
|
# The "annotation" parameter is used to pass auxilliary information
|
||||||
# about the value(s) to the query construction. Specifically, datetime
|
# about the value(s) to the query construction. Specifically, datetime
|
||||||
@ -88,18 +63,19 @@ class WhereNode(tree.Node):
|
|||||||
else:
|
else:
|
||||||
annotation = bool(value)
|
annotation = bool(value)
|
||||||
|
|
||||||
|
if hasattr(obj, "process"):
|
||||||
|
obj.validate(lookup_type, value)
|
||||||
|
super(WhereNode, self).add((obj, lookup_type, annotation, value),
|
||||||
|
connector)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# TODO: Make this lazy just like the above code for constraints.
|
||||||
|
params = Field().get_db_prep_lookup(lookup_type, value)
|
||||||
|
|
||||||
super(WhereNode, self).add((obj, lookup_type, annotation, params),
|
super(WhereNode, self).add((obj, lookup_type, annotation, params),
|
||||||
connector)
|
connector)
|
||||||
|
|
||||||
def update_connection(self, connection):
|
def as_sql(self, qn, connection):
|
||||||
self.connection = connection
|
|
||||||
for child in self.children:
|
|
||||||
if hasattr(child, 'update_connection'):
|
|
||||||
child.update_connection(connection)
|
|
||||||
elif hasattr(child[3], 'update_connection'):
|
|
||||||
child[3].update_connection(connection)
|
|
||||||
|
|
||||||
def as_sql(self, qn=None):
|
|
||||||
"""
|
"""
|
||||||
Returns the SQL version of the where clause and the value to be
|
Returns the SQL version of the where clause and the value to be
|
||||||
substituted in. Returns None, None if this node is empty.
|
substituted in. Returns None, None if this node is empty.
|
||||||
@ -108,8 +84,6 @@ class WhereNode(tree.Node):
|
|||||||
(generally not needed except by the internal implementation for
|
(generally not needed except by the internal implementation for
|
||||||
recursion).
|
recursion).
|
||||||
"""
|
"""
|
||||||
if not qn:
|
|
||||||
qn = self.connection.ops.quote_name
|
|
||||||
if not self.children:
|
if not self.children:
|
||||||
return None, []
|
return None, []
|
||||||
result = []
|
result = []
|
||||||
@ -118,10 +92,13 @@ class WhereNode(tree.Node):
|
|||||||
for child in self.children:
|
for child in self.children:
|
||||||
try:
|
try:
|
||||||
if hasattr(child, 'as_sql'):
|
if hasattr(child, 'as_sql'):
|
||||||
sql, params = child.as_sql(qn=qn)
|
if getattr(child, 'as_sql_takes_connection', False):
|
||||||
|
sql, params = child.as_sql(qn=qn, connection=connection)
|
||||||
|
else:
|
||||||
|
sql, params = child.as_sql(qn=qn)
|
||||||
else:
|
else:
|
||||||
# A leaf node in the tree.
|
# A leaf node in the tree.
|
||||||
sql, params = self.make_atom(child, qn)
|
sql, params = self.make_atom(child, qn, connection)
|
||||||
|
|
||||||
except EmptyResultSet:
|
except EmptyResultSet:
|
||||||
if self.connector == AND and not self.negated:
|
if self.connector == AND and not self.negated:
|
||||||
@ -157,7 +134,7 @@ class WhereNode(tree.Node):
|
|||||||
sql_string = '(%s)' % sql_string
|
sql_string = '(%s)' % sql_string
|
||||||
return sql_string, result_params
|
return sql_string, result_params
|
||||||
|
|
||||||
def make_atom(self, child, qn):
|
def make_atom(self, child, qn, connection):
|
||||||
"""
|
"""
|
||||||
Turn a tuple (table_alias, column_name, db_type, lookup_type,
|
Turn a tuple (table_alias, column_name, db_type, lookup_type,
|
||||||
value_annot, params) into valid SQL.
|
value_annot, params) into valid SQL.
|
||||||
@ -165,29 +142,39 @@ class WhereNode(tree.Node):
|
|||||||
Returns the string for the SQL fragment and the parameters to use for
|
Returns the string for the SQL fragment and the parameters to use for
|
||||||
it.
|
it.
|
||||||
"""
|
"""
|
||||||
lvalue, lookup_type, value_annot, params = child
|
lvalue, lookup_type, value_annot, params_or_value = child
|
||||||
|
if hasattr(lvalue, 'process'):
|
||||||
|
try:
|
||||||
|
lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
|
||||||
|
except EmptyShortCircuit:
|
||||||
|
raise EmptyResultSet
|
||||||
|
else:
|
||||||
|
params = params_or_value
|
||||||
if isinstance(lvalue, tuple):
|
if isinstance(lvalue, tuple):
|
||||||
# A direct database column lookup.
|
# A direct database column lookup.
|
||||||
field_sql = self.sql_for_columns(lvalue, qn)
|
field_sql = self.sql_for_columns(lvalue, qn, connection)
|
||||||
else:
|
else:
|
||||||
# A smart object with an as_sql() method.
|
# A smart object with an as_sql() method.
|
||||||
field_sql = lvalue.as_sql(quote_func=qn)
|
field_sql = lvalue.as_sql(quote_func=qn)
|
||||||
|
|
||||||
if value_annot is datetime.datetime:
|
if value_annot is datetime.datetime:
|
||||||
cast_sql = self.connection.ops.datetime_cast_sql()
|
cast_sql = connection.ops.datetime_cast_sql()
|
||||||
else:
|
else:
|
||||||
cast_sql = '%s'
|
cast_sql = '%s'
|
||||||
|
|
||||||
if hasattr(params, 'as_sql'):
|
if hasattr(params, 'as_sql'):
|
||||||
extra, params = params.as_sql(qn)
|
if getattr(params, 'as_sql_takes_connection', False):
|
||||||
|
extra, params = params.as_sql(qn, connection)
|
||||||
|
else:
|
||||||
|
extra, params = params.as_sql(qn)
|
||||||
cast_sql = ''
|
cast_sql = ''
|
||||||
else:
|
else:
|
||||||
extra = ''
|
extra = ''
|
||||||
|
|
||||||
if lookup_type in self.connection.operators:
|
if lookup_type in connection.operators:
|
||||||
format = "%s %%s %%s" % (self.connection.ops.lookup_cast(lookup_type),)
|
format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
|
||||||
return (format % (field_sql,
|
return (format % (field_sql,
|
||||||
self.connection.operators[lookup_type] % cast_sql,
|
connection.operators[lookup_type] % cast_sql,
|
||||||
extra), params)
|
extra), params)
|
||||||
|
|
||||||
if lookup_type == 'in':
|
if lookup_type == 'in':
|
||||||
@ -200,19 +187,19 @@ class WhereNode(tree.Node):
|
|||||||
elif lookup_type in ('range', 'year'):
|
elif lookup_type in ('range', 'year'):
|
||||||
return ('%s BETWEEN %%s and %%s' % field_sql, params)
|
return ('%s BETWEEN %%s and %%s' % field_sql, params)
|
||||||
elif lookup_type in ('month', 'day', 'week_day'):
|
elif lookup_type in ('month', 'day', 'week_day'):
|
||||||
return ('%s = %%s' % self.connection.ops.date_extract_sql(lookup_type, field_sql),
|
return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
|
||||||
params)
|
params)
|
||||||
elif lookup_type == 'isnull':
|
elif lookup_type == 'isnull':
|
||||||
return ('%s IS %sNULL' % (field_sql,
|
return ('%s IS %sNULL' % (field_sql,
|
||||||
(not value_annot and 'NOT ' or '')), ())
|
(not value_annot and 'NOT ' or '')), ())
|
||||||
elif lookup_type == 'search':
|
elif lookup_type == 'search':
|
||||||
return (self.connection.ops.fulltext_search_sql(field_sql), params)
|
return (connection.ops.fulltext_search_sql(field_sql), params)
|
||||||
elif lookup_type in ('regex', 'iregex'):
|
elif lookup_type in ('regex', 'iregex'):
|
||||||
return self.connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
|
return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
|
||||||
|
|
||||||
raise TypeError('Invalid lookup_type: %r' % lookup_type)
|
raise TypeError('Invalid lookup_type: %r' % lookup_type)
|
||||||
|
|
||||||
def sql_for_columns(self, data, qn):
|
def sql_for_columns(self, data, qn, connection):
|
||||||
"""
|
"""
|
||||||
Returns the SQL fragment used for the left-hand side of a column
|
Returns the SQL fragment used for the left-hand side of a column
|
||||||
constraint (for example, the "T1.foo" portion in the clause
|
constraint (for example, the "T1.foo" portion in the clause
|
||||||
@ -223,7 +210,7 @@ class WhereNode(tree.Node):
|
|||||||
lhs = '%s.%s' % (qn(table_alias), qn(name))
|
lhs = '%s.%s' % (qn(table_alias), qn(name))
|
||||||
else:
|
else:
|
||||||
lhs = qn(name)
|
lhs = qn(name)
|
||||||
return self.connection.ops.field_cast_sql(db_type) % lhs
|
return connection.ops.field_cast_sql(db_type) % lhs
|
||||||
|
|
||||||
def relabel_aliases(self, change_map, node=None):
|
def relabel_aliases(self, change_map, node=None):
|
||||||
"""
|
"""
|
||||||
@ -299,3 +286,11 @@ class Constraint(object):
|
|||||||
raise EmptyShortCircuit
|
raise EmptyShortCircuit
|
||||||
|
|
||||||
return (self.alias, self.col, db_type), params
|
return (self.alias, self.col, db_type), params
|
||||||
|
|
||||||
|
def relabel_aliases(self, change_map):
|
||||||
|
if self.alias in change_map:
|
||||||
|
self.alias = change_map[self.alias]
|
||||||
|
|
||||||
|
def validate(self, lookup_type, value):
|
||||||
|
if hasattr(self.field, 'validate'):
|
||||||
|
self.field.validate(lookup_type, value)
|
||||||
|
@ -20,7 +20,7 @@ try:
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
|
from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
|
||||||
from django.db import connection
|
from django.db import connections
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
class TransactionManagementError(Exception):
|
class TransactionManagementError(Exception):
|
||||||
@ -30,17 +30,20 @@ class TransactionManagementError(Exception):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# The states are dictionaries of lists. The key to the dict is the current
|
# The states are dictionaries of dictionaries of lists. The key to the outer
|
||||||
# thread and the list is handled as a stack of values.
|
# dict is the current thread, and the key to the inner dictionary is the
|
||||||
|
# connection alias and the list is handled as a stack of values.
|
||||||
state = {}
|
state = {}
|
||||||
savepoint_state = {}
|
savepoint_state = {}
|
||||||
|
|
||||||
# The dirty flag is set by *_unless_managed functions to denote that the
|
# The dirty flag is set by *_unless_managed functions to denote that the
|
||||||
# code under transaction management has changed things to require a
|
# code under transaction management has changed things to require a
|
||||||
# database commit.
|
# database commit.
|
||||||
|
# This is a dictionary mapping thread to a dictionary mapping connection
|
||||||
|
# alias to a boolean.
|
||||||
dirty = {}
|
dirty = {}
|
||||||
|
|
||||||
def enter_transaction_management(managed=True):
|
def enter_transaction_management(managed=True, using=None):
|
||||||
"""
|
"""
|
||||||
Enters transaction management for a running thread. It must be balanced with
|
Enters transaction management for a running thread. It must be balanced with
|
||||||
the appropriate leave_transaction_management call, since the actual state is
|
the appropriate leave_transaction_management call, since the actual state is
|
||||||
@ -50,166 +53,212 @@ def enter_transaction_management(managed=True):
|
|||||||
from the settings, if there is no surrounding block (dirty is always false
|
from the settings, if there is no surrounding block (dirty is always false
|
||||||
when no current block is running).
|
when no current block is running).
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in state and state[thread_ident]:
|
if thread_ident in state and state[thread_ident].get(using):
|
||||||
state[thread_ident].append(state[thread_ident][-1])
|
state[thread_ident][using].append(state[thread_ident][using][-1])
|
||||||
else:
|
else:
|
||||||
state[thread_ident] = []
|
state.setdefault(thread_ident, {})
|
||||||
state[thread_ident].append(settings.TRANSACTIONS_MANAGED)
|
state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED]
|
||||||
if thread_ident not in dirty:
|
if thread_ident not in dirty or using not in dirty[thread_ident]:
|
||||||
dirty[thread_ident] = False
|
dirty.setdefault(thread_ident, {})
|
||||||
|
dirty[thread_ident][using] = False
|
||||||
connection._enter_transaction_management(managed)
|
connection._enter_transaction_management(managed)
|
||||||
|
|
||||||
def leave_transaction_management():
|
def leave_transaction_management(using=None):
|
||||||
"""
|
"""
|
||||||
Leaves transaction management for a running thread. A dirty flag is carried
|
Leaves transaction management for a running thread. A dirty flag is carried
|
||||||
over to the surrounding block, as a commit will commit all changes, even
|
over to the surrounding block, as a commit will commit all changes, even
|
||||||
those from outside. (Commits are on connection level.)
|
those from outside. (Commits are on connection level.)
|
||||||
"""
|
"""
|
||||||
connection._leave_transaction_management(is_managed())
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
|
connection._leave_transaction_management(is_managed(using=using))
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in state and state[thread_ident]:
|
if thread_ident in state and state[thread_ident].get(using):
|
||||||
del state[thread_ident][-1]
|
del state[thread_ident][using][-1]
|
||||||
else:
|
else:
|
||||||
raise TransactionManagementError("This code isn't under transaction management")
|
raise TransactionManagementError("This code isn't under transaction management")
|
||||||
if dirty.get(thread_ident, False):
|
if dirty.get(thread_ident, {}).get(using, False):
|
||||||
rollback()
|
rollback(using=using)
|
||||||
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK")
|
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK")
|
||||||
dirty[thread_ident] = False
|
dirty[thread_ident][using] = False
|
||||||
|
|
||||||
def is_dirty():
|
def is_dirty(using=None):
|
||||||
"""
|
"""
|
||||||
Returns True if the current transaction requires a commit for changes to
|
Returns True if the current transaction requires a commit for changes to
|
||||||
happen.
|
happen.
|
||||||
"""
|
"""
|
||||||
return dirty.get(thread.get_ident(), False)
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
return dirty.get(thread.get_ident(), {}).get(using, False)
|
||||||
|
|
||||||
def set_dirty():
|
def set_dirty(using=None):
|
||||||
"""
|
"""
|
||||||
Sets a dirty flag for the current thread and code streak. This can be used
|
Sets a dirty flag for the current thread and code streak. This can be used
|
||||||
to decide in a managed block of code to decide whether there are open
|
to decide in a managed block of code to decide whether there are open
|
||||||
changes waiting for commit.
|
changes waiting for commit.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in dirty:
|
if thread_ident in dirty and using in dirty[thread_ident]:
|
||||||
dirty[thread_ident] = True
|
dirty[thread_ident][using] = True
|
||||||
else:
|
else:
|
||||||
raise TransactionManagementError("This code isn't under transaction management")
|
raise TransactionManagementError("This code isn't under transaction management")
|
||||||
|
|
||||||
def set_clean():
|
def set_clean(using=None):
|
||||||
"""
|
"""
|
||||||
Resets a dirty flag for the current thread and code streak. This can be used
|
Resets a dirty flag for the current thread and code streak. This can be used
|
||||||
to decide in a managed block of code to decide whether a commit or rollback
|
to decide in a managed block of code to decide whether a commit or rollback
|
||||||
should happen.
|
should happen.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in dirty:
|
if thread_ident in dirty and using in dirty[thread_ident]:
|
||||||
dirty[thread_ident] = False
|
dirty[thread_ident][using] = False
|
||||||
else:
|
else:
|
||||||
raise TransactionManagementError("This code isn't under transaction management")
|
raise TransactionManagementError("This code isn't under transaction management")
|
||||||
clean_savepoints()
|
clean_savepoints(using=using)
|
||||||
|
|
||||||
def clean_savepoints():
|
def clean_savepoints(using=None):
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in savepoint_state:
|
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
|
||||||
del savepoint_state[thread_ident]
|
del savepoint_state[thread_ident][using]
|
||||||
|
|
||||||
def is_managed():
|
def is_managed(using=None):
|
||||||
"""
|
"""
|
||||||
Checks whether the transaction manager is in manual or in auto state.
|
Checks whether the transaction manager is in manual or in auto state.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in state:
|
if thread_ident in state and using in state[thread_ident]:
|
||||||
if state[thread_ident]:
|
if state[thread_ident][using]:
|
||||||
return state[thread_ident][-1]
|
return state[thread_ident][using][-1]
|
||||||
return settings.TRANSACTIONS_MANAGED
|
return settings.TRANSACTIONS_MANAGED
|
||||||
|
|
||||||
def managed(flag=True):
|
def managed(flag=True, using=None):
|
||||||
"""
|
"""
|
||||||
Puts the transaction manager into a manual state: managed transactions have
|
Puts the transaction manager into a manual state: managed transactions have
|
||||||
to be committed explicitly by the user. If you switch off transaction
|
to be committed explicitly by the user. If you switch off transaction
|
||||||
management and there is a pending commit/rollback, the data will be
|
management and there is a pending commit/rollback, the data will be
|
||||||
commited.
|
commited.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
top = state.get(thread_ident, None)
|
top = state.get(thread_ident, {}).get(using, None)
|
||||||
if top:
|
if top:
|
||||||
top[-1] = flag
|
top[-1] = flag
|
||||||
if not flag and is_dirty():
|
if not flag and is_dirty(using=using):
|
||||||
connection._commit()
|
connection._commit()
|
||||||
set_clean()
|
set_clean(using=using)
|
||||||
else:
|
else:
|
||||||
raise TransactionManagementError("This code isn't under transaction management")
|
raise TransactionManagementError("This code isn't under transaction management")
|
||||||
|
|
||||||
def commit_unless_managed():
|
def commit_unless_managed(using=None):
|
||||||
"""
|
"""
|
||||||
Commits changes if the system is not in managed transaction mode.
|
Commits changes if the system is not in managed transaction mode.
|
||||||
"""
|
"""
|
||||||
if not is_managed():
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
|
if not is_managed(using=using):
|
||||||
connection._commit()
|
connection._commit()
|
||||||
clean_savepoints()
|
clean_savepoints(using=using)
|
||||||
else:
|
else:
|
||||||
set_dirty()
|
set_dirty(using=using)
|
||||||
|
|
||||||
def rollback_unless_managed():
|
def rollback_unless_managed(using=None):
|
||||||
"""
|
"""
|
||||||
Rolls back changes if the system is not in managed transaction mode.
|
Rolls back changes if the system is not in managed transaction mode.
|
||||||
"""
|
"""
|
||||||
if not is_managed():
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
|
if not is_managed(using=using):
|
||||||
connection._rollback()
|
connection._rollback()
|
||||||
else:
|
else:
|
||||||
set_dirty()
|
set_dirty(using=using)
|
||||||
|
|
||||||
def commit():
|
def commit(using=None):
|
||||||
"""
|
"""
|
||||||
Does the commit itself and resets the dirty flag.
|
Does the commit itself and resets the dirty flag.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
connection._commit()
|
connection._commit()
|
||||||
set_clean()
|
set_clean(using=using)
|
||||||
|
|
||||||
def rollback():
|
def rollback(using=None):
|
||||||
"""
|
"""
|
||||||
This function does the rollback itself and resets the dirty flag.
|
This function does the rollback itself and resets the dirty flag.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
connection._rollback()
|
connection._rollback()
|
||||||
set_clean()
|
set_clean(using=using)
|
||||||
|
|
||||||
def savepoint():
|
def savepoint(using=None):
|
||||||
"""
|
"""
|
||||||
Creates a savepoint (if supported and required by the backend) inside the
|
Creates a savepoint (if supported and required by the backend) inside the
|
||||||
current transaction. Returns an identifier for the savepoint that will be
|
current transaction. Returns an identifier for the savepoint that will be
|
||||||
used for the subsequent rollback or commit.
|
used for the subsequent rollback or commit.
|
||||||
"""
|
"""
|
||||||
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
thread_ident = thread.get_ident()
|
thread_ident = thread.get_ident()
|
||||||
if thread_ident in savepoint_state:
|
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
|
||||||
savepoint_state[thread_ident].append(None)
|
savepoint_state[thread_ident][using].append(None)
|
||||||
else:
|
else:
|
||||||
savepoint_state[thread_ident] = [None]
|
savepoint_state.setdefault(thread_ident, {})
|
||||||
|
savepoint_state[thread_ident][using] = [None]
|
||||||
tid = str(thread_ident).replace('-', '')
|
tid = str(thread_ident).replace('-', '')
|
||||||
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
|
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using]))
|
||||||
connection._savepoint(sid)
|
connection._savepoint(sid)
|
||||||
return sid
|
return sid
|
||||||
|
|
||||||
def savepoint_rollback(sid):
|
def savepoint_rollback(sid, using=None):
|
||||||
"""
|
"""
|
||||||
Rolls back the most recent savepoint (if one exists). Does nothing if
|
Rolls back the most recent savepoint (if one exists). Does nothing if
|
||||||
savepoints are not supported.
|
savepoints are not supported.
|
||||||
"""
|
"""
|
||||||
if thread.get_ident() in savepoint_state:
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
|
thread_ident = thread.get_ident()
|
||||||
|
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
|
||||||
connection._savepoint_rollback(sid)
|
connection._savepoint_rollback(sid)
|
||||||
|
|
||||||
def savepoint_commit(sid):
|
def savepoint_commit(sid, using=None):
|
||||||
"""
|
"""
|
||||||
Commits the most recent savepoint (if one exists). Does nothing if
|
Commits the most recent savepoint (if one exists). Does nothing if
|
||||||
savepoints are not supported.
|
savepoints are not supported.
|
||||||
"""
|
"""
|
||||||
if thread.get_ident() in savepoint_state:
|
if using is None:
|
||||||
|
raise ValueError # TODO use default
|
||||||
|
connection = connections[using]
|
||||||
|
thread_ident = thread.get_ident()
|
||||||
|
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
|
||||||
connection._savepoint_commit(sid)
|
connection._savepoint_commit(sid)
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# DECORATORS #
|
# DECORATORS #
|
||||||
##############
|
##############
|
||||||
|
|
||||||
|
# TODO update all of these for multi-db
|
||||||
|
|
||||||
def autocommit(func):
|
def autocommit(func):
|
||||||
"""
|
"""
|
||||||
Decorator that activates commit on save. This is Django's default behavior;
|
Decorator that activates commit on save. This is Django's default behavior;
|
||||||
|
@ -201,7 +201,8 @@ class DocTestRunner(doctest.DocTestRunner):
|
|||||||
example, exc_info)
|
example, exc_info)
|
||||||
# Rollback, in case of database errors. Otherwise they'd have
|
# Rollback, in case of database errors. Otherwise they'd have
|
||||||
# side effects on other tests.
|
# side effects on other tests.
|
||||||
transaction.rollback_unless_managed()
|
for conn in connections:
|
||||||
|
transaction.rollback_unless_managed(using=conn)
|
||||||
|
|
||||||
class TransactionTestCase(unittest.TestCase):
|
class TransactionTestCase(unittest.TestCase):
|
||||||
def _pre_setup(self):
|
def _pre_setup(self):
|
||||||
@ -446,8 +447,9 @@ class TestCase(TransactionTestCase):
|
|||||||
if not connections_support_transactions():
|
if not connections_support_transactions():
|
||||||
return super(TestCase, self)._fixture_setup()
|
return super(TestCase, self)._fixture_setup()
|
||||||
|
|
||||||
transaction.enter_transaction_management()
|
for conn in connections:
|
||||||
transaction.managed(True)
|
transaction.enter_transaction_management(using=conn)
|
||||||
|
transaction.managed(True, using=conn)
|
||||||
disable_transaction_methods()
|
disable_transaction_methods()
|
||||||
|
|
||||||
from django.contrib.sites.models import Site
|
from django.contrib.sites.models import Site
|
||||||
@ -464,7 +466,8 @@ class TestCase(TransactionTestCase):
|
|||||||
return super(TestCase, self)._fixture_teardown()
|
return super(TestCase, self)._fixture_teardown()
|
||||||
|
|
||||||
restore_transaction_methods()
|
restore_transaction_methods()
|
||||||
transaction.rollback()
|
for conn in connections:
|
||||||
transaction.leave_transaction_management()
|
transaction.rollback(using=conn)
|
||||||
|
transaction.leave_transaction_management(using=conn)
|
||||||
for connection in connections.all():
|
for connection in connections.all():
|
||||||
connection.close()
|
connection.close()
|
||||||
|
@ -15,11 +15,11 @@ class RevisionableModel(models.Model):
|
|||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return u"%s (%s, %s)" % (self.title, self.id, self.base.id)
|
return u"%s (%s, %s)" % (self.title, self.id, self.base.id)
|
||||||
|
|
||||||
def save(self, force_insert=False, force_update=False):
|
def save(self, *args, **kwargs):
|
||||||
super(RevisionableModel, self).save(force_insert, force_update)
|
super(RevisionableModel, self).save(*args, **kwargs)
|
||||||
if not self.base:
|
if not self.base:
|
||||||
self.base = self
|
self.base = self
|
||||||
super(RevisionableModel, self).save()
|
super(RevisionableModel, self).save(using=kwargs.pop('using', None))
|
||||||
|
|
||||||
def new_revision(self):
|
def new_revision(self):
|
||||||
new_revision = copy.copy(self)
|
new_revision = copy.copy(self)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user