diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py index 71522cb836..1a0a327baf 100644 --- a/tests/composite_pk/tests.py +++ b/tests/composite_pk/tests.py @@ -1,5 +1,4 @@ import json -import unittest from uuid import UUID import yaml @@ -35,9 +34,9 @@ class CompositePKTests(TestCase): cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) @staticmethod - def get_constraints(table): + def get_primary_key_columns(table): with connection.cursor() as cursor: - return connection.introspection.get_constraints(cursor, table) + return connection.introspection.get_primary_key_columns(cursor, table) def test_pk_updated_if_field_updated(self): user = User.objects.get(pk=self.user.pk) @@ -125,53 +124,15 @@ class CompositePKTests(TestCase): with self.assertRaises(IntegrityError): Comment.objects.create(tenant=self.tenant, id=self.comment.id) - @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") - def test_get_constraints_postgresql(self): - user_constraints = self.get_constraints(User._meta.db_table) - user_pk = user_constraints["composite_pk_user_pkey"] - self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) - self.assertIs(user_pk["primary_key"], True) - - comment_constraints = self.get_constraints(Comment._meta.db_table) - comment_pk = comment_constraints["composite_pk_comment_pkey"] - self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) - self.assertIs(comment_pk["primary_key"], True) - - @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") - def test_get_constraints_sqlite(self): - user_constraints = self.get_constraints(User._meta.db_table) - user_pk = user_constraints["__primary__"] - self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) - self.assertIs(user_pk["primary_key"], True) - - comment_constraints = self.get_constraints(Comment._meta.db_table) - comment_pk = comment_constraints["__primary__"] - self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) - self.assertIs(comment_pk["primary_key"], True) - - @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test") - def test_get_constraints_mysql(self): - user_constraints = self.get_constraints(User._meta.db_table) - user_pk = user_constraints["PRIMARY"] - self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) - self.assertIs(user_pk["primary_key"], True) - - comment_constraints = self.get_constraints(Comment._meta.db_table) - comment_pk = comment_constraints["PRIMARY"] - self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) - self.assertIs(comment_pk["primary_key"], True) - - @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test") - def test_get_constraints_oracle(self): - user_constraints = self.get_constraints(User._meta.db_table) - user_pk = next(c for c in user_constraints.values() if c["primary_key"]) - self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) - self.assertEqual(user_pk["primary_key"], 1) - - comment_constraints = self.get_constraints(Comment._meta.db_table) - comment_pk = next(c for c in comment_constraints.values() if c["primary_key"]) - self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"]) - self.assertEqual(comment_pk["primary_key"], 1) + def test_get_primary_key_columns(self): + self.assertEqual( + self.get_primary_key_columns(User._meta.db_table), + ["tenant_id", "id"], + ) + self.assertEqual( + self.get_primary_key_columns(Comment._meta.db_table), + ["tenant_id", "comment_id"], + ) def test_in_bulk(self): """