mirror of
https://github.com/django/django.git
synced 2025-07-05 18:29:11 +00:00
queryset-refactor: Made update() work with inherited models.
git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7179 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
fa11a6a128
commit
2f8f588df1
@ -503,7 +503,7 @@ class Query(object):
|
||||
name, order = get_order_dir(name, default_order)
|
||||
pieces = name.split(LOOKUP_SEP)
|
||||
if not alias:
|
||||
alias = self.join((None, opts.db_table, None, None))
|
||||
alias = self.get_initial_alias()
|
||||
field, target, opts, joins = self.setup_joins(pieces, opts, alias,
|
||||
False)
|
||||
alias = joins[-1][-1]
|
||||
@ -581,6 +581,62 @@ class Query(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def change_alias(self, old_alias, new_alias):
|
||||
"""
|
||||
Changes old_alias to new_alias, relabelling any references to it in
|
||||
select columns and the where clause.
|
||||
"""
|
||||
assert new_alias not in self.alias_map
|
||||
|
||||
# 1. Update references in "select" and "where".
|
||||
change_map = {old_alias: new_alias}
|
||||
self.where.relabel_aliases(change_map)
|
||||
for pos, col in enumerate(self.select):
|
||||
if isinstance(col, (list, tuple)):
|
||||
if col[0] == old_alias:
|
||||
self.select[pos] = (new_alias, col[1])
|
||||
else:
|
||||
col.relabel_aliases(change_map)
|
||||
|
||||
# 2. Rename the alias in the internal table/alias datastructures.
|
||||
alias_data = self.alias_map[old_alias]
|
||||
alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias
|
||||
table_aliases = self.table_map[alias_data[ALIAS_TABLE]]
|
||||
for pos, alias in enumerate(table_aliases):
|
||||
if alias == old_alias:
|
||||
table_aliases[pos] = new_alias
|
||||
break
|
||||
self.alias_map[new_alias] = alias_data
|
||||
del self.alias_map[old_alias]
|
||||
for pos, alias in enumerate(self.tables):
|
||||
if alias == old_alias:
|
||||
self.tables[pos] = new_alias
|
||||
break
|
||||
|
||||
# 3. Update any joins that refer to the old alias.
|
||||
for data in self.alias_map.values():
|
||||
if data[ALIAS_JOIN][LHS_ALIAS] == old_alias:
|
||||
data[ALIAS_JOIN][LHS_ALIAS] = new_alias
|
||||
|
||||
def get_initial_alias(self):
|
||||
"""
|
||||
Returns the first alias for this query, after increasing its reference
|
||||
count.
|
||||
"""
|
||||
if self.tables:
|
||||
alias = self.tables[0]
|
||||
self.ref_alias(alias)
|
||||
else:
|
||||
alias = self.join((None, self.model._meta.db_table, None, None))
|
||||
return alias
|
||||
|
||||
def count_active_tables(self):
|
||||
"""
|
||||
Returns the number of tables in this query with a non-zero reference
|
||||
count.
|
||||
"""
|
||||
return len([1 for o in self.alias_map.values() if o[ALIAS_REFCOUNT]])
|
||||
|
||||
def join(self, connection, always_create=False, exclusions=(),
|
||||
promote=False, outer_if_first=False, nullable=False):
|
||||
"""
|
||||
@ -728,7 +784,7 @@ class Query(object):
|
||||
value = value()
|
||||
|
||||
opts = self.get_meta()
|
||||
alias = self.join((None, opts.db_table, None, None))
|
||||
alias = self.get_initial_alias()
|
||||
allow_many = trim or not negate
|
||||
|
||||
result = self.setup_joins(parts, opts, alias, (connector == AND),
|
||||
@ -1021,8 +1077,12 @@ class Query(object):
|
||||
Adds the given column names to the select set, assuming they come from
|
||||
the root model (the one given in self.model).
|
||||
"""
|
||||
table = self.model._meta.db_table
|
||||
self.select.extend([(table, col) for col in columns])
|
||||
for alias in self.tables:
|
||||
if self.alias_map[alias][ALIAS_REFCOUNT]:
|
||||
break
|
||||
else:
|
||||
alias = self.get_initial_alias()
|
||||
self.select.extend([(alias, col) for col in columns])
|
||||
|
||||
def add_ordering(self, *ordering):
|
||||
"""
|
||||
@ -1111,7 +1171,7 @@ class Query(object):
|
||||
against the join table of many-to-many relation in a subquery.
|
||||
"""
|
||||
opts = self.model._meta
|
||||
alias = self.join((None, opts.db_table, None, None))
|
||||
alias = self.get_initial_alias()
|
||||
field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP),
|
||||
opts, alias, False)
|
||||
alias = joins[-1][0]
|
||||
@ -1141,6 +1201,8 @@ class Query(object):
|
||||
"""
|
||||
try:
|
||||
sql, params = self.as_sql()
|
||||
if not sql:
|
||||
raise EmptyResultSet
|
||||
except EmptyResultSet:
|
||||
if result_type == MULTI:
|
||||
raise StopIteration
|
||||
@ -1173,6 +1235,9 @@ def get_order_dir(field, default='ASC'):
|
||||
return field, dirn[0]
|
||||
|
||||
def results_iter(cursor):
|
||||
"""
|
||||
An iterator over the result set that returns a chunk of rows at a time.
|
||||
"""
|
||||
while 1:
|
||||
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
|
||||
if not rows:
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
from django.contrib.contenttypes import generic
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import *
|
||||
@ -94,31 +96,32 @@ class UpdateQuery(Query):
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Run on initialisation and after cloning.
|
||||
Runs on initialisation and after cloning. Any attributes that would
|
||||
normally be set in __init__ should go in here, instead, so that they
|
||||
are also set up after a clone() call.
|
||||
"""
|
||||
self.values = []
|
||||
self.related_updates = {}
|
||||
self.related_ids = None
|
||||
|
||||
def clone(self, klass=None, **kwargs):
|
||||
return super(UpdateQuery, self).clone(klass,
|
||||
related_updates=self.related_updates.copy, **kwargs)
|
||||
|
||||
def execute_sql(self, result_type=None):
|
||||
super(UpdateQuery, self).execute_sql(result_type)
|
||||
for query in self.get_related_updates():
|
||||
query.execute_sql(result_type)
|
||||
|
||||
def as_sql(self):
|
||||
"""
|
||||
Creates the SQL for this query. Returns the SQL string and list of
|
||||
parameters.
|
||||
"""
|
||||
self.select_related = False
|
||||
self.pre_sql_setup()
|
||||
|
||||
if len(self.tables) != 1:
|
||||
# We can only update one table at a time, so we need to check that
|
||||
# only one alias has a nonzero refcount.
|
||||
table = None
|
||||
for alias_list in self.table_map.values():
|
||||
for alias in alias_list:
|
||||
if self.alias_map[alias][ALIAS_REFCOUNT]:
|
||||
if table:
|
||||
raise FieldError('Updates can only access a single database table at a time.')
|
||||
table = alias
|
||||
else:
|
||||
table = self.tables[0]
|
||||
|
||||
if not self.values:
|
||||
return '', ()
|
||||
table = self.tables[0]
|
||||
qn = self.quote_name_unless_alias
|
||||
result = ['UPDATE %s' % qn(table)]
|
||||
result.append('SET')
|
||||
@ -135,6 +138,55 @@ class UpdateQuery(Query):
|
||||
result.append('WHERE %s' % where)
|
||||
return ' '.join(result), tuple(update_params + params)
|
||||
|
||||
def pre_sql_setup(self):
|
||||
"""
|
||||
If the update depends on results from other tables, we need to do some
|
||||
munging of the "where" conditions to match the format required for
|
||||
(portable) SQL updates. That is done here.
|
||||
|
||||
Further, if we are going to be running multiple updates, we pull out
|
||||
the id values to update at this point so that they don't change as a
|
||||
result of the progressive updates.
|
||||
"""
|
||||
self.select_related = False
|
||||
self.clear_ordering(True)
|
||||
super(UpdateQuery, self).pre_sql_setup()
|
||||
count = self.count_active_tables()
|
||||
if not self.related_updates and count == 1:
|
||||
return
|
||||
|
||||
# We need to use a sub-select in the where clause to filter on things
|
||||
# from other tables.
|
||||
query = self.clone(klass=Query)
|
||||
main_alias = query.tables[0]
|
||||
if count != 1:
|
||||
query.unref_alias(main_alias)
|
||||
if query.alias_map[main_alias][ALIAS_REFCOUNT]:
|
||||
alias = '%s0' % self.alias_prefix
|
||||
query.change_alias(main_alias, alias)
|
||||
col = query.model._meta.pk.column
|
||||
else:
|
||||
for model in query.model._meta.get_parent_list():
|
||||
for alias in query.table_map.get(model._meta.db_table, []):
|
||||
if query.alias_map[alias][ALIAS_REFCOUNT]:
|
||||
col = model._meta.pk.column
|
||||
break
|
||||
query.add_local_columns([col])
|
||||
|
||||
# Now we adjust the current query: reset the where clause and get rid
|
||||
# of all the tables we don't need (since they're in the sub-select).
|
||||
self.where = self.where_class()
|
||||
if self.related_updates:
|
||||
idents = []
|
||||
for rows in query.execute_sql(MULTI):
|
||||
idents.extend([r[0] for r in rows])
|
||||
self.add_filter(('pk__in', idents))
|
||||
self.related_ids = idents
|
||||
else:
|
||||
self.add_filter(('pk__in', query))
|
||||
for alias in self.tables[1:]:
|
||||
self.alias_map[alias][ALIAS_REFCOUNT] = 0
|
||||
|
||||
def clear_related(self, related_field, pk_list):
|
||||
"""
|
||||
Set up and execute an update query that clears related entries for the
|
||||
@ -156,11 +208,42 @@ class UpdateQuery(Query):
|
||||
for name, val in values.items():
|
||||
field, model, direct, m2m = self.model._meta.get_field_by_name(name)
|
||||
if not direct or m2m:
|
||||
# Can only update non-relation fields and foreign keys.
|
||||
raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
|
||||
# FIXME: Some sort of db_prep_* is probably more appropriate here.
|
||||
if field.rel and isinstance(val, Model):
|
||||
val = val.pk
|
||||
self.values.append((field.column, val))
|
||||
if model:
|
||||
self.add_related_update(model, field.column, val)
|
||||
else:
|
||||
self.values.append((field.column, val))
|
||||
|
||||
def add_related_update(self, model, column, value):
|
||||
"""
|
||||
Adds (name, value) to an update query for an ancestor model.
|
||||
|
||||
Updates are coalesced so that we only run one update query per ancestor.
|
||||
"""
|
||||
try:
|
||||
self.related_updates[model].append((column, value))
|
||||
except KeyError:
|
||||
self.related_updates[model] = [(column, value)]
|
||||
|
||||
def get_related_updates(self):
|
||||
"""
|
||||
Returns a list of query objects: one for each update required to an
|
||||
ancestor model. Each query will have the same filtering conditions as
|
||||
the current query but will only update a single table.
|
||||
"""
|
||||
if not self.related_updates:
|
||||
return []
|
||||
result = []
|
||||
for model, values in self.related_updates.items():
|
||||
query = UpdateQuery(model, self.connection)
|
||||
query.values = values
|
||||
if self.related_ids:
|
||||
query.add_filter(('pk__in', self.related_ids))
|
||||
result.append(query)
|
||||
return result
|
||||
|
||||
class InsertQuery(Query):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -216,5 +216,13 @@ DoesNotExist: Restaurant matching query does not exist.
|
||||
>>> Restaurant.objects.get(lot__name='Well Lit')
|
||||
<Restaurant: Ristorante Miron the restaurant>
|
||||
|
||||
# The update() command can update fields in parent and child classes at once
|
||||
# (although it executed multiple SQL queries to do so).
|
||||
>>> Restaurant.objects.filter(serves_hot_dogs=True, name__contains='D').update(name='Demon Puppies', serves_hot_dogs=False)
|
||||
>>> r1 = Restaurant.objects.get(pk=r.pk)
|
||||
>>> r1.serves_hot_dogs == False
|
||||
True
|
||||
>>> r1.name
|
||||
u'Demon Puppies'
|
||||
|
||||
"""}
|
||||
|
Loading…
x
Reference in New Issue
Block a user