diff --git a/django/contrib/gis/tests/layermap/tests.py b/django/contrib/gis/tests/layermap/tests.py index 85b8d0c8b5..4912e645c2 100644 --- a/django/contrib/gis/tests/layermap/tests.py +++ b/django/contrib/gis/tests/layermap/tests.py @@ -8,6 +8,7 @@ from django.contrib.gis.gdal import DataSource from django.contrib.gis.tests.utils import mysql from django.contrib.gis.utils.layermapping import (LayerMapping, LayerMapError, InvalidDecimal, MissingForeignKey) +from django.db import router from django.test import TestCase from .models import ( @@ -26,6 +27,7 @@ NAMES = ['Bexar', 'Galveston', 'Harris', 'Honolulu', 'Pueblo'] NUMS = [1, 2, 1, 19, 1] # Number of polygons for each. STATES = ['Texas', 'Texas', 'Texas', 'Hawaii', 'Colorado'] + class LayerMapTest(TestCase): def test_init(self): @@ -281,3 +283,31 @@ class LayerMapTest(TestCase): lm.save(silent=True, strict=True) self.assertEqual(City.objects.count(), 3) self.assertEqual(City.objects.all().order_by('name_txt')[0].name_txt, "Houston") + + +class OtherRouter(object): + def db_for_read(self, model, **hints): + return 'other' + + def db_for_write(self, model, **hints): + return self.db_for_read(model, **hints) + + def allow_relation(self, obj1, obj2, **hints): + return None + + def allow_syncdb(self, db, model): + return True + + +class LayerMapRouterTest(TestCase): + + def setUp(self): + self.old_routers = router.routers + router.routers = [OtherRouter()] + + def tearDown(self): + router.routers = self.old_routers + + def test_layermapping_default_db(self): + lm = LayerMapping(City, city_shp, city_mapping) + self.assertEqual(lm.using, 'other') diff --git a/django/contrib/gis/utils/layermapping.py b/django/contrib/gis/utils/layermapping.py index e898f6de2e..9511815426 100644 --- a/django/contrib/gis/utils/layermapping.py +++ b/django/contrib/gis/utils/layermapping.py @@ -9,7 +9,7 @@ import sys from decimal import Decimal from django.core.exceptions import ObjectDoesNotExist -from django.db import connections, DEFAULT_DB_ALIAS +from django.db import connections, router from django.contrib.gis.db.models import GeometryField from django.contrib.gis.gdal import (CoordTransform, DataSource, OGRException, OGRGeometry, OGRGeomType, SpatialReference) @@ -67,7 +67,7 @@ class LayerMapping(object): def __init__(self, model, data, mapping, layer=0, source_srs=None, encoding=None, transaction_mode='commit_on_success', - transform=True, unique=None, using=DEFAULT_DB_ALIAS): + transform=True, unique=None, using=None): """ A LayerMapping object is initialized using the given Model (not an instance), a DataSource (or string path to an OGR-supported data file), and a mapping @@ -81,8 +81,8 @@ class LayerMapping(object): self.ds = data self.layer = self.ds[layer] - self.using = using - self.spatial_backend = connections[using].ops + self.using = using if using is not None else router.db_for_write(model) + self.spatial_backend = connections[self.using].ops # Setting the mapping & model attributes. self.mapping = mapping