diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index fe5ff6704f..5cd96d2902 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -19,7 +19,7 @@ from django.core import checks, exceptions, validators # django.core.exceptions. It is retained here for backwards compatibility # purposes. from django.core.exceptions import FieldDoesNotExist # NOQA -from django.db import connection +from django.db import connection, connections, router from django.db.models.lookups import ( Lookup, RegisterLookupMixin, Transform, default_lookups, ) @@ -315,7 +315,11 @@ class Field(RegisterLookupMixin): return [] def _check_backend_specific_checks(self, **kwargs): - return connection.validation.check_field(self, **kwargs) + app_label = self.model._meta.app_label + for db in connections: + if router.allow_migrate(db, app_label, model=self.model): + return connections[db].validation.check_field(self, **kwargs) + return [] def _check_deprecation_details(self): if self.system_check_removed_details is not None: diff --git a/tests/check_framework/test_multi_db.py b/tests/check_framework/test_multi_db.py new file mode 100644 index 0000000000..168b2ea33f --- /dev/null +++ b/tests/check_framework/test_multi_db.py @@ -0,0 +1,43 @@ +from django.db import connections, models +from django.test import TestCase, mock +from django.test.utils import override_settings + +from .tests import IsolateModelsMixin + + +class TestRouter(object): + """ + Routes to the 'other' database if the model name starts with 'Other'. + """ + def allow_migrate(self, db, app_label, model=None, **hints): + return db == ('other' if model._meta.verbose_name.startswith('other') else 'default') + + +@override_settings(DATABASE_ROUTERS=[TestRouter()]) +class TestMultiDBChecks(IsolateModelsMixin, TestCase): + multi_db = True + + def _patch_check_field_on(self, db): + return mock.patch.object(connections[db].validation, 'check_field') + + def test_checks_called_on_the_default_database(self): + class Model(models.Model): + field = models.CharField(max_length=100) + + model = Model() + with self._patch_check_field_on('default') as mock_check_field_default: + with self._patch_check_field_on('other') as mock_check_field_other: + model.check() + self.assertTrue(mock_check_field_default.called) + self.assertFalse(mock_check_field_other.called) + + def test_checks_called_on_the_other_database(self): + class OtherModel(models.Model): + field = models.CharField(max_length=100) + + model = OtherModel() + with self._patch_check_field_on('other') as mock_check_field_other: + with self._patch_check_field_on('default') as mock_check_field_default: + model.check() + self.assertTrue(mock_check_field_other.called) + self.assertFalse(mock_check_field_default.called) diff --git a/tests/invalid_models_tests/test_backend_specific.py b/tests/invalid_models_tests/test_backend_specific.py index 2d3ea7e6c7..5be1d6f691 100644 --- a/tests/invalid_models_tests/test_backend_specific.py +++ b/tests/invalid_models_tests/test_backend_specific.py @@ -1,37 +1,31 @@ # -*- encoding: utf-8 -*- from __future__ import unicode_literals -from types import MethodType - from django.core.checks import Error -from django.db import connection, models +from django.db import connections, models +from django.test import mock from .base import IsolatedModelsTestCase +def dummy_allow_migrate(db, app_label, **hints): + # Prevent checks from being run on the 'other' database, which doesn't have + # its check_field() method mocked in the test. + return db == 'default' + + class BackendSpecificChecksTests(IsolatedModelsTestCase): + @mock.patch('django.db.models.fields.router.allow_migrate', new=dummy_allow_migrate) def test_check_field(self): """ Test if backend specific checks are performed. """ - error = Error('an error', hint=None) - def mock(self, field, **kwargs): - return [error] - class Model(models.Model): field = models.IntegerField() field = Model._meta.get_field('field') - - # Mock connection.validation.check_field method. - v = connection.validation - old_check_field = v.check_field - v.check_field = MethodType(mock, v) - try: + with mock.patch.object(connections['default'].validation, 'check_field', return_value=[error]): errors = field.check() - finally: - # Unmock connection.validation.check_field method. - v.check_field = old_check_field self.assertEqual(errors, [error]) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 0a41bf1bdd..626a879b77 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -14,6 +14,7 @@ from . import PostgreSQLTestCase from .models import ( ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, + PostgreSQLModel, ) try: @@ -246,16 +247,20 @@ class TestQuerying(PostgreSQLTestCase): class TestChecks(PostgreSQLTestCase): def test_field_checks(self): - field = ArrayField(models.CharField()) - field.set_attributes_from_name('field') - errors = field.check() + class MyModel(PostgreSQLModel): + field = ArrayField(models.CharField()) + + model = MyModel() + errors = model.check() self.assertEqual(len(errors), 1) self.assertEqual(errors[0].id, 'postgres.E001') def test_invalid_base_fields(self): - field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) - field.set_attributes_from_name('field') - errors = field.check() + class MyModel(PostgreSQLModel): + field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel')) + + model = MyModel() + errors = model.check() self.assertEqual(len(errors), 1) self.assertEqual(errors[0].id, 'postgres.E002')