From 11c00d632d5b3cb2fdbece6760c8b221e3c7daff Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Tue, 15 Dec 2009 17:58:17 +0000 Subject: [PATCH] [soc2009/multidb] Reorganized sql.InsertQuery to defer the need for a connection till later in the process. Patch from Russell Keith-Magee. git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11871 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- .../contrib/gis/db/models/sql/subqueries.py | 19 ++++++------------ django/db/models/query.py | 5 ++--- django/db/models/sql/compiler.py | 20 +++++++++++++++---- django/db/models/sql/subqueries.py | 20 ++++++++----------- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/django/contrib/gis/db/models/sql/subqueries.py b/django/contrib/gis/db/models/sql/subqueries.py index 851ca35fb7..21185a2967 100644 --- a/django/contrib/gis/db/models/sql/subqueries.py +++ b/django/contrib/gis/db/models/sql/subqueries.py @@ -2,7 +2,7 @@ from django.db import connections from django.db.models.sql.subqueries import InsertQuery class GeoInsertQuery(InsertQuery): - def insert_values(self, insert_values, connection, raw_values=False): + def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The dictionary gives the model field names and their target values. @@ -14,19 +14,13 @@ class GeoInsertQuery(InsertQuery): """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val, connection)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) - + if not placeholders[-1] == 'NULL': values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders) @@ -38,6 +32,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None): part of the public API. """ query = GeoInsertQuery(model) - compiler = query.get_compiler(using=using) - query.insert_values(values, compiler.connection, raw_values) - return compiler.execute_sql(return_id) + query.insert_values(values, raw_values) + return query.get_compiler(using=using).execute_sql(return_id) diff --git a/django/db/models/query.py b/django/db/models/query.py index 4d92abc92d..1059352945 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1149,6 +1149,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None): part of the public API. """ query = sql.InsertQuery(model) - compiler = query.get_compiler(using=using) - query.insert_values(values, compiler.connection, raw_values) - return compiler.execute_sql(return_id) + query.insert_values(values, raw_values) + return query.get_compiler(using=using).execute_sql(return_id) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 2b45c5d92e..d44a07a8e3 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -689,6 +689,18 @@ class SQLCompiler(object): class SQLInsertCompiler(SQLCompiler): + def placeholder(self, field, val): + if field is None: + # A field value of None means the value is raw. + return val + elif hasattr(field, 'get_placeholder'): + # Some fields (e.g. geo fields) need special munging before + # they can be inserted. + return field.get_placeholder(val, self.connection) + else: + # Return the common case for the placeholder + return '%s' + def as_sql(self): # We don't need quote_name_unless_alias() here, since these are all # going to be column names (so we can avoid the extra overhead). @@ -696,18 +708,18 @@ class SQLInsertCompiler(SQLCompiler): opts = self.query.model._meta result = ['INSERT INTO %s' % qn(opts.db_table)] result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns])) - result.append('VALUES (%s)' % ', '.join(self.query.values)) + values = [self.placeholder(*v) for v in self.query.values] + result.append('VALUES (%s)' % ', '.join(values)) params = self.query.params - if self.query.return_id and self.connection.features.can_return_id_from_insert: + if self.return_id and self.connection.features.can_return_id_from_insert: col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) r_fmt, r_params = self.connection.ops.return_insert_id() result.append(r_fmt % col) params = params + r_params return ' '.join(result), params - def execute_sql(self, return_id=False): - self.query.return_id = return_id + self.return_id = return_id cursor = super(SQLInsertCompiler, self).execute_sql(None) if not (return_id and cursor): return diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f9b565181d..e80a023699 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -183,15 +183,17 @@ class InsertQuery(Query): self.columns = [] self.values = [] self.params = () - self.return_id = False def clone(self, klass=None, **kwargs): - extras = {'columns': self.columns[:], 'values': self.values[:], - 'params': self.params, 'return_id': self.return_id} + extras = { + 'columns': self.columns[:], + 'values': self.values[:], + 'params': self.params + } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def insert_values(self, insert_values, connection, raw_values=False): + def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The dictionary gives the model field names and their target values. @@ -203,17 +205,11 @@ class InsertQuery(Query): """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val, connection)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders)