From 5d91dc8ee3ff7e2230cfacf41005f61b8efbf057 Mon Sep 17 00:00:00 2001 From: Gagaro Date: Mon, 31 Jan 2022 16:04:13 +0100 Subject: [PATCH] Refs #30581 -- Added Q.check() hook. --- django/db/models/query_utils.py | 29 ++++++++++++++++++++++++ tests/queries/test_q.py | 40 ++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 746efce67f..9b5824f89f 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -8,12 +8,16 @@ circular import difficulties. import copy import functools import inspect +import logging from collections import namedtuple from django.core.exceptions import FieldError +from django.db import DEFAULT_DB_ALIAS, DatabaseError from django.db.models.constants import LOOKUP_SEP from django.utils import tree +logger = logging.getLogger("django.db.models") + # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. @@ -110,6 +114,31 @@ class Q(tree.Node): else: yield child + def check(self, against, using=DEFAULT_DB_ALIAS): + """ + Do a database query to check if the expressions of the Q instance + matches against the expressions. + """ + # Avoid circular imports. + from django.db.models import Value + from django.db.models.sql import Query + from django.db.models.sql.constants import SINGLE + + query = Query(None) + for name, value in against.items(): + if not hasattr(value, "resolve_expression"): + value = Value(value) + query.add_annotation(value, name, select=False) + query.add_annotation(Value(1), "_check") + # This will raise a FieldError if a field is missing in "against". + query.add_q(self) + compiler = query.get_compiler(using=using) + try: + return compiler.execute_sql(SINGLE) is not None + except DatabaseError as e: + logger.warning("Got a database error calling check() on %r: %s", self, e) + return True + def deconstruct(self): path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) if path.startswith("django.db.models.query_utils"): diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py index 42a00da3eb..4801eb4807 100644 --- a/tests/queries/test_q.py +++ b/tests/queries/test_q.py @@ -1,3 +1,4 @@ +from django.core.exceptions import FieldError from django.db.models import ( BooleanField, Exists, @@ -10,7 +11,7 @@ from django.db.models import ( from django.db.models.expressions import RawSQL from django.db.models.functions import Lower from django.db.models.sql.where import NothingNode -from django.test import SimpleTestCase +from django.test import SimpleTestCase, TestCase from .models import Tag @@ -214,3 +215,40 @@ class QTests(SimpleTestCase): ) flatten = list(q.flatten()) self.assertEqual(len(flatten), 7) + + +class QCheckTests(TestCase): + def test_basic(self): + q = Q(price__gt=20) + self.assertIs(q.check({"price": 30}), True) + self.assertIs(q.check({"price": 10}), False) + + def test_expression(self): + q = Q(name="test") + self.assertIs(q.check({"name": Lower(Value("TeSt"))}), True) + self.assertIs(q.check({"name": Value("other")}), False) + + def test_missing_field(self): + q = Q(description__startswith="prefix") + msg = "Cannot resolve keyword 'description' into field." + with self.assertRaisesMessage(FieldError, msg): + q.check({"name": "test"}) + + def test_boolean_expression(self): + q = Q(ExpressionWrapper(Q(price__gt=20), output_field=BooleanField())) + self.assertIs(q.check({"price": 25}), True) + self.assertIs(q.check({"price": Value(10)}), False) + + def test_rawsql(self): + """ + RawSQL expressions cause a database error because "price" cannot be + replaced by its value. In this case, Q.check() logs a warning and + return True. + """ + q = Q(RawSQL("price > %s", params=(20,), output_field=BooleanField())) + with self.assertLogs("django.db.models", "WARNING") as cm: + self.assertIs(q.check({"price": 10}), True) + self.assertIn( + f"Got a database error calling check() on {q!r}: ", + cm.records[0].getMessage(), + )