From 57c700d550343a00494d288e4582ccaaa6b0964e Mon Sep 17 00:00:00 2001 From: Justin Bronn Date: Thu, 1 May 2008 18:17:50 +0000 Subject: [PATCH] gis: Fixed #7126 (with tests); moved `GeoQuery` and `GeoWhereNode` into `sql` submodule; the `GeoQuerySet.transform` may now be used on geometry fields related via foreign key. git-svn-id: http://code.djangoproject.com/svn/django/branches/gis@7512 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/contrib/gis/db/backend/__init__.py | 38 +-- django/contrib/gis/db/models/query.py | 185 +++---------- django/contrib/gis/db/models/sql/__init__.py | 2 + django/contrib/gis/db/models/sql/query.py | 244 ++++++++++++++++++ django/contrib/gis/db/models/sql/where.py | 24 ++ django/contrib/gis/tests/__init__.py | 2 +- .../contrib/gis/tests/relatedapp/__init__.py | 0 django/contrib/gis/tests/relatedapp/models.py | 12 + django/contrib/gis/tests/relatedapp/tests.py | 74 ++++++ .../gis/tests/relatedapp/tests_mysql.py | 1 + 10 files changed, 410 insertions(+), 172 deletions(-) create mode 100644 django/contrib/gis/db/models/sql/__init__.py create mode 100644 django/contrib/gis/db/models/sql/query.py create mode 100644 django/contrib/gis/db/models/sql/where.py create mode 100644 django/contrib/gis/tests/relatedapp/__init__.py create mode 100644 django/contrib/gis/tests/relatedapp/models.py create mode 100644 django/contrib/gis/tests/relatedapp/tests.py create mode 100644 django/contrib/gis/tests/relatedapp/tests_mysql.py diff --git a/django/contrib/gis/db/backend/__init__.py b/django/contrib/gis/db/backend/__init__.py index e687e85ead..cb027b9eec 100644 --- a/django/contrib/gis/db/backend/__init__.py +++ b/django/contrib/gis/db/backend/__init__.py @@ -2,16 +2,19 @@ This module provides the backend for spatial SQL construction with Django. Specifically, this module will import the correct routines and modules - needed for GeoDjango. + needed for GeoDjango to interface with the spatial database. + Some of the more important classes and routines from the spatial backend + include: + (1) `GeoBackEndField`, a base class needed for GeometryField. - (2) `GeoWhereNode`, a subclass of `WhereNode` used to contruct spatial SQL. - (3) `SpatialBackend`, a container object for information specific to the + (2) `get_geo_where_clause`, a routine used by `GeoWhereNode`. + (3) `GIS_TERMS`, a listing of all valid GeoDjango lookup types. + (4) `SpatialBackend`, a container object for information specific to the spatial backend. """ from django.conf import settings from django.db.models.sql.query import QUERY_TERMS -from django.db.models.sql.where import WhereNode from django.contrib.gis.db.backend.util import gqn # These routines (needed by GeoManager), default to False. @@ -61,28 +64,6 @@ elif settings.DATABASE_ENGINE == 'mysql': else: raise NotImplementedError('No Geographic Backend exists for %s' % settings.DATABASE_ENGINE) -class GeoWhereNode(WhereNode): - """ - The GeoWhereNode calls the `get_geo_where_clause` from the appropriate - spatial backend in order to construct correct spatial SQL. - """ - def make_atom(self, child, qn): - table_alias, name, field, lookup_type, value = child - if hasattr(field, '_geom'): - if lookup_type in GIS_TERMS: - # Getting the geographic where clause; substitution parameters - # will be populated in the GeoFieldSQL object returned by the - # GeometryField. - gwc = get_geo_where_clause(lookup_type, table_alias, field, value) - where, params = field.get_db_prep_lookup(lookup_type, value) - return gwc % tuple(where), params - else: - raise TypeError('Invalid lookup type: %r' % lookup_type) - else: - # If not a GeometryField, call the `make_atom` from the - # base class. - return super(GeoWhereNode, self).make_atom(child, qn) - class SpatialBackend(object): "A container for properties of the SpatialBackend." # Stored procedure names used by the `GeoManager`. @@ -106,6 +87,11 @@ class SpatialBackend(object): # Lookup types where additional WHERE parameters are excluded. limited_where = LIMITED_WHERE + # Shortcut booleans. + mysql = SPATIAL_BACKEND == 'mysql' + oracle = SPATIAL_BACKEND == 'oracle' + postgis = SPATIAL_BACKEND == 'postgis' + # Class for the backend field. Field = GeoBackendField diff --git a/django/contrib/gis/db/models/query.py b/django/contrib/gis/db/models/query.py index 3c2a533ba4..65b011c7a8 100644 --- a/django/contrib/gis/db/models/query.py +++ b/django/contrib/gis/db/models/query.py @@ -1,22 +1,16 @@ +from itertools import izip from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.db.models.query import sql, QuerySet, Q -from django.db.models.fields import FieldDoesNotExist -from django.contrib.gis.db.backend import gqn, GeoWhereNode, SpatialBackend, QUERY_TERMS + +from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.models.fields import GeometryField, PointField +from django.contrib.gis.db.models.sql import GeoQuery, GeoWhereNode from django.contrib.gis.geos import GEOSGeometry, Point - -# Aliases. qn = connection.ops.quote_name -oracle = SpatialBackend.name == 'oracle' -postgis = SpatialBackend.name == 'postgis' - -# All valid lookup terms. -ALL_TERMS = QUERY_TERMS.copy() -ALL_TERMS.update(dict((term, None) for term in SpatialBackend.gis_terms)) # For backwards-compatibility; Q object should work just fine -# using queryset-refactor. +# after queryset-refactor. class GeoQ(Q): pass class GeomSQL(object): @@ -27,128 +21,6 @@ class GeomSQL(object): def as_sql(self, *args, **kwargs): return self.sql -# Getting the `Query` base class from the backend (needed specifically -# for Oracle backends). -Query = QuerySet().query.__class__ - -class GeoQuery(Query): - "The Geographic Query, needed to construct spatial SQL." - - # Overridding the valid query terms. - query_terms = ALL_TERMS - - #### Methods overridden from the base Query class #### - def __init__(self, model, conn): - super(GeoQuery, self).__init__(model, conn, where=GeoWhereNode) - # The following attributes are customized for the GeoQuerySet. - # The GeoWhereNode and SpatialBackend classes contain backend-specific - # routines and functions. - self.custom_select = {} - self.ewkt = None - - def clone(self, *args, **kwargs): - obj = super(GeoQuery, self).clone(*args, **kwargs) - # Customized selection dictionary and EWKT flag have to be added to obj. - obj.custom_select = self.custom_select.copy() - obj.ewkt = self.ewkt - return obj - - def get_default_columns(self, with_aliases=False, col_aliases=None): - """ - Computes the default columns for selecting every field in the base - model. - - Returns a list of strings, quoted appropriately for use in SQL - directly, as well as a set of aliases used in the select statement. - - This routine is overridden from Query to handle customized selection of - geometry columns. - """ - result = [] - table_alias = self.tables[0] - root_pk = self.model._meta.pk.column - seen = {None: table_alias} - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - aliases = set() - for field, model in self.model._meta.get_fields_with_model(): - try: - alias = seen[model] - except KeyError: - alias = self.join((table_alias, model._meta.db_table, - root_pk, model._meta.pk.column)) - seen[model] = alias - - # This part of the function is customized for GeoQuerySet. We - # see if there was any custom selection specified in the - # dictionary, and set up the selection format appropriately. - sel_fmt = self.get_select_format(field) - if field.column in self.custom_select: - field_sel = sel_fmt % self.custom_select[field.column] - else: - field_sel = sel_fmt % self._field_column(field, alias) - - if with_aliases and field.column in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (field_sel, c_alias)) - col_aliases.add(c_alias) - aliases.add(c_alias) - else: - r = field_sel - result.append(r) - aliases.add(r) - if with_aliases: - col_aliases.add(field.column) - return result, aliases - - #### Routines unique to GeoQuery #### - def get_select_format(self, fld): - """ - Returns the selection format string, depending on the requirements - of the spatial backend. For example, Oracle and MySQL require custom - selection formats in order to retrieve geometries in OGC WKT. For all - other fields a simple '%s' format string is returned. - """ - if SpatialBackend.select and hasattr(fld, '_geom'): - # This allows operations to be done on fields in the SELECT, - # overriding their values -- used by the Oracle and MySQL - # spatial backends to get database values as WKT, and by the - # `transform` method. - sel_fmt = SpatialBackend.select - - # Because WKT doesn't contain spatial reference information, - # the SRID is prefixed to the returned WKT to ensure that the - # transformed geometries have an SRID different than that of the - # field -- this is only used by `transform` for Oracle backends. - if self.ewkt and oracle: - sel_fmt = "'SRID=%d;'||%s" % (self.ewkt, sel_fmt) - else: - sel_fmt = '%s' - return sel_fmt - - def _field_column(self, field, table_alias=None): - """ - Helper function that returns the database column for the given field. - The table and column are returned (quoted) in the proper format, e.g., - `"geoapp_city"."point"`. - """ - if table_alias is None: table_alias = self.model._meta.db_table - return "%s.%s" % (self.quote_name_unless_alias(table_alias), qn(field.column)) - - def _geo_field(self, field_name=None): - """ - Returns the first Geometry field encountered; or specified via the - `field_name` keyword. - """ - for field in self.model._meta.fields: - if isinstance(field, GeometryField): - fname = field.name - if field_name: - if field_name == field.name: return field - else: - return field - return False - class GeoQuerySet(QuerySet): "The Geographic QuerySet." @@ -187,7 +59,7 @@ class GeoQuerySet(QuerySet): # transformation SQL -- we pass in a 'dummy' `contains` # `distance_lte` lookup type. where, params = geo_field.get_db_prep_lookup('distance_lte', (geom, 0)) - if oracle: + if SpatialBackend.oracle: # The `tolerance` keyword may be used for Oracle; the tolerance is # in meters -- a default of 5 centimeters is used. tolerance = kwargs.get('tolerance', 0.05) @@ -229,6 +101,7 @@ class GeoQuerySet(QuerySet): extent_sql = '%s(%s)' % (EXTENT, geo_col) self.query.select = [GeomSQL(extent_sql)] + self.query.select_fields = [None] try: esql, params = self.query.as_sql() except sql.datastructures.EmptyResultSet: @@ -267,9 +140,9 @@ class GeoQuerySet(QuerySet): raise TypeError('GML output only available on GeometryFields.') geo_col = self.query._field_column(geo_field) - if oracle: + if SpatialBackend.oracle: gml_select = {'gml':'%s(%s)' % (ASGML, geo_col)} - elif postgis: + elif SpatialBackend.postgis: # PostGIS AsGML() aggregate function parameter order depends on the # version -- uggh. major, minor1, minor2 = SpatialBackend.version @@ -295,11 +168,12 @@ class GeoQuerySet(QuerySet): geo_field = self.query._geo_field(field_name) if not geo_field: raise TypeError('KML output only available on GeometryFields.') + geo_col = self.query._field_column(geo_field) # Adding the AsKML function call to SELECT part of the SQL. return self.extra(select={'kml':'%s(%s,%s)' % (ASKML, geo_col, precision)}) - + def transform(self, field_name=None, srid=4326): """ Transforms the given geometry field to the given SRID. If no SRID is @@ -320,22 +194,24 @@ class GeoQuerySet(QuerySet): geo_field = self.query._geo_field(field_name) if not geo_field: raise TypeError('%s() only available for GeometryFields' % TRANSFORM) - + + # Getting the selection SQL for the given geograph + field_col = self._geocol_select(geo_field, field_name) + # Why cascading substitutions? Because spatial backends like # Oracle and MySQL already require a function call to convert to text, thus # when there's also a transformation we need to cascade the substitutions. # For example, 'SDO_UTIL.TO_WKTGEOMETRY(SDO_CS.TRANSFORM( ... )' - geo_col = self.query.custom_select.get(geo_field.column, self.query._field_column(geo_field)) + geo_col = self.query.custom_select.get(geo_field, field_col) # Setting the key for the field's column with the custom SELECT SQL to # override the geometry column returned from the database. - if oracle: + if SpatialBackend.oracle: custom_sel = '%s(%s, %s)' % (TRANSFORM, geo_col, srid) self.query.ewkt = srid else: - custom_sel = '(%s(%s, %s)) AS %s' % \ - (TRANSFORM, geo_col, srid, qn(geo_field.column)) - self.query.custom_select[geo_field.column] = custom_sel + custom_sel = '%s(%s, %s)' % (TRANSFORM, geo_col, srid) + self.query.custom_select[geo_field] = custom_sel return self._clone() def union(self, field_name=None, tolerance=0.0005): @@ -357,7 +233,7 @@ class GeoQuerySet(QuerySet): # Replacing the select with a call to the ST_Union stored procedure # on the geographic field column. - if oracle: + if SpatialBackend.oracle: union_sql = '%s' % SpatialBackend.select union_sql = union_sql % ('%s(SDOAGGRTYPE(%s,%s))' % (UNION, geo_col, tolerance)) else: @@ -365,6 +241,7 @@ class GeoQuerySet(QuerySet): # Only want the union SQL to be selected. self.query.select = [GeomSQL(union_sql)] + self.query.select_fields = [GeometryField] try: usql, params = self.query.as_sql() except sql.datastructures.EmptyResultSet: @@ -373,7 +250,7 @@ class GeoQuerySet(QuerySet): # Getting a cursor, executing the query. cursor = connection.cursor() cursor.execute(usql, params) - if oracle: + if SpatialBackend.oracle: # On Oracle have to read out WKT from CLOB first. clob = cursor.fetchone()[0] if clob: u = clob.read() @@ -383,3 +260,21 @@ class GeoQuerySet(QuerySet): if u: return GEOSGeometry(u) else: return None + + # Private API utilities, subject to change. + def _geocol_select(self, geo_field, field_name): + """ + Helper routine for constructing the SQL to select the geographic + column. Takes into account if the geographic field is in a + ForeignKey relation to the current model. + """ + # Is this operation going to be on a related geographic field? + if not geo_field in self.model._meta.fields: + # If so, it'll have to be added to the select related information + # (e.g., if 'location__point' was given as the field name). + self.query.add_select_related([field_name]) + self.query.pre_sql_setup() + rel_table, rel_col = self.query.related_select_cols[self.query.related_select_fields.index(geo_field)] + return self.query._field_column(geo_field, rel_table) + else: + return self.query._field_column(geo_field) diff --git a/django/contrib/gis/db/models/sql/__init__.py b/django/contrib/gis/db/models/sql/__init__.py new file mode 100644 index 0000000000..8e87f6b2a2 --- /dev/null +++ b/django/contrib/gis/db/models/sql/__init__.py @@ -0,0 +1,2 @@ +from django.contrib.gis.db.models.sql.query import GeoQuery +from django.contrib.gis.db.models.sql.where import GeoWhereNode diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py new file mode 100644 index 0000000000..9cd691bea9 --- /dev/null +++ b/django/contrib/gis/db/models/sql/query.py @@ -0,0 +1,244 @@ +from itertools import izip +from django.db.models.query import sql +from django.db.models.fields import FieldDoesNotExist +from django.db.models.fields.related import ForeignKey + +from django.contrib.gis.db.backend import SpatialBackend +from django.contrib.gis.db.models.fields import GeometryField +from django.contrib.gis.db.models.sql.where import GeoWhereNode + +# Valid GIS query types. +ALL_TERMS = sql.constants.QUERY_TERMS.copy() +ALL_TERMS.update(dict([(term, None) for term in SpatialBackend.gis_terms])) + +class GeoQuery(sql.Query): + """ + A single spatial SQL query. + """ + # Overridding the valid query terms. + query_terms = ALL_TERMS + + #### Methods overridden from the base Query class #### + def __init__(self, model, conn): + super(GeoQuery, self).__init__(model, conn, where=GeoWhereNode) + # The following attributes are customized for the GeoQuerySet. + # The GeoWhereNode and SpatialBackend classes contain backend-specific + # routines and functions. + self.custom_select = {} + self.ewkt = None + + def clone(self, *args, **kwargs): + obj = super(GeoQuery, self).clone(*args, **kwargs) + # Customized selection dictionary and EWKT flag have to be added to obj. + obj.custom_select = self.custom_select.copy() + obj.ewkt = self.ewkt + return obj + + def get_columns(self, with_aliases=False): + """ + Return the list of columns to use in the select statement. If no + columns have been specified, returns all columns relating to fields in + the model. + + If 'with_aliases' is true, any column names that are duplicated + (without the table names) are given unique aliases. This is needed in + some cases to avoid ambiguitity with nested queries. + + This routine is overridden from Query to handle customized selection of + geometry columns. + """ + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + result = ['(%s) AS %s' % (col, qn2(alias)) for alias, col in self.extra_select.iteritems()] + aliases = set(self.extra_select.keys()) + if with_aliases: + col_aliases = aliases.copy() + else: + col_aliases = set() + if self.select: + # This loop customized for GeoQuery. + for col, field in izip(self.select, self.select_fields): + if isinstance(col, (list, tuple)): + r = self.get_field_select(field, col[0]) + if with_aliases and col[1] in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col[1]) + else: + result.append(col.as_sql(quote_func=qn)) + if hasattr(col, 'alias'): + aliases.add(col.alias) + col_aliases.add(col.alias) + elif self.default_cols: + cols, new_aliases = self.get_default_columns(with_aliases, + col_aliases) + result.extend(cols) + aliases.update(new_aliases) + # This loop customized for GeoQuery. + for (table, col), field in izip(self.related_select_cols, self.related_select_fields): + r = self.get_field_select(field, table) + if with_aliases and col in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (r, c_alias)) + aliases.add(c_alias) + col_aliases.add(c_alias) + else: + result.append(r) + aliases.add(r) + col_aliases.add(col) + + self._select_aliases = aliases + return result + + def get_default_columns(self, with_aliases=False, col_aliases=None): + """ + Computes the default columns for selecting every field in the base + model. + + Returns a list of strings, quoted appropriately for use in SQL + directly, as well as a set of aliases used in the select statement. + + This routine is overridden from Query to handle customized selection of + geometry columns. + """ + result = [] + table_alias = self.tables[0] + root_pk = self.model._meta.pk.column + seen = {None: table_alias} + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + aliases = set() + for field, model in self.model._meta.get_fields_with_model(): + try: + alias = seen[model] + except KeyError: + alias = self.join((table_alias, model._meta.db_table, + root_pk, model._meta.pk.column)) + seen[model] = alias + + # This part of the function is customized for GeoQuery. We + # see if there was any custom selection specified in the + # dictionary, and set up the selection format appropriately. + field_sel = self.get_field_select(field, alias) + + if with_aliases and field.column in col_aliases: + c_alias = 'Col%d' % len(col_aliases) + result.append('%s AS %s' % (field_sel, c_alias)) + col_aliases.add(c_alias) + aliases.add(c_alias) + else: + r = field_sel + result.append(r) + aliases.add(r) + if with_aliases: + col_aliases.add(field.column) + return result, aliases + + #### Routines unique to GeoQuery #### + def get_field_select(self, fld, alias=None): + """ + Returns the SELECT SQL string for the given field. Figures out + if any custom selection SQL is needed for the column The `alias` + keyword may be used to manually specify the database table where + the column exists, if not in the model associated with this + `GeoQuery`. + """ + sel_fmt = self.get_select_format(fld) + if fld in self.custom_select: + field_sel = sel_fmt % self.custom_select[fld] + else: + field_sel = sel_fmt % self._field_column(fld, alias) + return field_sel + + def get_select_format(self, fld): + """ + Returns the selection format string, depending on the requirements + of the spatial backend. For example, Oracle and MySQL require custom + selection formats in order to retrieve geometries in OGC WKT. For all + other fields a simple '%s' format string is returned. + """ + if SpatialBackend.select and hasattr(fld, '_geom'): + # This allows operations to be done on fields in the SELECT, + # overriding their values -- used by the Oracle and MySQL + # spatial backends to get database values as WKT, and by the + # `transform` method. + sel_fmt = SpatialBackend.select + + # Because WKT doesn't contain spatial reference information, + # the SRID is prefixed to the returned WKT to ensure that the + # transformed geometries have an SRID different than that of the + # field -- this is only used by `transform` for Oracle backends. + if self.ewkt and SpatialBackend.oracle: + sel_fmt = "'SRID=%d;'||%s" % (self.ewkt, sel_fmt) + else: + sel_fmt = '%s' + return sel_fmt + + # Private API utilities, subject to change. + def _check_geo_field(self, model, name_param): + """ + Recursive utility routine for checking the given name parameter + on the given model. Initially, the name parameter is a string, + of the field on the given model e.g., 'point', 'the_geom'. + Related model field strings like 'address__point', may also be + used. + + If a GeometryField exists according to the given name + parameter it will be returned, otherwise returns False. + """ + if isinstance(name_param, basestring): + # This takes into account the situation where the name is a + # lookup to a related geographic field, e.g., 'address__point'. + name_param = name_param.split(sql.constants.LOOKUP_SEP) + name_param.reverse() # Reversing so list operates like a queue of related lookups. + elif not isinstance(name_param, list): + raise TypeError + try: + # Getting the name of the field for the model (by popping the first + # name from the `name_param` list created above). + fld, mod, direct, m2m = model._meta.get_field_by_name(name_param.pop()) + except (FieldDoesNotExist, IndexError): + return False + # TODO: ManyToManyField? + if isinstance(fld, GeometryField): + return fld # A-OK. + elif isinstance(fld, ForeignKey): + # ForeignKey encountered, return the output of this utility called + # on the _related_ model with the remaining name parameters. + return self._check_geo_field(fld.rel.to, name_param) # Recurse to check ForeignKey relation. + else: + return False + + def _field_column(self, field, table_alias=None): + """ + Helper function that returns the database column for the given field. + The table and column are returned (quoted) in the proper format, e.g., + `"geoapp_city"."point"`. If `table_alias` is not specified, the + database table associated with the model of this `GeoQuery` will be + used. + """ + if table_alias is None: table_alias = self.model._meta.db_table + return "%s.%s" % (self.quote_name_unless_alias(table_alias), + self.connection.ops.quote_name(field.column)) + + def _geo_field(self, field_name=None): + """ + Returns the first Geometry field encountered; or specified via the + `field_name` keyword. The `field_name` may be a string specifying + the geometry field on this GeoQuery's model, or a lookup string + to a geometry field via a ForeignKey relation. + """ + if field_name is None: + # Incrementing until the first geographic field is found. + for fld in self.model._meta.fields: + if isinstance(fld, GeometryField): return fld + return False + else: + # Otherwise, check by the given field name -- which may be + # a lookup to a _related_ geographic field. + return self._check_geo_field(self.model, field_name) diff --git a/django/contrib/gis/db/models/sql/where.py b/django/contrib/gis/db/models/sql/where.py new file mode 100644 index 0000000000..3c37232cfd --- /dev/null +++ b/django/contrib/gis/db/models/sql/where.py @@ -0,0 +1,24 @@ +from django.db.models.sql.where import WhereNode +from django.contrib.gis.db.backend import get_geo_where_clause, GIS_TERMS + +class GeoWhereNode(WhereNode): + """ + The GeoWhereNode calls the `get_geo_where_clause` from the appropriate + spatial backend in order to construct correct spatial SQL. + """ + def make_atom(self, child, qn): + table_alias, name, field, lookup_type, value = child + if hasattr(field, '_geom'): + if lookup_type in GIS_TERMS: + # Getting the geographic where clause; substitution parameters + # will be populated in the GeoFieldSQL object returned by the + # GeometryField. + gwc = get_geo_where_clause(lookup_type, table_alias, field, value) + where, params = field.get_db_prep_lookup(lookup_type, value) + return gwc % tuple(where), params + else: + raise TypeError('Invalid lookup type: %r' % lookup_type) + else: + # If not a GeometryField, call the `make_atom` from the + # base class. + return super(GeoWhereNode, self).make_atom(child, qn) diff --git a/django/contrib/gis/tests/__init__.py b/django/contrib/gis/tests/__init__.py index 273d5625a0..b3afe9baa7 100644 --- a/django/contrib/gis/tests/__init__.py +++ b/django/contrib/gis/tests/__init__.py @@ -11,7 +11,7 @@ from django.conf import settings if not settings._target: settings.configure() # Tests that require use of a spatial database (e.g., creation of models) -test_models = ['geoapp'] +test_models = ['geoapp', 'relatedapp'] # Tests that do not require setting up and tearing down a spatial database. test_suite_names = [ diff --git a/django/contrib/gis/tests/relatedapp/__init__.py b/django/contrib/gis/tests/relatedapp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/django/contrib/gis/tests/relatedapp/models.py b/django/contrib/gis/tests/relatedapp/models.py new file mode 100644 index 0000000000..8484054e0d --- /dev/null +++ b/django/contrib/gis/tests/relatedapp/models.py @@ -0,0 +1,12 @@ +from django.contrib.gis.db import models + +class Location(models.Model): + name = models.CharField(max_length=50) + point = models.PointField() + objects = models.GeoManager() + +class City(models.Model): + name = models.CharField(max_length=50) + state = models.USStateField() + location = models.ForeignKey(Location) + objects = models.GeoManager() diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py new file mode 100644 index 0000000000..d92249e9d1 --- /dev/null +++ b/django/contrib/gis/tests/relatedapp/tests.py @@ -0,0 +1,74 @@ +import os, unittest +from django.contrib.gis.geos import * +from django.contrib.gis.tests.utils import no_mysql, postgis +from django.conf import settings +from models import City, Location + +cities = (('Aurora', 'TX', -97.516111, 33.058333), + ('Roswell', 'NM', -104.528056, 33.387222), + ('Kecksburg', 'PA', -79.460734, 40.18476), + ) + +class RelatedGeoModelTest(unittest.TestCase): + + def test01_setup(self): + "Setting up for related model tests." + for name, state, lon, lat in cities: + loc = Location(point=Point(lon, lat)) + loc.save() + c = City(name=name, state=state, location=loc) + c.save() + + def test02_select_related(self): + "Testing `select_related` on geographic models (see #7126)." + qs1 = City.objects.all() + qs2 = City.objects.select_related() + qs3 = City.objects.select_related('location') + + for qs in (qs1, qs2, qs3): + for ref, c in zip(cities, qs): + nm, st, lon, lat = ref + self.assertEqual(nm, c.name) + self.assertEqual(st, c.state) + self.assertEqual(Point(lon, lat), c.location.point) + + @no_mysql + def test03_transform_related(self): + "Testing the `transform` GeoManager method on related geographic models." + # All the transformations are to state plane coordinate systems using + # US Survey Feet (thus a tolerance of 0 implies error w/in 1 survey foot). + if postgis: + tol = 3 + nqueries = 4 # +1 for `postgis_lib_version` + else: + tol = 0 + nqueries = 3 + + def check_pnt(ref, pnt): + self.assertAlmostEqual(ref.x, pnt.x, tol) + self.assertAlmostEqual(ref.y, pnt.y, tol) + + # Turning on debug so we can manually verify the number of SQL queries issued. + dbg = settings.DEBUG + settings.DEBUG = True + from django.db import connection + + # Each city transformed to the SRID of their state plane coordinate system. + transformed = (('Kecksburg', 2272, 'POINT(1490553.98959621 314792.131023984)'), + ('Roswell', 2257, 'POINT(481902.189077221 868477.766629735)'), + ('Aurora', 2276, 'POINT(2269923.2484839 7069381.28722222)'), + ) + + for name, srid, wkt in transformed: + # Doing this implicitly sets `select_related` select the location. + qs = list(City.objects.filter(name=name).transform('location__point', srid)) + check_pnt(GEOSGeometry(wkt), qs[0].location.point) + settings.DEBUG= dbg + + # Verifying the number of issued SQL queries. + self.assertEqual(nqueries, len(connection.queries)) + +def suite(): + s = unittest.TestSuite() + s.addTest(unittest.makeSuite(RelatedGeoModelTest)) + return s diff --git a/django/contrib/gis/tests/relatedapp/tests_mysql.py b/django/contrib/gis/tests/relatedapp/tests_mysql.py new file mode 100644 index 0000000000..ecadf745f3 --- /dev/null +++ b/django/contrib/gis/tests/relatedapp/tests_mysql.py @@ -0,0 +1 @@ +from tests import *