From b45b2f9e4589dcaa2bbf325bb68f98831d56d53b Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Mon, 27 May 2024 14:06:00 +0800 Subject: [PATCH] Fixed #373 - Added composite foreign keys --- django/contrib/contenttypes/fields.py | 3 + django/db/backends/base/schema.py | 43 ++++-- django/db/backends/ddl_references.py | 2 +- django/db/backends/mysql/introspection.py | 7 +- django/db/backends/oracle/introspection.py | 14 +- .../db/backends/postgresql/introspection.py | 19 ++- django/db/models/fields/related.py | 47 +++++- docs/ref/models/fields.txt | 20 +++ tests/backends/test_ddl_references.py | 13 +- tests/composite_fk/__init__.py | 0 tests/composite_fk/models/__init__.py | 7 + tests/composite_fk/models/tenant.py | 24 +++ tests/composite_fk/test_checks.py | 115 ++++++++++++++ tests/composite_fk/tests.py | 141 ++++++++++++++++++ tests/migrations/test_writer.py | 25 ++++ tests/model_meta/models.py | 7 +- tests/model_meta/results.py | 7 + 17 files changed, 453 insertions(+), 41 deletions(-) create mode 100644 tests/composite_fk/__init__.py create mode 100644 tests/composite_fk/models/__init__.py create mode 100644 tests/composite_fk/models/tenant.py create mode 100644 tests/composite_fk/test_checks.py create mode 100644 tests/composite_fk/tests.py diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index a3e87f6ed4..1d58a538ed 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -362,6 +362,9 @@ class GenericRelation(ForeignObject): *self._check_generic_foreign_key_existence(), ] + def _check_from_fields_exist(self): + return [] + def _is_matching_generic_foreign_key(self, field): """ Return True if field is a GenericForeignKey whose content type and diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index de4886837e..9e7edd5c26 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -232,6 +232,17 @@ class BaseDatabaseSchemaEditor: params.extend(extra_params) # FK. if field.remote_field and field.db_constraint: + if len(field.to_fields) > 1: + if ( + self.sql_create_fk + and self.connection.features.supports_foreign_keys + ): + self.deferred_sql.append( + self._create_fk_sql( + model, field, "_fk_%(to_table)s_%(to_column)s" + ) + ) + continue to_table = field.remote_field.model._meta.db_table to_column = field.remote_field.model._meta.get_field( field.remote_field.field_name @@ -1655,7 +1666,11 @@ class BaseDatabaseSchemaEditor: """ output = [] if self._field_should_be_indexed(model, field): - output.append(self._create_index_sql(model, fields=[field])) + if hasattr(field, "local_related_fields"): + fields = field.local_related_fields + else: + fields = [field] + output.append(self._create_index_sql(model, fields=fields)) return output def _field_should_be_altered(self, old_field, new_field, ignore=None): @@ -1717,23 +1732,22 @@ class BaseDatabaseSchemaEditor: } def _create_fk_sql(self, model, field, suffix): - table = Table(model._meta.db_table, self.quote_name) + meta = model._meta + target_meta = field.target_field.model._meta + table = Table(meta.db_table, self.quote_name) name = self._fk_constraint_name(model, field, suffix) - column = Columns(model._meta.db_table, [field.column], self.quote_name) - to_table = Table(field.target_field.model._meta.db_table, self.quote_name) - to_column = Columns( - field.target_field.model._meta.db_table, - [field.target_field.column], - self.quote_name, - ) + from_columns = [field.column for field in field.local_related_fields] + to_columns = [field.column for field in field.foreign_related_fields] + to_table = Table(target_meta.db_table, self.quote_name) deferrable = self.connection.ops.deferrable_sql() + return Statement( self.sql_create_fk, table=table, name=name, - column=column, + column=Columns(meta.db_table, from_columns, self.quote_name), to_table=to_table, - to_column=to_column, + to_column=Columns(target_meta.db_table, to_columns, self.quote_name), deferrable=deferrable, ) @@ -1741,11 +1755,14 @@ class BaseDatabaseSchemaEditor: def create_fk_name(*args, **kwargs): return self.quote_name(self._create_index_name(*args, **kwargs)) + from_columns = [field.column for field in field.local_related_fields] + to_columns = [field.column for field in field.foreign_related_fields] + return ForeignKeyName( model._meta.db_table, - [field.column], + from_columns, split_identifier(field.target_field.model._meta.db_table)[1], - [field.target_field.column], + to_columns, suffix, create_fk_name, ) diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py index cb8d2defd2..3853959278 100644 --- a/django/db/backends/ddl_references.py +++ b/django/db/backends/ddl_references.py @@ -186,7 +186,7 @@ class ForeignKeyName(TableColumns): def __str__(self): suffix = self.suffix_template % { "to_table": self.to_reference.table, - "to_column": self.to_reference.columns[0], + "to_column": "_".join(self.to_reference.columns), } return self.create_fk_name(self.table, self.columns, suffix) diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index ea1d0aa187..6737c23533 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -277,11 +277,16 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "unique": kind in {"PRIMARY KEY", "UNIQUE"}, "index": False, "check": False, - "foreign_key": (ref_table, ref_column) if ref_column else None, + "foreign_key": None, } if self.connection.features.supports_index_column_ordering: constraints[constraint]["orders"] = [] constraints[constraint]["columns"].add(column) + if ref_column: + if constraints[constraint]["foreign_key"]: + constraints[constraint]["foreign_key"] += (ref_column,) + else: + constraints[constraint]["foreign_key"] = (ref_table, ref_column) # Add check constraints. if self.connection.features.can_introspect_check_constraints: unnamed_constraints_index = 0 diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 07c2d9bded..0afacc4669 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -347,31 +347,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): """ SELECT cons.constraint_name, + LOWER(rcols.table_name), LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position), - LOWER(rcols.table_name), - LOWER(rcols.column_name) + LISTAGG(LOWER(rcols.column_name), ',') + WITHIN GROUP (ORDER BY rcols.position) FROM user_constraints cons INNER JOIN user_cons_columns rcols - ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1 + ON rcols.constraint_name = cons.r_constraint_name LEFT OUTER JOIN user_cons_columns cols ON cons.constraint_name = cols.constraint_name + AND cols.position = rcols.position WHERE cons.constraint_type = 'R' AND cons.table_name = UPPER(%s) - GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name + GROUP BY cons.constraint_name, rcols.table_name """, [table_name], ) - for constraint, columns, other_table, other_column in cursor.fetchall(): + for constraint, other_table, columns, other_columns in cursor.fetchall(): constraint = self.identifier_converter(constraint) constraints[constraint] = { "primary_key": False, "unique": False, - "foreign_key": (other_table, other_column), + "foreign_key": (other_table, *other_columns.split(",")), "check": False, "index": False, "columns": columns.split(","), diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index 69bc8712bd..a2dc8af610 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -199,10 +199,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ORDER BY cols.arridx ), c.contype, - (SELECT fkc.relname || '.' || fka.attname - FROM pg_attribute AS fka - JOIN pg_class AS fkc ON fka.attrelid = fkc.oid - WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]), + array( + SELECT fkc.relname || '.' || fka.attname + FROM unnest(c.confkey) WITH ORDINALITY cols(colid, arridx) + JOIN pg_attribute AS fka ON cols.colid = fka.attnum + JOIN pg_class AS fkc ON fka.attrelid = fkc.oid + WHERE fka.attrelid = c.confrelid + ORDER BY cols.arridx + ), cl.reloptions FROM pg_constraint AS c JOIN pg_class AS cl ON c.conrelid = cl.oid @@ -211,11 +215,16 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): [table_name], ) for constraint, columns, kind, used_cols, options in cursor.fetchall(): + foreign_key = None + if kind == "f": + cols = tuple(tuple(col.split(".", 1)) for col in used_cols) + foreign_key = tuple([cols[0][0]] + [col[1] for col in cols]) + constraints[constraint] = { "columns": columns, "primary_key": kind == "p", "unique": kind in ["p", "u"], - "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None, + "foreign_key": foreign_key, "check": kind == "c", "index": False, "definition": None, diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 9ef2d29024..124176f7dd 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -580,6 +580,7 @@ class ForeignObject(RelatedField): return [ *super().check(**kwargs), *self._check_to_fields_exist(), + *self._check_from_fields_exist(), *self._check_unique_target(), ] @@ -605,6 +606,29 @@ class ForeignObject(RelatedField): ) return errors + def _check_from_fields_exist(self): + errors = [] + + if not self.from_fields or self.from_fields == [ + RECURSIVE_RELATIONSHIP_CONSTANT + ]: + return errors + + for from_field in self.from_fields: + try: + self.model._meta.get_field(from_field) + except exceptions.FieldDoesNotExist: + errors.append( + checks.Error( + "The from_field '%s' doesn't exist on the model '%s'." + % (from_field, self.model._meta.label), + obj=self, + id="fields.E312", + ) + ) + + return errors + def _check_unique_target(self): rel_is_string = isinstance(self.remote_field.model, str) if rel_is_string or not self.requires_unique_target: @@ -959,8 +983,12 @@ class ForeignKey(ForeignObject): parent_link=False, to_field=None, db_constraint=True, + from_fields=None, + to_fields=None, **kwargs, ): + if to_field is not None and to_fields is not None: + raise ValueError("Cannot specify both 'to_field' and 'to_fields'.") try: to._meta.model_name except AttributeError: @@ -1000,8 +1028,12 @@ class ForeignKey(ForeignObject): related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, - from_fields=[RECURSIVE_RELATIONSHIP_CONSTANT], - to_fields=[to_field], + from_fields=( + [RECURSIVE_RELATIONSHIP_CONSTANT] + if from_fields is None + else from_fields + ), + to_fields=[to_field] if to_fields is None else to_fields, **kwargs, ) self.db_constraint = db_constraint @@ -1062,8 +1094,9 @@ class ForeignKey(ForeignObject): def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs["to_fields"] - del kwargs["from_fields"] + if len(kwargs["from_fields"]) == len(kwargs["to_fields"]) <= 1: + del kwargs["to_fields"] + del kwargs["from_fields"] # Handle the simpler arguments if self.db_index: del kwargs["db_index"] @@ -1131,10 +1164,14 @@ class ForeignKey(ForeignObject): return related_fields def get_attname(self): + if len(self.from_fields) > 1: + return self.name return "%s_id" % self.name def get_attname_column(self): attname = self.get_attname() + if len(self.from_fields) > 1: + return attname, None column = self.db_column or attname return attname, column @@ -1161,6 +1198,8 @@ class ForeignKey(ForeignObject): return self.target_field.get_db_prep_value(value, connection, prepared) def get_prep_value(self, value): + if len(self.from_fields) > 1: + return super().get_prep_value(value) return self.target_field.get_prep_value(value) def contribute_to_related_class(self, cls, related): diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 5b0f127c6f..77bbc33ffd 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1910,6 +1910,26 @@ The possible values for :attr:`~ForeignKey.on_delete` are found in uses the primary key of the related object. If you reference a different field, that field must have ``unique=True``. +.. attribute:: ForeignKey.to_fields + + The fields on the remote side of the composite foreign key. + + If ``to_fields`` is set, :attr:`from_fields` must be set too. + + A composite foreign key is similar to a regular foreign key, except that + Django doesn't create any columns implicitly - the columns have to be + defined with ``from_fields`` and ``to_fields`` explicitly instead. + +.. attribute:: ForeignKey.from_fields + + The fields on the local side of the composite foreign key. + + If ``from_fields`` is set, :attr:`to_fields` must be set too. + + A composite foreign key is similar to a regular foreign key, except that + Django doesn't create any columns implicitly - the columns have to be + defined with ``from_fields`` and ``to_fields`` explicitly instead. + .. attribute:: ForeignKey.db_constraint Controls whether or not a constraint should be created in the database for diff --git a/tests/backends/test_ddl_references.py b/tests/backends/test_ddl_references.py index 8975b97124..7a8ac90122 100644 --- a/tests/backends/test_ddl_references.py +++ b/tests/backends/test_ddl_references.py @@ -98,10 +98,7 @@ class IndexNameTests(ColumnsTests): class ForeignKeyNameTests(IndexNameTests): def setUp(self): def create_foreign_key_name(table_name, column_names, suffix): - return ", ".join( - "%s_%s_%s" % (table_name, column_name, suffix) - for column_name in column_names - ) + return "%s_%s_%s" % (table_name, "_".join(column_names), suffix) self.reference = ForeignKeyName( "table", @@ -153,15 +150,15 @@ class ForeignKeyNameTests(IndexNameTests): def test_repr(self): self.assertEqual( repr(self.reference), - "", + "", ) def test_str(self): self.assertEqual( str(self.reference), - "table_first_column_to_table_to_first_column_fk, " - "table_second_column_to_table_to_first_column_fk", + "table_first_column_second_column_" + "to_table_to_first_column_to_second_column_fk", ) diff --git a/tests/composite_fk/__init__.py b/tests/composite_fk/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/composite_fk/models/__init__.py b/tests/composite_fk/models/__init__.py new file mode 100644 index 0000000000..beb137af76 --- /dev/null +++ b/tests/composite_fk/models/__init__.py @@ -0,0 +1,7 @@ +from .tenant import Comment, Tenant, User + +__all__ = [ + "Tenant", + "User", + "Comment", +] diff --git a/tests/composite_fk/models/tenant.py b/tests/composite_fk/models/tenant.py new file mode 100644 index 0000000000..6ca18f79e4 --- /dev/null +++ b/tests/composite_fk/models/tenant.py @@ -0,0 +1,24 @@ +from django.db import models + + +class Tenant(models.Model): + pass + + +class User(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + + class Meta: + unique_together = ("tenant_id", "id") + + +class Comment(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + user_id = models.IntegerField() + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + from_fields=("tenant_id", "user_id"), + to_fields=("tenant_id", "id"), + related_name="comments", + ) diff --git a/tests/composite_fk/test_checks.py b/tests/composite_fk/test_checks.py new file mode 100644 index 0000000000..c67ac1feeb --- /dev/null +++ b/tests/composite_fk/test_checks.py @@ -0,0 +1,115 @@ +from django.core import checks +from django.db import models +from django.test import TestCase +from django.test.utils import isolate_apps + + +@isolate_apps("composite_fk") +class CompositeFKChecksTests(TestCase): + def test_from_and_to_fields_must_be_same_length(self): + test_cases = [ + {"to_fields": ("foo_id", "id")}, + {"from_fields": ("foo_id", "id")}, + {"from_fields": ("id",), "to_fields": ("foo_id", "id")}, + {"from_fields": (), "to_fields": ()}, + ] + + for kwargs in test_cases: + with ( + self.subTest(kwargs=kwargs), + self.assertRaisesMessage( + ValueError, + "Foreign Object from and to fields must be the same non-zero " + "length", + ), + ): + fk = models.ForeignKey("Foo", on_delete=models.CASCADE, **kwargs) + self.assertIsNotNone(fk.related_fields) + + def test_to_field_conflicts_with_to_fields(self): + with self.assertRaisesMessage( + ValueError, "Cannot specify both 'to_field' and 'to_fields'." + ): + self.assertIsNotNone( + models.ForeignKey( + "Foo", + on_delete=models.CASCADE, + to_field="foo_id", + to_fields=["bar_id"], + ) + ) + + def test_to_fields_doesnt_exist(self): + class Foo(models.Model): + pass + + class Bar(models.Model): + foo_id = models.IntegerField() + foo = models.ForeignKey( + Foo, + on_delete=models.CASCADE, + from_fields=["foo_id", "id"], + to_fields=["id", "bar_id"], + ) + + self.assertEqual( + Bar.check(), + [ + checks.Error( + "The to_field 'bar_id' doesn't exist on the related model " + "'composite_fk.Foo'.", + obj=Bar._meta.get_field("foo"), + id="fields.E312", + ) + ], + ) + + def test_from_fields_doesnt_exist(self): + class Foo(models.Model): + bar_id = models.IntegerField() + + class Bar(models.Model): + foo_id = models.IntegerField() + foo = models.ForeignKey( + Foo, + on_delete=models.CASCADE, + from_fields=["foo_id", "baz_id"], + to_fields=["id", "bar_id"], + ) + + self.assertEqual( + Bar.check(), + [ + checks.Error( + "The from_field 'baz_id' doesn't exist on the model " + "'composite_fk.Bar'.", + obj=Bar._meta.get_field("foo"), + id="fields.E312", + ) + ], + ) + + def test_self_cant_be_used_in_from_fields(self): + class Foo(models.Model): + bar_id = models.IntegerField() + + class Bar(models.Model): + foo_id = models.IntegerField() + foo = models.ForeignKey( + Foo, + on_delete=models.CASCADE, + from_fields=["self", "foo_id"], + to_fields=["bar_id", "id"], + ) + + self.assertEqual( + Bar.check(), + [ + checks.Error( + "The from_field 'self' doesn't exist on the model " + "'composite_fk.Bar'.", + obj=Bar._meta.get_field("foo"), + id="fields.E312", + ) + ], + ) diff --git a/tests/composite_fk/tests.py b/tests/composite_fk/tests.py new file mode 100644 index 0000000000..de256b55fa --- /dev/null +++ b/tests/composite_fk/tests.py @@ -0,0 +1,141 @@ +import re +from unittest import skipUnless + +from django.db import connection +from django.test import TestCase + +from .models import Comment, Tenant, User + + +class CompositeFKTests(TestCase): + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.tenant_3 = Tenant.objects.create() + cls.user_1 = User.objects.create(tenant=cls.tenant_1) + cls.user_2 = User.objects.create(tenant=cls.tenant_1) + cls.user_3 = User.objects.create(tenant=cls.tenant_2) + cls.user_4 = User.objects.create(tenant=cls.tenant_2) + cls.comment_1 = Comment.objects.create(user=cls.user_1) + cls.comment_2 = Comment.objects.create(user=cls.user_1) + cls.comment_3 = Comment.objects.create(user=cls.user_2) + cls.comment_4 = Comment.objects.create(user=cls.user_3) + + @staticmethod + def get_constraints(table): + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + @staticmethod + def get_table_description(table): + with connection.cursor() as cursor: + return connection.introspection.get_table_description(cursor, table) + + @skipUnless(connection.vendor == "postgresql", "PostgreSQL specific SQL") + def test_get_constraints_postgresql(self): + constraints = self.get_constraints("composite_fk_comment") + keys = list(constraints.keys()) + + fk_pattern = re.compile( + r"composite_fk_comment_tenant_id_user_id_[\w]{8}_fk_composite" + ) + fk_key = next(key for key in keys if fk_pattern.fullmatch(key)) + fk_constraint = constraints[fk_key] + self.assertEqual(fk_constraint["columns"], ["tenant_id", "user_id"]) + self.assertEqual( + fk_constraint["foreign_key"], ("composite_fk_user", "tenant_id", "id") + ) + + idx_pattern = re.compile(r"composite_fk_comment_tenant_id_user_id_[\w]{8}") + idx_key = next(key for key in keys if idx_pattern.fullmatch(key)) + idx_constraint = constraints[idx_key] + self.assertEqual(idx_constraint["columns"], ["tenant_id", "user_id"]) + self.assertTrue(idx_constraint["index"]) + self.assertEqual(idx_constraint["orders"], ["ASC", "ASC"]) + + @skipUnless(connection.vendor == "mysql", "MySQL specific SQL") + def test_get_constraints_mysql(self): + constraints = self.get_constraints("composite_fk_comment") + keys = list(constraints.keys()) + + fk_pattern = re.compile( + r"composite_fk_comment_tenant_id_user_id_[\w]{8}_fk_composite" + ) + fk_key = next(key for key in keys if fk_pattern.fullmatch(key)) + fk_constraint = constraints[fk_key] + self.assertEqual(fk_constraint["columns"], ["tenant_id", "user_id"]) + self.assertTrue(fk_constraint["index"]) + self.assertEqual( + fk_constraint["foreign_key"], ("composite_fk_user", "tenant_id", "id") + ) + + @skipUnless(connection.vendor == "oracle", "Oracle specific SQL") + def test_get_constraints_oracle(self): + constraints = self.get_constraints("composite_fk_comment") + keys = list(constraints.keys()) + + fk_pattern = re.compile(r"composite_tenant_id_[\w]{8}_f") + fk_key = next( + key + for key in keys + if fk_pattern.fullmatch(key) and len(constraints[key]["columns"]) == 2 + ) + fk_constraint = constraints[fk_key] + self.assertEqual(fk_constraint["columns"], ["tenant_id", "user_id"]) + self.assertEqual( + fk_constraint["foreign_key"], ("composite_fk_user", "tenant_id", "id") + ) + + idx_pattern = re.compile(r"composite__tenant_id__[\w]{8}") + idx_key = next(key for key in keys if idx_pattern.fullmatch(key)) + idx_constraint = constraints[idx_key] + self.assertEqual(idx_constraint["columns"], ["tenant_id", "user_id"]) + self.assertTrue(idx_constraint["index"]) + self.assertEqual(idx_constraint["orders"], ["ASC", "ASC"]) + + def test_table_description(self): + table_description = self.get_table_description("composite_fk_comment") + self.assertEqual( + ["id", "tenant_id", "user_id"], + [field_info.name for field_info in table_description], + ) + + def test_get_field(self): + user = Comment._meta.get_field("user") + user_id = Comment._meta.get_field("user_id") + self.assertEqual(user.get_internal_type(), "ForeignKey") + self.assertEqual(user_id.get_internal_type(), "IntegerField") + + def test_fields(self): + # user_1 + self.assertSequenceEqual( + self.user_1.comments.all(), (self.comment_1, self.comment_2) + ) + # user_2 + self.assertSequenceEqual(self.user_2.comments.all(), (self.comment_3,)) + # user_3 + self.assertSequenceEqual(self.user_3.comments.all(), (self.comment_4,)) + # user_4 + self.assertSequenceEqual(self.user_4.comments.all(), ()) + # comment_1 + self.assertEqual(self.comment_1.user, self.user_1) + self.assertEqual(self.comment_1.user_id, self.user_1.id) + self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id) + self.assertEqual(self.comment_1.tenant, self.tenant_1) + # comment_2 + self.assertEqual(self.comment_2.user, self.user_1) + self.assertEqual(self.comment_2.user_id, self.user_1.id) + self.assertEqual(self.comment_2.tenant_id, self.tenant_1.id) + self.assertEqual(self.comment_2.tenant, self.tenant_1) + # comment_3 + self.assertEqual(self.comment_3.user, self.user_2) + self.assertEqual(self.comment_3.user_id, self.user_2.id) + self.assertEqual(self.comment_3.tenant_id, self.tenant_1.id) + self.assertEqual(self.comment_3.tenant, self.tenant_1) + # comment_4 + self.assertEqual(self.comment_4.user, self.user_3) + self.assertEqual(self.comment_4.user_id, self.user_3.id) + self.assertEqual(self.comment_4.tenant_id, self.tenant_2.id) + self.assertEqual(self.comment_4.tenant, self.tenant_2) diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 953a3cdb6c..fcb2a630ec 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -1157,3 +1157,28 @@ class WriterTests(SimpleTestCase): output = writer.as_string() self.assertEqual(output.count("import"), 1) self.assertIn("from django.db import migrations, models", output) + + def test_composite_fk(self): + migration = type( + "Migration", + (migrations.Migration,), + { + "operations": [ + migrations.AddField( + "comment", + "user", + models.ForeignKey( + "testapp.User", + models.CASCADE, + from_fields=("tenant_id", "user_id"), + to_fields=("tenant_id", "id"), + ), + ), + ] + }, + ) + + writer = MigrationWriter(migration) + output = writer.as_string() + self.assertIn("from_fields", output) + self.assertIn("to_fields", output) diff --git a/tests/model_meta/models.py b/tests/model_meta/models.py index 20a75baf4f..e8016709e4 100644 --- a/tests/model_meta/models.py +++ b/tests/model_meta/models.py @@ -30,10 +30,11 @@ class AbstractPerson(models.Model): ) # VIRTUAL fields + relation_id = models.IntegerField() data_not_concrete_abstract = models.ForeignObject( Relation, on_delete=models.CASCADE, - from_fields=["abstract_non_concrete_id"], + from_fields=["relation_id"], to_fields=["id"], related_name="fo_abstract_rel", ) @@ -76,7 +77,7 @@ class BasePerson(AbstractPerson): data_not_concrete_base = models.ForeignObject( Relation, on_delete=models.CASCADE, - from_fields=["base_non_concrete_id"], + from_fields=["relation_id"], to_fields=["id"], related_name="fo_base_rel", ) @@ -108,7 +109,7 @@ class Person(BasePerson): data_not_concrete_inherited = models.ForeignObject( Relation, on_delete=models.CASCADE, - from_fields=["model_non_concrete_id"], + from_fields=["relation_id"], to_fields=["id"], related_name="fo_concrete_rel", ) diff --git a/tests/model_meta/results.py b/tests/model_meta/results.py index 2b942ee814..681e11f50b 100644 --- a/tests/model_meta/results.py +++ b/tests/model_meta/results.py @@ -115,6 +115,7 @@ TEST_RESULTS = { "id", "data_abstract", "fk_abstract_id", + "relation_id", "data_not_concrete_abstract", "content_type_abstract_id", "object_id_abstract", @@ -134,6 +135,7 @@ TEST_RESULTS = { "id", "data_abstract", "fk_abstract_id", + "relation_id", "data_not_concrete_abstract", "content_type_abstract_id", "object_id_abstract", @@ -146,6 +148,7 @@ TEST_RESULTS = { AbstractPerson: [ "data_abstract", "fk_abstract_id", + "relation_id", "data_not_concrete_abstract", "content_type_abstract_id", "object_id_abstract", @@ -173,6 +176,7 @@ TEST_RESULTS = { "id", "data_abstract", "fk_abstract_id", + "relation_id", "data_not_concrete_abstract", "content_type_abstract_id", "object_id_abstract", @@ -185,6 +189,7 @@ TEST_RESULTS = { AbstractPerson: [ "data_abstract", "fk_abstract_id", + "relation_id", "data_not_concrete_abstract", "content_type_abstract_id", "object_id_abstract", @@ -211,6 +216,7 @@ TEST_RESULTS = { "id", "data_abstract", "fk_abstract_id", + "relation_id", "content_type_abstract_id", "object_id_abstract", "data_base", @@ -221,6 +227,7 @@ TEST_RESULTS = { AbstractPerson: [ "data_abstract", "fk_abstract_id", + "relation_id", "content_type_abstract_id", "object_id_abstract", ],