diff --git a/django/contrib/gis/gdal/raster/source.py b/django/contrib/gis/gdal/raster/source.py index 33d8b3069f..05bbe49372 100644 --- a/django/contrib/gis/gdal/raster/source.py +++ b/django/contrib/gis/gdal/raster/source.py @@ -425,17 +425,24 @@ class GDALRaster(GDALRasterBase): return target - def transform(self, srid, driver=None, name=None, resampling='NearestNeighbour', + def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour', max_error=0.0): """ - Return a copy of this raster reprojected into the given SRID. + Return a copy of this raster reprojected into the given spatial + reference system. """ # Convert the resampling algorithm name into an algorithm id algorithm = GDAL_RESAMPLE_ALGORITHMS[resampling] - # Instantiate target spatial reference system - target_srs = SpatialReference(srid) - + if isinstance(srs, SpatialReference): + target_srs = srs + elif isinstance(srs, (int, str)): + target_srs = SpatialReference(srs) + else: + raise TypeError( + 'Transform only accepts SpatialReference, string, and integer ' + 'objects.' + ) # Create warped virtual dataset in the target reference system target = capi.auto_create_warped_vrt( self._ptr, self.srs.wkt.encode(), target_srs.wkt.encode(), @@ -445,7 +452,7 @@ class GDALRaster(GDALRasterBase): # Construct the target warp dictionary from the virtual raster data = { - 'srid': srid, + 'srid': target_srs.srid, 'width': target.width, 'height': target.height, 'origin': [target.origin.x, target.origin.y], diff --git a/docs/ref/contrib/gis/gdal.txt b/docs/ref/contrib/gis/gdal.txt index 857d8d02f5..aa7e2a7eb8 100644 --- a/docs/ref/contrib/gis/gdal.txt +++ b/docs/ref/contrib/gis/gdal.txt @@ -1368,14 +1368,16 @@ blue. [ 19., 21., 23.], [ 31., 33., 35.]], dtype=float32) - .. method:: transform(srid, driver=None, name=None, resampling='NearestNeighbour', max_error=0.0) + .. method:: transform(srs, driver=None, name=None, resampling='NearestNeighbour', max_error=0.0) - Returns a transformed version of this raster with the specified SRID. + Transforms this raster to a different spatial reference system + (``srs``), which may be a :class:`SpatialReference` object, or any + other input accepted by :class:`SpatialReference` (including spatial + reference WKT and PROJ strings, or an integer SRID). - This function transforms the current raster into a new spatial reference - system that can be specified with an ``srid``. It calculates the bounds - and scale of the current raster in the new spatial reference system and - warps the raster using the :attr:`~GDALRaster.warp` function. + It calculates the bounds and scale of the current raster in the new + spatial reference system and warps the raster using the + :attr:`~GDALRaster.warp` function. By default, the driver of the source raster is used and the name of the raster is the original name appended with @@ -1394,10 +1396,15 @@ blue. ... "scale": [100, -100], ... "bands": [{"data": range(36), "nodata_value": 99}] ... }) - >>> target = rst.transform(4326) + >>> target_srs = SpatialReference(4326) + >>> target = rst.transform(target_srs) >>> target.origin [-82.98492744885776, 27.601924753080144] + .. versionchanged:: 3.2 + + Support for :class:`SpatialReference` ``srs`` was added + .. attribute:: info Returns a string with a summary of the raster. This is equivalent to diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 9f1d785934..88c4cdc998 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -68,7 +68,8 @@ Minor features :mod:`django.contrib.gis` ~~~~~~~~~~~~~~~~~~~~~~~~~ -* ... +* The :meth:`.GDALRaster.transform` method now supports + :class:`~django.contrib.gis.gdal.SpatialReference`. :mod:`django.contrib.messages` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/gis_tests/gdal_tests/test_raster.py b/tests/gis_tests/gdal_tests/test_raster.py index 3369d71ea0..98ac9d6340 100644 --- a/tests/gis_tests/gdal_tests/test_raster.py +++ b/tests/gis_tests/gdal_tests/test_raster.py @@ -3,7 +3,7 @@ import shutil import struct import tempfile -from django.contrib.gis.gdal import GDAL_VERSION, GDALRaster +from django.contrib.gis.gdal import GDAL_VERSION, GDALRaster, SpatialReference from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.raster.band import GDALBand from django.contrib.gis.shortcuts import numpy @@ -471,62 +471,65 @@ class GDALRasterTests(SimpleTestCase): self.assertEqual(result, [23] * 16) def test_raster_transform(self): - # Prepare tempfile and nodata value - rstfile = tempfile.NamedTemporaryFile(suffix='.tif') - ndv = 99 + tests = [ + 3086, + '3086', + SpatialReference(3086), + ] + for srs in tests: + with self.subTest(srs=srs): + # Prepare tempfile and nodata value. + rstfile = tempfile.NamedTemporaryFile(suffix='.tif') + ndv = 99 + # Create in file based raster. + source = GDALRaster({ + 'datatype': 1, + 'driver': 'tif', + 'name': rstfile.name, + 'width': 5, + 'height': 5, + 'nr_of_bands': 1, + 'srid': 4326, + 'origin': (-5, 5), + 'scale': (2, -2), + 'skew': (0, 0), + 'bands': [{ + 'data': range(25), + 'nodata_value': ndv, + }], + }) - # Create in file based raster - source = GDALRaster({ - 'datatype': 1, - 'driver': 'tif', - 'name': rstfile.name, - 'width': 5, - 'height': 5, - 'nr_of_bands': 1, - 'srid': 4326, - 'origin': (-5, 5), - 'scale': (2, -2), - 'skew': (0, 0), - 'bands': [{ - 'data': range(25), - 'nodata_value': ndv, - }], - }) + target = source.transform(srs) - # Transform raster into srid 4326. - target = source.transform(3086) + # Reload data from disk. + target = GDALRaster(target.name) + self.assertEqual(target.srs.srid, 3086) + self.assertEqual(target.width, 7) + self.assertEqual(target.height, 7) + self.assertEqual(target.bands[0].datatype(), source.bands[0].datatype()) + self.assertAlmostEqual(target.origin[0], 9124842.791079799, 3) + self.assertAlmostEqual(target.origin[1], 1589911.6476407414, 3) + self.assertAlmostEqual(target.scale[0], 223824.82664250192, 3) + self.assertAlmostEqual(target.scale[1], -223824.82664250192, 3) + self.assertEqual(target.skew, [0, 0]) - # Reload data from disk - target = GDALRaster(target.name) - - self.assertEqual(target.srs.srid, 3086) - self.assertEqual(target.width, 7) - self.assertEqual(target.height, 7) - self.assertEqual(target.bands[0].datatype(), source.bands[0].datatype()) - self.assertAlmostEqual(target.origin[0], 9124842.791079799, 3) - self.assertAlmostEqual(target.origin[1], 1589911.6476407414, 3) - self.assertAlmostEqual(target.scale[0], 223824.82664250192, 3) - self.assertAlmostEqual(target.scale[1], -223824.82664250192, 3) - self.assertEqual(target.skew, [0, 0]) - - result = target.bands[0].data() - if numpy: - result = result.flatten().tolist() - - # The reprojection of a raster that spans over a large area - # skews the data matrix and might introduce nodata values. - self.assertEqual( - result, - [ - ndv, ndv, ndv, ndv, 4, ndv, ndv, - ndv, ndv, 2, 3, 9, ndv, ndv, - ndv, 1, 2, 8, 13, 19, ndv, - 0, 6, 6, 12, 18, 18, 24, - ndv, 10, 11, 16, 22, 23, ndv, - ndv, ndv, 15, 21, 22, ndv, ndv, - ndv, ndv, 20, ndv, ndv, ndv, ndv, - ] - ) + result = target.bands[0].data() + if numpy: + result = result.flatten().tolist() + # The reprojection of a raster that spans over a large area + # skews the data matrix and might introduce nodata values. + self.assertEqual( + result, + [ + ndv, ndv, ndv, ndv, 4, ndv, ndv, + ndv, ndv, 2, 3, 9, ndv, ndv, + ndv, 1, 2, 8, 13, 19, ndv, + 0, 6, 6, 12, 18, 18, 24, + ndv, 10, 11, 16, 22, 23, ndv, + ndv, ndv, 15, 21, 22, ndv, ndv, + ndv, ndv, 20, ndv, ndv, ndv, ndv, + ], + ) class GDALBandTests(SimpleTestCase):