diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 4989ad8820..fd251b63e0 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -1,7 +1,8 @@ from decimal import Decimal -from django.contrib.gis.db.models.fields import GeometryField +from django.contrib.gis.db.models.fields import GeometryField, RasterField from django.contrib.gis.db.models.sql import AreaField +from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import ( Area as AreaMeasure, Distance as DistanceMeasure, ) @@ -40,6 +41,8 @@ class GeoFunc(Func): def as_sql(self, compiler, connection): if self.function is None: self.function = connection.ops.spatial_function_name(self.name) + if any(isinstance(field, RasterField) for field in self.get_source_fields()): + raise TypeError("Geometry functions not supported for raster fields.") return super(GeoFunc, self).as_sql(compiler, connection) def resolve_expression(self, *args, **kwargs): @@ -87,6 +90,8 @@ class GeomValue(Value): class GeoFuncWithGeoParam(GeoFunc): def __init__(self, expression, geom, *expressions, **extra): + if not isinstance(geom, Geometry): + raise TypeError("Please provide a geometry object.") if not hasattr(geom, 'srid') or not geom.srid: raise ValueError("Please provide a geometry attribute with a defined SRID.") super(GeoFuncWithGeoParam, self).__init__(expression, GeomValue(geom), *expressions, **extra) diff --git a/tests/gis_tests/rasterapp/test_rasterfield.py b/tests/gis_tests/rasterapp/test_rasterfield.py index f838f56bdc..5934fcc690 100644 --- a/tests/gis_tests/rasterapp/test_rasterfield.py +++ b/tests/gis_tests/rasterapp/test_rasterfield.py @@ -1,5 +1,6 @@ import json +from django.contrib.gis.db.models.functions import Distance from django.contrib.gis.db.models.lookups import ( DistanceLookupBase, gis_lookups, ) @@ -326,3 +327,18 @@ class RasterFieldTest(TransactionTestCase): msg = "Couldn't create spatial object from lookup value '%s'." % obj with self.assertRaisesMessage(ValueError, msg): RasterModel.objects.filter(geom__intersects=obj) + + def test_db_function_errors(self): + """ + Errors are raised when using DB functions with raster content. + """ + point = GEOSGeometry("SRID=3086;POINT (-697024.9213808845 683729.1705516104)") + rast = GDALRaster(json.loads(JSON_RASTER)) + msg = "Please provide a geometry object." + with self.assertRaisesMessage(TypeError, msg): + RasterModel.objects.annotate(distance_from_point=Distance("geom", rast)) + with self.assertRaisesMessage(TypeError, msg): + RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", rast)) + msg = "Geometry functions not supported for raster fields." + with self.assertRaisesMessage(TypeError, msg): + RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", point)).count()