diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index bb9a25e4d7..8d5724089c 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -856,21 +856,19 @@ class BasicExpressionsTests(TestCase): self.assertEqual(qs.get(), self.gmbh) def test_subquery_with_custom_template(self): - companies = Company.objects.annotate( - ceo_manager_count=Subquery( - Employee.objects.filter( - lastname=OuterRef("ceo__lastname"), - ).values("manager"), - template="(SELECT count(*) FROM (%(subquery)s) _count)", - ) + custom_subquery = Company.objects.filter( + ceo__salary__in=Subquery( + Employee.objects.all().values("salary"), + salary=20, + template="(SELECT salary FROM (%(subquery)s) _subquery " + "WHERE salary = %(salary)s)", + ), ) - expected_results = [ - {"name": "Example Inc.", "ceo_manager_count": 1}, - {"name": "Foobar Ltd.", "ceo_manager_count": 1}, - {"name": "Test GmbH", "ceo_manager_count": 1}, - ] - self.assertListEqual( - list(companies.values("name", "ceo_manager_count")), expected_results + expected_companies = Company.objects.filter(name="Foobar Ltd.") + self.assertQuerySetEqual( + custom_subquery.order_by("name"), + expected_companies.order_by("name"), + transform=lambda x: x, ) def test_aggregate_subquery_annotation(self): diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index c63cc3f252..8c740eda0c 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -726,16 +726,13 @@ class CaseExpressionTests(TestCase): case_expression = Case( When(integer=1, then=Value(10)), When(integer=2, then=Value(20)), - default=Value(0), - template="CASE %(cases)s ELSE %(default)s + 5 END", + custom_default=5, + template="CASE %(cases)s ELSE %(custom_default)s END", ) - self.assertListEqual( - list( - CaseTestModel.objects.annotate(values=case_expression).values_list( - "values", flat=True - ) - ), + self.assertQuerySetEqual( + CaseTestModel.objects.annotate(values=case_expression).order_by("pk"), [10, 20, 5, 20, 5, 5, 5], + transform=attrgetter("values"), ) def test_update(self):