diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 07fb700eaa..b93453bc1e 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -65,6 +65,48 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') # standard util.CursorDebugWrapper can be used. Also, using sql_mode # TRADITIONAL will automatically cause most warnings to be treated as errors. +class CursorWrapper(object): + """ + A thin wrapper around MySQLdb's normal cursor class so that we can catch + particular exception instances and reraise them with the right types. + + Implemented as a wrapper, rather than a subclass, so that we aren't stuck + to the particular underlying representation returned by Connection.cursor(). + """ + codes_for_integrityerror = (1048,) + + def __init__(self, cursor): + self.cursor = cursor + + def execute(self, query, args=None): + try: + return self.cursor.execute(query, args) + except Database.OperationalError, e: + # Map some error codes to IntegrityError, since they seem to be + # misclassified and Django would prefer the more logical place. + if e[0] in self.codes_for_integrityerror: + raise Database.IntegrityError(tuple(e)) + raise + + def executemany(self, query, args): + try: + return self.cursor.executemany(query, args) + except Database.OperationalError, e: + # Map some error codes to IntegrityError, since they seem to be + # misclassified and Django would prefer the more logical place. + if e[0] in self.codes_for_integrityerror: + raise Database.IntegrityError(tuple(e)) + raise + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + class DatabaseFeatures(BaseDatabaseFeatures): empty_fetchmany_value = () update_can_self_select = False @@ -207,7 +249,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): kwargs['port'] = int(settings.DATABASE_PORT) kwargs.update(self.options) self.connection = Database.connect(**kwargs) - cursor = self.connection.cursor() + cursor = CursorWrapper(self.connection.cursor()) return cursor def _rollback(self):