mirror of
https://github.com/django/django.git
synced 2024-12-22 09:05:43 +00:00
Refs #27149 -- Fixed sql.Query identity.
By making Query subclass BaseExpression in
3543129822
the former defined it's
identity based off _construct_args which is not appropriate.
This commit is contained in:
parent
556fa4bbba
commit
bbf141bcdc
@ -12,6 +12,7 @@ from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.deprecation import RemovedInDjango40Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
|
||||
class Lookup:
|
||||
@ -143,6 +144,18 @@ class Lookup:
|
||||
def is_summary(self):
|
||||
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.lhs, self.rhs
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Lookup):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(make_hashable(self.identity))
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
|
@ -114,17 +114,28 @@ class Join:
|
||||
self.join_field, self.nullable, filtered_relation=filtered_relation,
|
||||
)
|
||||
|
||||
def equals(self, other, with_filtered_relation):
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
isinstance(other, self.__class__) and
|
||||
self.table_name == other.table_name and
|
||||
self.parent_alias == other.parent_alias and
|
||||
self.join_field == other.join_field and
|
||||
(not with_filtered_relation or self.filtered_relation == other.filtered_relation)
|
||||
self.__class__,
|
||||
self.table_name,
|
||||
self.parent_alias,
|
||||
self.join_field,
|
||||
self.filtered_relation,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.equals(other, with_filtered_relation=True)
|
||||
if not isinstance(other, Join):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other, with_filtered_relation):
|
||||
if with_filtered_relation:
|
||||
return self == other
|
||||
return self.identity[:-1] == other.identity[:-1]
|
||||
|
||||
def demote(self):
|
||||
new = self.relabeled_clone({})
|
||||
@ -160,9 +171,17 @@ class BaseTable:
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.table_name, self.table_alias
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, BaseTable):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other, with_filtered_relation):
|
||||
return (
|
||||
isinstance(self, other.__class__) and
|
||||
self.table_name == other.table_name and
|
||||
self.table_alias == other.table_alias
|
||||
)
|
||||
return self.identity == other.identity
|
||||
|
@ -39,6 +39,7 @@ from django.db.models.sql.where import (
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango40Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
from django.utils.tree import Node
|
||||
|
||||
__all__ = ['Query', 'RawQuery']
|
||||
@ -246,6 +247,14 @@ class Query(BaseExpression):
|
||||
for alias in self.alias_map:
|
||||
return alias
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
identity = (
|
||||
(arg, make_hashable(value))
|
||||
for arg, value in self.__dict__.items()
|
||||
)
|
||||
return (self.__class__, *identity)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Return the query as a string of SQL with the parameter values
|
||||
|
@ -1,10 +1,34 @@
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
from django.db.models import DateTimeField, Value
|
||||
from django.db.models.lookups import YearLookup
|
||||
from django.db.models.lookups import Lookup, YearLookup
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
|
||||
class CustomLookup(Lookup):
|
||||
pass
|
||||
|
||||
|
||||
class LookupTests(SimpleTestCase):
|
||||
def test_equality(self):
|
||||
lookup = Lookup(Value(1), Value(2))
|
||||
self.assertEqual(lookup, lookup)
|
||||
self.assertEqual(lookup, Lookup(lookup.lhs, lookup.rhs))
|
||||
self.assertEqual(lookup, mock.ANY)
|
||||
self.assertNotEqual(lookup, Lookup(lookup.lhs, Value(3)))
|
||||
self.assertNotEqual(lookup, Lookup(Value(3), lookup.rhs))
|
||||
self.assertNotEqual(lookup, CustomLookup(lookup.lhs, lookup.rhs))
|
||||
|
||||
def test_hash(self):
|
||||
lookup = Lookup(Value(1), Value(2))
|
||||
self.assertEqual(hash(lookup), hash(lookup))
|
||||
self.assertEqual(hash(lookup), hash(Lookup(lookup.lhs, lookup.rhs)))
|
||||
self.assertNotEqual(hash(lookup), hash(Lookup(lookup.lhs, Value(3))))
|
||||
self.assertNotEqual(hash(lookup), hash(Lookup(Value(3), lookup.rhs)))
|
||||
self.assertNotEqual(hash(lookup), hash(CustomLookup(lookup.lhs, lookup.rhs)))
|
||||
|
||||
|
||||
class YearLookupTests(SimpleTestCase):
|
||||
def test_get_bound_params(self):
|
||||
look_up = YearLookup(
|
||||
|
@ -150,3 +150,31 @@ class TestQuery(SimpleTestCase):
|
||||
msg = 'Cannot filter against a non-conditional expression.'
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
query.build_where(Func(output_field=CharField()))
|
||||
|
||||
def test_equality(self):
|
||||
self.assertNotEqual(
|
||||
Author.objects.all().query,
|
||||
Author.objects.filter(item__name='foo').query,
|
||||
)
|
||||
self.assertEqual(
|
||||
Author.objects.filter(item__name='foo').query,
|
||||
Author.objects.filter(item__name='foo').query,
|
||||
)
|
||||
self.assertEqual(
|
||||
Author.objects.filter(item__name='foo').query,
|
||||
Author.objects.filter(Q(item__name='foo')).query,
|
||||
)
|
||||
|
||||
def test_hash(self):
|
||||
self.assertNotEqual(
|
||||
hash(Author.objects.all().query),
|
||||
hash(Author.objects.filter(item__name='foo').query)
|
||||
)
|
||||
self.assertEqual(
|
||||
hash(Author.objects.filter(item__name='foo').query),
|
||||
hash(Author.objects.filter(item__name='foo').query),
|
||||
)
|
||||
self.assertEqual(
|
||||
hash(Author.objects.filter(item__name='foo').query),
|
||||
hash(Author.objects.filter(Q(item__name='foo')).query),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user