From 83955447747561e6b89e84a062164d94910c6d7d Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 17 Jul 2009 15:57:43 +0000 Subject: [PATCH] [soc2009/multidb] Added connection parameter to the get_db_prep_* family of functions. This allows us to generate the lookup and save values for Fields in a backend specific manner. git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11264 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/base.py | 7 ++- django/db/models/fields/__init__.py | 49 ++++++++++++--------- django/db/models/fields/files.py | 8 ++-- django/db/models/fields/related.py | 10 +++-- django/db/models/related.py | 2 +- django/db/models/sql/subqueries.py | 4 +- django/db/models/sql/where.py | 9 ++-- django/forms/models.py | 4 +- docs/howto/custom-model-fields.txt | 13 +++--- tests/regressiontests/model_fields/tests.py | 18 ++++---- 10 files changed, 74 insertions(+), 50 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index e1c5a8461d..021706e3ae 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -18,6 +18,7 @@ from django.db.models.options import Options from django.db import connections, transaction, DatabaseError, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.loading import register_models, get_model +from django.db.utils import call_with_connection from django.utils.functional import curry from django.utils.encoding import smart_str, force_unicode, smart_unicode from django.conf import settings @@ -483,9 +484,11 @@ class Model(object): if not pk_set: if force_update: raise ValueError("Cannot force an update in save() with no primary key.") - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)] + values = [(f, call_with_connection(f.get_db_prep_save, raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + for f in meta.local_fields if not isinstance(f, AutoField)] else: - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields] + values = [(f, call_with_connection(f.get_db_prep_save, raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) + for f in meta.local_fields] if meta.order_with_respect_to: field = meta.order_with_respect_to diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index d5c7490781..af2031ce10 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -11,6 +11,7 @@ except ImportError: from django.db import connection from django.db.models import signals from django.db.models.query_utils import QueryWrapper +from django.db.utils import call_with_connection from django.dispatch import dispatcher from django.conf import settings from django import forms @@ -178,7 +179,7 @@ class Field(object): "Returns field's value just before saving." return getattr(model_instance, self.attname) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): """Returns field's value prepared for interacting with the database backend. @@ -187,11 +188,12 @@ class Field(object): """ return value - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): "Returns field's value prepared for saving into a database." - return self.get_db_prep_value(value) + return call_with_connection(self.get_db_prep_value, value, + connection=connection) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): "Returns field's value prepared for database lookup." if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to @@ -208,9 +210,9 @@ class Field(object): if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day', 'search'): return [value] elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): - return [self.get_db_prep_value(value)] + return [call_with_connection(self.get_db_prep_value, value, connection=connection)] elif lookup_type in ('range', 'in'): - return [self.get_db_prep_value(v) for v in value] + return [call_with_connection(self.get_db_prep_value, v, connection=connection) for v in value] elif lookup_type in ('contains', 'icontains'): return ["%%%s%%" % connection.ops.prep_for_like_query(value)] elif lookup_type == 'iexact': @@ -374,7 +376,7 @@ class AutoField(Field): raise exceptions.ValidationError( _("This value must be an integer.")) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): if value is None: return None return int(value) @@ -417,14 +419,15 @@ class BooleanField(Field): raise exceptions.ValidationError( _("This value must be either True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(BooleanField, self).get_db_prep_lookup(lookup_type, value) + return call_with_connection(super(BooleanField, self).get_db_prep_lookup, + lookup_type, value, connection=connection) def validate(self, lookup_type, value): if super(BooleanField, self).validate(lookup_type, value): @@ -433,7 +436,7 @@ class BooleanField(Field): value = int(value) bool(value) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): if value is None: return None return bool(value) @@ -536,14 +539,15 @@ class DateField(Field): setattr(cls, 'get_previous_by_%s' % self.name, curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): # For "__month", "__day", and "__week_day" lookups, convert the value # to an int so the database backend always sees a consistent type. if lookup_type in ('month', 'day', 'week_day'): return [int(value)] - return super(DateField, self).get_db_prep_lookup(lookup_type, value) + return call_with_connection(super(DateField, self).get_db_prep_lookup, + lookup_type, value, connection=connection) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): # Casts dates into the format expected by the backend return connection.ops.value_to_db_date(self.to_python(value)) @@ -615,7 +619,7 @@ class DateTimeField(DateField): raise exceptions.ValidationError( _('Enter a valid date/time in YYYY-MM-DD HH:MM[:ss[.uuuuuu]] format.')) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): # Casts dates into the format expected by the backend return connection.ops.value_to_db_datetime(self.to_python(value)) @@ -671,11 +675,11 @@ class DecimalField(Field): from django.db.backends import util return util.format_number(value, self.max_digits, self.decimal_places) - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): return connection.ops.value_to_db_decimal(self.to_python(value), self.max_digits, self.decimal_places) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): return self.to_python(value) def formfield(self, **kwargs): @@ -719,7 +723,7 @@ class FilePathField(Field): class FloatField(Field): empty_strings_allowed = False - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): if value is None: return None return float(value) @@ -743,7 +747,7 @@ class FloatField(Field): class IntegerField(Field): empty_strings_allowed = False - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): if value is None: return None return int(value) @@ -796,21 +800,22 @@ class NullBooleanField(Field): raise exceptions.ValidationError( _("This value must be either None, True or False.")) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): # Special-case handling for filters coming from a web request (e.g. the # admin interface). Only works for scalar values (not lists). If you're # passing in a list, you might as well make things the right type when # constructing the list. if value in ('1', '0'): value = bool(int(value)) - return super(NullBooleanField, self).get_db_prep_lookup(lookup_type, value) + return call_with_connection(super(NullBooleanField, self).get_db_prep_lookup, + lookup_type, value, connection=connection) def validate(self, lookup_type, value): if value in ('1', '0'): value = int(value) bool(value) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): if value is None: return None return bool(value) @@ -926,7 +931,7 @@ class TimeField(Field): else: return super(TimeField, self).pre_save(model_instance, add) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): # Casts times into the format expected by the backend return connection.ops.value_to_db_time(self.to_python(value)) diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index aab4f3789f..3290645231 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -10,6 +10,7 @@ from django.core.files.images import ImageFile, get_image_dimensions from django.core.files.uploadedfile import UploadedFile from django.utils.functional import curry from django.db.models import signals +from django.db.utils import call_with_connection from django.utils.encoding import force_unicode, smart_str from django.utils.translation import ugettext_lazy, ugettext as _ from django import forms @@ -232,12 +233,13 @@ class FileField(Field): def get_internal_type(self): return "FileField" - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): if hasattr(value, 'name'): value = value.name - return super(FileField, self).get_db_prep_lookup(lookup_type, value) + return call_with_connection(super(FileField, self).get_db_prep_lookup, + lookup_type, value, connection=connection) - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): "Returns field's value prepared for saving into a database." # Need to convert File objects provided via a form to unicode for database insertion if value is None: diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 47f5c68fed..33cef0bd17 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -117,7 +117,7 @@ class RelatedField(object): if not cls._meta.abstract: self.contribute_to_related_class(other, self.related) - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): # If we are doing a lookup on a Related Field, we must be # comparing object instances. The value should be the PK of value, # not value itself. @@ -137,7 +137,8 @@ class RelatedField(object): if field: if lookup_type in ('range', 'in'): v = [v] - v = field.get_db_prep_lookup(lookup_type, v) + v = call_with_connection(field.get_db_prep_lookup, + lookup_type, v, connection=connection) if isinstance(v, list): v = v[0] return v @@ -720,11 +721,12 @@ class ForeignKey(RelatedField, Field): return getattr(field_default, self.rel.get_related_field().attname) return field_default - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): if value == '' or value == None: return None else: - return self.rel.get_related_field().get_db_prep_save(value) + return call_with_connection(self.rel.get_related_field().get_db_prep_save, + value, connection=connection) def value_to_string(self, obj): if not obj: diff --git a/django/db/models/related.py b/django/db/models/related.py index ff7c787a93..9df143550e 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -18,7 +18,7 @@ class RelatedObject(object): self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name) self.var_name = self.opts.object_name.lower() - def get_db_prep_lookup(self, lookup_type, value): + def get_db_prep_lookup(self, lookup_type, value, connection): # Defer to the actual field definition for db prep return self.field.get_db_prep_lookup(lookup_type, value) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6b0235c16e..919f180fff 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -8,6 +8,7 @@ from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint +from django.db.utils import call_with_connection __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 'AggregateQuery'] @@ -243,7 +244,8 @@ class UpdateQuery(Query): if hasattr(val, 'prepare_database_save'): val = val.prepare_database_save(field) else: - val = field.get_db_prep_save(val) + val = call_with_connection(field.get_db_prep_save, + val, connection=self.connection) # Getting the placeholder for the field. if hasattr(field, 'get_placeholder'): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index ec0bfdfe9e..46e7c927e0 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -143,7 +143,8 @@ class WhereNode(tree.Node): except EmptyShortCircuit: raise EmptyResultSet else: - params = Field().get_db_prep_lookup(lookup_type, params_or_value) + params = call_with_connection(Field().get_db_prep_lookup, + lookup_type, params_or_value, connection=connection) if isinstance(lvalue, tuple): # A direct database column lookup. field_sql = self.sql_for_columns(lvalue, qn, connection) @@ -266,13 +267,15 @@ class Constraint(object): from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = self.field.get_db_prep_lookup(lookup_type, value) + params = call_with_connection(self.field.get_db_prep_lookup, + lookup_type, value, connection=connection) db_type = call_with_connection(self.field.db_type, connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = Field().get_db_prep_lookup(lookup_type, value) + params = call_with_connection(Field().get_db_prep_lookup, + lookup_type, value, connection=connection) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit diff --git a/django/forms/models.py b/django/forms/models.py index cc43612bf5..7cb0de9de9 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -3,6 +3,7 @@ Helper functions for creating Form classes from Django models and database field objects. """ +from django.db.utils import call_with_connection from django.utils.encoding import smart_unicode, force_unicode from django.utils.datastructures import SortedDict from django.utils.text import get_text_list, capfirst @@ -474,7 +475,8 @@ class BaseModelFormSet(BaseFormSet): pk_key = "%s-%s" % (self.add_prefix(i), self.model._meta.pk.name) pk = self.data[pk_key] pk_field = self.model._meta.pk - pk = pk_field.get_db_prep_lookup('exact', pk) + pk = call_with_connection(pk_field.get_db_prep_lookup, 'exact', pk, + connection=self.get_queryset().query.connection) if isinstance(pk, list): pk = pk[0] kwargs['instance'] = self._existing_object(pk) diff --git a/docs/howto/custom-model-fields.txt b/docs/howto/custom-model-fields.txt index 307e2c05d7..d76028ec2c 100644 --- a/docs/howto/custom-model-fields.txt +++ b/docs/howto/custom-model-fields.txt @@ -399,24 +399,27 @@ mentioned earlier. Otherwise :meth:`to_python` won't be called automatically. Converting Python objects to database values ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. method:: get_db_prep_value(self, value) +.. method:: get_db_prep_value(self, value, connection) This is the reverse of :meth:`to_python` when working with the database backends (as opposed to serialization). The ``value`` parameter is the current value of the model's attribute (a field has no reference to its containing model, so it cannot retrieve the value itself), and the method should return data in a format -that can be used as a parameter in a query for the database backend. +that can be used as a parameter in a query for the database backend. The +specific connection that will be used for the query is passed as the +``connection`` parameter, this allows you to generate the value in a backend +specific mannner if necessary. For example:: class HandField(models.Field): # ... - def get_db_prep_value(self, value): + def get_db_prep_value(self, value, connection): return ''.join([''.join(l) for l in (value.north, value.east, value.south, value.west)]) -.. method:: get_db_prep_save(self, value) +.. method:: get_db_prep_save(self, value, connection) Same as the above, but called when the Field value must be *saved* to the database. As the default implementation just calls ``get_db_prep_value``, you @@ -450,7 +453,7 @@ correct value. Preparing values for use in database lookups ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. method:: get_db_prep_lookup(self, lookup_type, value) +.. method:: get_db_prep_lookup(self, lookup_type, value, connection) Prepares the ``value`` for passing to the database when used in a lookup (a ``WHERE`` constraint in SQL). The ``lookup_type`` will be one of the valid diff --git a/tests/regressiontests/model_fields/tests.py b/tests/regressiontests/model_fields/tests.py index 7a6fee5a2a..a56706ae91 100644 --- a/tests/regressiontests/model_fields/tests.py +++ b/tests/regressiontests/model_fields/tests.py @@ -44,8 +44,9 @@ class DecimalFieldTests(django.test.TestCase): self.assertEqual(f._format(None), None) def test_get_db_prep_lookup(self): + from django.db import connection f = models.DecimalField(max_digits=5, decimal_places=1) - self.assertEqual(f.get_db_prep_lookup('exact', None), [None]) + self.assertEqual(f.get_db_prep_lookup('exact', None, connection), [None]) def test_filter_with_strings(self): """ @@ -98,13 +99,14 @@ class DateTimeFieldTests(unittest.TestCase): class BooleanFieldTests(unittest.TestCase): def _test_get_db_prep_lookup(self, f): - self.assertEqual(f.get_db_prep_lookup('exact', True), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', '1'), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', 1), [True]) - self.assertEqual(f.get_db_prep_lookup('exact', False), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', '0'), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', 0), [False]) - self.assertEqual(f.get_db_prep_lookup('exact', None), [None]) + from django.db import connection + self.assertEqual(f.get_db_prep_lookup('exact', True, connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', '1', connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', 1, connection), [True]) + self.assertEqual(f.get_db_prep_lookup('exact', False, connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', '0', connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', 0, connection), [False]) + self.assertEqual(f.get_db_prep_lookup('exact', None, connection), [None]) def test_booleanfield_get_db_prep_lookup(self): self._test_get_db_prep_lookup(models.BooleanField())