1
0
mirror of https://github.com/django/django.git synced 2025-10-24 14:16:09 +00:00

Refs #28900 -- Made SELECT respect the order specified by values(*selected).

Previously the order was always extra_fields + model_fields + annotations with
respective local ordering inferred from the insertion order of *selected.

This commits introduces a new `Query.selected` propery that keeps tracks of the
global select order as specified by on values assignment. This is crucial
feature to allow the combination of queries mixing annotations and table
references.

It also allows the removal of the re-ordering shenanigans perform by
ValuesListIterable in order to re-map the tuples returned from the database
backend to the order specified by values_list() as they'll be in the right
order at query compilation time.

Refs #28553 as the initially reported issue that was only partially fixed
for annotations by d6b6e5d0fd.

Thanks Mariusz Felisiak and Sarah Boyce for review.
This commit is contained in:
Simon Charette
2023-03-28 00:13:00 -04:00
committed by Sarah Boyce
parent 2e47dde438
commit 65ad4ade74
10 changed files with 125 additions and 62 deletions

View File

@@ -200,6 +200,9 @@ class ValuesIterable(BaseIterable):
query = queryset.query
compiler = query.get_compiler(queryset.db)
if query.selected:
names = list(query.selected)
else:
# extra(select=...) cols are always at the start of the row.
names = [
*query.extra_select,
@@ -223,28 +226,6 @@ class ValuesListIterable(BaseIterable):
queryset = self.queryset
query = queryset.query
compiler = query.get_compiler(queryset.db)
if queryset._fields:
# extra(select=...) cols are always at the start of the row.
names = [
*query.extra_select,
*query.values_select,
*query.annotation_select,
]
fields = [
*queryset._fields,
*(f for f in query.annotation_select if f not in queryset._fields),
]
if fields != names:
# Reorder according to fields.
index_map = {name: idx for idx, name in enumerate(names)}
rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
return map(
rowfactory,
compiler.results_iter(
chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
),
)
return compiler.results_iter(
tuple_expected=True,
chunked_fetch=self.chunked_fetch,

View File

@@ -247,11 +247,6 @@ class SQLCompiler:
select = []
klass_info = None
annotations = {}
select_idx = 0
for alias, (sql, params) in self.query.extra_select.items():
annotations[alias] = select_idx
select.append((RawSQL(sql, params), alias))
select_idx += 1
assert not (self.query.select and self.query.default_cols)
select_mask = self.query.get_select_mask()
if self.query.default_cols:
@@ -261,19 +256,39 @@ class SQLCompiler:
# any model.
cols = self.query.select
if cols:
select_list = []
for col in cols:
select_list.append(select_idx)
select.append((col, None))
select_idx += 1
klass_info = {
"model": self.query.model,
"select_fields": select_list,
"select_fields": list(
range(
len(self.query.extra_select),
len(self.query.extra_select) + len(cols),
)
),
}
for alias, annotation in self.query.annotation_select.items():
selected = []
if self.query.selected is None:
selected = [
*(
(alias, RawSQL(*args))
for alias, args in self.query.extra_select.items()
),
*((None, col) for col in cols),
*self.query.annotation_select.items(),
]
else:
for alias, expression in self.query.selected.items():
# Reference to an annotation.
if isinstance(expression, str):
expression = self.query.annotations[expression]
# Reference to a column.
elif isinstance(expression, int):
expression = cols[expression]
selected.append((alias, expression))
for select_idx, (alias, expression) in enumerate(selected):
if alias:
annotations[alias] = select_idx
select.append((annotation, alias))
select_idx += 1
select.append((expression, alias))
if self.query.select_related:
related_klass_infos = self.get_related_selections(select, select_mask)

View File

@@ -26,6 +26,7 @@ from django.db.models.expressions import (
Exists,
F,
OuterRef,
RawSQL,
Ref,
ResolvedOuterRef,
Value,
@@ -265,6 +266,7 @@ class Query(BaseExpression):
# Holds the selects defined by a call to values() or values_list()
# excluding annotation_select and extra_select.
values_select = ()
selected = None
# SQL annotation-related attributes.
annotation_select_mask = None
@@ -584,6 +586,7 @@ class Query(BaseExpression):
else:
outer_query = self
self.select = ()
self.selected = None
self.default_cols = False
self.extra = {}
if self.annotations:
@@ -1194,13 +1197,10 @@ class Query(BaseExpression):
if select:
self.append_annotation_mask([alias])
else:
annotation_mask = (
value
for value in dict.fromkeys(self.annotation_select)
if value != alias
)
self.set_annotation_mask(annotation_mask)
self.set_annotation_mask(set(self.annotation_select).difference({alias}))
self.annotations[alias] = annotation
if self.selected:
self.selected[alias] = alias
def resolve_expression(self, query, *args, **kwargs):
clone = self.clone()
@@ -2153,6 +2153,7 @@ class Query(BaseExpression):
self.select_related = False
self.set_extra_mask(())
self.set_annotation_mask(())
self.selected = None
def clear_select_fields(self):
"""
@@ -2162,10 +2163,12 @@ class Query(BaseExpression):
"""
self.select = ()
self.values_select = ()
self.selected = None
def add_select_col(self, col, name):
self.select += (col,)
self.values_select += (name,)
self.selected[name] = len(self.select) - 1
def set_select(self, cols):
self.default_cols = False
@@ -2416,12 +2419,23 @@ class Query(BaseExpression):
if names is None:
self.annotation_select_mask = None
else:
self.annotation_select_mask = list(dict.fromkeys(names))
self.annotation_select_mask = set(names)
if self.selected:
# Prune the masked annotations.
self.selected = {
key: value
for key, value in self.selected.items()
if not isinstance(value, str)
or value in self.annotation_select_mask
}
# Append the unmasked annotations.
for name in names:
self.selected[name] = name
self._annotation_select_cache = None
def append_annotation_mask(self, names):
if self.annotation_select_mask is not None:
self.set_annotation_mask((*self.annotation_select_mask, *names))
self.set_annotation_mask(self.annotation_select_mask.union(names))
def set_extra_mask(self, names):
"""
@@ -2440,6 +2454,7 @@ class Query(BaseExpression):
self.clear_select_fields()
self.has_select_fields = True
selected = {}
if fields:
field_names = []
extra_names = []
@@ -2448,13 +2463,16 @@ class Query(BaseExpression):
# Shortcut - if there are no extra or annotations, then
# the values() clause must be just field names.
field_names = list(fields)
selected = dict(zip(fields, range(len(fields))))
else:
self.default_cols = False
for f in fields:
if f in self.extra_select:
if extra := self.extra_select.get(f):
extra_names.append(f)
selected[f] = RawSQL(*extra)
elif f in self.annotation_select:
annotation_names.append(f)
selected[f] = f
elif f in self.annotations:
raise FieldError(
f"Cannot select the '{f}' alias. Use annotate() to "
@@ -2466,13 +2484,13 @@ class Query(BaseExpression):
# `f` is not resolvable.
if self.annotation_select:
self.names_to_path(f.split(LOOKUP_SEP), self.model._meta)
selected[f] = len(field_names)
field_names.append(f)
self.set_extra_mask(extra_names)
self.set_annotation_mask(annotation_names)
selected = frozenset(field_names + extra_names + annotation_names)
else:
field_names = [f.attname for f in self.model._meta.concrete_fields]
selected = frozenset(field_names)
selected = dict.fromkeys(field_names, None)
# Selected annotations must be known before setting the GROUP BY
# clause.
if self.group_by is True:
@@ -2495,6 +2513,7 @@ class Query(BaseExpression):
self.values_select = tuple(field_names)
self.add_fields(field_names, True)
self.selected = selected if fields else None
@property
def annotation_select(self):
@@ -2508,9 +2527,9 @@ class Query(BaseExpression):
return {}
elif self.annotation_select_mask is not None:
self._annotation_select_cache = {
k: self.annotations[k]
for k in self.annotation_select_mask
if k in self.annotations
k: v
for k, v in self.annotations.items()
if k in self.annotation_select_mask
}
return self._annotation_select_cache
else:

View File

@@ -745,6 +745,11 @@ You can also refer to fields on related models with reverse relations through
``"true"``, ``"false"``, and ``"null"`` strings for
:class:`~django.db.models.JSONField` key transforms.
.. versionchanged:: 5.2
The ``SELECT`` clause generated when using ``values()`` was updated to
respect the order of the specified ``*fields`` and ``**expressions``.
``values_list()``
~~~~~~~~~~~~~~~~~
@@ -835,6 +840,11 @@ not having any author:
``"true"``, ``"false"``, and ``"null"`` strings for
:class:`~django.db.models.JSONField` key transforms.
.. versionchanged:: 5.2
The ``SELECT`` clause generated when using ``values_list()`` was updated to
respect the order of the specified ``*fields``.
``dates()``
~~~~~~~~~~~

View File

@@ -195,7 +195,13 @@ Migrations
Models
~~~~~~
* ...
* The ``SELECT`` clause generated when using
:meth:`QuerySet.values()<django.db.models.query.QuerySet.values>` and
:meth:`~django.db.models.query.QuerySet.values_list` now matches the
specified order of the referenced expressions. Previously the order was based
of a set of counterintuitive rules which made query combination through
methods such as
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -96,6 +96,7 @@ contenttypes
contrib
coroutine
coroutines
counterintuitive
criticals
cron
crontab

View File

@@ -568,6 +568,16 @@ class NonAggregateAnnotationTestCase(TestCase):
self.assertEqual(book["other_rating"], 4)
self.assertEqual(book["other_isbn"], "155860191")
def test_values_fields_annotations_order(self):
qs = Book.objects.annotate(other_rating=F("rating") - 1).values(
"other_rating", "rating"
)
book = qs.get(pk=self.b1.pk)
self.assertEqual(
list(book.items()),
[("other_rating", self.b1.rating - 1), ("rating", self.b1.rating)],
)
def test_values_with_pk_annotation(self):
# annotate references a field in values() with pk
publishers = Publisher.objects.values("id", "book__rating").annotate(

View File

@@ -466,8 +466,8 @@ class TestQuerying(PostgreSQLTestCase):
],
)
sql = ctx[0]["sql"]
self.assertIn("GROUP BY 2", sql)
self.assertIn("ORDER BY 2", sql)
self.assertIn("GROUP BY 1", sql)
self.assertIn("ORDER BY 1", sql)
def test_order_by_arrayagg_index(self):
qs = (

View File

@@ -257,6 +257,23 @@ class QuerySetSetOperationTests(TestCase):
)
self.assertCountEqual(qs1.union(qs2), [(1, 0), (1, 2)])
def test_union_with_field_and_annotation_values(self):
qs1 = (
Number.objects.filter(num=1)
.annotate(
zero=Value(0, IntegerField()),
)
.values_list("num", "zero")
)
qs2 = (
Number.objects.filter(num=2)
.annotate(
zero=Value(0, IntegerField()),
)
.values_list("zero", "num")
)
self.assertCountEqual(qs1.union(qs2), [(1, 0), (0, 2)])
def test_union_with_extra_and_values_list(self):
qs1 = (
Number.objects.filter(num=1)
@@ -265,7 +282,11 @@ class QuerySetSetOperationTests(TestCase):
)
.values_list("num", "count")
)
qs2 = Number.objects.filter(num=2).extra(select={"count": 1})
qs2 = (
Number.objects.filter(num=2)
.extra(select={"count": 1})
.values_list("num", "count")
)
self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)])
def test_union_with_values_list_on_annotated_and_unannotated(self):

View File

@@ -2200,7 +2200,7 @@ class Queries6Tests(TestCase):
{"tag_per_parent__max": 2},
)
sql = captured_queries[0]["sql"]
self.assertIn("AS %s" % connection.ops.quote_name("col1"), sql)
self.assertIn("AS %s" % connection.ops.quote_name("parent"), sql)
def test_xor_subquery(self):
self.assertSequenceEqual(