mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #26788 -- Fixed QuerySet.update() crash when updating a geometry to another one.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							64264c9a19
						
					
				
				
					commit
					e7afef13f5
				
			| @@ -71,7 +71,29 @@ class BaseSpatialOperations: | ||||
|         stored procedure call to the transformation function of the spatial | ||||
|         backend. | ||||
|         """ | ||||
|         raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') | ||||
|         def transform_value(value, field): | ||||
|             return ( | ||||
|                 not (value is None or value.srid == field.srid) and | ||||
|                 self.connection.features.supports_transform | ||||
|             ) | ||||
|  | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             return ( | ||||
|                 '%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid) | ||||
|                 if transform_value(value.output_field, f) | ||||
|                 else '%s' | ||||
|             ) | ||||
|         if transform_value(value, f): | ||||
|             # Add Transform() to the SQL placeholder. | ||||
|             return '%s(%s(%%s,%s), %s)' % ( | ||||
|                 self.spatial_function_name('Transform'), | ||||
|                 self.from_text, value.srid, f.srid, | ||||
|             ) | ||||
|         elif self.connection.features.has_spatialrefsys_table: | ||||
|             return '%s(%%s,%s)' % (self.from_text, f.srid) | ||||
|         else: | ||||
|             # For backwards compatibility on MySQL (#27464). | ||||
|             return '%s(%%s)' % self.from_text | ||||
|  | ||||
|     def check_expression_support(self, expression): | ||||
|         if isinstance(expression, self.disallowed_aggregates): | ||||
|   | ||||
| @@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|     def geo_db_type(self, f): | ||||
|         return f.geom_type | ||||
|  | ||||
|     def get_geom_placeholder(self, f, value, compiler): | ||||
|         """ | ||||
|         The placeholder here has to include MySQL's WKT constructor.  Because | ||||
|         MySQL does not support spatial transformations, there is no need to | ||||
|         modify the placeholder based on the contents of the given value. | ||||
|         """ | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             placeholder, _ = compiler.compile(value) | ||||
|         else: | ||||
|             placeholder = '%s(%%s)' % self.from_text | ||||
|         return placeholder | ||||
|  | ||||
|     def get_db_converters(self, expression): | ||||
|         converters = super().get_db_converters(expression) | ||||
|         if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection: | ||||
|   | ||||
| @@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         return [dist_param] | ||||
|  | ||||
|     def get_geom_placeholder(self, f, value, compiler): | ||||
|         """ | ||||
|         Provide a proper substitution value for Geometries that are not in the | ||||
|         SRID of the field.  Specifically, this routine will substitute in the | ||||
|         SDO_CS.TRANSFORM() function call. | ||||
|         """ | ||||
|         tranform_func = self.spatial_function_name('Transform') | ||||
|  | ||||
|         if value is None: | ||||
|             return 'NULL' | ||||
|  | ||||
|         def transform_value(val, srid): | ||||
|             return val.srid != srid | ||||
|  | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             if transform_value(value, f.srid): | ||||
|                 placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) | ||||
|             else: | ||||
|                 placeholder = '%s' | ||||
|             # No geometry value used for F expression, substitute in | ||||
|             # the column name instead. | ||||
|             sql, _ = compiler.compile(value) | ||||
|             return placeholder % sql | ||||
|         else: | ||||
|             if transform_value(value, f.srid): | ||||
|                 return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid) | ||||
|             else: | ||||
|                 return 'SDO_GEOMETRY(%%s, %s)' % f.srid | ||||
|         return super().get_geom_placeholder(f, value, compiler) | ||||
|  | ||||
|     def spatial_aggregate_name(self, agg_name): | ||||
|         """ | ||||
|   | ||||
| @@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         substitute in the ST_Transform() function call. | ||||
|         """ | ||||
|         tranform_func = self.spatial_function_name('Transform') | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             if value.field.srid == f.srid: | ||||
|                 placeholder = '%s' | ||||
|             else: | ||||
|                 placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) | ||||
|             return placeholder | ||||
|  | ||||
|         # Get the srid for this object | ||||
|         if value is None: | ||||
| @@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         else: | ||||
|             placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) | ||||
|  | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             # If this is an F expression, then we don't really want | ||||
|             # a placeholder and instead substitute in the column | ||||
|             # of the expression. | ||||
|             sql, _ = compiler.compile(value) | ||||
|             placeholder = placeholder % sql | ||||
|  | ||||
|         return placeholder | ||||
|  | ||||
|     def _get_postgis_func(self, func): | ||||
|   | ||||
| @@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|             dist_param = value | ||||
|         return [dist_param] | ||||
|  | ||||
|     def get_geom_placeholder(self, f, value, compiler): | ||||
|         """ | ||||
|         Provide a proper substitution value for Geometries that are not in the | ||||
|         SRID of the field.  Specifically, this routine will substitute in the | ||||
|         Transform() and GeomFromText() function call(s). | ||||
|         """ | ||||
|         tranform_func = self.spatial_function_name('Transform') | ||||
|  | ||||
|         def transform_value(value, srid): | ||||
|             return not (value is None or value.srid == srid) | ||||
|         if hasattr(value, 'as_sql'): | ||||
|             if transform_value(value, f.srid): | ||||
|                 placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) | ||||
|             else: | ||||
|                 placeholder = '%s' | ||||
|             # No geometry value used for F expression, substitute in | ||||
|             # the column name instead. | ||||
|             sql, _ = compiler.compile(value) | ||||
|             return placeholder % sql | ||||
|         else: | ||||
|             if transform_value(value, f.srid): | ||||
|                 # Adding Transform() to the SQL placeholder. | ||||
|                 return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid) | ||||
|             else: | ||||
|                 return '%s(%%s,%s)' % (self.from_text, f.srid) | ||||
|  | ||||
|     def _get_spatialite_func(self, func): | ||||
|         """ | ||||
|         Helper routine for calling SpatiaLite functions and returning | ||||
|   | ||||
| @@ -68,6 +68,8 @@ class GISLookup(Lookup): | ||||
|             if not hasattr(geo_fld, 'srid'): | ||||
|                 raise ValueError('No geographic field found in expression.') | ||||
|             self.rhs.srid = geo_fld.srid | ||||
|             sql, _ = compiler.compile(geom) | ||||
|             return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, [] | ||||
|         elif isinstance(self.rhs, Expression): | ||||
|             raise ValueError('Complex expressions not supported for spatial fields.') | ||||
|         elif isinstance(self.rhs, (list, tuple)): | ||||
|   | ||||
| @@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations): | ||||
|         return value | ||||
|  | ||||
|     def binary_placeholder_sql(self, value): | ||||
|         return '_binary %s' if value is not None else '%s' | ||||
|         return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s' | ||||
|  | ||||
|     def subtract_temporals(self, internal_type, lhs, rhs): | ||||
|         lhs_sql, lhs_params = lhs | ||||
|   | ||||
| @@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler): | ||||
|             name = field.column | ||||
|             if hasattr(val, 'as_sql'): | ||||
|                 sql, params = self.compile(val) | ||||
|                 values.append('%s = %s' % (qn(name), sql)) | ||||
|                 values.append('%s = %s' % (qn(name), placeholder % sql)) | ||||
|                 update_params.extend(params) | ||||
|             elif val is not None: | ||||
|                 values.append('%s = %s' % (qn(name), placeholder)) | ||||
|   | ||||
| @@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField): | ||||
| class NonConcreteModel(NamedModel): | ||||
|     non_concrete = NonConcreteField() | ||||
|     point = models.PointField(geography=True) | ||||
|  | ||||
|  | ||||
| class ManyPointModel(NamedModel): | ||||
|     point1 = models.PointField() | ||||
|     point2 = models.PointField() | ||||
|     point3 = models.PointField(srid=3857) | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from unittest import skipUnless | ||||
|  | ||||
| from django.contrib.gis.db.models import GeometryField, Value, functions | ||||
| from django.contrib.gis.db.models import F, GeometryField, Value, functions | ||||
| from django.contrib.gis.geos import Point, Polygon | ||||
| from django.db import connection | ||||
| from django.test import TestCase, skipUnlessDBFeature | ||||
|  | ||||
| from ..utils import postgis | ||||
| from .models import City | ||||
| from .models import City, ManyPointModel | ||||
|  | ||||
|  | ||||
| @skipUnlessDBFeature('gis_enabled') | ||||
| @@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase): | ||||
|         p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1))) | ||||
|         area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a | ||||
|         self.assertAlmostEqual(area.sq_km, 12305.1, 0) | ||||
|  | ||||
|     def test_update_from_other_field(self): | ||||
|         p1 = Point(1, 1, srid=4326) | ||||
|         p2 = Point(2, 2, srid=4326) | ||||
|         obj = ManyPointModel.objects.create( | ||||
|             point1=p1, | ||||
|             point2=p2, | ||||
|             point3=p2.transform(3857, clone=True), | ||||
|         ) | ||||
|         # Updating a point to a point of the same SRID. | ||||
|         ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1')) | ||||
|         obj.refresh_from_db() | ||||
|         self.assertEqual(obj.point2, p1) | ||||
|         # Updating a point to a point with a different SRID. | ||||
|         if connection.features.supports_transform: | ||||
|             ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1')) | ||||
|             obj.refresh_from_db() | ||||
|             self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1)) | ||||
|  | ||||
|     @skipUnlessDBFeature('has_Translate_function') | ||||
|     def test_update_with_expression(self): | ||||
|         city = City.objects.create(point=Point(1, 1, srid=4326)) | ||||
|         City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1)) | ||||
|         city.refresh_from_db() | ||||
|         self.assertEqual(city.point, Point(2, 2, srid=4326)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user