1
0
mirror of https://github.com/django/django.git synced 2025-10-26 15:16:09 +00:00

Ensure cursors are closed when no longer needed.

This commit touchs various parts of the code base and test framework. Any
found usage of opening a cursor for the sake of initializing a connection
has been replaced with 'ensure_connection()'.
This commit is contained in:
Michael Manfre
2014-01-09 10:05:15 -05:00
parent 0837eacc4e
commit 3ffeb93186
31 changed files with 657 additions and 615 deletions

View File

@@ -11,10 +11,10 @@ class PostGISCreation(DatabaseCreation):
@cached_property @cached_property
def template_postgis(self): def template_postgis(self):
template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis') template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,)) cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
if cursor.fetchone(): if cursor.fetchone():
return template_postgis return template_postgis
return None return None
def sql_indexes_for_field(self, model, f, style): def sql_indexes_for_field(self, model, f, style):
@@ -88,8 +88,8 @@ class PostGISCreation(DatabaseCreation):
# Connect to the test database in order to create the postgis extension # Connect to the test database in order to create the postgis extension
self.connection.close() self.connection.close()
self.connection.settings_dict["NAME"] = test_database_name self.connection.settings_dict["NAME"] = test_database_name
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis") cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
cursor.connection.commit() cursor.connection.commit()
return test_database_name return test_database_name

View File

@@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
call_command('createcachetable', database=self.connection.alias) call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has # Ensure a connection for the side effect of initializing the test database.
# the side effect of initializing the test database. self.connection.ensure_connection()
self.connection.cursor()
return test_database_name return test_database_name

View File

@@ -33,9 +33,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
if sequence_sql: if sequence_sql:
if verbosity >= 2: if verbosity >= 2:
print("Resetting sequence") print("Resetting sequence")
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
for command in sequence_sql: for command in sequence_sql:
cursor.execute(command) cursor.execute(command)
Site.objects.clear_cache() Site.objects.clear_cache()

View File

@@ -59,11 +59,11 @@ class DatabaseCache(BaseDatabaseCache):
self.validate_key(key) self.validate_key(key)
db = router.db_for_read(self.cache_model_class) db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("SELECT cache_key, value, expires FROM %s " with connections[db].cursor() as cursor:
"WHERE cache_key = %%s" % table, [key]) cursor.execute("SELECT cache_key, value, expires FROM %s "
row = cursor.fetchone() "WHERE cache_key = %%s" % table, [key])
row = cursor.fetchone()
if row is None: if row is None:
return default return default
now = timezone.now() now = timezone.now()
@@ -75,9 +75,9 @@ class DatabaseCache(BaseDatabaseCache):
expires = typecast_timestamp(str(expires)) expires = typecast_timestamp(str(expires))
if expires < now: if expires < now:
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s " cursor.execute("DELETE FROM %s "
"WHERE cache_key = %%s" % table, [key]) "WHERE cache_key = %%s" % table, [key])
return default return default
value = connections[db].ops.process_clob(row[1]) value = connections[db].ops.process_clob(row[1])
return pickle.loads(base64.b64decode(force_bytes(value))) return pickle.loads(base64.b64decode(force_bytes(value)))
@@ -96,55 +96,55 @@ class DatabaseCache(BaseDatabaseCache):
timeout = self.get_backend_timeout(timeout) timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("SELECT COUNT(*) FROM %s" % table) with connections[db].cursor() as cursor:
num = cursor.fetchone()[0] cursor.execute("SELECT COUNT(*) FROM %s" % table)
now = timezone.now() num = cursor.fetchone()[0]
now = now.replace(microsecond=0) now = timezone.now()
if timeout is None: now = now.replace(microsecond=0)
exp = datetime.max if timeout is None:
elif settings.USE_TZ: exp = datetime.max
exp = datetime.utcfromtimestamp(timeout) elif settings.USE_TZ:
else: exp = datetime.utcfromtimestamp(timeout)
exp = datetime.fromtimestamp(timeout) else:
exp = exp.replace(microsecond=0) exp = datetime.fromtimestamp(timeout)
if num > self._max_entries: exp = exp.replace(microsecond=0)
self._cull(db, cursor, now) if num > self._max_entries:
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL) self._cull(db, cursor, now)
b64encoded = base64.b64encode(pickled) pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
# The DB column is expecting a string, so make sure the value is a b64encoded = base64.b64encode(pickled)
# string, not bytes. Refs #19274. # The DB column is expecting a string, so make sure the value is a
if six.PY3: # string, not bytes. Refs #19274.
b64encoded = b64encoded.decode('latin1') if six.PY3:
try: b64encoded = b64encoded.decode('latin1')
# Note: typecasting for datetimes is needed by some 3rd party try:
# database backends. All core backends work without typecasting, # Note: typecasting for datetimes is needed by some 3rd party
# so be careful about changes here - test suite will NOT pick # database backends. All core backends work without typecasting,
# regressions. # so be careful about changes here - test suite will NOT pick
with transaction.atomic(using=db): # regressions.
cursor.execute("SELECT cache_key, expires FROM %s " with transaction.atomic(using=db):
"WHERE cache_key = %%s" % table, [key]) cursor.execute("SELECT cache_key, expires FROM %s "
result = cursor.fetchone() "WHERE cache_key = %%s" % table, [key])
if result: result = cursor.fetchone()
current_expires = result[1] if result:
if (connections[db].features.needs_datetime_string_cast and not current_expires = result[1]
isinstance(current_expires, datetime)): if (connections[db].features.needs_datetime_string_cast and not
current_expires = typecast_timestamp(str(current_expires)) isinstance(current_expires, datetime)):
exp = connections[db].ops.value_to_db_datetime(exp) current_expires = typecast_timestamp(str(current_expires))
if result and (mode == 'set' or (mode == 'add' and current_expires < now)): exp = connections[db].ops.value_to_db_datetime(exp)
cursor.execute("UPDATE %s SET value = %%s, expires = %%s " if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
"WHERE cache_key = %%s" % table, cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
[b64encoded, exp, key]) "WHERE cache_key = %%s" % table,
else: [b64encoded, exp, key])
cursor.execute("INSERT INTO %s (cache_key, value, expires) " else:
"VALUES (%%s, %%s, %%s)" % table, cursor.execute("INSERT INTO %s (cache_key, value, expires) "
[key, b64encoded, exp]) "VALUES (%%s, %%s, %%s)" % table,
except DatabaseError: [key, b64encoded, exp])
# To be threadsafe, updates/inserts are allowed to fail silently except DatabaseError:
return False # To be threadsafe, updates/inserts are allowed to fail silently
else: return False
return True else:
return True
def delete(self, key, version=None): def delete(self, key, version=None):
key = self.make_key(key, version=version) key = self.make_key(key, version=version)
@@ -152,9 +152,9 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key]) with connections[db].cursor() as cursor:
cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
def has_key(self, key, version=None): def has_key(self, key, version=None):
key = self.make_key(key, version=version) key = self.make_key(key, version=version)
@@ -162,17 +162,18 @@ class DatabaseCache(BaseDatabaseCache):
db = router.db_for_read(self.cache_model_class) db = router.db_for_read(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor()
if settings.USE_TZ: if settings.USE_TZ:
now = datetime.utcnow() now = datetime.utcnow()
else: else:
now = datetime.now() now = datetime.now()
now = now.replace(microsecond=0) now = now.replace(microsecond=0)
cursor.execute("SELECT cache_key FROM %s "
"WHERE cache_key = %%s and expires > %%s" % table, with connections[db].cursor() as cursor:
[key, connections[db].ops.value_to_db_datetime(now)]) cursor.execute("SELECT cache_key FROM %s "
return cursor.fetchone() is not None "WHERE cache_key = %%s and expires > %%s" % table,
[key, connections[db].ops.value_to_db_datetime(now)])
return cursor.fetchone() is not None
def _cull(self, db, cursor, now): def _cull(self, db, cursor, now):
if self._cull_frequency == 0: if self._cull_frequency == 0:
@@ -197,8 +198,8 @@ class DatabaseCache(BaseDatabaseCache):
def clear(self): def clear(self):
db = router.db_for_write(self.cache_model_class) db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table) table = connections[db].ops.quote_name(self._table)
cursor = connections[db].cursor() with connections[db].cursor() as cursor:
cursor.execute('DELETE FROM %s' % table) cursor.execute('DELETE FROM %s' % table)
# For backwards compatibility # For backwards compatibility

View File

@@ -72,14 +72,14 @@ class Command(BaseCommand):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else '')) full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');') full_statement.append(');')
with transaction.commit_on_success_unless_managed(): with transaction.commit_on_success_unless_managed():
curs = connection.cursor() with connection.cursor() as curs:
try: try:
curs.execute("\n".join(full_statement)) curs.execute("\n".join(full_statement))
except DatabaseError as e: except DatabaseError as e:
raise CommandError( raise CommandError(
"Cache table '%s' could not be created.\nThe error was: %s." % "Cache table '%s' could not be created.\nThe error was: %s." %
(tablename, force_text(e))) (tablename, force_text(e)))
for statement in index_output: for statement in index_output:
curs.execute(statement) curs.execute(statement)
if self.verbosity > 1: if self.verbosity > 1:
self.stdout.write("Cache table '%s' created." % tablename) self.stdout.write("Cache table '%s' created." % tablename)

View File

@@ -64,9 +64,9 @@ Are you sure you want to do this?
if confirm == 'yes': if confirm == 'yes':
try: try:
with transaction.commit_on_success_unless_managed(): with transaction.commit_on_success_unless_managed():
cursor = connection.cursor() with connection.cursor() as cursor:
for sql in sql_list: for sql in sql_list:
cursor.execute(sql) cursor.execute(sql)
except Exception as e: except Exception as e:
new_msg = ( new_msg = (
"Database %s couldn't be flushed. Possible reasons:\n" "Database %s couldn't be flushed. Possible reasons:\n"

View File

@@ -37,108 +37,108 @@ class Command(NoArgsCommand):
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s[1:] if s.startswith("u'") else s strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
cursor = connection.cursor() with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module." yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:" yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order" yield "# * Rearrange models' order"
yield "# * Make sure each model has one field with primary_key=True" yield "# * Make sure each model has one field with primary_key=True"
yield "# * Remove `managed = False` lines for those models you wish to give write DB access" yield "# * Remove `managed = False` lines for those models you wish to give write DB access"
yield "# Feel free to rename the models, but don't rename db_table values or field names." yield "# Feel free to rename the models, but don't rename db_table values or field names."
yield "#" yield "#"
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'" yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
yield "# into your database." yield "# into your database."
yield "from __future__ import unicode_literals" yield "from __future__ import unicode_literals"
yield ''
yield 'from %s import models' % self.db_module
known_models = []
for table_name in connection.introspection.table_names(cursor):
if table_name_filter is not None and callable(table_name_filter):
if not table_name_filter(table_name):
continue
yield '' yield ''
yield '' yield 'from %s import models' % self.db_module
yield 'class %s(models.Model):' % table2model(table_name) known_models = []
known_models.append(table2model(table_name)) for table_name in connection.introspection.table_names(cursor):
try: if table_name_filter is not None and callable(table_name_filter):
relations = connection.introspection.get_relations(cursor, table_name) if not table_name_filter(table_name):
except NotImplementedError:
relations = {}
try:
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
column_name = row[0]
is_relation = i in relations
att_name, params, notes = self.normalize_col_name(
column_name, used_column_names, is_relation)
extra_params.update(params)
comment_notes.extend(notes)
used_column_names.append(att_name)
# Add primary_key and unique, if necessary.
if column_name in indexes:
if indexes[column_name]['primary_key']:
extra_params['primary_key'] = True
elif indexes[column_name]['unique']:
extra_params['unique'] = True
if is_relation:
rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to
else:
field_type = "ForeignKey('%s'" % rel_to
else:
# Calling `get_field_type` to get the field type string and any
# additional paramters and notes.
field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
extra_params.update(field_params)
comment_notes.extend(field_notes)
field_type += '('
# Don't output 'id = meta.AutoField(primary_key=True)', because
# that's assumed if it doesn't exist.
if att_name == 'id' and extra_params == {'primary_key': True}:
if field_type == 'AutoField(':
continue continue
elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield: yield ''
comment_notes.append('AutoField?') yield ''
yield 'class %s(models.Model):' % table2model(table_name)
known_models.append(table2model(table_name))
try:
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
column_name = row[0]
is_relation = i in relations
# Add 'null' and 'blank', if the 'null_ok' flag was present in the att_name, params, notes = self.normalize_col_name(
# table description. column_name, used_column_names, is_relation)
if row[6]: # If it's NULL... extra_params.update(params)
if field_type == 'BooleanField(': comment_notes.extend(notes)
field_type = 'NullBooleanField('
used_column_names.append(att_name)
# Add primary_key and unique, if necessary.
if column_name in indexes:
if indexes[column_name]['primary_key']:
extra_params['primary_key'] = True
elif indexes[column_name]['unique']:
extra_params['unique'] = True
if is_relation:
rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to
else:
field_type = "ForeignKey('%s'" % rel_to
else: else:
extra_params['blank'] = True # Calling `get_field_type` to get the field type string and any
if not field_type in ('TextField(', 'CharField('): # additional paramters and notes.
extra_params['null'] = True field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
extra_params.update(field_params)
comment_notes.extend(field_notes)
field_desc = '%s = %s%s' % ( field_type += '('
att_name,
# Custom fields will have a dotted path # Don't output 'id = meta.AutoField(primary_key=True)', because
'' if '.' in field_type else 'models.', # that's assumed if it doesn't exist.
field_type, if att_name == 'id' and extra_params == {'primary_key': True}:
) if field_type == 'AutoField(':
if extra_params: continue
if not field_desc.endswith('('): elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
field_desc += ', ' comment_notes.append('AutoField?')
field_desc += ', '.join([
'%s=%s' % (k, strip_prefix(repr(v))) # Add 'null' and 'blank', if the 'null_ok' flag was present in the
for k, v in extra_params.items()]) # table description.
field_desc += ')' if row[6]: # If it's NULL...
if comment_notes: if field_type == 'BooleanField(':
field_desc += ' # ' + ' '.join(comment_notes) field_type = 'NullBooleanField('
yield ' %s' % field_desc else:
for meta_line in self.get_meta(table_name): extra_params['blank'] = True
yield meta_line if not field_type in ('TextField(', 'CharField('):
extra_params['null'] = True
field_desc = '%s = %s%s' % (
att_name,
# Custom fields will have a dotted path
'' if '.' in field_type else 'models.',
field_type,
)
if extra_params:
if not field_desc.endswith('('):
field_desc += ', '
field_desc += ', '.join([
'%s=%s' % (k, strip_prefix(repr(v)))
for k, v in extra_params.items()])
field_desc += ')'
if comment_notes:
field_desc += ' # ' + ' '.join(comment_notes)
yield ' %s' % field_desc
for meta_line in self.get_meta(table_name):
yield meta_line
def normalize_col_name(self, col_name, used_column_names, is_relation): def normalize_col_name(self, col_name, used_column_names, is_relation):
""" """

View File

@@ -100,10 +100,9 @@ class Command(BaseCommand):
if sequence_sql: if sequence_sql:
if self.verbosity >= 2: if self.verbosity >= 2:
self.stdout.write("Resetting sequences\n") self.stdout.write("Resetting sequences\n")
cursor = connection.cursor() with connection.cursor() as cursor:
for line in sequence_sql: for line in sequence_sql:
cursor.execute(line) cursor.execute(line)
cursor.close()
if self.verbosity >= 1: if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count: if self.fixture_object_count == self.loaded_object_count:

View File

@@ -171,105 +171,110 @@ class Command(BaseCommand):
"Runs the old syncdb-style operation on a list of app_labels." "Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor() cursor = connection.cursor()
# Get a list of already installed *models* so that references work right. try:
tables = connection.introspection.table_names() # Get a list of already installed *models* so that references work right.
seen_models = connection.introspection.installed_models(tables) tables = connection.introspection.table_names(cursor)
created_models = set() seen_models = connection.introspection.installed_models(tables)
pending_references = {} created_models = set()
pending_references = {}
# Build the manifest of apps and models that are to be synchronized # Build the manifest of apps and models that are to be synchronized
all_models = [ all_models = [
(app_config.label, (app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=True)) router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
for app_config in apps.get_app_configs() for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels if app_config.models_module is not None and app_config.label in app_labels
] ]
def model_installed(model): def model_installed(model):
opts = model._meta opts = model._meta
converter = connection.introspection.table_name_converter converter = connection.introspection.table_name_converter
# Note that if a model is unmanaged we short-circuit and never try to install it # Note that if a model is unmanaged we short-circuit and never try to install it
return not ((converter(opts.db_table) in tables) or return not ((converter(opts.db_table) in tables) or
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)) (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
manifest = OrderedDict( manifest = OrderedDict(
(app_name, list(filter(model_installed, model_list))) (app_name, list(filter(model_installed, model_list)))
for app_name, model_list in all_models for app_name, model_list in all_models
) )
create_models = set(itertools.chain(*manifest.values())) create_models = set(itertools.chain(*manifest.values()))
emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias) emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
# Create the tables for each model # Create the tables for each model
if self.verbosity >= 1: if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n") self.stdout.write(" Creating tables...\n")
with transaction.atomic(using=connection.alias, savepoint=False): with transaction.atomic(using=connection.alias, savepoint=False):
for app_name, model_list in manifest.items(): for app_name, model_list in manifest.items():
for model in model_list: for model in model_list:
# Create the model's database table, if it doesn't already exist. # Create the model's database table, if it doesn't already exist.
if self.verbosity >= 3: if self.verbosity >= 3:
self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name)) self.stdout.write(" Processing %s.%s model\n" % (app_name, model._meta.object_name))
sql, references = connection.creation.sql_create_model(model, no_style(), seen_models) sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
seen_models.add(model) seen_models.add(model)
created_models.add(model) created_models.add(model)
for refto, refs in references.items(): for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs) pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models: if refto in seen_models:
sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references)) sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references)) sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
if self.verbosity >= 1 and sql: if self.verbosity >= 1 and sql:
self.stdout.write(" Creating table %s\n" % model._meta.db_table) self.stdout.write(" Creating table %s\n" % model._meta.db_table)
for statement in sql: for statement in sql:
cursor.execute(statement) cursor.execute(statement)
tables.append(connection.introspection.table_name_converter(model._meta.db_table)) tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# We force a commit here, as that was the previous behaviour. # We force a commit here, as that was the previous behaviour.
# If you can prove we don't need this, remove it. # If you can prove we don't need this, remove it.
transaction.set_dirty(using=connection.alias) transaction.set_dirty(using=connection.alias)
finally:
cursor.close()
# The connection may have been closed by a syncdb handler. # The connection may have been closed by a syncdb handler.
cursor = connection.cursor() cursor = connection.cursor()
try:
# Install custom SQL for the app (but only if this
# is a model we've just created)
if self.verbosity >= 1:
self.stdout.write(" Installing custom SQL...\n")
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
custom_sql = custom_sql_for_model(model, no_style(), connection)
if custom_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in custom_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
if self.show_traceback:
traceback.print_exc()
else:
if self.verbosity >= 3:
self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
# Install custom SQL for the app (but only if this if self.verbosity >= 1:
# is a model we've just created) self.stdout.write(" Installing indexes...\n")
if self.verbosity >= 1:
self.stdout.write(" Installing custom SQL...\n")
for app_name, model_list in manifest.items():
for model in model_list:
if model in created_models:
custom_sql = custom_sql_for_model(model, no_style(), connection)
if custom_sql:
if self.verbosity >= 2:
self.stdout.write(" Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
try:
with transaction.commit_on_success_unless_managed(using=connection.alias):
for sql in custom_sql:
cursor.execute(sql)
except Exception as e:
self.stderr.write(" Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
if self.show_traceback:
traceback.print_exc()
else:
if self.verbosity >= 3:
self.stdout.write(" No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
if self.verbosity >= 1: # Install SQL indices for all newly created models
self.stdout.write(" Installing indexes...\n") for app_name, model_list in manifest.items():
for model in model_list:
# Install SQL indices for all newly created models if model in created_models:
for app_name, model_list in manifest.items(): index_sql = connection.creation.sql_indexes_for_model(model, no_style())
for model in model_list: if index_sql:
if model in created_models: if self.verbosity >= 2:
index_sql = connection.creation.sql_indexes_for_model(model, no_style()) self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
if index_sql: try:
if self.verbosity >= 2: with transaction.commit_on_success_unless_managed(using=connection.alias):
self.stdout.write(" Installing index for %s.%s model\n" % (app_name, model._meta.object_name)) for sql in index_sql:
try: cursor.execute(sql)
with transaction.commit_on_success_unless_managed(using=connection.alias): except Exception as e:
for sql in index_sql: self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
cursor.execute(sql) finally:
except Exception as e: cursor.close()
self.stderr.write(" Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
# Load initial_data fixtures (unless that has been disabled) # Load initial_data fixtures (unless that has been disabled)
if self.load_initial_data: if self.load_initial_data:

View File

@@ -67,38 +67,39 @@ def sql_delete(app_config, style, connection):
except Exception: except Exception:
cursor = None cursor = None
# Figure out which tables already exist try:
if cursor: # Figure out which tables already exist
table_names = connection.introspection.table_names(cursor) if cursor:
else: table_names = connection.introspection.table_names(cursor)
table_names = [] else:
table_names = []
output = [] output = []
# Output DROP TABLE statements for standard application tables. # Output DROP TABLE statements for standard application tables.
to_delete = set() to_delete = set()
references_to_delete = {} references_to_delete = {}
app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True) app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
for model in app_models: for model in app_models:
if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names: if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped # The table exists, so it needs to be dropped
opts = model._meta opts = model._meta
for f in opts.local_fields: for f in opts.local_fields:
if f.rel and f.rel.to not in to_delete: if f.rel and f.rel.to not in to_delete:
references_to_delete.setdefault(f.rel.to, []).append((model, f)) references_to_delete.setdefault(f.rel.to, []).append((model, f))
to_delete.add(model) to_delete.add(model)
for model in app_models: for model in app_models:
if connection.introspection.table_name_converter(model._meta.db_table) in table_names: if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
finally:
# Close database connection explicitly, in case this output is being piped # Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues. # directly into a database client, to avoid locking issues.
if cursor: if cursor:
cursor.close() cursor.close()
connection.close() connection.close()
return output[::-1] # Reverse it, to deal with table dependencies. return output[::-1] # Reverse it, to deal with table dependencies.

View File

@@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
##### Backend-specific savepoint management methods ##### ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid): def _savepoint(self, sid):
self.cursor().execute(self.ops.savepoint_create_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid): def _savepoint_rollback(self, sid):
self.cursor().execute(self.ops.savepoint_rollback_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid): def _savepoint_commit(self, sid):
self.cursor().execute(self.ops.savepoint_commit_sql(sid)) with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self): def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction # Savepoints cannot be created outside a transaction
@@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object):
# otherwise autocommit will cause the confimation to # otherwise autocommit will cause the confimation to
# fail. # fail.
self.connection.enter_transaction_management() self.connection.enter_transaction_management()
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.commit() self.connection.commit()
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback() self.connection.rollback()
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone() count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST') cursor.execute('DROP TABLE ROLLBACK_TEST')
self.connection.commit() self.connection.commit()
finally: finally:
self.connection.leave_transaction_management() self.connection.leave_transaction_management()
return count == 0 return count == 0
@@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
in sorting order between databases. in sorting order between databases.
""" """
if cursor is None: if cursor is None:
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
return sorted(self.get_table_list(cursor))
return sorted(self.get_table_list(cursor)) return sorted(self.get_table_list(cursor))
def get_table_list(self, cursor): def get_table_list(self, cursor):

View File

@@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
call_command('createcachetable', database=self.connection.alias) call_command('createcachetable', database=self.connection.alias)
# Get a cursor (even though we don't need one yet). This has # Ensure a connection for the side effect of initializing the test database.
# the side effect of initializing the test database. self.connection.ensure_connection()
self.connection.cursor()
return test_database_name return test_database_name
@@ -406,34 +405,34 @@ class BaseDatabaseCreation(object):
qn = self.connection.ops.quote_name qn = self.connection.ops.quote_name
# Create the test database and connect to it. # Create the test database and connect to it.
cursor = self._nodb_connection.cursor() with self._nodb_connection.cursor() as cursor:
try: try:
cursor.execute( cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception as e: except Exception as e:
sys.stderr.write( sys.stderr.write(
"Got an error creating the test database: %s\n" % e) "Got an error creating the test database: %s\n" % e)
if not autoclobber: if not autoclobber:
confirm = input( confirm = input(
"Type 'yes' if you would like to try deleting the test " "Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name) "database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes': if autoclobber or confirm == 'yes':
try: try:
if verbosity >= 1: if verbosity >= 1:
print("Destroying old test database '%s'..." print("Destroying old test database '%s'..."
% self.connection.alias) % self.connection.alias)
cursor.execute( cursor.execute(
"DROP DATABASE %s" % qn(test_database_name)) "DROP DATABASE %s" % qn(test_database_name))
cursor.execute( cursor.execute(
"CREATE DATABASE %s %s" % (qn(test_database_name), "CREATE DATABASE %s %s" % (qn(test_database_name),
suffix)) suffix))
except Exception as e: except Exception as e:
sys.stderr.write( sys.stderr.write(
"Got an error recreating the test database: %s\n" % e) "Got an error recreating the test database: %s\n" % e)
sys.exit(2) sys.exit(2)
else: else:
print("Tests cancelled.") print("Tests cancelled.")
sys.exit(1) sys.exit(1)
return test_database_name return test_database_name
@@ -461,11 +460,11 @@ class BaseDatabaseCreation(object):
# ourselves. Connect to the previous database (not the test database) # ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being # to do so, because it's not allowed to delete a database while being
# connected to it. # connected to it.
cursor = self._nodb_connection.cursor() with self._nodb_connection.cursor() as cursor:
# Wait to avoid "database is being accessed by other users" errors. # Wait to avoid "database is being accessed by other users" errors.
time.sleep(1) time.sleep(1)
cursor.execute("DROP DATABASE %s" cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name)) % self.connection.ops.quote_name(test_database_name))
def set_autocommit(self): def set_autocommit(self):
""" """

View File

@@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property @cached_property
def _mysql_storage_engine(self): def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code" "Internal method used in Django tests. Don't rely on this from your code"
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)') cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
# This command is MySQL specific; the second column # This command is MySQL specific; the second column
# will tell you the default table type of the created # will tell you the default table type of the created
# table. Since all Django's test tables will have the same # table. Since all Django's test tables will have the same
# table type, that's enough to evaluate the feature. # table type, that's enough to evaluate the feature.
cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'") cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
result = cursor.fetchone() result = cursor.fetchone()
cursor.execute('DROP TABLE INTROSPECT_TEST') cursor.execute('DROP TABLE INTROSPECT_TEST')
return result[1] return result[1]
@cached_property @cached_property
@@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False return False
# Test if the time zone definitions are installed. # Test if the time zone definitions are installed.
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None return cursor.fetchone() is not None
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
@@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return conn return conn
def init_connection_state(self): def init_connection_state(self):
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for # on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with # NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards. # SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0') cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def create_cursor(self): def create_cursor(self):
cursor = self.connection.cursor() cursor = self.connection.cursor()

View File

@@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL)
def regex_lookup(self, lookup_type): def regex_lookup(self, lookup_type):
# If regex_lookup is called before it's been initialized, then create # If regex_lookup is called before it's been initialized, then create
# a cursor to initialize it and recur. # a cursor to initialize it and recur.
self.connection.cursor() with self.connection.cursor():
return self.connection.ops.regex_lookup(lookup_type) return self.connection.ops.regex_lookup(lookup_type)
def return_insert_id(self): def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),) return "RETURNING %s INTO %%s", (InsertIdVar(),)

View File

@@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if conn_tz != tz: if conn_tz != tz:
cursor = self.connection.cursor() cursor = self.connection.cursor()
cursor.execute(self.ops.set_time_zone_sql(), [tz]) try:
cursor.close() cursor.execute(self.ops.set_time_zone_sql(), [tz])
finally:
cursor.close()
# Commit after setting the time zone (see #17062) # Commit after setting the time zone (see #17062)
if not self.get_autocommit(): if not self.get_autocommit():
self.connection.commit() self.connection.commit()

View File

@@ -39,6 +39,6 @@ def get_version(connection):
if hasattr(connection, 'server_version'): if hasattr(connection, 'server_version'):
return connection.server_version return connection.server_version
else: else:
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("SELECT version()") cursor.execute("SELECT version()")
return _parse_version(cursor.fetchone()[0]) return _parse_version(cursor.fetchone()[0])

View File

@@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object):
""" """
Executes the given SQL statement, with optional parameters. Executes the given SQL statement, with optional parameters.
""" """
# Get the cursor
cursor = self.connection.cursor()
# Log the command we're running, then run it # Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params)) logger.debug("%s; (params %r)" % (sql, params))
if self.collect_sql: if self.collect_sql:
self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";") self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
else: else:
cursor.execute(sql, params) with self.connection.cursor() as cursor:
cursor.execute(sql, params)
def quote_name(self, name): def quote_name(self, name):
return self.connection.ops.quote_name(name) return self.connection.ops.quote_name(name)
@@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
Returns all constraint names matching the columns and conditions Returns all constraint names matching the columns and conditions
""" """
column_names = list(column_names) if column_names else None column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
result = [] result = []
for name, infodict in constraints.items(): for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']: if column_names is None or column_names == infodict['columns']:

View File

@@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
rule out support for STDDEV. We need to manually check rule out support for STDDEV. We need to manually check
whether the call works. whether the call works.
""" """
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)') cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try: try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST') cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
has_support = True has_support = True
except utils.DatabaseError: except utils.DatabaseError:
has_support = False has_support = False
cursor.execute('DROP TABLE STDDEV_TEST') cursor.execute('DROP TABLE STDDEV_TEST')
return has_support return has_support
@cached_property @cached_property

View File

@@ -1522,54 +1522,59 @@ class RawQuerySet(object):
query = iter(self.query) query = iter(self.query)
# Find out which columns are model's fields, and which ones should be try:
# annotated to the model. # Find out which columns are model's fields, and which ones should be
for pos, column in enumerate(self.columns): # annotated to the model.
if column in self.model_fields: for pos, column in enumerate(self.columns):
model_init_field_names[self.model_fields[column].attname] = pos if column in self.model_fields:
else: model_init_field_names[self.model_fields[column].attname] = pos
annotation_fields.append((column, pos)) else:
annotation_fields.append((column, pos))
# Find out which model's fields are not present in the query. # Find out which model's fields are not present in the query.
skip = set() skip = set()
for field in self.model._meta.fields:
if field.attname not in model_init_field_names:
skip.add(field.attname)
if skip:
if self.model._meta.pk.attname in skip:
raise InvalidQuery('Raw query must include the primary key')
model_cls = deferred_class_factory(self.model, skip)
else:
model_cls = self.model
# All model's fields are present in the query. So, it is possible
# to use *args based model instantation. For each field of the model,
# record the query column position matching that field.
model_init_field_pos = []
for field in self.model._meta.fields: for field in self.model._meta.fields:
model_init_field_pos.append(model_init_field_names[field.attname]) if field.attname not in model_init_field_names:
if need_resolv_columns: skip.add(field.attname)
fields = [self.model_fields.get(c, None) for c in self.columns]
# Begin looping through the query values.
for values in query:
if need_resolv_columns:
values = compiler.resolve_columns(values, fields)
# Associate fields to values
if skip: if skip:
model_init_kwargs = {} if self.model._meta.pk.attname in skip:
for attname, pos in six.iteritems(model_init_field_names): raise InvalidQuery('Raw query must include the primary key')
model_init_kwargs[attname] = values[pos] model_cls = deferred_class_factory(self.model, skip)
instance = model_cls(**model_init_kwargs)
else: else:
model_init_args = [values[pos] for pos in model_init_field_pos] model_cls = self.model
instance = model_cls(*model_init_args) # All model's fields are present in the query. So, it is possible
if annotation_fields: # to use *args based model instantation. For each field of the model,
for column, pos in annotation_fields: # record the query column position matching that field.
setattr(instance, column, values[pos]) model_init_field_pos = []
for field in self.model._meta.fields:
model_init_field_pos.append(model_init_field_names[field.attname])
if need_resolv_columns:
fields = [self.model_fields.get(c, None) for c in self.columns]
# Begin looping through the query values.
for values in query:
if need_resolv_columns:
values = compiler.resolve_columns(values, fields)
# Associate fields to values
if skip:
model_init_kwargs = {}
for attname, pos in six.iteritems(model_init_field_names):
model_init_kwargs[attname] = values[pos]
instance = model_cls(**model_init_kwargs)
else:
model_init_args = [values[pos] for pos in model_init_field_pos]
instance = model_cls(*model_init_args)
if annotation_fields:
for column, pos in annotation_fields:
setattr(instance, column, values[pos])
instance._state.db = db instance._state.db = db
instance._state.adding = False instance._state.adding = False
yield instance yield instance
finally:
# Done iterating the Query. If it has its own cursor, close it.
if hasattr(self.query, 'cursor') and self.query.cursor:
self.query.cursor.close()
def __repr__(self): def __repr__(self):
text = self.raw_query text = self.raw_query

View File

@@ -1,4 +1,5 @@
import datetime import datetime
import sys
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
@@ -777,7 +778,7 @@ class SQLCompiler(object):
cursor = self.connection.cursor() cursor = self.connection.cursor()
try: try:
cursor.execute(sql, params) cursor.execute(sql, params)
except: except Exception:
cursor.close() cursor.close()
raise raise
@@ -908,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler):
def execute_sql(self, return_id=False): def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1) assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id self.return_id = return_id
cursor = self.connection.cursor() with self.connection.cursor() as cursor:
for sql, params in self.as_sql(): for sql, params in self.as_sql():
cursor.execute(sql, params) cursor.execute(sql, params)
if not (return_id and cursor): if not (return_id and cursor):
return return
if self.connection.features.can_return_id_from_insert: if self.connection.features.can_return_id_from_insert:
return self.connection.ops.fetch_returned_insert_id(cursor) return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor, return self.connection.ops.last_insert_id(cursor,
self.query.get_meta().db_table, self.query.get_meta().pk.column) self.query.get_meta().db_table, self.query.get_meta().pk.column)
class SQLDeleteCompiler(SQLCompiler): class SQLDeleteCompiler(SQLCompiler):

View File

@@ -59,9 +59,9 @@ class OracleChecks(unittest.TestCase):
# stored procedure through our cursor wrapper. # stored procedure through our cursor wrapper.
from django.db.backends.oracle.base import convert_unicode from django.db.backends.oracle.base import convert_unicode
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!')]) [convert_unicode('_django_testing!')])
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics") "No need to check Oracle cursor semantics")
@@ -70,31 +70,31 @@ class OracleChecks(unittest.TestCase):
# as query parameters. # as query parameters.
from django.db.backends.oracle.base import Database from django.db.backends.oracle.base import Database
cursor = connection.cursor() with connection.cursor() as cursor:
var = cursor.var(Database.STRING) var = cursor.var(Database.STRING)
cursor.execute("BEGIN %s := 'X'; END; ", [var]) cursor.execute("BEGIN %s := 'X'; END; ", [var])
self.assertEqual(var.getvalue(), 'X') self.assertEqual(var.getvalue(), 'X')
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics") "No need to check Oracle cursor semantics")
def test_long_string(self): def test_long_string(self):
# If the backend is Oracle, test that we can save a text longer # If the backend is Oracle, test that we can save a text longer
# than 4000 chars and read it properly # than 4000 chars and read it properly
c = connection.cursor() with connection.cursor() as cursor:
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)') cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join(six.text_type(x) for x in xrange(4000)) long_str = ''.join(six.text_type(x) for x in xrange(4000))
c.execute('INSERT INTO ltext VALUES (%s)', [long_str]) cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext') cursor.execute('SELECT text FROM ltext')
row = c.fetchone() row = cursor.fetchone()
self.assertEqual(long_str, row[0].read()) self.assertEqual(long_str, row[0].read())
c.execute('DROP TABLE ltext') cursor.execute('DROP TABLE ltext')
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle connection semantics") "No need to check Oracle connection semantics")
def test_client_encoding(self): def test_client_encoding(self):
# If the backend is Oracle, test that the client encoding is set # If the backend is Oracle, test that the client encoding is set
# correctly. This was broken under Cygwin prior to r14781. # correctly. This was broken under Cygwin prior to r14781.
connection.cursor() # Ensure the connection is initialized. self.connection.ensure_connection()
self.assertEqual(connection.connection.encoding, "UTF-8") self.assertEqual(connection.connection.encoding, "UTF-8")
self.assertEqual(connection.connection.nencoding, "UTF-8") self.assertEqual(connection.connection.nencoding, "UTF-8")
@@ -103,12 +103,12 @@ class OracleChecks(unittest.TestCase):
def test_order_of_nls_parameters(self): def test_order_of_nls_parameters(self):
# an 'almost right' datetime should work with configured # an 'almost right' datetime should work with configured
# NLS parameters as per #18465. # NLS parameters as per #18465.
c = connection.cursor() with connection.cursor() as cursor:
query = "select 1 from dual where '1936-12-29 00:00' < sysdate" query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
# Test that the query succeeds without errors - pre #18465 this # Test that the query succeeds without errors - pre #18465 this
# wasn't the case. # wasn't the case.
c.execute(query) cursor.execute(query)
self.assertEqual(c.fetchone()[0], 1) self.assertEqual(cursor.fetchone()[0], 1)
class SQLiteTests(TestCase): class SQLiteTests(TestCase):
@@ -328,6 +328,12 @@ class PostgresVersionTest(TestCase):
def fetchone(self): def fetchone(self):
return ["PostgreSQL 8.3"] return ["PostgreSQL 8.3"]
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class OlderConnectionMock(object): class OlderConnectionMock(object):
"Mock of psycopg2 (< 2.0.12) connection" "Mock of psycopg2 (< 2.0.12) connection"
def cursor(self): def cursor(self):

View File

@@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
management.call_command('createcachetable', verbosity=0, interactive=False) management.call_command('createcachetable', verbosity=0, interactive=False)
def drop_table(self): def drop_table(self):
cursor = connection.cursor() with connection.cursor() as cursor:
table_name = connection.ops.quote_name('test cache table') table_name = connection.ops.quote_name('test cache table')
cursor.execute('DROP TABLE %s' % table_name) cursor.execute('DROP TABLE %s' % table_name)
cursor.close()
def test_zero_cull(self): def test_zero_cull(self):
self._perform_cull_test(caches['zero_cull'], 50, 18) self._perform_cull_test(caches['zero_cull'], 50, 18)

View File

@@ -30,11 +30,11 @@ class Article(models.Model):
database query for the sake of demonstration. database query for the sake of demonstration.
""" """
from django.db import connection from django.db import connection
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute(""" cursor.execute("""
SELECT id, headline, pub_date SELECT id, headline, pub_date
FROM custom_methods_article FROM custom_methods_article
WHERE pub_date = %s WHERE pub_date = %s
AND id != %s""", [connection.ops.value_to_db_date(self.pub_date), AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
self.id]) self.id])
return [self.__class__(*row) for row in cursor.fetchall()] return [self.__class__(*row) for row in cursor.fetchall()]

View File

@@ -28,9 +28,9 @@ class InitialSQLTests(TestCase):
connection = connections[DEFAULT_DB_ALIAS] connection = connections[DEFAULT_DB_ALIAS]
custom_sql = custom_sql_for_model(Simple, no_style(), connection) custom_sql = custom_sql_for_model(Simple, no_style(), connection)
self.assertEqual(len(custom_sql), 9) self.assertEqual(len(custom_sql), 9)
cursor = connection.cursor() with connection.cursor() as cursor:
for sql in custom_sql: for sql in custom_sql:
cursor.execute(sql) cursor.execute(sql)
self.assertEqual(Simple.objects.count(), 9) self.assertEqual(Simple.objects.count(), 9)
self.assertEqual( self.assertEqual(
Simple.objects.get(name__contains='placeholders').name, Simple.objects.get(name__contains='placeholders').name,

View File

@@ -23,17 +23,17 @@ class IntrospectionTests(TestCase):
"'%s' isn't in table_list()." % Article._meta.db_table) "'%s' isn't in table_list()." % Article._meta.db_table)
def test_django_table_names(self): def test_django_table_names(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names() tl = connection.introspection.django_table_names()
cursor.execute("DROP TABLE django_ixn_test_table;") cursor.execute("DROP TABLE django_ixn_test_table;")
self.assertTrue('django_ixn_testcase_table' not in tl, self.assertTrue('django_ixn_testcase_table' not in tl,
"django_table_names() returned a non-Django table") "django_table_names() returned a non-Django table")
def test_django_table_names_retval_type(self): def test_django_table_names_retval_type(self):
# Ticket #15216 # Ticket #15216
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);') cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
tl = connection.introspection.django_table_names(only_existing=True) tl = connection.introspection.django_table_names(only_existing=True)
self.assertIs(type(tl), list) self.assertIs(type(tl), list)
@@ -53,14 +53,14 @@ class IntrospectionTests(TestCase):
'Reporter sequence not found in sequence_list()') 'Reporter sequence not found in sequence_list()')
def test_get_table_description_names(self): def test_get_table_description_names(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual([r[0] for r in desc], self.assertEqual([r[0] for r in desc],
[f.column for f in Reporter._meta.fields]) [f.column for f in Reporter._meta.fields])
def test_get_table_description_types(self): def test_get_table_description_types(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
# The MySQL exception is due to the cursor.description returning the same constant for # The MySQL exception is due to the cursor.description returning the same constant for
# text and blob columns. TODO: use information_schema database to retrieve the proper # text and blob columns. TODO: use information_schema database to retrieve the proper
# field type on MySQL # field type on MySQL
@@ -75,8 +75,8 @@ class IntrospectionTests(TestCase):
# inspect the length of character columns). # inspect the length of character columns).
@expectedFailureOnOracle @expectedFailureOnOracle
def test_get_table_description_col_lengths(self): def test_get_table_description_col_lengths(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual( self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'], [r[3] for r in desc if datatype(r[1], r) == 'CharField'],
[30, 30, 75] [30, 30, 75]
@@ -87,8 +87,8 @@ class IntrospectionTests(TestCase):
# so its idea about null_ok in cursor.description is different from ours. # so its idea about null_ok in cursor.description is different from ours.
@skipIfDBFeature('interprets_empty_strings_as_nulls') @skipIfDBFeature('interprets_empty_strings_as_nulls')
def test_get_table_description_nullable(self): def test_get_table_description_nullable(self):
cursor = connection.cursor() with connection.cursor() as cursor:
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table) desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual( self.assertEqual(
[r[6] for r in desc], [r[6] for r in desc],
[False, False, False, False, True, True] [False, False, False, False, True, True]
@@ -97,15 +97,15 @@ class IntrospectionTests(TestCase):
# Regression test for #9991 - 'real' types in postgres # Regression test for #9991 - 'real' types in postgres
@skipUnlessDBFeature('has_real_datatype') @skipUnlessDBFeature('has_real_datatype')
def test_postgresql_real_type(self): def test_postgresql_real_type(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);") cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table') desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
cursor.execute('DROP TABLE django_ixn_real_test_table;') cursor.execute('DROP TABLE django_ixn_real_test_table;')
self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField') self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
def test_get_relations(self): def test_get_relations(self):
cursor = connection.cursor() with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table) relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# Older versions of MySQL don't have the chops to report on this stuff, # Older versions of MySQL don't have the chops to report on this stuff,
# so just skip it if no relations come back. If they do, though, we # so just skip it if no relations come back. If they do, though, we
@@ -117,21 +117,21 @@ class IntrospectionTests(TestCase):
@skipUnlessDBFeature('can_introspect_foreign_keys') @skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self): def test_get_key_columns(self):
cursor = connection.cursor() with connection.cursor() as cursor:
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table) key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual( self.assertEqual(
set(key_columns), set(key_columns),
set([('reporter_id', Reporter._meta.db_table, 'id'), set([('reporter_id', Reporter._meta.db_table, 'id'),
('response_to_id', Article._meta.db_table, 'id')])) ('response_to_id', Article._meta.db_table, 'id')]))
def test_get_primary_key_column(self): def test_get_primary_key_column(self):
cursor = connection.cursor() with connection.cursor() as cursor:
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, 'id') self.assertEqual(primary_key_column, 'id')
def test_get_indexes(self): def test_get_indexes(self):
cursor = connection.cursor() with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False}) self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
def test_get_indexes_multicol(self): def test_get_indexes_multicol(self):
@@ -139,8 +139,8 @@ class IntrospectionTests(TestCase):
Test that multicolumn indexes are not included in the introspection Test that multicolumn indexes are not included in the introspection
results. results.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table) indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
self.assertNotIn('first_name', indexes) self.assertNotIn('first_name', indexes)
self.assertIn('id', indexes) self.assertIn('id', indexes)

View File

@@ -9,33 +9,40 @@ class MigrationTestBase(TransactionTestCase):
available_apps = ["migrations"] available_apps = ["migrations"]
def get_table_description(self, table):
with connection.cursor() as cursor:
return connection.introspection.get_table_description(cursor, table)
def assertTableExists(self, table): def assertTableExists(self, table):
self.assertIn(table, connection.introspection.get_table_list(connection.cursor())) with connection.cursor() as cursor:
self.assertIn(table, connection.introspection.get_table_list(cursor))
def assertTableNotExists(self, table): def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor())) with connection.cursor() as cursor:
self.assertNotIn(table, connection.introspection.get_table_list(cursor))
def assertColumnExists(self, table, column): def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) self.assertIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNotExists(self, table, column): def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)]) self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
def assertColumnNull(self, table, column): def assertColumnNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True) self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
def assertColumnNotNull(self, table, column): def assertColumnNotNull(self, table, column):
self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False) self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
def assertIndexExists(self, table, columns, value=True): def assertIndexExists(self, table, columns, value=True):
self.assertEqual( with connection.cursor() as cursor:
value, self.assertEqual(
any( value,
c["index"] any(
for c in connection.introspection.get_constraints(connection.cursor(), table).values() c["index"]
if c['columns'] == list(columns) for c in connection.introspection.get_constraints(cursor, table).values()
), if c['columns'] == list(columns)
) ),
)
def assertIndexNotExists(self, table, columns): def assertIndexNotExists(self, table, columns):
return self.assertIndexExists(table, columns, False) return self.assertIndexExists(table, columns, False)

View File

@@ -19,15 +19,15 @@ class OperationTests(MigrationTestBase):
Creates a test model state and database table. Creates a test model state and database table.
""" """
# Delete the tables if they already exist # Delete the tables if they already exist
cursor = connection.cursor() with connection.cursor() as cursor:
try: try:
cursor.execute("DROP TABLE %s_pony" % app_label) cursor.execute("DROP TABLE %s_pony" % app_label)
except: except:
pass pass
try: try:
cursor.execute("DROP TABLE %s_stable" % app_label) cursor.execute("DROP TABLE %s_stable" % app_label)
except: except:
pass pass
# Make the "current" state # Make the "current" state
operations = [migrations.CreateModel( operations = [migrations.CreateModel(
"Pony", "Pony",
@@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
operation.state_forwards("test_alflpkfk", new_state) operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
def assertIdTypeEqualsFkType(self):
with connection.cursor() as cursor:
id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
assertIdTypeEqualsFkType()
# Test the database alteration # Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state) operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] assertIdTypeEqualsFkType()
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
# And test reversal # And test reversal
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state) operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] assertIdTypeEqualsFkType()
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
def test_rename_field(self): def test_rename_field(self):
""" """
@@ -400,24 +400,24 @@ class OperationTests(MigrationTestBase):
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0) self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1) self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows # Make sure we can insert duplicate rows
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony") cursor.execute("DELETE FROM test_alunto_pony")
# Test the database alteration # Test the database alteration
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_forwards("test_alunto", editor, project_state, new_state) operation.database_forwards("test_alunto", editor, project_state, new_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
with atomic(): with atomic():
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony") cursor.execute("DELETE FROM test_alunto_pony")
# And test reversal # And test reversal
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_backwards("test_alunto", editor, new_state, project_state) operation.database_backwards("test_alunto", editor, new_state, project_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony") cursor.execute("DELETE FROM test_alunto_pony")
# Test flat unique_together # Test flat unique_together
operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight")) operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight"))
operation.state_forwards("test_alunto", new_state) operation.state_forwards("test_alunto", new_state)

View File

@@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
# request_finished signal. # request_finished signal.
response = self.client.get('/') response = self.client.get('/')
# Make sure there is an open connection # Make sure there is an open connection
connection.cursor() self.connection.ensure_connection()
connection.enter_transaction_management() connection.enter_transaction_management()
signals.request_finished.send(sender=response._handler_class) signals.request_finished.send(sender=response._handler_class)
self.assertEqual(len(connection.transaction_state), 0) self.assertEqual(len(connection.transaction_state), 0)

View File

@@ -37,38 +37,38 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self): def delete_tables(self):
"Deletes all model tables for our models for a clean test environment" "Deletes all model tables for our models for a clean test environment"
cursor = connection.cursor() with connection.cursor() as cursor:
connection.disable_constraint_checking() connection.disable_constraint_checking()
table_names = connection.introspection.table_names(cursor) table_names = connection.introspection.table_names(cursor)
for model in self.models: for model in self.models:
# Remove any M2M tables first # Remove any M2M tables first
for field in model._meta.local_many_to_many: for field in model._meta.local_many_to_many:
with atomic():
tbl = field.rel.through._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
# Then remove the main tables
with atomic(): with atomic():
tbl = field.rel.through._meta.db_table tbl = model._meta.db_table
if tbl in table_names: if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % { cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl), "table": connection.ops.quote_name(tbl),
}) })
table_names.remove(tbl) table_names.remove(tbl)
# Then remove the main tables
with atomic():
tbl = model._meta.db_table
if tbl in table_names:
cursor.execute(connection.schema_editor().sql_delete_table % {
"table": connection.ops.quote_name(tbl),
})
table_names.remove(tbl)
connection.enable_constraint_checking() connection.enable_constraint_checking()
def column_classes(self, model): def column_classes(self, model):
cursor = connection.cursor() with connection.cursor() as cursor:
columns = dict( columns = dict(
(d[0], (connection.introspection.get_field_type(d[1], d), d)) (d[0], (connection.introspection.get_field_type(d[1], d), d))
for d in connection.introspection.get_table_description( for d in connection.introspection.get_table_description(
cursor, cursor,
model._meta.db_table, model._meta.db_table,
)
) )
)
# SQLite has a different format for field_type # SQLite has a different format for field_type
for name, (type, desc) in columns.items(): for name, (type, desc) in columns.items():
if isinstance(type, tuple): if isinstance(type, tuple):
@@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
raise DatabaseError("Table does not exist (empty pragma)") raise DatabaseError("Table does not exist (empty pragma)")
return columns return columns
def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_indexes(cursor, table)
def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
"""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
# Tests # Tests
def test_creation_deletion(self): def test_creation_deletion(self):
@@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
strict=True, strict=True,
) )
# Make sure the new FK constraint is present # Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table) constraints = self.get_constraints(Book._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["author_id"] and details['foreign_key']: if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id')) self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(TagM2MTest) editor.create_model(TagM2MTest)
editor.create_model(UniqueTest) editor.create_model(UniqueTest)
# Ensure the M2M exists and points to TagM2MTest # Ensure the M2M exists and points to TagM2MTest
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table) constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']: if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
# Ensure old M2M is gone # Ensure old M2M is gone
self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through) self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
# Ensure the new M2M exists and points to UniqueTest # Ensure the new M2M exists and points to UniqueTest
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table) constraints = self.get_constraints(new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["uniquetest_id"] and details['foreign_key']: if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_model(Author) editor.create_model(Author)
# Ensure the constraint exists # Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
break break
@@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
new_field, new_field,
strict=True, strict=True,
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found") self.fail("Check constraint for height found")
@@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
Author._meta.get_field_by_name("height")[0], Author._meta.get_field_by_name("height")[0],
strict=True, strict=True,
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = self.get_constraints(Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == ["height"] and details['check']: if details['columns'] == ["height"] and details['check']:
break break
@@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
False, False,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
True, True,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
False, False,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in self.get_constraints("schema_tag").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
True, True,
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values() for c in self.get_constraints("schema_tagindexed").values()
if c['columns'] == ["slug", "title"] if c['columns'] == ["slug", "title"]
), ),
) )
@@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the right index # Ensure the table is there and has the right index
self.assertIn( self.assertIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Alter to remove the index # Alter to remove the index
new_field = CharField(max_length=100, db_index=False) new_field = CharField(max_length=100, db_index=False)
@@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has no index # Ensure the table is there and has no index
self.assertNotIn( self.assertNotIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Alter to re-add the index # Alter to re-add the index
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has the index again # Ensure the table is there and has the index again
self.assertIn( self.assertIn(
"title", "title",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Add a unique column, verify that creates an implicit index # Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
) )
self.assertIn( self.assertIn(
"slug", "slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
# Remove the unique, check the index goes with it # Remove the unique, check the index goes with it
new_field2 = CharField(max_length=20, unique=False) new_field2 = CharField(max_length=20, unique=False)
@@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
) )
self.assertNotIn( self.assertNotIn(
"slug", "slug",
connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), self.get_indexes(Book._meta.db_table),
) )
def test_primary_key(self): def test_primary_key(self):
@@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
editor.create_model(Tag) editor.create_model(Tag)
# Ensure the table is there and has the right PK # Ensure the table is there and has the right PK
self.assertTrue( self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'], self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
) )
# Alter to change the PK # Alter to change the PK
new_field = SlugField(primary_key=True) new_field = SlugField(primary_key=True)
@@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
# Ensure the PK changed # Ensure the PK changed
self.assertNotIn( self.assertNotIn(
'id', 'id',
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table), self.get_indexes(Tag._meta.db_table),
) )
self.assertTrue( self.assertTrue(
connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'], self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
) )
def test_context_manager_exit(self): def test_context_manager_exit(self):
@@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the table is there and has an index on the column # Ensure the table is there and has an index on the column
self.assertIn( self.assertIn(
column_name, column_name,
connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table), self.get_indexes(BookWithLongName._meta.db_table),
) )
def test_creation_deletion_reserved_names(self): def test_creation_deletion_reserved_names(self):

View File

@@ -202,8 +202,9 @@ class AtomicTests(TransactionTestCase):
# trigger a database error inside an inner atomic without savepoint # trigger a database error inside an inner atomic without savepoint
with self.assertRaises(DatabaseError): with self.assertRaises(DatabaseError):
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
connection.cursor().execute( with connection.cursor() as cursor:
"SELECT no_such_col FROM transactions_reporter") cursor.execute(
"SELECT no_such_col FROM transactions_reporter")
# prevent atomic from rolling back since we're recovering manually # prevent atomic from rolling back since we're recovering manually
self.assertTrue(transaction.get_rollback()) self.assertTrue(transaction.get_rollback())
transaction.set_rollback(False) transaction.set_rollback(False)
@@ -534,8 +535,8 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
available_apps = ['transactions'] available_apps = ['transactions']
def execute_bad_sql(self): def execute_bad_sql(self):
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
@skipUnlessDBFeature('requires_rollback_on_dirty_transaction') @skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
def test_bad_sql(self): def test_bad_sql(self):
@@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
""" """
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
with transaction.commit_on_success(): with transaction.commit_on_success():
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');") cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
transaction.rollback() transaction.rollback()

View File

@@ -54,8 +54,8 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
@commit_on_success @commit_on_success
def raw_sql(): def raw_sql():
"Write a record using raw sql under a commit_on_success decorator" "Write a record using raw sql under a commit_on_success decorator"
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (18)") cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
raw_sql() raw_sql()
# Rollback so that if the decorator didn't commit, the record is unwritten # Rollback so that if the decorator didn't commit, the record is unwritten
@@ -143,10 +143,10 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
(reference). All this under commit_on_success, so the second insert should (reference). All this under commit_on_success, so the second insert should
be committed. be committed.
""" """
cursor = connection.cursor() with connection.cursor() as cursor:
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
transaction.rollback() transaction.rollback()
cursor.execute("INSERT into transactions_regress_mod (fld) values (2)") cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
reuse_cursor_ref() reuse_cursor_ref()
# Rollback so that if the decorator didn't commit, the record is unwritten # Rollback so that if the decorator didn't commit, the record is unwritten