diff --git a/django/contrib/gis/gdal/raster/source.py b/django/contrib/gis/gdal/raster/source.py index cb7b4cf3ac..03fa1f8f22 100644 --- a/django/contrib/gis/gdal/raster/source.py +++ b/django/contrib/gis/gdal/raster/source.py @@ -250,7 +250,7 @@ class GDALRaster(GDALBase): @geotransform.setter def geotransform(self, values): "Set the geotransform for the data source." - if sum([isinstance(x, (int, float)) for x in values]) != 6: + if len(values) != 6 or not all(isinstance(x, (int, float)) for x in values): raise ValueError('Geotransform must consist of 6 numeric values.') # Create ctypes double array with input and write data values = (c_double * 6)(*values) diff --git a/tests/gis_tests/gdal_tests/test_raster.py b/tests/gis_tests/gdal_tests/test_raster.py index 9e23d78c68..5495178a13 100644 --- a/tests/gis_tests/gdal_tests/test_raster.py +++ b/tests/gis_tests/gdal_tests/test_raster.py @@ -103,6 +103,9 @@ class GDALRasterTests(SimpleTestCase): self.assertEqual(self.rs.skew.y, 0) # Create in-memory rasters and change gtvalues rsmem = GDALRaster(JSON_RASTER) + # geotransform accepts both floats and ints + rsmem.geotransform = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] + self.assertEqual(rsmem.geotransform, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) rsmem.geotransform = range(6) self.assertEqual(rsmem.geotransform, [float(x) for x in range(6)]) self.assertEqual(rsmem.origin, [0, 3]) @@ -117,6 +120,18 @@ class GDALRasterTests(SimpleTestCase): self.assertEqual(rsmem.width, 5) self.assertEqual(rsmem.height, 5) + def test_geotransform_bad_inputs(self): + rsmem = GDALRaster(JSON_RASTER) + error_geotransforms = [ + [1, 2], + [1, 2, 3, 4, 5, 'foo'], + [1, 2, 3, 4, 5, 6, 'foo'], + ] + msg = 'Geotransform must consist of 6 numeric values.' + for geotransform in error_geotransforms: + with self.subTest(i=geotransform), self.assertRaisesMessage(ValueError, msg): + rsmem.geotransform = geotransform + def test_rs_extent(self): self.assertEqual( self.rs.extent,