diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 870dc10353..0d6b4e3d00 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -233,7 +233,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): # If database is in memory, closing the connection destroys the # database. To prevent accidental data loss, ignore close requests on # an in-memory db. - if not self.is_in_memory_db(self.settings_dict['NAME']): + if not self.is_in_memory_db(): BaseDatabaseWrapper.close(self) def _savepoint_allowed(self): @@ -319,8 +319,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): """ self.cursor().execute("BEGIN") - def is_in_memory_db(self, name): - return name == ":memory:" or "mode=memory" in force_text(name) + def is_in_memory_db(self): + return self.creation.is_in_memory_db(self.settings_dict['NAME']) FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index 26ad0d4f95..33661e1357 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -4,11 +4,16 @@ import sys from django.core.exceptions import ImproperlyConfigured from django.db.backends.base.creation import BaseDatabaseCreation +from django.utils.encoding import force_text from django.utils.six.moves import input class DatabaseCreation(BaseDatabaseCreation): + @staticmethod + def is_in_memory_db(database_name): + return database_name == ':memory:' or 'mode=memory' in force_text(database_name) + def _get_test_db_name(self): test_database_name = self.connection.settings_dict['TEST']['NAME'] can_share_in_memory_db = self.connection.features.can_share_in_memory_db @@ -30,7 +35,7 @@ class DatabaseCreation(BaseDatabaseCreation): if keepdb: return test_database_name - if not self.connection.is_in_memory_db(test_database_name): + if not self.is_in_memory_db(test_database_name): # Erase the old test database if verbosity >= 1: print("Destroying old test database for alias %s..." % ( @@ -56,7 +61,7 @@ class DatabaseCreation(BaseDatabaseCreation): def get_test_db_clone_settings(self, number): orig_settings_dict = self.connection.settings_dict source_database_name = orig_settings_dict['NAME'] - if self.connection.is_in_memory_db(source_database_name): + if self.is_in_memory_db(source_database_name): return orig_settings_dict else: new_settings_dict = orig_settings_dict.copy() @@ -68,7 +73,7 @@ class DatabaseCreation(BaseDatabaseCreation): source_database_name = self.connection.settings_dict['NAME'] target_database_name = self.get_test_db_clone_settings(number)['NAME'] # Forking automatically makes a copy of an in-memory database. - if not self.connection.is_in_memory_db(source_database_name): + if not self.is_in_memory_db(source_database_name): # Erase the old test database if os.access(target_database_name, os.F_OK): if keepdb: @@ -89,7 +94,7 @@ class DatabaseCreation(BaseDatabaseCreation): sys.exit(2) def _destroy_test_db(self, test_database_name, verbosity): - if test_database_name and not self.connection.is_in_memory_db(test_database_name): + if test_database_name and not self.is_in_memory_db(test_database_name): # Remove the SQLite database file os.remove(test_database_name) @@ -103,6 +108,6 @@ class DatabaseCreation(BaseDatabaseCreation): """ test_database_name = self._get_test_db_name() sig = [self.connection.settings_dict['NAME']] - if self.connection.is_in_memory_db(test_database_name): + if self.is_in_memory_db(test_database_name): sig.append(self.connection.alias) return tuple(sig) diff --git a/django/test/testcases.py b/django/test/testcases.py index 2220ec7f52..399c8adf57 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1299,7 +1299,7 @@ class LiveServerTestCase(TransactionTestCase): for conn in connections.all(): # If using in-memory sqlite databases, pass the connections to # the server thread. - if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']): + if conn.vendor == 'sqlite' and conn.is_in_memory_db(): # Explicitly enable thread-shareability for this connection conn.allow_thread_sharing = True connections_override[conn.alias] = conn @@ -1339,7 +1339,7 @@ class LiveServerTestCase(TransactionTestCase): # Restore sqlite in-memory database connections' non-shareability for conn in connections.all(): - if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']): + if conn.vendor == 'sqlite' and conn.is_in_memory_db(): conn.allow_thread_sharing = False @classmethod