From 3cd611b89ccbd9d01bd1e965fb5060dfa3879057 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sat, 1 Aug 2009 05:26:02 +0000 Subject: [PATCH] [soc2009/multidb] Switched from using an ugly hacky wrapper to a Metaclass for maitaing backwards compatibility in the get_db_prep_* and db_type methods. Thanks to Jacob for the good idea. git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11375 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- .../management/commands/createcachetable.py | 3 +- django/core/management/sql.py | 2 +- django/db/backends/creation.py | 11 +++-- django/db/models/base.py | 5 +-- django/db/models/fields/__init__.py | 22 +++++----- django/db/models/fields/files.py | 5 +-- django/db/models/fields/related.py | 12 +++--- django/db/models/fields/subclassing.py | 43 +++++++++++++++++-- django/db/models/sql/subqueries.py | 4 +- django/db/models/sql/where.py | 15 +++---- django/db/utils.py | 6 --- django/forms/models.py | 3 +- tests/regressiontests/model_fields/tests.py | 16 +++---- 13 files changed, 84 insertions(+), 63 deletions(-) diff --git a/django/core/management/commands/createcachetable.py b/django/core/management/commands/createcachetable.py index 30eaa06c21..2a87f27087 100644 --- a/django/core/management/commands/createcachetable.py +++ b/django/core/management/commands/createcachetable.py @@ -2,7 +2,6 @@ from optparse import make_option from django.core.management.base import LabelCommand from django.db import connections, transaction, models, DEFAULT_DB_ALIAS -from django.db.utils import call_with_connection class Command(LabelCommand): help = "Creates the table needed to use the SQL cache backend." @@ -31,7 +30,7 @@ class Command(LabelCommand): index_output = [] qn = connection.ops.quote_name for f in fields: - field_output = [qn(f.name), call_with_connection(f.db_type, connection=connection)] + field_output = [qn(f.name), f.db_type(connection=connection)] field_output.append("%sNULL" % (not f.null and "NOT " or "")) if f.primary_key: field_output.append("PRIMARY KEY") diff --git a/django/core/management/sql.py b/django/core/management/sql.py index 63122ec96d..ca229844a2 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -116,7 +116,7 @@ def sql_delete(app, style, connection): def sql_reset(app, style, connection): "Returns a list of the DROP TABLE SQL, then the CREATE TABLE SQL, for the given module." - return sql_delete(app, style) + sql_all(app, style) + return sql_delete(app, style, connection) + sql_all(app, style, connection) def sql_flush(style, connection, only_django=False): """ diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index 618ca25ea9..5327c6c2df 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -8,7 +8,6 @@ except NameError: from django.conf import settings from django.core.management import call_command -from django.db.utils import call_with_connection # The prefix to put on the default database name when creating # the test database. @@ -48,7 +47,7 @@ class BaseDatabaseCreation(object): pending_references = {} qn = self.connection.ops.quote_name for f in opts.local_fields: - col_type = call_with_connection(f.db_type, connection=self.connection) + col_type = f.db_type(connection=self.connection) tablespace = f.db_tablespace or opts.db_tablespace if col_type is None: # Skip ManyToManyFields, because they're not represented as @@ -76,7 +75,7 @@ class BaseDatabaseCreation(object): table_output.append(' '.join(field_output)) if opts.order_with_respect_to: table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \ - style.SQL_COLTYPE(models.IntegerField().db_type(self.connection))) + style.SQL_COLTYPE(models.IntegerField().db_type(connection=self.connection))) for field_constraints in opts.unique_together: table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \ ", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints])) @@ -174,7 +173,7 @@ class BaseDatabaseCreation(object): style.SQL_TABLE(qn(f.m2m_db_table())) + ' ('] table_output.append(' %s %s %s%s,' % (style.SQL_FIELD(qn('id')), - style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type(self.connection)), + style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type(connection=self.connection)), style.SQL_KEYWORD('NOT NULL PRIMARY KEY'), tablespace_sql)) @@ -218,14 +217,14 @@ class BaseDatabaseCreation(object): table_output = [ ' %s %s %s %s (%s)%s,' % (style.SQL_FIELD(qn(field.m2m_column_name())), - style.SQL_COLTYPE(models.ForeignKey(model).db_type(self.connection)), + style.SQL_COLTYPE(models.ForeignKey(model).db_type(connection=self.connection)), style.SQL_KEYWORD('NOT NULL REFERENCES'), style.SQL_TABLE(qn(opts.db_table)), style.SQL_FIELD(qn(opts.pk.column)), self.connection.ops.deferrable_sql()), ' %s %s %s %s (%s)%s,' % (style.SQL_FIELD(qn(field.m2m_reverse_name())), - style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type(self.connection)), + style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type(connection=self.connection)), style.SQL_KEYWORD('NOT NULL REFERENCES'), style.SQL_TABLE(qn(field.rel.to._meta.db_table)), style.SQL_FIELD(qn(field.rel.to._meta.pk.column)), diff --git a/django/db/models/base.py b/django/db/models/base.py index 021706e3ae..590cafe535 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -18,7 +18,6 @@ from django.db.models.options import Options from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.loading import register_models, get_model -from django.db.utils import call_with_connection from django.utils.functional import curry from django.utils.encoding import smart_str, force_unicode, smart_unicode from django.conf import settings @@ -484,10 +483,10 @@ class Model(object): if not pk_set: if force_update: raise ValueError("Cannot force an update in save() with no primary key.") - values = [(f, call_with_connection(f.get_db_prep_save, raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) for f in meta.local_fields if not isinstance(f, AutoField)] else: - values = [(f, call_with_connection(f.get_db_prep_save, raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) for f in meta.local_fields] if meta.order_with_respect_to: diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 8d327170cf..d4b3499d0c 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -10,8 +10,8 @@ except ImportError: from django.db import connection from django.db.models import signals +from django.db.models.fields.subclassing import LegacyConnection from django.db.models.query_utils import QueryWrapper -from django.db.utils import call_with_connection from django.dispatch import dispatcher from django.conf import settings from django import forms @@ -50,6 +50,7 @@ class FieldDoesNotExist(Exception): # getattr(obj, opts.pk.attname) class Field(object): + __metaclass__ = LegacyConnection # Designates whether empty strings fundamentally are allowed at the # database level. empty_strings_allowed = True @@ -190,8 +191,7 @@ class Field(object): def get_db_prep_save(self, value, connection): "Returns field's value prepared for saving into a database." - return call_with_connection(self.get_db_prep_value, value, - connection=connection) + return self.get_db_prep_value(value, connection=connection) def get_db_prep_lookup(self, lookup_type, value, connection): "Returns field's value prepared for database lookup." @@ -210,9 +210,9 @@ class Field(object): if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day', 'search'): return [value] elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): - return [call_with_connection(self.get_db_prep_value, value, connection=connection)] + return [self.get_db_prep_value(value, connection=connection)] elif lookup_type in ('range', 'in'): - return [call_with_connection(self.get_db_prep_value, v, connection=connection) for v in value] + return [self.get_db_prep_value(v, connection=connection) for v in value] elif lookup_type in ('contains', 'icontains'): return ["%%%s%%" % connection.ops.prep_for_like_query(value)] elif lookup_type == 'iexact': @@ -426,8 +426,8 @@ class BooleanField(Field): # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return call_with_connection(super(BooleanField, self).get_db_prep_lookup, - lookup_type, value, connection=connection) + return super(BooleanField, self).get_db_prep_lookup(lookup_type, value, + connection=connection) def validate(self, lookup_type, value): if super(BooleanField, self).validate(lookup_type, value): @@ -544,8 +544,8 @@ class DateField(Field): # to an int so the database backend always sees a consistent type. if lookup_type in ('month', 'day', 'week_day'): return [int(value)] - return call_with_connection(super(DateField, self).get_db_prep_lookup, - lookup_type, value, connection=connection) + return super(DateField, self).get_db_prep_lookup(lookup_type, value, + connection=connection) def get_db_prep_value(self, value, connection): # Casts dates into the format expected by the backend @@ -807,8 +807,8 @@ class NullBooleanField(Field): # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return call_with_connection(super(NullBooleanField, self).get_db_prep_lookup, - lookup_type, value, connection=connection) + return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, + value, connection=connection) def validate(self, lookup_type, value): if value in ('1', '0'): diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 3290645231..f112bbd593 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -10,7 +10,6 @@ from django.core.files.images import ImageFile, get_image_dimensions from django.core.files.uploadedfile import UploadedFile from django.utils.functional import curry from django.db.models import signals -from django.db.utils import call_with_connection from django.utils.encoding import force_unicode, smart_str from django.utils.translation import ugettext_lazy, ugettext as _ from django import forms @@ -236,8 +235,8 @@ class FileField(Field): def get_db_prep_lookup(self, lookup_type, value, connection): if hasattr(value, 'name'): value = value.name - return call_with_connection(super(FileField, self).get_db_prep_lookup, - lookup_type, value, connection=connection) + return super(FileField, self).get_db_prep_lookup(lookup_type, value, + connection=connection) def get_db_prep_value(self, value, connection): "Returns field's value prepared for saving into a database." diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 5d43243912..1db4156268 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -5,7 +5,6 @@ from django.db.models.fields import AutoField, Field, IntegerField, PositiveInte from django.db.models.related import RelatedObject from django.db.models.query import QuerySet from django.db.models.query_utils import QueryWrapper -from django.db.utils import call_with_connection from django.utils.encoding import smart_unicode from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _ from django.utils.functional import curry @@ -137,8 +136,7 @@ class RelatedField(object): if field: if lookup_type in ('range', 'in'): v = [v] - v = call_with_connection(field.get_db_prep_lookup, - lookup_type, v, connection=connection) + v = field.get_db_prep_lookup(lookup_type, v, connection=connection) if isinstance(v, list): v = v[0] return v @@ -725,8 +723,8 @@ class ForeignKey(RelatedField, Field): if value == '' or value == None: return None else: - return call_with_connection(self.rel.get_related_field().get_db_prep_save, - value, connection=connection) + return self.rel.get_related_field().get_db_prep_save(value, + connection=connection) def value_to_string(self, obj): if not obj: @@ -774,8 +772,8 @@ class ForeignKey(RelatedField, Field): (not connection.features.related_fields_match_type and isinstance(rel_field, (PositiveIntegerField, PositiveSmallIntegerField)))): - return IntegerField().db_type(connection) - return call_with_connection(rel_field.db_type, connection=connection) + return IntegerField().db_type(connection=connection) + return rel_field.db_type(connection=connection) class OneToOneField(ForeignKey): """ diff --git a/django/db/models/fields/subclassing.py b/django/db/models/fields/subclassing.py index 10add10739..d80c5805e0 100644 --- a/django/db/models/fields/subclassing.py +++ b/django/db/models/fields/subclassing.py @@ -1,11 +1,48 @@ """ -Convenience routines for creating non-trivial Field subclasses. +Convenience routines for creating non-trivial Field subclasses, as well as +backwards compatibility utilities. Add SubfieldBase as the __metaclass__ for your Field subclass, implement to_python() and the other necessary methods and everything will work seamlessly. """ -class SubfieldBase(type): +from inspect import getargspec +from warnings import warn + +def call_with_connection(func): + arg_names, varargs, varkwargs, defaults = getargspec(func) + takes_connection = 'connection' in arg_names or varkwargs + if not takes_connection: + warn("A Field class who's %s method doesn't take connection has been " + "defined, please add a connection argument" % func.__name__, + PendingDeprecationWarning, depth=2) + def inner(*args, **kwargs): + if 'connection' not in kwargs: + from django.db import connection + kwargs['connection'] = connection + warn("%s has been called without providing a connection argument, " + "please provide one" % func.__name__, PendingDeprecationWarning, + depth=1) + if takes_connection: + return func(*args, **kwargs) + if 'connection' in kwargs: + del kwargs['connection'] + return func(*args, **kwargs) + return inner + +class LegacyConnection(type): + """ + A metaclass to normalize arguments give to the get_db_prep_* and db_type + methods on fields. + """ + def __new__(cls, names, bases, attrs): + new_cls = super(LegacyConnection, cls).__new__(cls, names, bases, attrs) + for attr in ('db_type', 'get_db_prep_save', 'get_db_prep_lookup', + 'get_db_prep_value'): + setattr(new_cls, attr, call_with_connection(getattr(new_cls, attr))) + return new_cls + +class SubfieldBase(LegacyConnection): """ A metaclass for custom Field subclasses. This ensures the model's attribute has the descriptor protocol attached to it. @@ -26,7 +63,7 @@ class Creator(object): def __get__(self, obj, type=None): if obj is None: raise AttributeError('Can only be accessed via an instance.') - return obj.__dict__[self.field.name] + return obj.__dict__[self.field.name] def __set__(self, obj, value): obj.__dict__[self.field.name] = self.field.to_python(value) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 919f180fff..03cfd12a1d 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -8,7 +8,6 @@ from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint -from django.db.utils import call_with_connection __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 'AggregateQuery'] @@ -244,8 +243,7 @@ class UpdateQuery(Query): if hasattr(val, 'prepare_database_save'): val = val.prepare_database_save(field) else: - val = call_with_connection(field.get_db_prep_save, - val, connection=self.connection) + val = field.get_db_prep_save(val, connection=self.connection) # Getting the placeholder for the field. if hasattr(field, 'get_placeholder'): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 46e7c927e0..825b0ff7e6 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -6,7 +6,6 @@ import datetime from django.utils import tree from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper -from django.db.utils import call_with_connection from datastructures import EmptyResultSet, FullResultSet # Connection types @@ -143,8 +142,8 @@ class WhereNode(tree.Node): except EmptyShortCircuit: raise EmptyResultSet else: - params = call_with_connection(Field().get_db_prep_lookup, - lookup_type, params_or_value, connection=connection) + params = Field().get_db_prep_lookup(lookup_type, params_or_value, + connection=connection) if isinstance(lvalue, tuple): # A direct database column lookup. field_sql = self.sql_for_columns(lvalue, qn, connection) @@ -267,15 +266,15 @@ class Constraint(object): from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = call_with_connection(self.field.get_db_prep_lookup, - lookup_type, value, connection=connection) - db_type = call_with_connection(self.field.db_type, connection=connection) + params = self.field.get_db_prep_lookup(lookup_type, value, + connection=connection) + db_type = self.field.db_type(connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = call_with_connection(Field().get_db_prep_lookup, - lookup_type, value, connection=connection) + params = Field().get_db_prep_lookup(lookup_type, value, + connection=connection) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit diff --git a/django/db/utils.py b/django/db/utils.py index cee3c392a6..c3ee34031a 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -32,12 +32,6 @@ def load_backend(backend_name): else: raise # If there's some other error, this must be an error in Django itself. -def call_with_connection(func, *args, **kwargs): - arg_names, varargs, varkwargs, defaults = inspect.getargspec(func) - if 'connection' not in arg_names and varkwargs is None: - del kwargs['connection'] - return func(*args, **kwargs) - class ConnectionHandler(object): def __init__(self, databases): diff --git a/django/forms/models.py b/django/forms/models.py index 7cb0de9de9..b3b472384d 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -3,7 +3,6 @@ Helper functions for creating Form classes from Django models and database field objects. """ -from django.db.utils import call_with_connection from django.utils.encoding import smart_unicode, force_unicode from django.utils.datastructures import SortedDict from django.utils.text import get_text_list, capfirst @@ -475,7 +474,7 @@ class BaseModelFormSet(BaseFormSet): pk_key = "%s-%s" % (self.add_prefix(i), self.model._meta.pk.name) pk = self.data[pk_key] pk_field = self.model._meta.pk - pk = call_with_connection(pk_field.get_db_prep_lookup, 'exact', pk, + pk = pk_field.get_db_prep_lookup('exact', pk, connection=self.get_queryset().query.connection) if isinstance(pk, list): pk = pk[0] diff --git a/tests/regressiontests/model_fields/tests.py b/tests/regressiontests/model_fields/tests.py index a56706ae91..ce18c3cd98 100644 --- a/tests/regressiontests/model_fields/tests.py +++ b/tests/regressiontests/model_fields/tests.py @@ -46,7 +46,7 @@ class DecimalFieldTests(django.test.TestCase): def test_get_db_prep_lookup(self): from django.db import connection f = models.DecimalField(max_digits=5, decimal_places=1) - self.assertEqual(f.get_db_prep_lookup('exact', None, connection), [None]) + self.assertEqual(f.get_db_prep_lookup('exact', None, connection=connection), [None]) def test_filter_with_strings(self): """ @@ -100,13 +100,13 @@ class DateTimeFieldTests(unittest.TestCase): class BooleanFieldTests(unittest.TestCase): def _test_get_db_prep_lookup(self, f): from django.db import connection - self.assertEqual(f.get_db_prep_lookup('exact', True, connection), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', '1', connection), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', 1, connection), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', False, connection), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', '0', connection), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', 0, connection), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', None, connection), [None]) + self.assertEqual(f.get_db_prep_lookup('exact', True, connection=connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', '1', connection=connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', 1, connection=connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', False, connection=connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', '0', connection=connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', 0, connection=connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', None, connection=connection), [None]) def test_booleanfield_get_db_prep_lookup(self): self._test_get_db_prep_lookup(models.BooleanField())