1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Refs #27149 -- Fixed sql.Query identity.

By making Query subclass BaseExpression in
3543129822 the former defined it's
identity based off _construct_args which is not appropriate.
This commit is contained in:
Simon Charette
2020-10-25 17:41:06 -04:00
committed by Mariusz Felisiak
parent 556fa4bbba
commit bbf141bcdc
5 changed files with 106 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet from django.utils.datastructures import OrderedSet
from django.utils.deprecation import RemovedInDjango40Warning from django.utils.deprecation import RemovedInDjango40Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
class Lookup: class Lookup:
@@ -143,6 +144,18 @@ class Lookup:
def is_summary(self): def is_summary(self):
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
@property
def identity(self):
return self.__class__, self.lhs, self.rhs
def __eq__(self, other):
if not isinstance(other, Lookup):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(make_hashable(self.identity))
class Transform(RegisterLookupMixin, Func): class Transform(RegisterLookupMixin, Func):
""" """

View File

@@ -114,17 +114,28 @@ class Join:
self.join_field, self.nullable, filtered_relation=filtered_relation, self.join_field, self.nullable, filtered_relation=filtered_relation,
) )
def equals(self, other, with_filtered_relation): @property
def identity(self):
return ( return (
isinstance(other, self.__class__) and self.__class__,
self.table_name == other.table_name and self.table_name,
self.parent_alias == other.parent_alias and self.parent_alias,
self.join_field == other.join_field and self.join_field,
(not with_filtered_relation or self.filtered_relation == other.filtered_relation) self.filtered_relation,
) )
def __eq__(self, other): def __eq__(self, other):
return self.equals(other, with_filtered_relation=True) if not isinstance(other, Join):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def equals(self, other, with_filtered_relation):
if with_filtered_relation:
return self == other
return self.identity[:-1] == other.identity[:-1]
def demote(self): def demote(self):
new = self.relabeled_clone({}) new = self.relabeled_clone({})
@@ -160,9 +171,17 @@ class BaseTable:
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
@property
def identity(self):
return self.__class__, self.table_name, self.table_alias
def __eq__(self, other):
if not isinstance(other, BaseTable):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def equals(self, other, with_filtered_relation): def equals(self, other, with_filtered_relation):
return ( return self.identity == other.identity
isinstance(self, other.__class__) and
self.table_name == other.table_name and
self.table_alias == other.table_alias
)

View File

@@ -39,6 +39,7 @@ from django.db.models.sql.where import (
) )
from django.utils.deprecation import RemovedInDjango40Warning from django.utils.deprecation import RemovedInDjango40Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
from django.utils.tree import Node from django.utils.tree import Node
__all__ = ['Query', 'RawQuery'] __all__ = ['Query', 'RawQuery']
@@ -246,6 +247,14 @@ class Query(BaseExpression):
for alias in self.alias_map: for alias in self.alias_map:
return alias return alias
@property
def identity(self):
identity = (
(arg, make_hashable(value))
for arg, value in self.__dict__.items()
)
return (self.__class__, *identity)
def __str__(self): def __str__(self):
""" """
Return the query as a string of SQL with the parameter values Return the query as a string of SQL with the parameter values

View File

@@ -1,10 +1,34 @@
from datetime import datetime from datetime import datetime
from unittest import mock
from django.db.models import DateTimeField, Value from django.db.models import DateTimeField, Value
from django.db.models.lookups import YearLookup from django.db.models.lookups import Lookup, YearLookup
from django.test import SimpleTestCase from django.test import SimpleTestCase
class CustomLookup(Lookup):
pass
class LookupTests(SimpleTestCase):
def test_equality(self):
lookup = Lookup(Value(1), Value(2))
self.assertEqual(lookup, lookup)
self.assertEqual(lookup, Lookup(lookup.lhs, lookup.rhs))
self.assertEqual(lookup, mock.ANY)
self.assertNotEqual(lookup, Lookup(lookup.lhs, Value(3)))
self.assertNotEqual(lookup, Lookup(Value(3), lookup.rhs))
self.assertNotEqual(lookup, CustomLookup(lookup.lhs, lookup.rhs))
def test_hash(self):
lookup = Lookup(Value(1), Value(2))
self.assertEqual(hash(lookup), hash(lookup))
self.assertEqual(hash(lookup), hash(Lookup(lookup.lhs, lookup.rhs)))
self.assertNotEqual(hash(lookup), hash(Lookup(lookup.lhs, Value(3))))
self.assertNotEqual(hash(lookup), hash(Lookup(Value(3), lookup.rhs)))
self.assertNotEqual(hash(lookup), hash(CustomLookup(lookup.lhs, lookup.rhs)))
class YearLookupTests(SimpleTestCase): class YearLookupTests(SimpleTestCase):
def test_get_bound_params(self): def test_get_bound_params(self):
look_up = YearLookup( look_up = YearLookup(

View File

@@ -150,3 +150,31 @@ class TestQuery(SimpleTestCase):
msg = 'Cannot filter against a non-conditional expression.' msg = 'Cannot filter against a non-conditional expression.'
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
query.build_where(Func(output_field=CharField())) query.build_where(Func(output_field=CharField()))
def test_equality(self):
self.assertNotEqual(
Author.objects.all().query,
Author.objects.filter(item__name='foo').query,
)
self.assertEqual(
Author.objects.filter(item__name='foo').query,
Author.objects.filter(item__name='foo').query,
)
self.assertEqual(
Author.objects.filter(item__name='foo').query,
Author.objects.filter(Q(item__name='foo')).query,
)
def test_hash(self):
self.assertNotEqual(
hash(Author.objects.all().query),
hash(Author.objects.filter(item__name='foo').query)
)
self.assertEqual(
hash(Author.objects.filter(item__name='foo').query),
hash(Author.objects.filter(item__name='foo').query),
)
self.assertEqual(
hash(Author.objects.filter(item__name='foo').query),
hash(Author.objects.filter(Q(item__name='foo')).query),
)