mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #28161 -- Fixed return type of ArrayField(CITextField()).
Thanks Tim for the review.
This commit is contained in:
		| @@ -5,7 +5,7 @@ from django.db.models import CharField, TextField | |||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from .lookups import SearchLookup, TrigramSimilar, Unaccent | from .lookups import SearchLookup, TrigramSimilar, Unaccent | ||||||
| from .signals import register_hstore_handler | from .signals import register_type_handlers | ||||||
|  |  | ||||||
|  |  | ||||||
| class PostgresConfig(AppConfig): | class PostgresConfig(AppConfig): | ||||||
| @@ -16,8 +16,8 @@ class PostgresConfig(AppConfig): | |||||||
|         # Connections may already exist before we are called. |         # Connections may already exist before we are called. | ||||||
|         for conn in connections.all(): |         for conn in connections.all(): | ||||||
|             if conn.connection is not None: |             if conn.connection is not None: | ||||||
|                 register_hstore_handler(conn) |                 register_type_handlers(conn) | ||||||
|         connection_created.connect(register_hstore_handler) |         connection_created.connect(register_type_handlers) | ||||||
|         CharField.register_lookup(Unaccent) |         CharField.register_lookup(Unaccent) | ||||||
|         TextField.register_lookup(Unaccent) |         TextField.register_lookup(Unaccent) | ||||||
|         CharField.register_lookup(SearchLookup) |         CharField.register_lookup(SearchLookup) | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from django.contrib.postgres.signals import register_hstore_handler | from django.contrib.postgres.signals import register_type_handlers | ||||||
| from django.db.migrations.operations.base import Operation | from django.db.migrations.operations.base import Operation | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -15,6 +15,10 @@ class CreateExtension(Operation): | |||||||
|         if schema_editor.connection.vendor != 'postgresql': |         if schema_editor.connection.vendor != 'postgresql': | ||||||
|             return |             return | ||||||
|         schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) |         schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) | ||||||
|  |         # Registering new type handlers cannot be done before the extension is | ||||||
|  |         # installed, otherwise a subsequent data migration would use the same | ||||||
|  |         # connection. | ||||||
|  |         register_type_handlers(schema_editor.connection) | ||||||
|  |  | ||||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): |     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||||
|         schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) |         schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) | ||||||
| @@ -46,13 +50,6 @@ class HStoreExtension(CreateExtension): | |||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.name = 'hstore' |         self.name = 'hstore' | ||||||
|  |  | ||||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): |  | ||||||
|         super().database_forwards(app_label, schema_editor, from_state, to_state) |  | ||||||
|         # Register hstore straight away as it cannot be done before the |  | ||||||
|         # extension is installed, a subsequent data migration would use the |  | ||||||
|         # same connection |  | ||||||
|         register_hstore_handler(schema_editor.connection) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TrigramExtension(CreateExtension): | class TrigramExtension(CreateExtension): | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
|  | import psycopg2 | ||||||
| from psycopg2 import ProgrammingError | from psycopg2 import ProgrammingError | ||||||
| from psycopg2.extras import register_hstore | from psycopg2.extras import register_hstore | ||||||
|  |  | ||||||
|  |  | ||||||
| def register_hstore_handler(connection, **kwargs): | def register_type_handlers(connection, **kwargs): | ||||||
|     if connection.vendor != 'postgresql': |     if connection.vendor != 'postgresql': | ||||||
|         return |         return | ||||||
|  |  | ||||||
| @@ -18,3 +19,17 @@ def register_hstore_handler(connection, **kwargs): | |||||||
|         # This is also needed in order to create the connection in order to |         # This is also needed in order to create the connection in order to | ||||||
|         # install the hstore extension. |         # install the hstore extension. | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         with connection.cursor() as cursor: | ||||||
|  |             # Retrieve oids of citext arrays. | ||||||
|  |             cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'") | ||||||
|  |             oids = tuple(row[0] for row in cursor) | ||||||
|  |         array_type = psycopg2.extensions.new_array_type(oids, 'citext[]', psycopg2.STRING) | ||||||
|  |         psycopg2.extensions.register_type(array_type, None) | ||||||
|  |     except ProgrammingError: | ||||||
|  |         # citext is not available on the database. | ||||||
|  |         # | ||||||
|  |         # The same comments in the except block of the above call to | ||||||
|  |         # register_hstore() also apply here. | ||||||
|  |         pass | ||||||
|   | |||||||
| @@ -88,3 +88,6 @@ Bugfixes | |||||||
| * Fixed a regression where ``Model._state.db`` wasn't set correctly on | * Fixed a regression where ``Model._state.db`` wasn't set correctly on | ||||||
|   multi-table inheritance parent models after saving a child model |   multi-table inheritance parent models after saving a child model | ||||||
|   (:ticket:`28166`). |   (:ticket:`28166`). | ||||||
|  |  | ||||||
|  | * Corrected the return type of ``ArrayField(CITextField())`` values retrieved | ||||||
|  |   from the database (:ticket:`28161`). | ||||||
|   | |||||||
| @@ -12,14 +12,14 @@ class PostgreSQLTestCase(TestCase): | |||||||
|     @classmethod |     @classmethod | ||||||
|     def tearDownClass(cls): |     def tearDownClass(cls): | ||||||
|         # No need to keep that signal overhead for non PostgreSQL-related tests. |         # No need to keep that signal overhead for non PostgreSQL-related tests. | ||||||
|         from django.contrib.postgres.signals import register_hstore_handler |         from django.contrib.postgres.signals import register_type_handlers | ||||||
|  |  | ||||||
|         connection_created.disconnect(register_hstore_handler) |         connection_created.disconnect(register_type_handlers) | ||||||
|         super().tearDownClass() |         super().tearDownClass() | ||||||
|  |  | ||||||
|  |  | ||||||
| @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") | @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") | ||||||
| # To locate the widget's template. | # To locate the widget's template. | ||||||
| @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) | @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) | ||||||
| class PostgreSQLWidgetTestCase(WidgetTest): | class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLTestCase): | ||||||
|     pass |     pass | ||||||
|   | |||||||
| @@ -142,6 +142,7 @@ class Migration(migrations.Migration): | |||||||
|                 ('name', CICharField(primary_key=True, max_length=255)), |                 ('name', CICharField(primary_key=True, max_length=255)), | ||||||
|                 ('email', CIEmailField()), |                 ('email', CIEmailField()), | ||||||
|                 ('description', CITextField()), |                 ('description', CITextField()), | ||||||
|  |                 ('array_field', ArrayField(CITextField(), null=True)), | ||||||
|             ], |             ], | ||||||
|             options={ |             options={ | ||||||
|                 'required_db_vendor': 'postgresql', |                 'required_db_vendor': 'postgresql', | ||||||
|   | |||||||
| @@ -106,6 +106,7 @@ class CITestModel(PostgreSQLModel): | |||||||
|     name = CICharField(primary_key=True, max_length=255) |     name = CICharField(primary_key=True, max_length=255) | ||||||
|     email = CIEmailField() |     email = CIEmailField() | ||||||
|     description = CITextField() |     description = CITextField() | ||||||
|  |     array_field = ArrayField(CITextField(), null=True) | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return self.name |         return self.name | ||||||
|   | |||||||
| @@ -4,11 +4,13 @@ strings and thus eliminates the need for operations such as iexact and other | |||||||
| modifiers to enforce use of an index. | modifiers to enforce use of an index. | ||||||
| """ | """ | ||||||
| from django.db import IntegrityError | from django.db import IntegrityError | ||||||
|  | from django.test.utils import modify_settings | ||||||
|  |  | ||||||
| from . import PostgreSQLTestCase | from . import PostgreSQLTestCase | ||||||
| from .models import CITestModel | from .models import CITestModel | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) | ||||||
| class CITextTestCase(PostgreSQLTestCase): | class CITextTestCase(PostgreSQLTestCase): | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -17,6 +19,7 @@ class CITextTestCase(PostgreSQLTestCase): | |||||||
|             name='JoHn', |             name='JoHn', | ||||||
|             email='joHn@johN.com', |             email='joHn@johN.com', | ||||||
|             description='Average Joe named JoHn', |             description='Average Joe named JoHn', | ||||||
|  |             array_field=['JoE', 'jOhn'], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_equal_lowercase(self): |     def test_equal_lowercase(self): | ||||||
| @@ -34,3 +37,8 @@ class CITextTestCase(PostgreSQLTestCase): | |||||||
|         """ |         """ | ||||||
|         with self.assertRaises(IntegrityError): |         with self.assertRaises(IntegrityError): | ||||||
|             CITestModel.objects.create(name='John') |             CITestModel.objects.create(name='John') | ||||||
|  |  | ||||||
|  |     def test_array_field(self): | ||||||
|  |         instance = CITestModel.objects.get() | ||||||
|  |         self.assertEqual(instance.array_field, self.john.array_field) | ||||||
|  |         self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists()) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user