diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 9be17a4a84..1bf396723e 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -175,6 +175,19 @@ class Q(tree.Node): def __hash__(self): return hash(self.identity) + @cached_property + def referenced_base_fields(self): + """ + Retrieve all base fields referenced directly or through F expressions + excluding any fields referenced through joins. + """ + # Avoid circular imports. + from django.db.models.sql import query + + return { + child.split(LOOKUP_SEP, 1)[0] for child in query.get_children_from_q(self) + } + class DeferredAttribute: """ diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py index d3bab1f2a0..f7192a430a 100644 --- a/tests/queries/test_q.py +++ b/tests/queries/test_q.py @@ -10,6 +10,7 @@ from django.db.models import ( ) from django.db.models.expressions import NegatedExpression, RawSQL from django.db.models.functions import Lower +from django.db.models.lookups import Exact, IsNull from django.db.models.sql.where import NothingNode from django.test import SimpleTestCase, TestCase @@ -263,6 +264,33 @@ class QTests(SimpleTestCase): Q(*items, _connector=connector), ) + def test_referenced_base_fields(self): + # Make sure Q.referenced_base_fields retrieves all base fields from + # both filters and F expressions. + tests = [ + (Q(field_1=1) & Q(field_2=1), {"field_1", "field_2"}), + ( + Q(Exact(F("field_3"), IsNull(F("field_4"), True))), + {"field_3", "field_4"}, + ), + (Q(Exact(Q(field_5=F("field_6")), True)), {"field_5", "field_6"}), + (Q(field_2=1), {"field_2"}), + (Q(field_7__lookup=True), {"field_7"}), + (Q(field_7__joined_field__lookup=True), {"field_7"}), + ] + combined_q = Q(1) + combined_q_base_fields = set() + for q, expected_base_fields in tests: + combined_q &= q + combined_q_base_fields |= expected_base_fields + tests.append((combined_q, combined_q_base_fields)) + for q, expected_base_fields in tests: + with self.subTest(q=q): + self.assertEqual( + q.referenced_base_fields, + expected_base_fields, + ) + class QCheckTests(TestCase): def test_basic(self):