1
0
mirror of https://github.com/django/django.git synced 2025-03-31 19:46:42 +00:00

Fixed #25172 -- Fixed check framework to work with multiple databases.

This commit is contained in:
Ion Scerbatiuc 2015-08-01 07:46:25 -07:00 committed by Tim Graham
parent d0bd533043
commit 0cc059cd10
4 changed files with 70 additions and 24 deletions

View File

@ -19,7 +19,7 @@ from django.core import checks, exceptions, validators
# django.core.exceptions. It is retained here for backwards compatibility # django.core.exceptions. It is retained here for backwards compatibility
# purposes. # purposes.
from django.core.exceptions import FieldDoesNotExist # NOQA from django.core.exceptions import FieldDoesNotExist # NOQA
from django.db import connection from django.db import connection, connections, router
from django.db.models.lookups import ( from django.db.models.lookups import (
Lookup, RegisterLookupMixin, Transform, default_lookups, Lookup, RegisterLookupMixin, Transform, default_lookups,
) )
@ -315,7 +315,11 @@ class Field(RegisterLookupMixin):
return [] return []
def _check_backend_specific_checks(self, **kwargs): def _check_backend_specific_checks(self, **kwargs):
return connection.validation.check_field(self, **kwargs) app_label = self.model._meta.app_label
for db in connections:
if router.allow_migrate(db, app_label, model=self.model):
return connections[db].validation.check_field(self, **kwargs)
return []
def _check_deprecation_details(self): def _check_deprecation_details(self):
if self.system_check_removed_details is not None: if self.system_check_removed_details is not None:

View File

@ -0,0 +1,43 @@
from django.db import connections, models
from django.test import TestCase, mock
from django.test.utils import override_settings
from .tests import IsolateModelsMixin
class TestRouter(object):
"""
Routes to the 'other' database if the model name starts with 'Other'.
"""
def allow_migrate(self, db, app_label, model=None, **hints):
return db == ('other' if model._meta.verbose_name.startswith('other') else 'default')
@override_settings(DATABASE_ROUTERS=[TestRouter()])
class TestMultiDBChecks(IsolateModelsMixin, TestCase):
multi_db = True
def _patch_check_field_on(self, db):
return mock.patch.object(connections[db].validation, 'check_field')
def test_checks_called_on_the_default_database(self):
class Model(models.Model):
field = models.CharField(max_length=100)
model = Model()
with self._patch_check_field_on('default') as mock_check_field_default:
with self._patch_check_field_on('other') as mock_check_field_other:
model.check()
self.assertTrue(mock_check_field_default.called)
self.assertFalse(mock_check_field_other.called)
def test_checks_called_on_the_other_database(self):
class OtherModel(models.Model):
field = models.CharField(max_length=100)
model = OtherModel()
with self._patch_check_field_on('other') as mock_check_field_other:
with self._patch_check_field_on('default') as mock_check_field_default:
model.check()
self.assertTrue(mock_check_field_other.called)
self.assertFalse(mock_check_field_default.called)

View File

@ -1,37 +1,31 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from types import MethodType
from django.core.checks import Error from django.core.checks import Error
from django.db import connection, models from django.db import connections, models
from django.test import mock
from .base import IsolatedModelsTestCase from .base import IsolatedModelsTestCase
def dummy_allow_migrate(db, app_label, **hints):
# Prevent checks from being run on the 'other' database, which doesn't have
# its check_field() method mocked in the test.
return db == 'default'
class BackendSpecificChecksTests(IsolatedModelsTestCase): class BackendSpecificChecksTests(IsolatedModelsTestCase):
@mock.patch('django.db.models.fields.router.allow_migrate', new=dummy_allow_migrate)
def test_check_field(self): def test_check_field(self):
""" Test if backend specific checks are performed. """ """ Test if backend specific checks are performed. """
error = Error('an error', hint=None) error = Error('an error', hint=None)
def mock(self, field, **kwargs):
return [error]
class Model(models.Model): class Model(models.Model):
field = models.IntegerField() field = models.IntegerField()
field = Model._meta.get_field('field') field = Model._meta.get_field('field')
with mock.patch.object(connections['default'].validation, 'check_field', return_value=[error]):
# Mock connection.validation.check_field method.
v = connection.validation
old_check_field = v.check_field
v.check_field = MethodType(mock, v)
try:
errors = field.check() errors = field.check()
finally:
# Unmock connection.validation.check_field method.
v.check_field = old_check_field
self.assertEqual(errors, [error]) self.assertEqual(errors, [error])

View File

@ -14,6 +14,7 @@ from . import PostgreSQLTestCase
from .models import ( from .models import (
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
PostgreSQLModel,
) )
try: try:
@ -246,16 +247,20 @@ class TestQuerying(PostgreSQLTestCase):
class TestChecks(PostgreSQLTestCase): class TestChecks(PostgreSQLTestCase):
def test_field_checks(self): def test_field_checks(self):
field = ArrayField(models.CharField()) class MyModel(PostgreSQLModel):
field.set_attributes_from_name('field') field = ArrayField(models.CharField())
errors = field.check()
model = MyModel()
errors = model.check()
self.assertEqual(len(errors), 1) self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, 'postgres.E001') self.assertEqual(errors[0].id, 'postgres.E001')
def test_invalid_base_fields(self): def test_invalid_base_fields(self):
field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) class MyModel(PostgreSQLModel):
field.set_attributes_from_name('field') field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel'))
errors = field.check()
model = MyModel()
errors = model.check()
self.assertEqual(len(errors), 1) self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, 'postgres.E002') self.assertEqual(errors[0].id, 'postgres.E002')