mirror of
https://github.com/django/django.git
synced 2025-07-06 02:39:12 +00:00
queryset-refactor: Made update() work with inherited models.
git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7179 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
fa11a6a128
commit
2f8f588df1
@ -503,7 +503,7 @@ class Query(object):
|
|||||||
name, order = get_order_dir(name, default_order)
|
name, order = get_order_dir(name, default_order)
|
||||||
pieces = name.split(LOOKUP_SEP)
|
pieces = name.split(LOOKUP_SEP)
|
||||||
if not alias:
|
if not alias:
|
||||||
alias = self.join((None, opts.db_table, None, None))
|
alias = self.get_initial_alias()
|
||||||
field, target, opts, joins = self.setup_joins(pieces, opts, alias,
|
field, target, opts, joins = self.setup_joins(pieces, opts, alias,
|
||||||
False)
|
False)
|
||||||
alias = joins[-1][-1]
|
alias = joins[-1][-1]
|
||||||
@ -581,6 +581,62 @@ class Query(object):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def change_alias(self, old_alias, new_alias):
|
||||||
|
"""
|
||||||
|
Changes old_alias to new_alias, relabelling any references to it in
|
||||||
|
select columns and the where clause.
|
||||||
|
"""
|
||||||
|
assert new_alias not in self.alias_map
|
||||||
|
|
||||||
|
# 1. Update references in "select" and "where".
|
||||||
|
change_map = {old_alias: new_alias}
|
||||||
|
self.where.relabel_aliases(change_map)
|
||||||
|
for pos, col in enumerate(self.select):
|
||||||
|
if isinstance(col, (list, tuple)):
|
||||||
|
if col[0] == old_alias:
|
||||||
|
self.select[pos] = (new_alias, col[1])
|
||||||
|
else:
|
||||||
|
col.relabel_aliases(change_map)
|
||||||
|
|
||||||
|
# 2. Rename the alias in the internal table/alias datastructures.
|
||||||
|
alias_data = self.alias_map[old_alias]
|
||||||
|
alias_data[ALIAS_JOIN][RHS_ALIAS] = new_alias
|
||||||
|
table_aliases = self.table_map[alias_data[ALIAS_TABLE]]
|
||||||
|
for pos, alias in enumerate(table_aliases):
|
||||||
|
if alias == old_alias:
|
||||||
|
table_aliases[pos] = new_alias
|
||||||
|
break
|
||||||
|
self.alias_map[new_alias] = alias_data
|
||||||
|
del self.alias_map[old_alias]
|
||||||
|
for pos, alias in enumerate(self.tables):
|
||||||
|
if alias == old_alias:
|
||||||
|
self.tables[pos] = new_alias
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3. Update any joins that refer to the old alias.
|
||||||
|
for data in self.alias_map.values():
|
||||||
|
if data[ALIAS_JOIN][LHS_ALIAS] == old_alias:
|
||||||
|
data[ALIAS_JOIN][LHS_ALIAS] = new_alias
|
||||||
|
|
||||||
|
def get_initial_alias(self):
|
||||||
|
"""
|
||||||
|
Returns the first alias for this query, after increasing its reference
|
||||||
|
count.
|
||||||
|
"""
|
||||||
|
if self.tables:
|
||||||
|
alias = self.tables[0]
|
||||||
|
self.ref_alias(alias)
|
||||||
|
else:
|
||||||
|
alias = self.join((None, self.model._meta.db_table, None, None))
|
||||||
|
return alias
|
||||||
|
|
||||||
|
def count_active_tables(self):
|
||||||
|
"""
|
||||||
|
Returns the number of tables in this query with a non-zero reference
|
||||||
|
count.
|
||||||
|
"""
|
||||||
|
return len([1 for o in self.alias_map.values() if o[ALIAS_REFCOUNT]])
|
||||||
|
|
||||||
def join(self, connection, always_create=False, exclusions=(),
|
def join(self, connection, always_create=False, exclusions=(),
|
||||||
promote=False, outer_if_first=False, nullable=False):
|
promote=False, outer_if_first=False, nullable=False):
|
||||||
"""
|
"""
|
||||||
@ -728,7 +784,7 @@ class Query(object):
|
|||||||
value = value()
|
value = value()
|
||||||
|
|
||||||
opts = self.get_meta()
|
opts = self.get_meta()
|
||||||
alias = self.join((None, opts.db_table, None, None))
|
alias = self.get_initial_alias()
|
||||||
allow_many = trim or not negate
|
allow_many = trim or not negate
|
||||||
|
|
||||||
result = self.setup_joins(parts, opts, alias, (connector == AND),
|
result = self.setup_joins(parts, opts, alias, (connector == AND),
|
||||||
@ -1021,8 +1077,12 @@ class Query(object):
|
|||||||
Adds the given column names to the select set, assuming they come from
|
Adds the given column names to the select set, assuming they come from
|
||||||
the root model (the one given in self.model).
|
the root model (the one given in self.model).
|
||||||
"""
|
"""
|
||||||
table = self.model._meta.db_table
|
for alias in self.tables:
|
||||||
self.select.extend([(table, col) for col in columns])
|
if self.alias_map[alias][ALIAS_REFCOUNT]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
alias = self.get_initial_alias()
|
||||||
|
self.select.extend([(alias, col) for col in columns])
|
||||||
|
|
||||||
def add_ordering(self, *ordering):
|
def add_ordering(self, *ordering):
|
||||||
"""
|
"""
|
||||||
@ -1111,7 +1171,7 @@ class Query(object):
|
|||||||
against the join table of many-to-many relation in a subquery.
|
against the join table of many-to-many relation in a subquery.
|
||||||
"""
|
"""
|
||||||
opts = self.model._meta
|
opts = self.model._meta
|
||||||
alias = self.join((None, opts.db_table, None, None))
|
alias = self.get_initial_alias()
|
||||||
field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP),
|
field, col, opts, joins = self.setup_joins(start.split(LOOKUP_SEP),
|
||||||
opts, alias, False)
|
opts, alias, False)
|
||||||
alias = joins[-1][0]
|
alias = joins[-1][0]
|
||||||
@ -1141,6 +1201,8 @@ class Query(object):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
sql, params = self.as_sql()
|
sql, params = self.as_sql()
|
||||||
|
if not sql:
|
||||||
|
raise EmptyResultSet
|
||||||
except EmptyResultSet:
|
except EmptyResultSet:
|
||||||
if result_type == MULTI:
|
if result_type == MULTI:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
@ -1173,6 +1235,9 @@ def get_order_dir(field, default='ASC'):
|
|||||||
return field, dirn[0]
|
return field, dirn[0]
|
||||||
|
|
||||||
def results_iter(cursor):
|
def results_iter(cursor):
|
||||||
|
"""
|
||||||
|
An iterator over the result set that returns a chunk of rows at a time.
|
||||||
|
"""
|
||||||
while 1:
|
while 1:
|
||||||
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
|
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
|
||||||
if not rows:
|
if not rows:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||||
"""
|
"""
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from django.contrib.contenttypes import generic
|
from django.contrib.contenttypes import generic
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models.sql.constants import *
|
from django.db.models.sql.constants import *
|
||||||
@ -94,31 +96,32 @@ class UpdateQuery(Query):
|
|||||||
|
|
||||||
def _setup_query(self):
|
def _setup_query(self):
|
||||||
"""
|
"""
|
||||||
Run on initialisation and after cloning.
|
Runs on initialisation and after cloning. Any attributes that would
|
||||||
|
normally be set in __init__ should go in here, instead, so that they
|
||||||
|
are also set up after a clone() call.
|
||||||
"""
|
"""
|
||||||
self.values = []
|
self.values = []
|
||||||
|
self.related_updates = {}
|
||||||
|
self.related_ids = None
|
||||||
|
|
||||||
|
def clone(self, klass=None, **kwargs):
|
||||||
|
return super(UpdateQuery, self).clone(klass,
|
||||||
|
related_updates=self.related_updates.copy, **kwargs)
|
||||||
|
|
||||||
|
def execute_sql(self, result_type=None):
|
||||||
|
super(UpdateQuery, self).execute_sql(result_type)
|
||||||
|
for query in self.get_related_updates():
|
||||||
|
query.execute_sql(result_type)
|
||||||
|
|
||||||
def as_sql(self):
|
def as_sql(self):
|
||||||
"""
|
"""
|
||||||
Creates the SQL for this query. Returns the SQL string and list of
|
Creates the SQL for this query. Returns the SQL string and list of
|
||||||
parameters.
|
parameters.
|
||||||
"""
|
"""
|
||||||
self.select_related = False
|
|
||||||
self.pre_sql_setup()
|
self.pre_sql_setup()
|
||||||
|
if not self.values:
|
||||||
if len(self.tables) != 1:
|
return '', ()
|
||||||
# We can only update one table at a time, so we need to check that
|
|
||||||
# only one alias has a nonzero refcount.
|
|
||||||
table = None
|
|
||||||
for alias_list in self.table_map.values():
|
|
||||||
for alias in alias_list:
|
|
||||||
if self.alias_map[alias][ALIAS_REFCOUNT]:
|
|
||||||
if table:
|
|
||||||
raise FieldError('Updates can only access a single database table at a time.')
|
|
||||||
table = alias
|
|
||||||
else:
|
|
||||||
table = self.tables[0]
|
table = self.tables[0]
|
||||||
|
|
||||||
qn = self.quote_name_unless_alias
|
qn = self.quote_name_unless_alias
|
||||||
result = ['UPDATE %s' % qn(table)]
|
result = ['UPDATE %s' % qn(table)]
|
||||||
result.append('SET')
|
result.append('SET')
|
||||||
@ -135,6 +138,55 @@ class UpdateQuery(Query):
|
|||||||
result.append('WHERE %s' % where)
|
result.append('WHERE %s' % where)
|
||||||
return ' '.join(result), tuple(update_params + params)
|
return ' '.join(result), tuple(update_params + params)
|
||||||
|
|
||||||
|
def pre_sql_setup(self):
|
||||||
|
"""
|
||||||
|
If the update depends on results from other tables, we need to do some
|
||||||
|
munging of the "where" conditions to match the format required for
|
||||||
|
(portable) SQL updates. That is done here.
|
||||||
|
|
||||||
|
Further, if we are going to be running multiple updates, we pull out
|
||||||
|
the id values to update at this point so that they don't change as a
|
||||||
|
result of the progressive updates.
|
||||||
|
"""
|
||||||
|
self.select_related = False
|
||||||
|
self.clear_ordering(True)
|
||||||
|
super(UpdateQuery, self).pre_sql_setup()
|
||||||
|
count = self.count_active_tables()
|
||||||
|
if not self.related_updates and count == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# We need to use a sub-select in the where clause to filter on things
|
||||||
|
# from other tables.
|
||||||
|
query = self.clone(klass=Query)
|
||||||
|
main_alias = query.tables[0]
|
||||||
|
if count != 1:
|
||||||
|
query.unref_alias(main_alias)
|
||||||
|
if query.alias_map[main_alias][ALIAS_REFCOUNT]:
|
||||||
|
alias = '%s0' % self.alias_prefix
|
||||||
|
query.change_alias(main_alias, alias)
|
||||||
|
col = query.model._meta.pk.column
|
||||||
|
else:
|
||||||
|
for model in query.model._meta.get_parent_list():
|
||||||
|
for alias in query.table_map.get(model._meta.db_table, []):
|
||||||
|
if query.alias_map[alias][ALIAS_REFCOUNT]:
|
||||||
|
col = model._meta.pk.column
|
||||||
|
break
|
||||||
|
query.add_local_columns([col])
|
||||||
|
|
||||||
|
# Now we adjust the current query: reset the where clause and get rid
|
||||||
|
# of all the tables we don't need (since they're in the sub-select).
|
||||||
|
self.where = self.where_class()
|
||||||
|
if self.related_updates:
|
||||||
|
idents = []
|
||||||
|
for rows in query.execute_sql(MULTI):
|
||||||
|
idents.extend([r[0] for r in rows])
|
||||||
|
self.add_filter(('pk__in', idents))
|
||||||
|
self.related_ids = idents
|
||||||
|
else:
|
||||||
|
self.add_filter(('pk__in', query))
|
||||||
|
for alias in self.tables[1:]:
|
||||||
|
self.alias_map[alias][ALIAS_REFCOUNT] = 0
|
||||||
|
|
||||||
def clear_related(self, related_field, pk_list):
|
def clear_related(self, related_field, pk_list):
|
||||||
"""
|
"""
|
||||||
Set up and execute an update query that clears related entries for the
|
Set up and execute an update query that clears related entries for the
|
||||||
@ -156,12 +208,43 @@ class UpdateQuery(Query):
|
|||||||
for name, val in values.items():
|
for name, val in values.items():
|
||||||
field, model, direct, m2m = self.model._meta.get_field_by_name(name)
|
field, model, direct, m2m = self.model._meta.get_field_by_name(name)
|
||||||
if not direct or m2m:
|
if not direct or m2m:
|
||||||
# Can only update non-relation fields and foreign keys.
|
|
||||||
raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
|
raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
|
||||||
|
# FIXME: Some sort of db_prep_* is probably more appropriate here.
|
||||||
if field.rel and isinstance(val, Model):
|
if field.rel and isinstance(val, Model):
|
||||||
val = val.pk
|
val = val.pk
|
||||||
|
if model:
|
||||||
|
self.add_related_update(model, field.column, val)
|
||||||
|
else:
|
||||||
self.values.append((field.column, val))
|
self.values.append((field.column, val))
|
||||||
|
|
||||||
|
def add_related_update(self, model, column, value):
|
||||||
|
"""
|
||||||
|
Adds (name, value) to an update query for an ancestor model.
|
||||||
|
|
||||||
|
Updates are coalesced so that we only run one update query per ancestor.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.related_updates[model].append((column, value))
|
||||||
|
except KeyError:
|
||||||
|
self.related_updates[model] = [(column, value)]
|
||||||
|
|
||||||
|
def get_related_updates(self):
|
||||||
|
"""
|
||||||
|
Returns a list of query objects: one for each update required to an
|
||||||
|
ancestor model. Each query will have the same filtering conditions as
|
||||||
|
the current query but will only update a single table.
|
||||||
|
"""
|
||||||
|
if not self.related_updates:
|
||||||
|
return []
|
||||||
|
result = []
|
||||||
|
for model, values in self.related_updates.items():
|
||||||
|
query = UpdateQuery(model, self.connection)
|
||||||
|
query.values = values
|
||||||
|
if self.related_ids:
|
||||||
|
query.add_filter(('pk__in', self.related_ids))
|
||||||
|
result.append(query)
|
||||||
|
return result
|
||||||
|
|
||||||
class InsertQuery(Query):
|
class InsertQuery(Query):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(InsertQuery, self).__init__(*args, **kwargs)
|
super(InsertQuery, self).__init__(*args, **kwargs)
|
||||||
|
@ -216,5 +216,13 @@ DoesNotExist: Restaurant matching query does not exist.
|
|||||||
>>> Restaurant.objects.get(lot__name='Well Lit')
|
>>> Restaurant.objects.get(lot__name='Well Lit')
|
||||||
<Restaurant: Ristorante Miron the restaurant>
|
<Restaurant: Ristorante Miron the restaurant>
|
||||||
|
|
||||||
|
# The update() command can update fields in parent and child classes at once
|
||||||
|
# (although it executed multiple SQL queries to do so).
|
||||||
|
>>> Restaurant.objects.filter(serves_hot_dogs=True, name__contains='D').update(name='Demon Puppies', serves_hot_dogs=False)
|
||||||
|
>>> r1 = Restaurant.objects.get(pk=r.pk)
|
||||||
|
>>> r1.serves_hot_dogs == False
|
||||||
|
True
|
||||||
|
>>> r1.name
|
||||||
|
u'Demon Puppies'
|
||||||
|
|
||||||
"""}
|
"""}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user