1
0
mirror of https://github.com/django/django.git synced 2025-07-04 17:59:13 +00:00

[soc2009/multidb] Reorganized sql.InsertQuery to defer the need for a connection till later in the process. Patch from Russell Keith-Magee.

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11871 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-12-15 17:58:17 +00:00
parent 5afff61b02
commit 11c00d632d
4 changed files with 32 additions and 32 deletions

View File

@ -2,7 +2,7 @@ from django.db import connections
from django.db.models.sql.subqueries import InsertQuery
class GeoInsertQuery(InsertQuery):
def insert_values(self, insert_values, connection, raw_values=False):
def insert_values(self, insert_values, raw_values=False):
"""
Set up the insert query from the 'insert_values' dictionary. The
dictionary gives the model field names and their target values.
@ -14,19 +14,13 @@ class GeoInsertQuery(InsertQuery):
"""
placeholders, values = [], []
for field, val in insert_values:
if hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
placeholders.append(field.get_placeholder(val, connection))
else:
placeholders.append('%s')
placeholders.append((field, val))
self.columns.append(field.column)
if not placeholders[-1] == 'NULL':
values.append(val)
if raw_values:
self.values.extend(values)
self.values.extend([(None, v) for v in values])
else:
self.params += tuple(values)
self.values.extend(placeholders)
@ -38,6 +32,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
part of the public API.
"""
query = GeoInsertQuery(model)
compiler = query.get_compiler(using=using)
query.insert_values(values, compiler.connection, raw_values)
return compiler.execute_sql(return_id)
query.insert_values(values, raw_values)
return query.get_compiler(using=using).execute_sql(return_id)

View File

@ -1149,6 +1149,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
part of the public API.
"""
query = sql.InsertQuery(model)
compiler = query.get_compiler(using=using)
query.insert_values(values, compiler.connection, raw_values)
return compiler.execute_sql(return_id)
query.insert_values(values, raw_values)
return query.get_compiler(using=using).execute_sql(return_id)

View File

@ -689,6 +689,18 @@ class SQLCompiler(object):
class SQLInsertCompiler(SQLCompiler):
def placeholder(self, field, val):
if field is None:
# A field value of None means the value is raw.
return val
elif hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
return field.get_placeholder(val, self.connection)
else:
# Return the common case for the placeholder
return '%s'
def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
# going to be column names (so we can avoid the extra overhead).
@ -696,18 +708,18 @@ class SQLInsertCompiler(SQLCompiler):
opts = self.query.model._meta
result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
result.append('VALUES (%s)' % ', '.join(self.query.values))
values = [self.placeholder(*v) for v in self.query.values]
result.append('VALUES (%s)' % ', '.join(values))
params = self.query.params
if self.query.return_id and self.connection.features.can_return_id_from_insert:
if self.return_id and self.connection.features.can_return_id_from_insert:
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
def execute_sql(self, return_id=False):
self.query.return_id = return_id
self.return_id = return_id
cursor = super(SQLInsertCompiler, self).execute_sql(None)
if not (return_id and cursor):
return

View File

@ -183,15 +183,17 @@ class InsertQuery(Query):
self.columns = []
self.values = []
self.params = ()
self.return_id = False
def clone(self, klass=None, **kwargs):
extras = {'columns': self.columns[:], 'values': self.values[:],
'params': self.params, 'return_id': self.return_id}
extras = {
'columns': self.columns[:],
'values': self.values[:],
'params': self.params
}
extras.update(kwargs)
return super(InsertQuery, self).clone(klass, **extras)
def insert_values(self, insert_values, connection, raw_values=False):
def insert_values(self, insert_values, raw_values=False):
"""
Set up the insert query from the 'insert_values' dictionary. The
dictionary gives the model field names and their target values.
@ -203,17 +205,11 @@ class InsertQuery(Query):
"""
placeholders, values = [], []
for field, val in insert_values:
if hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
placeholders.append(field.get_placeholder(val, connection))
else:
placeholders.append('%s')
placeholders.append((field, val))
self.columns.append(field.column)
values.append(val)
if raw_values:
self.values.extend(values)
self.values.extend([(None, v) for v in values])
else:
self.params += tuple(values)
self.values.extend(placeholders)