mirror of
https://github.com/django/django.git
synced 2025-07-04 17:59:13 +00:00
[soc2009/multidb] Reorganized sql.InsertQuery to defer the need for a connection till later in the process. Patch from Russell Keith-Magee.
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11871 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
5afff61b02
commit
11c00d632d
@ -2,7 +2,7 @@ from django.db import connections
|
|||||||
from django.db.models.sql.subqueries import InsertQuery
|
from django.db.models.sql.subqueries import InsertQuery
|
||||||
|
|
||||||
class GeoInsertQuery(InsertQuery):
|
class GeoInsertQuery(InsertQuery):
|
||||||
def insert_values(self, insert_values, connection, raw_values=False):
|
def insert_values(self, insert_values, raw_values=False):
|
||||||
"""
|
"""
|
||||||
Set up the insert query from the 'insert_values' dictionary. The
|
Set up the insert query from the 'insert_values' dictionary. The
|
||||||
dictionary gives the model field names and their target values.
|
dictionary gives the model field names and their target values.
|
||||||
@ -14,19 +14,13 @@ class GeoInsertQuery(InsertQuery):
|
|||||||
"""
|
"""
|
||||||
placeholders, values = [], []
|
placeholders, values = [], []
|
||||||
for field, val in insert_values:
|
for field, val in insert_values:
|
||||||
if hasattr(field, 'get_placeholder'):
|
placeholders.append((field, val))
|
||||||
# Some fields (e.g. geo fields) need special munging before
|
|
||||||
# they can be inserted.
|
|
||||||
placeholders.append(field.get_placeholder(val, connection))
|
|
||||||
else:
|
|
||||||
placeholders.append('%s')
|
|
||||||
|
|
||||||
self.columns.append(field.column)
|
self.columns.append(field.column)
|
||||||
|
|
||||||
if not placeholders[-1] == 'NULL':
|
if not placeholders[-1] == 'NULL':
|
||||||
values.append(val)
|
values.append(val)
|
||||||
if raw_values:
|
if raw_values:
|
||||||
self.values.extend(values)
|
self.values.extend([(None, v) for v in values])
|
||||||
else:
|
else:
|
||||||
self.params += tuple(values)
|
self.params += tuple(values)
|
||||||
self.values.extend(placeholders)
|
self.values.extend(placeholders)
|
||||||
@ -38,6 +32,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
|
|||||||
part of the public API.
|
part of the public API.
|
||||||
"""
|
"""
|
||||||
query = GeoInsertQuery(model)
|
query = GeoInsertQuery(model)
|
||||||
compiler = query.get_compiler(using=using)
|
query.insert_values(values, raw_values)
|
||||||
query.insert_values(values, compiler.connection, raw_values)
|
return query.get_compiler(using=using).execute_sql(return_id)
|
||||||
return compiler.execute_sql(return_id)
|
|
||||||
|
@ -1149,6 +1149,5 @@ def insert_query(model, values, return_id=False, raw_values=False, using=None):
|
|||||||
part of the public API.
|
part of the public API.
|
||||||
"""
|
"""
|
||||||
query = sql.InsertQuery(model)
|
query = sql.InsertQuery(model)
|
||||||
compiler = query.get_compiler(using=using)
|
query.insert_values(values, raw_values)
|
||||||
query.insert_values(values, compiler.connection, raw_values)
|
return query.get_compiler(using=using).execute_sql(return_id)
|
||||||
return compiler.execute_sql(return_id)
|
|
||||||
|
@ -689,6 +689,18 @@ class SQLCompiler(object):
|
|||||||
|
|
||||||
|
|
||||||
class SQLInsertCompiler(SQLCompiler):
|
class SQLInsertCompiler(SQLCompiler):
|
||||||
|
def placeholder(self, field, val):
|
||||||
|
if field is None:
|
||||||
|
# A field value of None means the value is raw.
|
||||||
|
return val
|
||||||
|
elif hasattr(field, 'get_placeholder'):
|
||||||
|
# Some fields (e.g. geo fields) need special munging before
|
||||||
|
# they can be inserted.
|
||||||
|
return field.get_placeholder(val, self.connection)
|
||||||
|
else:
|
||||||
|
# Return the common case for the placeholder
|
||||||
|
return '%s'
|
||||||
|
|
||||||
def as_sql(self):
|
def as_sql(self):
|
||||||
# We don't need quote_name_unless_alias() here, since these are all
|
# We don't need quote_name_unless_alias() here, since these are all
|
||||||
# going to be column names (so we can avoid the extra overhead).
|
# going to be column names (so we can avoid the extra overhead).
|
||||||
@ -696,18 +708,18 @@ class SQLInsertCompiler(SQLCompiler):
|
|||||||
opts = self.query.model._meta
|
opts = self.query.model._meta
|
||||||
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
||||||
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
|
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
|
||||||
result.append('VALUES (%s)' % ', '.join(self.query.values))
|
values = [self.placeholder(*v) for v in self.query.values]
|
||||||
|
result.append('VALUES (%s)' % ', '.join(values))
|
||||||
params = self.query.params
|
params = self.query.params
|
||||||
if self.query.return_id and self.connection.features.can_return_id_from_insert:
|
if self.return_id and self.connection.features.can_return_id_from_insert:
|
||||||
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
||||||
r_fmt, r_params = self.connection.ops.return_insert_id()
|
r_fmt, r_params = self.connection.ops.return_insert_id()
|
||||||
result.append(r_fmt % col)
|
result.append(r_fmt % col)
|
||||||
params = params + r_params
|
params = params + r_params
|
||||||
return ' '.join(result), params
|
return ' '.join(result), params
|
||||||
|
|
||||||
|
|
||||||
def execute_sql(self, return_id=False):
|
def execute_sql(self, return_id=False):
|
||||||
self.query.return_id = return_id
|
self.return_id = return_id
|
||||||
cursor = super(SQLInsertCompiler, self).execute_sql(None)
|
cursor = super(SQLInsertCompiler, self).execute_sql(None)
|
||||||
if not (return_id and cursor):
|
if not (return_id and cursor):
|
||||||
return
|
return
|
||||||
|
@ -183,15 +183,17 @@ class InsertQuery(Query):
|
|||||||
self.columns = []
|
self.columns = []
|
||||||
self.values = []
|
self.values = []
|
||||||
self.params = ()
|
self.params = ()
|
||||||
self.return_id = False
|
|
||||||
|
|
||||||
def clone(self, klass=None, **kwargs):
|
def clone(self, klass=None, **kwargs):
|
||||||
extras = {'columns': self.columns[:], 'values': self.values[:],
|
extras = {
|
||||||
'params': self.params, 'return_id': self.return_id}
|
'columns': self.columns[:],
|
||||||
|
'values': self.values[:],
|
||||||
|
'params': self.params
|
||||||
|
}
|
||||||
extras.update(kwargs)
|
extras.update(kwargs)
|
||||||
return super(InsertQuery, self).clone(klass, **extras)
|
return super(InsertQuery, self).clone(klass, **extras)
|
||||||
|
|
||||||
def insert_values(self, insert_values, connection, raw_values=False):
|
def insert_values(self, insert_values, raw_values=False):
|
||||||
"""
|
"""
|
||||||
Set up the insert query from the 'insert_values' dictionary. The
|
Set up the insert query from the 'insert_values' dictionary. The
|
||||||
dictionary gives the model field names and their target values.
|
dictionary gives the model field names and their target values.
|
||||||
@ -203,17 +205,11 @@ class InsertQuery(Query):
|
|||||||
"""
|
"""
|
||||||
placeholders, values = [], []
|
placeholders, values = [], []
|
||||||
for field, val in insert_values:
|
for field, val in insert_values:
|
||||||
if hasattr(field, 'get_placeholder'):
|
placeholders.append((field, val))
|
||||||
# Some fields (e.g. geo fields) need special munging before
|
|
||||||
# they can be inserted.
|
|
||||||
placeholders.append(field.get_placeholder(val, connection))
|
|
||||||
else:
|
|
||||||
placeholders.append('%s')
|
|
||||||
|
|
||||||
self.columns.append(field.column)
|
self.columns.append(field.column)
|
||||||
values.append(val)
|
values.append(val)
|
||||||
if raw_values:
|
if raw_values:
|
||||||
self.values.extend(values)
|
self.values.extend([(None, v) for v in values])
|
||||||
else:
|
else:
|
||||||
self.params += tuple(values)
|
self.params += tuple(values)
|
||||||
self.values.extend(placeholders)
|
self.values.extend(placeholders)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user