from django.db import DatabaseError, NotSupportedError, connection from django.db.models import Exists, F, IntegerField, OuterRef, Value from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from .models import Number, ReservedName @skipUnlessDBFeature('supports_select_union') class QuerySetSetOperationTests(TestCase): @classmethod def setUpTestData(cls): Number.objects.bulk_create(Number(num=i, other_num=10 - i) for i in range(10)) def number_transform(self, value): return value.num def assertNumbersEqual(self, queryset, expected_numbers, ordered=True): self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered) def test_simple_union(self): qs1 = Number.objects.filter(num__lte=1) qs2 = Number.objects.filter(num__gte=8) qs3 = Number.objects.filter(num=5) self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False) @skipUnlessDBFeature('supports_select_intersection') def test_simple_intersection(self): qs1 = Number.objects.filter(num__lte=5) qs2 = Number.objects.filter(num__gte=5) qs3 = Number.objects.filter(num__gte=4, num__lte=6) self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False) @skipUnlessDBFeature('supports_select_intersection') def test_intersection_with_values(self): ReservedName.objects.create(name='a', order=2) qs1 = ReservedName.objects.all() reserved_name = qs1.intersection(qs1).values('name', 'order', 'id').get() self.assertEqual(reserved_name['name'], 'a') self.assertEqual(reserved_name['order'], 2) reserved_name = qs1.intersection(qs1).values_list('name', 'order', 'id').get() self.assertEqual(reserved_name[:2], ('a', 2)) @skipUnlessDBFeature('supports_select_difference') def test_simple_difference(self): qs1 = Number.objects.filter(num__lte=5) qs2 = Number.objects.filter(num__lte=4) self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False) def test_union_distinct(self): qs1 = Number.objects.all() qs2 = Number.objects.all() self.assertEqual(len(list(qs1.union(qs2, all=True))), 20) self.assertEqual(len(list(qs1.union(qs2))), 10) @skipUnlessDBFeature('supports_select_intersection') def test_intersection_with_empty_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.none() qs3 = Number.objects.filter(pk__in=[]) self.assertEqual(len(qs1.intersection(qs2)), 0) self.assertEqual(len(qs1.intersection(qs3)), 0) self.assertEqual(len(qs2.intersection(qs1)), 0) self.assertEqual(len(qs3.intersection(qs1)), 0) self.assertEqual(len(qs2.intersection(qs2)), 0) self.assertEqual(len(qs3.intersection(qs3)), 0) @skipUnlessDBFeature('supports_select_difference') def test_difference_with_empty_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.none() qs3 = Number.objects.filter(pk__in=[]) self.assertEqual(len(qs1.difference(qs2)), 10) self.assertEqual(len(qs1.difference(qs3)), 10) self.assertEqual(len(qs2.difference(qs1)), 0) self.assertEqual(len(qs3.difference(qs1)), 0) self.assertEqual(len(qs2.difference(qs2)), 0) self.assertEqual(len(qs3.difference(qs3)), 0) @skipUnlessDBFeature('supports_select_difference') def test_difference_with_values(self): ReservedName.objects.create(name='a', order=2) qs1 = ReservedName.objects.all() qs2 = ReservedName.objects.none() reserved_name = qs1.difference(qs2).values('name', 'order', 'id').get() self.assertEqual(reserved_name['name'], 'a') self.assertEqual(reserved_name['order'], 2) reserved_name = qs1.difference(qs2).values_list('name', 'order', 'id').get() self.assertEqual(reserved_name[:2], ('a', 2)) def test_union_with_empty_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.none() qs3 = Number.objects.filter(pk__in=[]) self.assertEqual(len(qs1.union(qs2)), 10) self.assertEqual(len(qs2.union(qs1)), 10) self.assertEqual(len(qs1.union(qs3)), 10) self.assertEqual(len(qs3.union(qs1)), 10) self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10) self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20) self.assertEqual(len(qs2.union(qs2)), 0) self.assertEqual(len(qs3.union(qs3)), 0) def test_limits(self): qs1 = Number.objects.all() qs2 = Number.objects.all() self.assertEqual(len(list(qs1.union(qs2)[:2])), 2) def test_ordering(self): qs1 = Number.objects.filter(num__lte=1) qs2 = Number.objects.filter(num__gte=2, num__lte=3) self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0]) def test_ordering_by_f_expression(self): qs1 = Number.objects.filter(num__lte=1) qs2 = Number.objects.filter(num__gte=2, num__lte=3) self.assertNumbersEqual(qs1.union(qs2).order_by(F('num').desc()), [3, 2, 1, 0]) def test_union_with_values(self): ReservedName.objects.create(name='a', order=2) qs1 = ReservedName.objects.all() reserved_name = qs1.union(qs1).values('name', 'order', 'id').get() self.assertEqual(reserved_name['name'], 'a') self.assertEqual(reserved_name['order'], 2) reserved_name = qs1.union(qs1).values_list('name', 'order', 'id').get() self.assertEqual(reserved_name[:2], ('a', 2)) # List of columns can be changed. reserved_name = qs1.union(qs1).values_list('order').get() self.assertEqual(reserved_name, (2,)) def test_union_with_two_annotated_values_list(self): qs1 = Number.objects.filter(num=1).annotate( count=Value(0, IntegerField()), ).values_list('num', 'count') qs2 = Number.objects.filter(num=2).values('pk').annotate( count=F('num'), ).annotate( num=Value(1, IntegerField()), ).values_list('num', 'count') self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)]) def test_union_with_extra_and_values_list(self): qs1 = Number.objects.filter(num=1).extra( select={'count': 0}, ).values_list('num', 'count') qs2 = Number.objects.filter(num=2).extra(select={'count': 1}) self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)]) def test_union_with_values_list_on_annotated_and_unannotated(self): ReservedName.objects.create(name='rn1', order=1) qs1 = Number.objects.annotate( has_reserved_name=Exists(ReservedName.objects.filter(order=OuterRef('num'))) ).filter(has_reserved_name=True) qs2 = Number.objects.filter(num=9) self.assertCountEqual(qs1.union(qs2).values_list('num', flat=True), [1, 9]) def test_union_with_values_list_and_order(self): ReservedName.objects.bulk_create([ ReservedName(name='rn1', order=7), ReservedName(name='rn2', order=5), ReservedName(name='rn0', order=6), ReservedName(name='rn9', order=-1), ]) qs1 = ReservedName.objects.filter(order__gte=6) qs2 = ReservedName.objects.filter(order__lte=5) union_qs = qs1.union(qs2) for qs, expected_result in ( # Order by a single column. (union_qs.order_by('-pk').values_list('order', flat=True), [-1, 6, 5, 7]), (union_qs.order_by('pk').values_list('order', flat=True), [7, 5, 6, -1]), (union_qs.values_list('order', flat=True).order_by('-pk'), [-1, 6, 5, 7]), (union_qs.values_list('order', flat=True).order_by('pk'), [7, 5, 6, -1]), # Order by multiple columns. (union_qs.order_by('-name', 'pk').values_list('order', flat=True), [-1, 5, 7, 6]), (union_qs.values_list('order', flat=True).order_by('-name', 'pk'), [-1, 5, 7, 6]), ): with self.subTest(qs=qs): self.assertEqual(list(qs), expected_result) def test_count_union(self): qs1 = Number.objects.filter(num__lte=1).values('num') qs2 = Number.objects.filter(num__gte=2, num__lte=3).values('num') self.assertEqual(qs1.union(qs2).count(), 4) def test_count_union_empty_result(self): qs = Number.objects.filter(pk__in=[]) self.assertEqual(qs.union(qs).count(), 0) @skipUnlessDBFeature('supports_select_difference') def test_count_difference(self): qs1 = Number.objects.filter(num__lt=10) qs2 = Number.objects.filter(num__lt=9) self.assertEqual(qs1.difference(qs2).count(), 1) @skipUnlessDBFeature('supports_select_intersection') def test_count_intersection(self): qs1 = Number.objects.filter(num__gte=5) qs2 = Number.objects.filter(num__lte=5) self.assertEqual(qs1.intersection(qs2).count(), 1) @skipUnlessDBFeature('supports_slicing_ordering_in_compound') def test_ordering_subqueries(self): qs1 = Number.objects.order_by('num')[:2] qs2 = Number.objects.order_by('-num')[:2] self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0]) @skipIfDBFeature('supports_slicing_ordering_in_compound') def test_unsupported_ordering_slicing_raises_db_error(self): qs1 = Number.objects.all() qs2 = Number.objects.all() msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements' with self.assertRaisesMessage(DatabaseError, msg): list(qs1.union(qs2[:10])) msg = 'ORDER BY not allowed in subqueries of compound statements' with self.assertRaisesMessage(DatabaseError, msg): list(qs1.order_by('id').union(qs2)) @skipIfDBFeature('supports_select_intersection') def test_unsupported_intersection_raises_db_error(self): qs1 = Number.objects.all() qs2 = Number.objects.all() msg = 'intersection is not supported on this database backend' with self.assertRaisesMessage(NotSupportedError, msg): list(qs1.intersection(qs2)) def test_combining_multiple_models(self): ReservedName.objects.create(name='99 little bugs', order=99) qs1 = Number.objects.filter(num=1).values_list('num', flat=True) qs2 = ReservedName.objects.values_list('order') self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99]) def test_order_raises_on_non_selected_column(self): qs1 = Number.objects.filter().annotate( annotation=Value(1, IntegerField()), ).values('annotation', num2=F('num')) qs2 = Number.objects.filter().values('id', 'num') # Should not raise list(qs1.union(qs2).order_by('annotation')) list(qs1.union(qs2).order_by('num2')) msg = 'ORDER BY term does not match any column in the result set' # 'id' is not part of the select with self.assertRaisesMessage(DatabaseError, msg): list(qs1.union(qs2).order_by('id')) # 'num' got realiased to num2 with self.assertRaisesMessage(DatabaseError, msg): list(qs1.union(qs2).order_by('num')) # switched order, now 'exists' again: list(qs2.union(qs1).order_by('num')) @skipUnlessDBFeature('supports_select_difference', 'supports_select_intersection') def test_qs_with_subcompound_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.intersection(Number.objects.filter(num__gt=1)) self.assertEqual(qs1.difference(qs2).count(), 2) def test_order_by_same_type(self): qs = Number.objects.all() union = qs.union(qs) numbers = list(range(10)) self.assertNumbersEqual(union.order_by('num'), numbers) self.assertNumbersEqual(union.order_by('other_num'), reversed(numbers)) def test_unsupported_operations_on_combined_qs(self): qs = Number.objects.all() msg = 'Calling QuerySet.%s() after %s() is not supported.' combinators = ['union'] if connection.features.supports_select_difference: combinators.append('difference') if connection.features.supports_select_intersection: combinators.append('intersection') for combinator in combinators: for operation in ( 'annotate', 'defer', 'delete', 'distinct', 'exclude', 'extra', 'filter', 'only', 'prefetch_related', 'select_related', 'update', ): with self.subTest(combinator=combinator, operation=operation): with self.assertRaisesMessage( NotSupportedError, msg % (operation, combinator), ): getattr(getattr(qs, combinator)(qs), operation)()