From d91479a287fcda93415569bbf872df15bc6cdbf8 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Thu, 13 Mar 2008 11:48:25 +0000 Subject: [PATCH] queryset-refactor: Changed the return type of an internal function. Previous polymorphic return type was dumb. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7235 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 2 +- django/db/models/sql/datastructures.py | 4 +++ django/db/models/sql/query.py | 34 ++++++++++++++------------ 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 76f5eab680..09badef5aa 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -506,7 +506,7 @@ class ValuesQuerySet(QuerySet): # Default to all fields. field_names = [f.attname for f in self.model._meta.fields] - self.query.add_fields(field_names) + self.query.add_fields(field_names, False) self.query.default_cols = False self.field_names = field_names diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index e1ab0eded7..bc21fb3b68 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -9,6 +9,10 @@ class EmptyResultSet(Exception): class FullResultSet(Exception): pass +class JoinError(Exception): + def __init__(self, level): + self.level = level + class Empty(object): pass diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 3afd4ab367..dd6365282b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -17,7 +17,7 @@ from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR from django.db.models.sql.datastructures import Count from django.db.models.fields import FieldDoesNotExist from django.core.exceptions import FieldError -from datastructures import EmptyResultSet, Empty +from datastructures import EmptyResultSet, Empty, JoinError from constants import * try: @@ -507,10 +507,11 @@ class Query(object): pieces = name.split(LOOKUP_SEP) if not alias: alias = self.get_initial_alias() - result = self.setup_joins(pieces, opts, alias, False, False) - if isinstance(result, int): + try: + field, target, opts, joins = self.setup_joins(pieces, opts, alias, + False, False) + except JoinError: raise FieldError("Cannot order by many-valued field: '%s'" % name) - field, target, opts, joins = result alias = joins[-1][-1] col = target.column @@ -812,12 +813,12 @@ class Query(object): alias = self.get_initial_alias() allow_many = trim or not negate - result = self.setup_joins(parts, opts, alias, (connector == AND), - allow_many) - if isinstance(result, int): - self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:result])) + try: + field, target, opts, join_list = self.setup_joins(parts, opts, + alias, (connector == AND), allow_many) + except JoinError, e: + self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level])) return - field, target, opts, join_list = result if trim and len(join_list) > 1: extra = join_list[-1] join_list = join_list[:-1] @@ -972,7 +973,7 @@ class Query(object): for join in joins: for alias in join: self.unref_alias(alias) - return pos + 1 + raise JoinError(pos + 1) if model: # The field lives on a base class of the current model. alias_list = [] @@ -1128,17 +1129,20 @@ class Query(object): """ return not (self.low_mark or self.high_mark) - def add_fields(self, field_names): + def add_fields(self, field_names, allow_m2m=True): """ Adds the given (model) fields to the select set. The field names are added in the order specified. """ alias = self.get_initial_alias() opts = self.get_meta() - for name in field_names: - u1, target, u2, joins = self.setup_joins(name.split(LOOKUP_SEP), - opts, alias, False, False, True) - self.select.append((joins[-1][-1], target.column)) + try: + for name in field_names: + u1, target, u2, joins = self.setup_joins(name.split(LOOKUP_SEP), + opts, alias, False, allow_m2m, True) + self.select.append((joins[-1][-1], target.column)) + except JoinError: + raise FieldError("Invalid field name: '%s'" % name) def add_ordering(self, *ordering): """