mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #31347 -- Checked allow_migrate() in CreateExtension operation.
This commit is contained in:
parent
d88365708c
commit
ec292f261d
@ -1,7 +1,7 @@
|
|||||||
from django.contrib.postgres.signals import (
|
from django.contrib.postgres.signals import (
|
||||||
get_citext_oids, get_hstore_oids, register_type_handlers,
|
get_citext_oids, get_hstore_oids, register_type_handlers,
|
||||||
)
|
)
|
||||||
from django.db import NotSupportedError
|
from django.db import NotSupportedError, router
|
||||||
from django.db.migrations import AddIndex, RemoveIndex
|
from django.db.migrations import AddIndex, RemoveIndex
|
||||||
from django.db.migrations.operations.base import Operation
|
from django.db.migrations.operations.base import Operation
|
||||||
|
|
||||||
@ -16,7 +16,10 @@ class CreateExtension(Operation):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
if schema_editor.connection.vendor != 'postgresql':
|
if (
|
||||||
|
schema_editor.connection.vendor != 'postgresql' or
|
||||||
|
not router.allow_migrate(schema_editor.connection.alias, app_label)
|
||||||
|
):
|
||||||
return
|
return
|
||||||
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
|
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
|
||||||
# Clear cached, stale oids.
|
# Clear cached, stale oids.
|
||||||
@ -28,6 +31,8 @@ class CreateExtension(Operation):
|
|||||||
register_type_handlers(schema_editor.connection)
|
register_type_handlers(schema_editor.connection)
|
||||||
|
|
||||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
|
if not router.allow_migrate(schema_editor.connection.alias, app_label):
|
||||||
|
return
|
||||||
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
|
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
|
||||||
# Clear cached, stale oids.
|
# Clear cached, stale oids.
|
||||||
get_hstore_oids.cache_clear()
|
get_hstore_oids.cache_clear()
|
||||||
|
@ -3,12 +3,16 @@ import unittest
|
|||||||
from migrations.test_base import OperationTestBase
|
from migrations.test_base import OperationTestBase
|
||||||
|
|
||||||
from django.db import NotSupportedError, connection
|
from django.db import NotSupportedError, connection
|
||||||
|
from django.db.migrations.state import ProjectState
|
||||||
from django.db.models import Index
|
from django.db.models import Index
|
||||||
from django.test import modify_settings
|
from django.test import modify_settings, override_settings
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
|
from . import PostgreSQLTestCase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from django.contrib.postgres.operations import (
|
from django.contrib.postgres.operations import (
|
||||||
AddIndexConcurrently, RemoveIndexConcurrently,
|
AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently,
|
||||||
)
|
)
|
||||||
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
|
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -141,3 +145,44 @@ class RemoveIndexConcurrentlyTests(OperationTestBase):
|
|||||||
self.assertEqual(name, 'RemoveIndexConcurrently')
|
self.assertEqual(name, 'RemoveIndexConcurrently')
|
||||||
self.assertEqual(args, [])
|
self.assertEqual(args, [])
|
||||||
self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
|
self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
|
||||||
|
|
||||||
|
|
||||||
|
class NoExtensionRouter():
|
||||||
|
def allow_migrate(self, db, app_label, **hints):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
|
||||||
|
class CreateExtensionTests(PostgreSQLTestCase):
|
||||||
|
app_label = 'test_allow_create_extention'
|
||||||
|
|
||||||
|
@override_settings(DATABASE_ROUTERS=[NoExtensionRouter()])
|
||||||
|
def test_no_allow_migrate(self):
|
||||||
|
operation = CreateExtension('uuid-ossp')
|
||||||
|
project_state = ProjectState()
|
||||||
|
new_state = project_state.clone()
|
||||||
|
# Don't create an extension.
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
with connection.schema_editor(atomic=False) as editor:
|
||||||
|
operation.database_forwards(self.app_label, editor, project_state, new_state)
|
||||||
|
self.assertEqual(len(captured_queries), 0)
|
||||||
|
# Reversal.
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
with connection.schema_editor(atomic=False) as editor:
|
||||||
|
operation.database_backwards(self.app_label, editor, new_state, project_state)
|
||||||
|
self.assertEqual(len(captured_queries), 0)
|
||||||
|
|
||||||
|
def test_allow_migrate(self):
|
||||||
|
operation = CreateExtension('uuid-ossp')
|
||||||
|
project_state = ProjectState()
|
||||||
|
new_state = project_state.clone()
|
||||||
|
# Create an extension.
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
with connection.schema_editor(atomic=False) as editor:
|
||||||
|
operation.database_forwards(self.app_label, editor, project_state, new_state)
|
||||||
|
self.assertIn('CREATE EXTENSION', captured_queries[0]['sql'])
|
||||||
|
# Reversal.
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
with connection.schema_editor(atomic=False) as editor:
|
||||||
|
operation.database_backwards(self.app_label, editor, new_state, project_state)
|
||||||
|
self.assertIn('DROP EXTENSION', captured_queries[0]['sql'])
|
||||||
|
Loading…
Reference in New Issue
Block a user