1
0
mirror of https://github.com/django/django.git synced 2025-01-25 01:30:48 +00:00
django/tests/postgres_tests/test_apps.py
2023-09-25 08:17:03 +02:00

92 lines
3.6 KiB
Python

import unittest
from decimal import Decimal
from django.db import connection
from django.db.backends.signals import connection_created
from django.db.migrations.writer import MigrationWriter
from django.test import TestCase
from django.test.utils import CaptureQueriesContext, modify_settings, override_settings
try:
from django.contrib.postgres.fields import (
DateRangeField,
DateTimeRangeField,
DecimalRangeField,
IntegerRangeField,
)
from django.contrib.postgres.signals import get_hstore_oids
from django.db.backends.postgresql.psycopg_any import (
DateRange,
DateTimeRange,
DateTimeTZRange,
NumericRange,
is_psycopg3,
)
except ImportError:
pass
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class PostgresConfigTests(TestCase):
def test_install_app_no_warning(self):
# Clear cache to force queries when (re)initializing the
# "django.contrib.postgres" app.
get_hstore_oids.cache_clear()
with CaptureQueriesContext(connection) as captured_queries:
with override_settings(INSTALLED_APPS=["django.contrib.postgres"]):
pass
self.assertGreaterEqual(len(captured_queries), 1)
def test_register_type_handlers_connection(self):
from django.contrib.postgres.signals import register_type_handlers
self.assertNotIn(
register_type_handlers, connection_created._live_receivers(None)[0]
)
with modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
self.assertIn(
register_type_handlers, connection_created._live_receivers(None)[0]
)
self.assertNotIn(
register_type_handlers, connection_created._live_receivers(None)[0]
)
def test_register_serializer_for_migrations(self):
tests = (
(DateRange(empty=True), DateRangeField),
(DateTimeRange(empty=True), DateRangeField),
(DateTimeTZRange(None, None, "[]"), DateTimeRangeField),
(NumericRange(Decimal("1.0"), Decimal("5.0"), "()"), DecimalRangeField),
(NumericRange(1, 10), IntegerRangeField),
)
def assertNotSerializable():
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
with self.assertRaisesMessage(
ValueError, "Cannot serialize: %s" % default.__class__.__name__
):
MigrationWriter.serialize(field)
assertNotSerializable()
import_name = "psycopg.types.range" if is_psycopg3 else "psycopg2.extras"
with self.modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
serialized_field, imports = MigrationWriter.serialize(field)
self.assertEqual(
imports,
{
"import django.contrib.postgres.fields.ranges",
f"import {import_name}",
},
)
self.assertIn(
f"{field.__module__}.{field.__class__.__name__}"
f"(default={import_name}.{default!r})",
serialized_field,
)
assertNotSerializable()