From e13ea3c70dc1a217d4e88e0a66f4311167f156c3 Mon Sep 17 00:00:00 2001 From: Adrian Holovaty Date: Mon, 20 Aug 2007 02:39:05 +0000 Subject: [PATCH] Refactored get_query_set_class() to DatabaseOperations.query_set_class(). Also added BaseDatabaseFeatures.uses_custom_queryset. Refs #5106 git-svn-id: http://code.djangoproject.com/svn/django/trunk@5976 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/backends/__init__.py | 10 + django/db/backends/oracle/base.py | 478 +++++++++++++++--------------- django/db/models/query.py | 6 +- 3 files changed, 251 insertions(+), 243 deletions(-) diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 2415d453ce..0a86ed1476 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -48,6 +48,7 @@ class BaseDatabaseFeatures(object): supports_constraints = True supports_tablespaces = False uses_case_insensitive_names = False + uses_custom_queryset = False class BaseDatabaseOperations(object): """ @@ -144,6 +145,15 @@ class BaseDatabaseOperations(object): """ return 'DEFAULT' + def query_set_class(self, DefaultQuerySet): + """ + Given the default QuerySet class, returns a custom QuerySet class + to use for this backend. Returns None if a custom QuerySet isn't used. + See also BaseDatabaseFeatures.uses_custom_queryset, which regulates + whether this method is called at all. + """ + return None + def quote_name(self, name): """ Returns a quoted version of the given table, index or column name. Does diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index acebd75cb3..0f1c475871 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): needs_upper_for_iops = True supports_tablespaces = True uses_case_insensitive_names = True + uses_custom_queryset = True class DatabaseOperations(BaseDatabaseOperations): def autoinc_sql(self, table): @@ -78,6 +79,243 @@ class DatabaseOperations(BaseDatabaseOperations): def max_name_length(self): return 30 + def query_set_class(self, DefaultQuerySet): + from django.db import connection + from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word + + class OracleQuerySet(DefaultQuerySet): + + def iterator(self): + "Performs the SELECT database lookup of this QuerySet." + + from django.db.models.query import get_cached_row + + # self._select is a dictionary, and dictionaries' key order is + # undefined, so we convert it to a list of tuples. + extra_select = self._select.items() + + full_query = None + + try: + try: + select, sql, params, full_query = self._get_sql_clause(get_full_query=True) + except TypeError: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + if not full_query: + full_query = "SELECT %s%s\n%s" % \ + ((self._distinct and "DISTINCT " or ""), + ', '.join(select), sql) + + cursor = connection.cursor() + cursor.execute(full_query, params) + + fill_cache = self._select_related + fields = self.model._meta.fields + index_end = len(fields) + + # so here's the logic; + # 1. retrieve each row in turn + # 2. convert NCLOBs + + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + for row in rows: + row = self.resolve_columns(row, fields) + if fill_cache: + obj, index_end = get_cached_row(klass=self.model, row=row, + index_start=0, max_depth=self._max_related_depth) + else: + obj = self.model(*row[:index_end]) + for i, k in enumerate(extra_select): + setattr(obj, k[0], row[index_end+i]) + yield obj + + + def _get_sql_clause(self, get_full_query=False): + from django.db.models.query import fill_table_cache, \ + handle_legacy_orderlist, orderfield2column + + opts = self.model._meta + qn = connection.ops.quote_name + + # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. + select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] + tables = [quote_only_if_word(t) for t in self._tables] + joins = SortedDict() + where = self._where[:] + params = self._params[:] + + # Convert self._filters into SQL. + joins2, where2, params2 = self._filters.get_sql(opts) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + + # Add additional tables and WHERE clauses based on select_related. + if self._select_related: + fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) + + # Add any additional SELECTs. + if self._select: + select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) + + # Start composing the body of the SQL statement. + sql = [" FROM", qn(opts.db_table)] + + # Compose the join dictionary into SQL describing the joins. + if joins: + sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition) + for (alias, (table, join_type, condition)) in joins.items()])) + + # Compose the tables clause into SQL. + if tables: + sql.append(", " + ", ".join(tables)) + + # Compose the where clause into SQL. + if where: + sql.append(where and "WHERE " + " AND ".join(where)) + + # ORDER BY clause + order_by = [] + if self._order_by is not None: + ordering_to_use = self._order_by + else: + ordering_to_use = opts.ordering + for f in handle_legacy_orderlist(ordering_to_use): + if f == '?': # Special case. + order_by.append(DatabaseOperations().random_function_sql()) + else: + if f.startswith('-'): + col_name = f[1:] + order = "DESC" + else: + col_name = f + order = "ASC" + if "." in col_name: + table_prefix, col_name = col_name.split('.', 1) + table_prefix = qn(table_prefix) + '.' + else: + # Use the database table as a column prefix if it wasn't given, + # and if the requested column isn't a custom SELECT. + if "." not in col_name and col_name not in (self._select or ()): + table_prefix = qn(opts.db_table) + '.' + else: + table_prefix = '' + order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) + if order_by: + sql.append("ORDER BY " + ", ".join(order_by)) + + # Look for column name collisions in the select elements + # and fix them with an AS alias. This allows us to do a + # SELECT * later in the paging query. + cols = [clause.split('.')[-1] for clause in select] + for index, col in enumerate(cols): + if cols.count(col) > 1: + col = '%s%d' % (col.replace('"', ''), index) + cols[index] = col + select[index] = '%s AS %s' % (select[index], col) + + # LIMIT and OFFSET clauses + # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query. + select_clause = ",".join(select) + distinct = (self._distinct and "DISTINCT " or "") + + if order_by: + order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by)) + else: + #Oracle's row_number() function always requires an order-by clause. + #So we need to define a default order-by, since none was provided. + order_by_clause = " OVER (ORDER BY %s.%s)" % \ + (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) + # limit_and_offset_clause + if self._limit is None: + assert self._offset is None, "'offset' is not allowed without 'limit'" + + if self._offset is not None: + offset = int(self._offset) + else: + offset = 0 + if self._limit is not None: + limit = int(self._limit) + else: + limit = None + + limit_and_offset_clause = '' + if limit is not None: + limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset) + elif offset: + limit_and_offset_clause = "WHERE rn > %s" % (offset) + + if len(limit_and_offset_clause) > 0: + fmt = \ + """SELECT * FROM + (SELECT %s%s, + ROW_NUMBER()%s AS rn + %s) + %s""" + full_query = fmt % (distinct, select_clause, + order_by_clause, ' '.join(sql).strip(), + limit_and_offset_clause) + else: + full_query = None + + if get_full_query: + return select, " ".join(sql), params, full_query + else: + return select, " ".join(sql), params + + def resolve_columns(self, row, fields=()): + from django.db.models.fields import DateField, DateTimeField, \ + TimeField, BooleanField, NullBooleanField, DecimalField, Field + values = [] + for value, field in map(None, row, fields): + if isinstance(value, Database.LOB): + value = value.read() + # Oracle stores empty strings as null. We need to undo this in + # order to adhere to the Django convention of using the empty + # string instead of null, but only if the field accepts the + # empty string. + if value is None and isinstance(field, Field) and field.empty_strings_allowed: + value = '' + # Convert 1 or 0 to True or False + elif value in (1, 0) and isinstance(field, (BooleanField, NullBooleanField)): + value = bool(value) + # Convert floats to decimals + elif value is not None and isinstance(field, DecimalField): + value = util.typecast_decimal(field.format_number(value)) + # cx_Oracle always returns datetime.datetime objects for + # DATE and TIMESTAMP columns, but Django wants to see a + # python datetime.date, .time, or .datetime. We use the type + # of the Field to determine which to cast to, but it's not + # always available. + # As a workaround, we cast to date if all the time-related + # values are 0, or to time if the date is 1/1/1900. + # This could be cleaned a bit by adding a method to the Field + # classes to normalize values from the database (the to_python + # method is used for validation and isn't what we want here). + elif isinstance(value, Database.Timestamp): + # In Python 2.3, the cx_Oracle driver returns its own + # Timestamp object that we must convert to a datetime class. + if not isinstance(value, datetime.datetime): + value = datetime.datetime(value.year, value.month, value.day, value.hour, + value.minute, value.second, value.fsecond) + if isinstance(field, DateTimeField): + pass # DateTimeField subclasses DateField so must be checked first. + elif isinstance(field, DateField): + value = value.date() + elif isinstance(field, TimeField) or (value.year == 1900 and value.month == value.day == 1): + value = value.time() + elif value.hour == value.minute == value.second == value.microsecond == 0: + value = value.date() + values.append(value) + return values + + return OracleQuerySet + def quote_name(self, name): # SQL92 requires delimited (quoted) names to be case-sensitive. When # not quoted, Oracle has case-insensitive behavior for identifiers, but @@ -261,246 +499,6 @@ def get_trigger_name(table): name_length = DatabaseOperations().max_name_length() - 3 return '%s_TR' % util.truncate_name(table, name_length).upper() -def get_query_set_class(DefaultQuerySet): - "Create a custom QuerySet class for Oracle." - - from django.db import connection - from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word - - class OracleQuerySet(DefaultQuerySet): - - def iterator(self): - "Performs the SELECT database lookup of this QuerySet." - - from django.db.models.query import get_cached_row - - # self._select is a dictionary, and dictionaries' key order is - # undefined, so we convert it to a list of tuples. - extra_select = self._select.items() - - full_query = None - - try: - try: - select, sql, params, full_query = self._get_sql_clause(get_full_query=True) - except TypeError: - select, sql, params = self._get_sql_clause() - except EmptyResultSet: - raise StopIteration - if not full_query: - full_query = "SELECT %s%s\n%s" % \ - ((self._distinct and "DISTINCT " or ""), - ', '.join(select), sql) - - cursor = connection.cursor() - cursor.execute(full_query, params) - - fill_cache = self._select_related - fields = self.model._meta.fields - index_end = len(fields) - - # so here's the logic; - # 1. retrieve each row in turn - # 2. convert NCLOBs - - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - for row in rows: - row = self.resolve_columns(row, fields) - if fill_cache: - obj, index_end = get_cached_row(klass=self.model, row=row, - index_start=0, max_depth=self._max_related_depth) - else: - obj = self.model(*row[:index_end]) - for i, k in enumerate(extra_select): - setattr(obj, k[0], row[index_end+i]) - yield obj - - - def _get_sql_clause(self, get_full_query=False): - from django.db.models.query import fill_table_cache, \ - handle_legacy_orderlist, orderfield2column - - opts = self.model._meta - qn = connection.ops.quote_name - - # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. - select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] - tables = [quote_only_if_word(t) for t in self._tables] - joins = SortedDict() - where = self._where[:] - params = self._params[:] - - # Convert self._filters into SQL. - joins2, where2, params2 = self._filters.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) - - # Add additional tables and WHERE clauses based on select_related. - if self._select_related: - fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) - - # Add any additional SELECTs. - if self._select: - select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) - - # Start composing the body of the SQL statement. - sql = [" FROM", qn(opts.db_table)] - - # Compose the join dictionary into SQL describing the joins. - if joins: - sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition) - for (alias, (table, join_type, condition)) in joins.items()])) - - # Compose the tables clause into SQL. - if tables: - sql.append(", " + ", ".join(tables)) - - # Compose the where clause into SQL. - if where: - sql.append(where and "WHERE " + " AND ".join(where)) - - # ORDER BY clause - order_by = [] - if self._order_by is not None: - ordering_to_use = self._order_by - else: - ordering_to_use = opts.ordering - for f in handle_legacy_orderlist(ordering_to_use): - if f == '?': # Special case. - order_by.append(DatabaseOperations().random_function_sql()) - else: - if f.startswith('-'): - col_name = f[1:] - order = "DESC" - else: - col_name = f - order = "ASC" - if "." in col_name: - table_prefix, col_name = col_name.split('.', 1) - table_prefix = qn(table_prefix) + '.' - else: - # Use the database table as a column prefix if it wasn't given, - # and if the requested column isn't a custom SELECT. - if "." not in col_name and col_name not in (self._select or ()): - table_prefix = qn(opts.db_table) + '.' - else: - table_prefix = '' - order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) - if order_by: - sql.append("ORDER BY " + ", ".join(order_by)) - - # Look for column name collisions in the select elements - # and fix them with an AS alias. This allows us to do a - # SELECT * later in the paging query. - cols = [clause.split('.')[-1] for clause in select] - for index, col in enumerate(cols): - if cols.count(col) > 1: - col = '%s%d' % (col.replace('"', ''), index) - cols[index] = col - select[index] = '%s AS %s' % (select[index], col) - - # LIMIT and OFFSET clauses - # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query. - select_clause = ",".join(select) - distinct = (self._distinct and "DISTINCT " or "") - - if order_by: - order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by)) - else: - #Oracle's row_number() function always requires an order-by clause. - #So we need to define a default order-by, since none was provided. - order_by_clause = " OVER (ORDER BY %s.%s)" % \ - (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) - # limit_and_offset_clause - if self._limit is None: - assert self._offset is None, "'offset' is not allowed without 'limit'" - - if self._offset is not None: - offset = int(self._offset) - else: - offset = 0 - if self._limit is not None: - limit = int(self._limit) - else: - limit = None - - limit_and_offset_clause = '' - if limit is not None: - limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset) - elif offset: - limit_and_offset_clause = "WHERE rn > %s" % (offset) - - if len(limit_and_offset_clause) > 0: - fmt = \ -"""SELECT * FROM - (SELECT %s%s, - ROW_NUMBER()%s AS rn - %s) -%s""" - full_query = fmt % (distinct, select_clause, - order_by_clause, ' '.join(sql).strip(), - limit_and_offset_clause) - else: - full_query = None - - if get_full_query: - return select, " ".join(sql), params, full_query - else: - return select, " ".join(sql), params - - def resolve_columns(self, row, fields=()): - from django.db.models.fields import DateField, DateTimeField, \ - TimeField, BooleanField, NullBooleanField, DecimalField, Field - values = [] - for value, field in map(None, row, fields): - if isinstance(value, Database.LOB): - value = value.read() - # Oracle stores empty strings as null. We need to undo this in - # order to adhere to the Django convention of using the empty - # string instead of null, but only if the field accepts the - # empty string. - if value is None and isinstance(field, Field) and field.empty_strings_allowed: - value = '' - # Convert 1 or 0 to True or False - elif value in (1, 0) and isinstance(field, (BooleanField, NullBooleanField)): - value = bool(value) - # Convert floats to decimals - elif value is not None and isinstance(field, DecimalField): - value = util.typecast_decimal(field.format_number(value)) - # cx_Oracle always returns datetime.datetime objects for - # DATE and TIMESTAMP columns, but Django wants to see a - # python datetime.date, .time, or .datetime. We use the type - # of the Field to determine which to cast to, but it's not - # always available. - # As a workaround, we cast to date if all the time-related - # values are 0, or to time if the date is 1/1/1900. - # This could be cleaned a bit by adding a method to the Field - # classes to normalize values from the database (the to_python - # method is used for validation and isn't what we want here). - elif isinstance(value, Database.Timestamp): - # In Python 2.3, the cx_Oracle driver returns its own - # Timestamp object that we must convert to a datetime class. - if not isinstance(value, datetime.datetime): - value = datetime.datetime(value.year, value.month, value.day, value.hour, - value.minute, value.second, value.fsecond) - if isinstance(field, DateTimeField): - pass # DateTimeField subclasses DateField so must be checked first. - elif isinstance(field, DateField): - value = value.date() - elif isinstance(field, TimeField) or (value.year == 1900 and value.month == value.day == 1): - value = value.time() - elif value.hour == value.minute == value.second == value.microsecond == 0: - value = value.date() - values.append(value) - return values - - return OracleQuerySet - - OPERATOR_MAPPING = { 'exact': '= %s', 'iexact': '= UPPER(%s)', diff --git a/django/db/models/query.py b/django/db/models/query.py index 767942c6a5..f4ebca2ac5 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -564,9 +564,9 @@ class _QuerySet(object): return select, " ".join(sql), params -# Use the backend's QuerySet class if it defines one, otherwise use _QuerySet. -if hasattr(backend, 'get_query_set_class'): - QuerySet = backend.get_query_set_class(_QuerySet) +# Use the backend's QuerySet class if it defines one. Otherwise, use _QuerySet. +if connection.features.uses_custom_queryset: + QuerySet = connection.ops.query_set_class(_QuerySet) else: QuerySet = _QuerySet