From 22bb040b60c868b77c956333ccbd3c07eb342487 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Wed, 19 Mar 2008 15:46:20 +0000 Subject: [PATCH] queryset-refactor: Initial pass at fixing the Oracle support. Thanks, Justin Bronn. Fixed #6161 This is untested (by me) and is a slight modification on Justin's original patch, so feedback and bug reports are welcome. git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7321 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/backends/oracle/base.py | 251 ++++++++---------------------- 1 file changed, 65 insertions(+), 186 deletions(-) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index e5e398920e..39fd2fc3a3 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -25,6 +25,7 @@ IntegrityError = Database.IntegrityError class DatabaseFeatures(BaseDatabaseFeatures): allows_group_by_ordinal = False allows_unique_and_pk = False # Suppress UNIQUE/PK for Oracle (ORA-02259) + empty_fetchmany_value = () needs_datetime_string_cast = False needs_upper_for_iops = True supports_tablespaces = True @@ -99,195 +100,13 @@ class DatabaseOperations(BaseDatabaseOperations): return 30 def query_set_class(self, DefaultQuerySet): - from django.db import connection - from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word - - class OracleQuerySet(DefaultQuerySet): - - def iterator(self): - "Performs the SELECT database lookup of this QuerySet." - - from django.db.models.query import get_cached_row - - # self._select is a dictionary, and dictionaries' key order is - # undefined, so we convert it to a list of tuples. - extra_select = self._select.items() - - full_query = None - - try: - try: - select, sql, params, full_query = self._get_sql_clause(get_full_query=True) - except TypeError: - select, sql, params = self._get_sql_clause() - except EmptyResultSet: - raise StopIteration - if not full_query: - full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql) - - cursor = connection.cursor() - cursor.execute(full_query, params) - - fill_cache = self._select_related - fields = self.model._meta.fields - index_end = len(fields) - - # so here's the logic; - # 1. retrieve each row in turn - # 2. convert NCLOBs - - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - for row in rows: - row = self.resolve_columns(row, fields) - if fill_cache: - obj, index_end = get_cached_row(klass=self.model, row=row, - index_start=0, max_depth=self._max_related_depth) - else: - obj = self.model(*row[:index_end]) - for i, k in enumerate(extra_select): - setattr(obj, k[0], row[index_end+i]) - yield obj - - - def _get_sql_clause(self, get_full_query=False): - from django.db.models.query import fill_table_cache, \ - handle_legacy_orderlist, orderfield2column - - opts = self.model._meta - qn = connection.ops.quote_name - - # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. - select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] - tables = [quote_only_if_word(t) for t in self._tables] - joins = SortedDict() - where = self._where[:] - params = self._params[:] - - # Convert self._filters into SQL. - joins2, where2, params2 = self._filters.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) - - # Add additional tables and WHERE clauses based on select_related. - if self._select_related: - fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) - - # Add any additional SELECTs. - if self._select: - select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) - - # Start composing the body of the SQL statement. - sql = [" FROM", qn(opts.db_table)] - - # Compose the join dictionary into SQL describing the joins. - if joins: - sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition) - for (alias, (table, join_type, condition)) in joins.items()])) - - # Compose the tables clause into SQL. - if tables: - sql.append(", " + ", ".join(tables)) - - # Compose the where clause into SQL. - if where: - sql.append(where and "WHERE " + " AND ".join(where)) - - # ORDER BY clause - order_by = [] - if self._order_by is not None: - ordering_to_use = self._order_by - else: - ordering_to_use = opts.ordering - for f in handle_legacy_orderlist(ordering_to_use): - if f == '?': # Special case. - order_by.append(DatabaseOperations().random_function_sql()) - else: - if f.startswith('-'): - col_name = f[1:] - order = "DESC" - else: - col_name = f - order = "ASC" - if "." in col_name: - table_prefix, col_name = col_name.split('.', 1) - table_prefix = qn(table_prefix) + '.' - else: - # Use the database table as a column prefix if it wasn't given, - # and if the requested column isn't a custom SELECT. - if "." not in col_name and col_name not in (self._select or ()): - table_prefix = qn(opts.db_table) + '.' - else: - table_prefix = '' - order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) - if order_by: - sql.append("ORDER BY " + ", ".join(order_by)) - - # Look for column name collisions in the select elements - # and fix them with an AS alias. This allows us to do a - # SELECT * later in the paging query. - cols = [clause.split('.')[-1] for clause in select] - for index, col in enumerate(cols): - if cols.count(col) > 1: - col = '%s%d' % (col.replace('"', ''), index) - cols[index] = col - select[index] = '%s AS %s' % (select[index], col) - - # LIMIT and OFFSET clauses - # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query. - select_clause = ",".join(select) - distinct = (self._distinct and "DISTINCT " or "") - - if order_by: - order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by)) - else: - #Oracle's row_number() function always requires an order-by clause. - #So we need to define a default order-by, since none was provided. - order_by_clause = " OVER (ORDER BY %s.%s)" % \ - (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) - # limit_and_offset_clause - if self._limit is None: - assert self._offset is None, "'offset' is not allowed without 'limit'" - - if self._offset is not None: - offset = int(self._offset) - else: - offset = 0 - if self._limit is not None: - limit = int(self._limit) - else: - limit = None - - limit_and_offset_clause = '' - if limit is not None: - limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset) - elif offset: - limit_and_offset_clause = "WHERE rn > %s" % (offset) - - if len(limit_and_offset_clause) > 0: - fmt = \ - """SELECT * FROM - (SELECT %s%s, - ROW_NUMBER()%s AS rn - %s) - %s""" - full_query = fmt % (distinct, select_clause, - order_by_clause, ' '.join(sql).strip(), - limit_and_offset_clause) - else: - full_query = None - - if get_full_query: - return select, " ".join(sql), params, full_query - else: - return select, " ".join(sql), params + # Getting the base default `Query` object. + DefaultQuery = DefaultQuerySet().query.__class__ + class OracleQuery(DefaultQuery): def resolve_columns(self, row, fields=()): from django.db.models.fields import DateField, DateTimeField, \ - TimeField, BooleanField, NullBooleanField, DecimalField, Field + TimeField, BooleanField, NullBooleanField, DecimalField, Field values = [] for value, field in map(None, row, fields): if isinstance(value, Database.LOB): @@ -331,6 +150,66 @@ class DatabaseOperations(BaseDatabaseOperations): values.append(value) return values + def as_sql(self, with_limits=True): + """ + Creates the SQL for this query. Returns the SQL string and list + of parameters. This is overriden from the original Query class + to accommodate Oracle's limit/offset SQL. + + If 'with_limits' is False, any limit/offset information is not + included in the query. + """ + # The `do_offset` flag indicates whether we need to construct + # the SQL needed to use limit/offset w/Oracle. + do_offset = with_limits and (self.high_mark or self.low_mark) + + # If no offsets, just return the result of the base class + # `as_sql`. + if not do_offset: + return super(OracleQuery, self).as_sql(with_limits=False) + + # `get_columns` needs to be called before `get_ordering` to + # populate `_select_alias`. + self.pre_sql_setup() + out_cols = self.get_columns() + ordering = self.get_ordering() + + # Getting the "ORDER BY" SQL for the ROW_NUMBER() result. + if ordering: + rn_orderby = ', '.join(ordering) + else: + # Oracle's ROW_NUMBER() function always requires an + # order-by clause. So we need to define a default + # order-by, since none was provided. + qn = self.quote_name_unless_alias + opts = self.model._meta + rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) + + # Getting the selection SQL and the params, which has the `rn` + # extra selection SQL; we pop `rn` after this completes so we do + # not get the attribute on the returned models. + self.extra_select['rn'] = 'ROW_NUMBER() OVER (ORDER BY %s )' % rn_orderby + sql, params= super(OracleQuery, self).as_sql(with_limits=False) + self.extra_select.pop('rn') + + # Constructing the result SQL, using the initial select SQL + # obtained above. + result = ['SELECT * FROM (%s)' % sql] + + # Place WHERE condition on `rn` for the desired range. + result.append('WHERE rn > %d' % self.low_mark) + if self.high_mark: + result.append('AND rn <= %d' % self.high_mark) + + # Returning the SQL w/params. + return ' '.join(result), params + + from django.db import connection + class OracleQuerySet(DefaultQuerySet): + "The OracleQuerySet is overriden to use OracleQuery." + def __init__(self, model=None, query=None): + super(OracleQuerySet, self).__init__(model=model, query=query) + self.query = query or OracleQuery(self.model, connection) return OracleQuerySet def quote_name(self, name):