1
0
mirror of https://github.com/django/django.git synced 2025-08-08 19:09:15 +00:00

Refactor all uses of thread locals to be more consistant and sane.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@15232 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2011-01-17 09:52:47 +00:00
parent 964cf1be86
commit fcbf881d82
6 changed files with 83 additions and 100 deletions

View File

@ -8,6 +8,7 @@ a string) and returns a tuple in this format:
""" """
import re import re
from threading import local
from django.http import Http404 from django.http import Http404
from django.conf import settings from django.conf import settings
@ -17,7 +18,6 @@ from django.utils.encoding import iri_to_uri, force_unicode, smart_str
from django.utils.functional import memoize from django.utils.functional import memoize
from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.regex_helper import normalize from django.utils.regex_helper import normalize
from django.utils.thread_support import currentThread
_resolver_cache = {} # Maps URLconf modules to RegexURLResolver instances. _resolver_cache = {} # Maps URLconf modules to RegexURLResolver instances.
_callable_cache = {} # Maps view and url pattern names to their view functions. _callable_cache = {} # Maps view and url pattern names to their view functions.
@ -25,10 +25,11 @@ _callable_cache = {} # Maps view and url pattern names to their view functions.
# SCRIPT_NAME prefixes for each thread are stored here. If there's no entry for # SCRIPT_NAME prefixes for each thread are stored here. If there's no entry for
# the current thread (which is the only one we ever access), it is assumed to # the current thread (which is the only one we ever access), it is assumed to
# be empty. # be empty.
_prefixes = {} _prefixes = local()
# Overridden URLconfs for each thread are stored here. # Overridden URLconfs for each thread are stored here.
_urlconfs = {} _urlconfs = local()
class ResolverMatch(object): class ResolverMatch(object):
def __init__(self, func, args, kwargs, url_name=None, app_name=None, namespaces=None): def __init__(self, func, args, kwargs, url_name=None, app_name=None, namespaces=None):
@ -401,7 +402,7 @@ def set_script_prefix(prefix):
""" """
if not prefix.endswith('/'): if not prefix.endswith('/'):
prefix += '/' prefix += '/'
_prefixes[currentThread()] = prefix _prefixes.value = prefix
def get_script_prefix(): def get_script_prefix():
""" """
@ -409,27 +410,22 @@ def get_script_prefix():
wishes to construct their own URLs manually (although accessing the request wishes to construct their own URLs manually (although accessing the request
instance is normally going to be a lot cleaner). instance is normally going to be a lot cleaner).
""" """
return _prefixes.get(currentThread(), u'/') return getattr(_prefixes, "value", u'/')
def set_urlconf(urlconf_name): def set_urlconf(urlconf_name):
""" """
Sets the URLconf for the current thread (overriding the default one in Sets the URLconf for the current thread (overriding the default one in
settings). Set to None to revert back to the default. settings). Set to None to revert back to the default.
""" """
thread = currentThread()
if urlconf_name: if urlconf_name:
_urlconfs[thread] = urlconf_name _urlconfs.value = urlconf_name
else: else:
# faster than wrapping in a try/except if hasattr(_urlconfs, "value"):
if thread in _urlconfs: del _urlconfs.value
del _urlconfs[thread]
def get_urlconf(default=None): def get_urlconf(default=None):
""" """
Returns the root URLconf to use for the current thread if it has been Returns the root URLconf to use for the current thread if it has been
changed from the default one. changed from the default one.
""" """
thread = currentThread() return getattr(_urlconfs, "value", default)
if thread in _urlconfs:
return _urlconfs[thread]
return default

View File

@ -25,6 +25,11 @@ class BaseDatabaseWrapper(local):
self.alias = alias self.alias = alias
self.use_debug_cursor = None self.use_debug_cursor = None
# Transaction related attributes
self.transaction_state = []
self.savepoint_state = 0
self.dirty = None
def __eq__(self, other): def __eq__(self, other):
return self.alias == other.alias return self.alias == other.alias

View File

@ -25,6 +25,7 @@ except ImportError:
from django.conf import settings from django.conf import settings
from django.db import connections, DEFAULT_DB_ALIAS from django.db import connections, DEFAULT_DB_ALIAS
class TransactionManagementError(Exception): class TransactionManagementError(Exception):
""" """
This exception is thrown when something bad happens with transaction This exception is thrown when something bad happens with transaction
@ -32,19 +33,6 @@ class TransactionManagementError(Exception):
""" """
pass pass
# The states are dictionaries of dictionaries of lists. The key to the outer
# dict is the current thread, and the key to the inner dictionary is the
# connection alias and the list is handled as a stack of values.
state = {}
savepoint_state = {}
# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
# database commit.
# This is a dictionary mapping thread to a dictionary mapping connection
# alias to a boolean.
dirty = {}
def enter_transaction_management(managed=True, using=None): def enter_transaction_management(managed=True, using=None):
""" """
Enters transaction management for a running thread. It must be balanced with Enters transaction management for a running thread. It must be balanced with
@ -58,15 +46,14 @@ def enter_transaction_management(managed=True, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident].get(using): if connection.transaction_state:
state[thread_ident][using].append(state[thread_ident][using][-1]) connection.transaction_state.append(connection.transaction_state[-1])
else: else:
state.setdefault(thread_ident, {}) connection.transaction_state.append(settings.TRANSACTIONS_MANAGED)
state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED]
if thread_ident not in dirty or using not in dirty[thread_ident]: if connection.dirty is None:
dirty.setdefault(thread_ident, {}) connection.dirty = False
dirty[thread_ident][using] = False
connection._enter_transaction_management(managed) connection._enter_transaction_management(managed)
def leave_transaction_management(using=None): def leave_transaction_management(using=None):
@ -78,16 +65,18 @@ def leave_transaction_management(using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
connection._leave_transaction_management(is_managed(using=using)) connection._leave_transaction_management(is_managed(using=using))
thread_ident = thread.get_ident() if connection.transaction_state:
if thread_ident in state and state[thread_ident].get(using): del connection.transaction_state[-1]
del state[thread_ident][using][-1]
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
if dirty.get(thread_ident, {}).get(using, False): "management")
if connection.dirty:
rollback(using=using) rollback(using=using)
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK") raise TransactionManagementError("Transaction managed block ended with "
dirty[thread_ident][using] = False "pending COMMIT/ROLLBACK")
connection.dirty = False
def is_dirty(using=None): def is_dirty(using=None):
""" """
@ -96,7 +85,9 @@ def is_dirty(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
return dirty.get(thread.get_ident(), {}).get(using, False) connection = connections[using]
return connection.dirty
def set_dirty(using=None): def set_dirty(using=None):
""" """
@ -106,11 +97,13 @@ def set_dirty(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = True if connection.dirty is not None:
connection.dirty = True
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
"management")
def set_clean(using=None): def set_clean(using=None):
""" """
@ -120,9 +113,10 @@ def set_clean(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = False if connection.dirty is not None:
connection.dirty = False
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
clean_savepoints(using=using) clean_savepoints(using=using)
@ -130,9 +124,8 @@ def set_clean(using=None):
def clean_savepoints(using=None): def clean_savepoints(using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: connection.savepoint_state = 0
del savepoint_state[thread_ident][using]
def is_managed(using=None): def is_managed(using=None):
""" """
@ -140,10 +133,9 @@ def is_managed(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in state and using in state[thread_ident]: if connection.transaction_state:
if state[thread_ident][using]: return connection.transaction_state[-1]
return state[thread_ident][using][-1]
return settings.TRANSACTIONS_MANAGED return settings.TRANSACTIONS_MANAGED
def managed(flag=True, using=None): def managed(flag=True, using=None):
@ -156,15 +148,16 @@ def managed(flag=True, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
top = state.get(thread_ident, {}).get(using, None) top = connection.transaction_state
if top: if top:
top[-1] = flag top[-1] = flag
if not flag and is_dirty(using=using): if not flag and is_dirty(using=using):
connection._commit() connection._commit()
set_clean(using=using) set_clean(using=using)
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
"management")
def commit_unless_managed(using=None): def commit_unless_managed(using=None):
""" """
@ -221,13 +214,11 @@ def savepoint(using=None):
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
savepoint_state[thread_ident][using].append(None) connection.savepoint_state += 1
else:
savepoint_state.setdefault(thread_ident, {})
savepoint_state[thread_ident][using] = [None]
tid = str(thread_ident).replace('-', '') tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using])) sid = "s%s_x%d" % (tid, connection.savepoint_state)
connection._savepoint(sid) connection._savepoint(sid)
return sid return sid
@ -239,8 +230,8 @@ def savepoint_rollback(sid, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: if connection.savepoint_state:
connection._savepoint_rollback(sid) connection._savepoint_rollback(sid)
def savepoint_commit(sid, using=None): def savepoint_commit(sid, using=None):
@ -251,8 +242,8 @@ def savepoint_commit(sid, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: if connection.savepoint_state:
connection._savepoint_commit(sid) connection._savepoint_commit(sid)
############## ##############

View File

@ -1,12 +0,0 @@
"""
Code used in a couple of places to work with the current thread's environment.
Current users include i18n and request prefix handling.
"""
try:
import threading
currentThread = threading.currentThread
except ImportError:
def currentThread():
return "no threading"

View File

@ -7,15 +7,16 @@ import sys
import warnings import warnings
import gettext as gettext_module import gettext as gettext_module
from cStringIO import StringIO from cStringIO import StringIO
from threading import local
from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.safestring import mark_safe, SafeData from django.utils.safestring import mark_safe, SafeData
from django.utils.thread_support import currentThread
# Translations are cached in a dictionary for every language+app tuple. # Translations are cached in a dictionary for every language+app tuple.
# The active translations are stored by threadid to make them thread local. # The active translations are stored by threadid to make them thread local.
_translations = {} _translations = {}
_active = {} _active = local()
# The default translation is based on the settings file. # The default translation is based on the settings file.
_default = None _default = None
@ -197,16 +198,15 @@ def activate(language):
"Please use the 'nb' translation instead.", "Please use the 'nb' translation instead.",
DeprecationWarning DeprecationWarning
) )
_active[currentThread()] = translation(language) _active.value = translation(language)
def deactivate(): def deactivate():
""" """
Deinstalls the currently active translation object so that further _ calls Deinstalls the currently active translation object so that further _ calls
will resolve against the default translation object, again. will resolve against the default translation object, again.
""" """
global _active if hasattr(_active, "value"):
if currentThread() in _active: del _active.value
del _active[currentThread()]
def deactivate_all(): def deactivate_all():
""" """
@ -214,11 +214,11 @@ def deactivate_all():
useful when we want delayed translations to appear as the original string useful when we want delayed translations to appear as the original string
for some reason. for some reason.
""" """
_active[currentThread()] = gettext_module.NullTranslations() _active.value = gettext_module.NullTranslations()
def get_language(): def get_language():
"""Returns the currently selected language.""" """Returns the currently selected language."""
t = _active.get(currentThread(), None) t = getattr(_active, "value", None)
if t is not None: if t is not None:
try: try:
return t.to_language() return t.to_language()
@ -246,8 +246,9 @@ def catalog():
This can be used if you need to modify the catalog or want to access the This can be used if you need to modify the catalog or want to access the
whole message catalog instead of just translating one string. whole message catalog instead of just translating one string.
""" """
global _default, _active global _default
t = _active.get(currentThread(), None)
t = getattr(_active, "value", None)
if t is not None: if t is not None:
return t return t
if _default is None: if _default is None:
@ -262,9 +263,10 @@ def do_translate(message, translation_function):
translation object to use. If no current translation is activated, the translation object to use. If no current translation is activated, the
message will be run through the default translation object. message will be run through the default translation object.
""" """
global _default
eol_message = message.replace('\r\n', '\n').replace('\r', '\n') eol_message = message.replace('\r\n', '\n').replace('\r', '\n')
global _default, _active t = getattr(_active, "value", None)
t = _active.get(currentThread(), None)
if t is not None: if t is not None:
result = getattr(t, translation_function)(eol_message) result = getattr(t, translation_function)(eol_message)
else: else:
@ -300,9 +302,9 @@ def gettext_noop(message):
return message return message
def do_ntranslate(singular, plural, number, translation_function): def do_ntranslate(singular, plural, number, translation_function):
global _default, _active global _default
t = _active.get(currentThread(), None) t = getattr(_active, "value", None)
if t is not None: if t is not None:
return getattr(t, translation_function)(singular, plural, number) return getattr(t, translation_function)(singular, plural, number)
if _default is None: if _default is None:
@ -587,4 +589,3 @@ def get_partial_date_formats():
if month_day_format == 'MONTH_DAY_FORMAT': if month_day_format == 'MONTH_DAY_FORMAT':
month_day_format = settings.MONTH_DAY_FORMAT month_day_format = settings.MONTH_DAY_FORMAT
return year_month_format, month_day_format return year_month_format, month_day_format

View File

@ -4,10 +4,12 @@ import decimal
import os import os
import sys import sys
import pickle import pickle
from threading import local
from django.conf import settings from django.conf import settings
from django.template import Template, Context from django.template import Template, Context
from django.utils.formats import get_format, date_format, time_format, localize, localize_input, iter_format_modules from django.utils.formats import (get_format, date_format, time_format,
localize, localize_input, iter_format_modules)
from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.numberformat import format as nformat from django.utils.numberformat import format as nformat
from django.utils.safestring import mark_safe, SafeString, SafeUnicode from django.utils.safestring import mark_safe, SafeString, SafeUnicode
@ -61,7 +63,7 @@ class TranslationTests(TestCase):
self.old_locale_paths = settings.LOCALE_PATHS self.old_locale_paths = settings.LOCALE_PATHS
settings.LOCALE_PATHS += (os.path.join(os.path.dirname(os.path.abspath(__file__)), 'other', 'locale'),) settings.LOCALE_PATHS += (os.path.join(os.path.dirname(os.path.abspath(__file__)), 'other', 'locale'),)
from django.utils.translation import trans_real from django.utils.translation import trans_real
trans_real._active = {} trans_real._active = local()
trans_real._translations = {} trans_real._translations = {}
activate('de') activate('de')
@ -649,7 +651,7 @@ class ResolutionOrderI18NTests(TestCase):
from django.utils.translation import trans_real from django.utils.translation import trans_real
# Okay, this is brutal, but we have no other choice to fully reset # Okay, this is brutal, but we have no other choice to fully reset
# the translation framework # the translation framework
trans_real._active = {} trans_real._active = local()
trans_real._translations = {} trans_real._translations = {}
activate('de') activate('de')