From cecef94275118dc49a1b0d89d3ca9734e2ec9776 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 21 Sep 2016 22:06:09 -0400 Subject: [PATCH] Fixed #27257 -- Fixed builtin text lookups on JSONField keys. Thanks Nick Stefan for the report and Tim for the review. --- django/contrib/postgres/fields/jsonb.py | 67 +++++++++++++++++++++++-- tests/postgres_tests/test_json.py | 22 ++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index 1a6ad0831b..b4574cc5aa 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -4,7 +4,9 @@ from psycopg2.extras import Json from django.contrib.postgres import forms, lookups from django.core import exceptions -from django.db.models import Field, Transform +from django.db.models import ( + Field, TextField, Transform, lookups as builtin_lookups, +) from django.utils.translation import ugettext_lazy as _ __all__ = ['JSONField'] @@ -86,6 +88,8 @@ JSONField.register_lookup(lookups.HasAnyKeys) class KeyTransform(Transform): + operator = '->' + nested_operator = '#>' def __init__(self, key_name, *args, **kwargs): super(KeyTransform, self).__init__(*args, **kwargs) @@ -99,14 +103,71 @@ class KeyTransform(Transform): previous = previous.lhs lhs, params = compiler.compile(previous) if len(key_transforms) > 1: - return "{} #> %s".format(lhs), [key_transforms] + params + return "(%s %s %%s)" % (lhs, self.nested_operator), [key_transforms] + params try: int(self.key_name) except ValueError: lookup = "'%s'" % self.key_name else: lookup = "%s" % self.key_name - return "(%s -> %s)" % (lhs, lookup), params + return "(%s %s %s)" % (lhs, self.operator, lookup), params + + +class KeyTextTransform(KeyTransform): + operator = '->>' + nested_operator = '#>>' + _output_field = TextField() + + +class KeyTransformTextLookupMixin(object): + """ + Mixin for combining with a lookup expecting a text lhs from a JSONField + key lookup. Make use of the ->> operator instead of casting key values to + text and performing the lookup on the resulting representation. + """ + def __init__(self, key_transform, *args, **kwargs): + assert isinstance(key_transform, KeyTransform) + key_text_transform = KeyTextTransform( + key_transform.key_name, *key_transform.source_expressions, **key_transform.extra + ) + super(KeyTransformTextLookupMixin, self).__init__(key_text_transform, *args, **kwargs) + + +class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains): + pass + + +class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith): + pass + + +class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith): + pass + + +class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith): + pass + + +class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith): + pass + + +class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex): + pass + + +class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex): + pass + + +KeyTransform.register_lookup(KeyTransformIContains) +KeyTransform.register_lookup(KeyTransformStartsWith) +KeyTransform.register_lookup(KeyTransformIStartsWith) +KeyTransform.register_lookup(KeyTransformEndsWith) +KeyTransform.register_lookup(KeyTransformIEndsWith) +KeyTransform.register_lookup(KeyTransformRegex) +KeyTransform.register_lookup(KeyTransformIRegex) class KeyTransformFactory(object): diff --git a/tests/postgres_tests/test_json.py b/tests/postgres_tests/test_json.py index f306284b7f..78dded31a9 100644 --- a/tests/postgres_tests/test_json.py +++ b/tests/postgres_tests/test_json.py @@ -124,6 +124,7 @@ class TestQuerying(PostgreSQLTestCase): 'k': True, 'l': False, }), + JSONModel.objects.create(field={'foo': 'bar'}), ] def test_exact(self): @@ -237,6 +238,27 @@ class TestQuerying(PostgreSQLTestCase): self.objs[7:9] ) + def test_icontains(self): + self.assertFalse(JSONModel.objects.filter(field__foo__icontains='"bar"').exists()) + + def test_startswith(self): + self.assertTrue(JSONModel.objects.filter(field__foo__startswith='b').exists()) + + def test_istartswith(self): + self.assertTrue(JSONModel.objects.filter(field__foo__istartswith='B').exists()) + + def test_endswith(self): + self.assertTrue(JSONModel.objects.filter(field__foo__endswith='r').exists()) + + def test_iendswith(self): + self.assertTrue(JSONModel.objects.filter(field__foo__iendswith='R').exists()) + + def test_regex(self): + self.assertTrue(JSONModel.objects.filter(field__foo__regex=r'^bar$').exists()) + + def test_iregex(self): + self.assertTrue(JSONModel.objects.filter(field__foo__iregex=r'^bAr$').exists()) + @skipUnlessDBFeature('has_jsonb_datatype') class TestSerialization(PostgreSQLTestCase):