From c7bd48cb9f645e5ff07d1e68b86130e3bb2b043f Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 9 Aug 2010 21:16:59 +0000 Subject: [PATCH] [soc2010/query-refactor] Improved the ListField implementation, and added an EmbeddedModelField. git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13564 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/contrib/mongodb/compiler.py | 4 ++ django/db/models/__init__.py | 2 +- django/db/models/fields/structures.py | 57 ++++++++++++++++++++++++- django/utils/encoding.py | 1 + tests/regressiontests/mongodb/models.py | 20 +++++++++ tests/regressiontests/mongodb/tests.py | 57 ++++++++++++++++++++++++- 6 files changed, 137 insertions(+), 4 deletions(-) diff --git a/django/contrib/mongodb/compiler.py b/django/contrib/mongodb/compiler.py index c2a8c11b23..1851967ba9 100644 --- a/django/contrib/mongodb/compiler.py +++ b/django/contrib/mongodb/compiler.py @@ -172,6 +172,10 @@ class SQLUpdateCompiler(SQLCompiler): vals = {} for field, o, value in self.query.values: + if hasattr(value, 'prepare_database_save'): + value = value.prepare_database_save(field) + else: + value = field.get_db_prep_save(value, connection=self.connection) if hasattr(value, "evaluate"): assert value.connector in (value.ADD, value.SUB) assert not value.negated diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 0736648a7b..6c0161a0e9 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -13,7 +13,7 @@ from django.db.models.fields.subclassing import SubfieldBase from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.related import (ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel) -from django.db.models.fields.structures import ListField +from django.db.models.fields.structures import ListField, EmbeddedModel from django.db.models import signals # Admin stages. diff --git a/django/db/models/fields/structures.py b/django/db/models/fields/structures.py index 2d3822e932..b90164cc28 100644 --- a/django/db/models/fields/structures.py +++ b/django/db/models/fields/structures.py @@ -1,10 +1,15 @@ +from django.core.exceptions import ValidationError +from django.db.models.loading import cache from django.db.models.fields import Field +from django.db.models.fields.subclassing import SubfieldBase class ListField(Field): + __metaclass__ = SubfieldBase + def __init__(self, field_type): self.field_type = field_type - super(ListField, self).__init__() + super(ListField, self).__init__(default=[]) def get_prep_lookup(self, lookup_type, value): return self.field_type.get_prep_lookup(lookup_type, value) @@ -19,3 +24,53 @@ class ListField(Field): return self.field_type.get_db_prep_lookup( lookup_type, value, connection=connection, prepared=prepared ) + + def to_python(self, value): + try: + value = iter(value) + except TypeError: + raise ValidationError("Value should be iterable") + return [ + self.field_type.to_python(v) + for v in value + ] + + +class EmbeddedModel(Field): + __metaclass__ = SubfieldBase + + def __init__(self, to): + self.to = to + super(EmbeddedModel, self).__init__() + + def get_db_prep_save(self, value, connection): + data = {} + if not isinstance(value, self.to): + raise ValidationError("Value must be an instance of %s, got %s " + "instead" % (self.to, value)) + if type(value) is not self.to: + data["_cls"] = (value._meta.app_label, value._meta.object_name) + for field in value._meta.fields: + # If the field is a OneToOneField that makes the inheritance link, + # ignore it. + if field.rel and field.rel.parent_link: + continue + data[field.column] = field.get_db_prep_save( + getattr(value, field.name), connection=connection + ) + return data + + def to_python(self, value): + if isinstance(value, self.to): + return value + try: + value = dict(value) + except TypeError: + raise ValidationError("Value should be a dict") + + if "_cls" in value: + cls = cache.get_model(*value.pop("_cls")) + else: + cls = self.to + + return cls(**value) diff --git a/django/utils/encoding.py b/django/utils/encoding.py index e2d7249903..c4139d9171 100644 --- a/django/utils/encoding.py +++ b/django/utils/encoding.py @@ -47,6 +47,7 @@ def is_protected_type(obj): return isinstance(obj, ( types.NoneType, int, long, + list, datetime.datetime, datetime.date, datetime.time, float, Decimal) ) diff --git a/tests/regressiontests/mongodb/models.py b/tests/regressiontests/mongodb/models.py index f0d950dcbb..9b7e25108e 100644 --- a/tests/regressiontests/mongodb/models.py +++ b/tests/regressiontests/mongodb/models.py @@ -31,3 +31,23 @@ class Post(models.Model): magic_numbers = models.ListField( models.IntegerField() ) + + +class Revision(models.Model): + number = models.IntegerField() + content = models.TextField() + + +class AuthenticatedRevision(Revision): + # This is a really stupid way to add optional authentication, but it serves + # its purpose. + author = models.CharField(max_length=100) + + +class WikiPage(models.Model): + id = models.NativeAutoField(primary_key=True) + title = models.CharField(max_length=255) + + revisions = models.ListField( + models.EmbeddedModel(Revision) + ) diff --git a/tests/regressiontests/mongodb/tests.py b/tests/regressiontests/mongodb/tests.py index 81b780b8ff..fff6f248db 100644 --- a/tests/regressiontests/mongodb/tests.py +++ b/tests/regressiontests/mongodb/tests.py @@ -1,8 +1,9 @@ +from django.core.exceptions import ValidationError from django.db import connection, UnsupportedDatabaseOperation from django.db.models import Count, Sum, F, Q from django.test import TestCase -from models import Artist, Group, Post +from models import Artist, Group, Post, WikiPage, Revision, AuthenticatedRevision class MongoTestCase(TestCase): @@ -398,6 +399,9 @@ class MongoTestCase(TestCase): ) def test_list_field(self): + p = Post() + self.assertEqual(p.tags, []) + p = Post.objects.create( title="Django ORM grows MongoDB support", tags=["python", "django", "mongodb", "web"] @@ -428,7 +432,7 @@ class MongoTestCase(TestCase): lambda p: p.title ) - self.assertRaises(ValueError, + self.assertRaises(ValidationError, lambda: Post.objects.create(magic_numbers=["a"]) ) @@ -448,3 +452,52 @@ class MongoTestCase(TestCase): ], lambda p: p.title, ) + + def test_embedded_model(self): + page = WikiPage(title="Django") + page.revisions.append( + Revision(number=1, content="Django is a Python") + ) + page.revisions.append( + Revision(number=2, content="Django is a Python web framework.") + ) + + page.save() + + page = WikiPage.objects.get(pk=page.pk) + self.assertEqual(len(page.revisions), 2) + self.assertEqual( + [(r.number, r.content) for r in page.revisions], + [(1, "Django is a Python"), (2, "Django is a Python web framework.")] + ) + + self.assertEqual(Revision.objects.count(), 0) + + self.assertRaises(ValidationError, + lambda: WikiPage.objects.create(title="Python", revisions=14) + ) + self.assertRaises(ValidationError, + lambda: WikiPage.objects.create(title="Python", revisions=[14]) + ) + + page = WikiPage.objects.create(title="Python", revisions=[ + Revision(number=1, content="Python was created by Guido van Rossum.") + ]) + page = WikiPage.objects.get(pk=page.pk) + self.assertEqual(len(page.revisions), 1) + + page.revisions.append( + AuthenticatedRevision(number=2, content="Python is a trap.", author="Rasmus Lerdorf"), + ) + + page.save() + self.assertEqual(len(page.revisions), 2) + self.assertEqual(page.revisions[-1].author, "Rasmus Lerdorf") + + page = WikiPage.objects.get(pk=page.pk) + self.assertEqual(len(page.revisions), 2) + self.assertTrue(isinstance(page.revisions[-1], AuthenticatedRevision)) + self.assertEqual(page.revisions[-1].author, "Rasmus Lerdorf") + + page.revisions.append(14) + self.assertRaises(ValidationError, page.save)