diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index 3194aecacb..5c2ed53db8 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -56,9 +56,6 @@ class Command(BaseCommand): # 'table_name_filter' is a stealth option table_name_filter = options.get("table_name_filter") - def table2model(table_name): - return re.sub(r"[^a-zA-Z0-9]", "", table_name.title()) - with connection.cursor() as cursor: yield "# This is an auto-generated Django model module." yield "# You'll have to do the following manually to clean this up:" @@ -125,7 +122,7 @@ class Command(BaseCommand): yield "# The error was: %s" % e continue - model_name = table2model(table_name) + model_name = self.normalize_table_name(table_name) yield "" yield "" yield "class %s(models.Model):" % model_name @@ -180,7 +177,7 @@ class Command(BaseCommand): rel_to = ( "self" if ref_db_table == table_name - else table2model(ref_db_table) + else self.normalize_table_name(ref_db_table) ) if rel_to in known_models: field_type = "%s(%s" % (rel_type, rel_to) @@ -322,6 +319,10 @@ class Command(BaseCommand): return new_name, field_params, field_notes + def normalize_table_name(self, table_name): + """Translate the table name to a Python-compatible model name.""" + return re.sub(r"[^a-zA-Z0-9]", "", table_name.title()) + def get_field_type(self, connection, table_name, row): """ Given the database connection, the table name, and the cursor row diff --git a/tests/inspectdb/models.py b/tests/inspectdb/models.py index 9e6871ce46..25714cb086 100644 --- a/tests/inspectdb/models.py +++ b/tests/inspectdb/models.py @@ -51,6 +51,11 @@ class SpecialName(models.Model): db_table = "inspectdb_special.table name" +class PascalCaseName(models.Model): + class Meta: + db_table = "inspectdb_pascal.PascalCase" + + class ColumnTypes(models.Model): id = models.AutoField(primary_key=True) big_int_field = models.BigIntegerField() diff --git a/tests/inspectdb/tests.py b/tests/inspectdb/tests.py index 4f44190686..1be4efc430 100644 --- a/tests/inspectdb/tests.py +++ b/tests/inspectdb/tests.py @@ -3,6 +3,7 @@ from io import StringIO from unittest import mock, skipUnless from django.core.management import call_command +from django.core.management.commands import inspectdb from django.db import connection from django.db.backends.base.introspection import TableInfo from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature @@ -354,6 +355,25 @@ class InspectDBTestCase(TestCase): output = out.getvalue() self.assertIn("class InspectdbSpecialTableName(models.Model):", output) + def test_custom_normalize_table_name(self): + def pascal_case_table_only(table_name): + return table_name.startswith("inspectdb_pascal") + + class MyCommand(inspectdb.Command): + def normalize_table_name(self, table_name): + normalized_name = table_name.split(".")[1] + if connection.features.ignores_table_name_case: + normalized_name = normalized_name.lower() + return normalized_name + + out = StringIO() + call_command(MyCommand(), table_name_filter=pascal_case_table_only, stdout=out) + if connection.features.ignores_table_name_case: + expected_model_name = "pascalcase" + else: + expected_model_name = "PascalCase" + self.assertIn(f"class {expected_model_name}(models.Model):", out.getvalue()) + @skipUnlessDBFeature("supports_expression_indexes") def test_table_with_func_unique_constraint(self): out = StringIO()