From 30b3d51599aed4cf08ccace9ba5b1277ded78439 Mon Sep 17 00:00:00 2001 From: Ramiro Morales Date: Tue, 5 Apr 2011 00:19:17 +0000 Subject: [PATCH] Fixed #13630 -- Made __init__ methods of all DB backends' DatabaseOperations classes take a `connection` argument. Thanks calexium for the report. git-svn-id: http://code.djangoproject.com/svn/django/trunk@16016 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- .../gis/db/backends/oracle/operations.py | 3 +- .../gis/db/backends/spatialite/operations.py | 3 +- django/db/backends/__init__.py | 3 +- django/db/backends/dummy/base.py | 2 +- django/db/backends/mysql/base.py | 4 +-- django/db/backends/oracle/base.py | 34 +++++++++---------- .../postgresql_psycopg2/operations.py | 3 +- django/db/backends/sqlite3/base.py | 8 ++--- tests/regressiontests/backends/tests.py | 6 ++++ 9 files changed, 34 insertions(+), 32 deletions(-) diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 97f7b6c20b..e7ad04840e 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -133,8 +133,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): truncate_params = {'relate' : None} def __init__(self, connection): - super(OracleOperations, self).__init__() - self.connection = connection + super(OracleOperations, self).__init__(connection) def convert_extent(self, clob): if clob: diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index e6f8409fdb..f3bf3e5a56 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -110,8 +110,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): geometry_functions.update(distance_functions) def __init__(self, connection): - super(DatabaseOperations, self).__init__() - self.connection = connection + super(DatabaseOperations, self).__init__(connection) # Determine the version of the SpatiaLite library. try: diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index b7d0c8cb90..4b845d5135 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -388,7 +388,8 @@ class BaseDatabaseOperations(object): """ compiler_module = "django.db.models.sql.compiler" - def __init__(self): + def __init__(self, connection): + self.connection = connection self._cache = None def autoinc_sql(self, table, column): diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index 04a8d850bf..7de48c8b00 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -59,7 +59,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): super(DatabaseWrapper, self).__init__(*args, **kwargs) self.features = BaseDatabaseFeatures(self) - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = BaseDatabaseCreation(self) self.introspection = DatabaseIntrospection(self) diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 98233c7329..630b9805c3 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -23,7 +23,7 @@ if (version < (1,2,1) or (version[:3] == (1, 2, 1) and raise ImproperlyConfigured("MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__) from MySQLdb.converters import conversions -from MySQLdb.constants import FIELD_TYPE, FLAG, CLIENT +from MySQLdb.constants import FIELD_TYPE, CLIENT from django.db import utils from django.db.backends import * @@ -279,7 +279,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.server_version = None self.features = DatabaseFeatures(self) - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index b327e453e0..a477c8d974 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -84,8 +84,8 @@ class DatabaseOperations(BaseDatabaseOperations): def autoinc_sql(self, table, column): # To simulate auto-incrementing primary keys in Oracle, we have to # create a sequence and a trigger. - sq_name = get_sequence_name(table) - tr_name = get_trigger_name(table) + sq_name = self._get_sequence_name(table) + tr_name = self._get_trigger_name(table) tbl_name = self.quote_name(table) col_name = self.quote_name(column) sequence_sql = """ @@ -197,7 +197,7 @@ WHEN (new.%(col_name)s IS NULL) return " DEFERRABLE INITIALLY DEFERRED" def drop_sequence_sql(self, table): - return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table)) + return "DROP SEQUENCE %s;" % self.quote_name(self._get_sequence_name(table)) def fetch_returned_insert_id(self, cursor): return long(cursor._insert_id_var.getvalue()) @@ -209,7 +209,7 @@ WHEN (new.%(col_name)s IS NULL) return "%s" def last_insert_id(self, cursor, table_name, pk_name): - sq_name = get_sequence_name(table_name) + sq_name = self._get_sequence_name(table_name) cursor.execute('SELECT "%s".currval FROM dual' % sq_name) return cursor.fetchone()[0] @@ -285,7 +285,7 @@ WHEN (new.%(col_name)s IS NULL) # Since we've just deleted all the rows, running our sequence # ALTER code will reset the sequence to 0. for sequence_info in sequences: - sequence_name = get_sequence_name(sequence_info['table']) + sequence_name = self._get_sequence_name(sequence_info['table']) table_name = self.quote_name(sequence_info['table']) column_name = self.quote_name(sequence_info['column'] or 'id') query = _get_sequence_reset_sql() % {'sequence': sequence_name, @@ -304,7 +304,7 @@ WHEN (new.%(col_name)s IS NULL) for f in model._meta.local_fields: if isinstance(f, models.AutoField): table_name = self.quote_name(model._meta.db_table) - sequence_name = get_sequence_name(model._meta.db_table) + sequence_name = self._get_sequence_name(model._meta.db_table) column_name = self.quote_name(f.column) output.append(query % {'sequence': sequence_name, 'table': table_name, @@ -315,7 +315,7 @@ WHEN (new.%(col_name)s IS NULL) for f in model._meta.many_to_many: if not f.rel.through: table_name = self.quote_name(f.m2m_db_table()) - sequence_name = get_sequence_name(f.m2m_db_table()) + sequence_name = self._get_sequence_name(f.m2m_db_table()) column_name = self.quote_name('id') output.append(query % {'sequence': sequence_name, 'table': table_name, @@ -365,6 +365,14 @@ WHEN (new.%(col_name)s IS NULL) raise NotImplementedError("Bit-wise or is not supported in Oracle.") return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) + def _get_sequence_name(self, table): + name_length = self.max_name_length() - 3 + return '%s_SQ' % util.truncate_name(table, name_length).upper() + + def _get_trigger_name(self, table): + name_length = self.max_name_length() - 3 + return '%s_TR' % util.truncate_name(table, name_length).upper() + class _UninitializedOperatorsDescriptor(object): @@ -415,7 +423,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.features = DatabaseFeatures(self) use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) self.features.can_return_id_from_insert = use_returning_into - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) @@ -776,13 +784,3 @@ BEGIN END LOOP; END; /""" - - -def get_sequence_name(table): - name_length = DatabaseOperations().max_name_length() - 3 - return '%s_SQ' % util.truncate_name(table, name_length).upper() - - -def get_trigger_name(table): - name_length = DatabaseOperations().max_name_length() - 3 - return '%s_TR' % util.truncate_name(table, name_length).upper() diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index 537fa45981..4efdb8fa3e 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -5,9 +5,8 @@ from django.db.backends import BaseDatabaseOperations class DatabaseOperations(BaseDatabaseOperations): def __init__(self, connection): - super(DatabaseOperations, self).__init__() + super(DatabaseOperations, self).__init__(connection) self._postgres_version = None - self.connection = connection def _get_postgres_version(self): if self._postgres_version is None: diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index c7f8db0a6e..511e5e9635 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -88,10 +88,10 @@ class DatabaseOperations(BaseDatabaseOperations): # It would be more straightforward if we could use the sqlite strftime # function, but it does not allow for keeping six digits of fractional # second information, nor does it allow for formatting date and datetime - # values differently. So instead we register our own function that - # formats the datetime combined with the delta in a manner suitable + # values differently. So instead we register our own function that + # formats the datetime combined with the delta in a manner suitable # for comparisons. - return u'django_format_dtdelta(%s, "%s", "%d", "%d", "%d")' % (sql, + return u'django_format_dtdelta(%s, "%s", "%d", "%d", "%d")' % (sql, connector, timedelta.days, timedelta.seconds, timedelta.microseconds) def date_trunc_sql(self, lookup_type, field_name): @@ -179,7 +179,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): super(DatabaseWrapper, self).__init__(*args, **kwargs) self.features = DatabaseFeatures(self) - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py index 0b6c568297..a6bac5a009 100644 --- a/tests/regressiontests/backends/tests.py +++ b/tests/regressiontests/backends/tests.py @@ -232,6 +232,12 @@ class BackendTestCase(TestCase): self.assertEqual(list(cursor.fetchmany(2)), [(u'Jane', u'Doe'), (u'John', u'Doe')]) self.assertEqual(list(cursor.fetchall()), [(u'Mary', u'Agnelline'), (u'Peter', u'Parker')]) + def test_database_operations_helper_class(self): + # Ticket #13630 + self.assertTrue(hasattr(connection, 'ops')) + self.assertTrue(hasattr(connection.ops, 'connection')) + self.assertEqual(connection, connection.ops.connection) + # We don't make these tests conditional because that means we would need to # check and differentiate between: