1
0
mirror of https://github.com/django/django.git synced 2025-01-03 06:55:47 +00:00

Fixed #34587 -- Allowed customizing table name normalization in inspectdb command.

This commit is contained in:
Andrii Kohut 2023-05-22 00:00:25 +03:00 committed by Mariusz Felisiak
parent c3862735cd
commit f8172f45fc
3 changed files with 31 additions and 5 deletions

View File

@ -56,9 +56,6 @@ class Command(BaseCommand):
# 'table_name_filter' is a stealth option # 'table_name_filter' is a stealth option
table_name_filter = options.get("table_name_filter") table_name_filter = options.get("table_name_filter")
def table2model(table_name):
return re.sub(r"[^a-zA-Z0-9]", "", table_name.title())
with connection.cursor() as cursor: with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module." yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:" yield "# You'll have to do the following manually to clean this up:"
@ -125,7 +122,7 @@ class Command(BaseCommand):
yield "# The error was: %s" % e yield "# The error was: %s" % e
continue continue
model_name = table2model(table_name) model_name = self.normalize_table_name(table_name)
yield "" yield ""
yield "" yield ""
yield "class %s(models.Model):" % model_name yield "class %s(models.Model):" % model_name
@ -180,7 +177,7 @@ class Command(BaseCommand):
rel_to = ( rel_to = (
"self" "self"
if ref_db_table == table_name if ref_db_table == table_name
else table2model(ref_db_table) else self.normalize_table_name(ref_db_table)
) )
if rel_to in known_models: if rel_to in known_models:
field_type = "%s(%s" % (rel_type, rel_to) field_type = "%s(%s" % (rel_type, rel_to)
@ -322,6 +319,10 @@ class Command(BaseCommand):
return new_name, field_params, field_notes return new_name, field_params, field_notes
def normalize_table_name(self, table_name):
"""Translate the table name to a Python-compatible model name."""
return re.sub(r"[^a-zA-Z0-9]", "", table_name.title())
def get_field_type(self, connection, table_name, row): def get_field_type(self, connection, table_name, row):
""" """
Given the database connection, the table name, and the cursor row Given the database connection, the table name, and the cursor row

View File

@ -51,6 +51,11 @@ class SpecialName(models.Model):
db_table = "inspectdb_special.table name" db_table = "inspectdb_special.table name"
class PascalCaseName(models.Model):
class Meta:
db_table = "inspectdb_pascal.PascalCase"
class ColumnTypes(models.Model): class ColumnTypes(models.Model):
id = models.AutoField(primary_key=True) id = models.AutoField(primary_key=True)
big_int_field = models.BigIntegerField() big_int_field = models.BigIntegerField()

View File

@ -3,6 +3,7 @@ from io import StringIO
from unittest import mock, skipUnless from unittest import mock, skipUnless
from django.core.management import call_command from django.core.management import call_command
from django.core.management.commands import inspectdb
from django.db import connection from django.db import connection
from django.db.backends.base.introspection import TableInfo from django.db.backends.base.introspection import TableInfo
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
@ -354,6 +355,25 @@ class InspectDBTestCase(TestCase):
output = out.getvalue() output = out.getvalue()
self.assertIn("class InspectdbSpecialTableName(models.Model):", output) self.assertIn("class InspectdbSpecialTableName(models.Model):", output)
def test_custom_normalize_table_name(self):
def pascal_case_table_only(table_name):
return table_name.startswith("inspectdb_pascal")
class MyCommand(inspectdb.Command):
def normalize_table_name(self, table_name):
normalized_name = table_name.split(".")[1]
if connection.features.ignores_table_name_case:
normalized_name = normalized_name.lower()
return normalized_name
out = StringIO()
call_command(MyCommand(), table_name_filter=pascal_case_table_only, stdout=out)
if connection.features.ignores_table_name_case:
expected_model_name = "pascalcase"
else:
expected_model_name = "PascalCase"
self.assertIn(f"class {expected_model_name}(models.Model):", out.getvalue())
@skipUnlessDBFeature("supports_expression_indexes") @skipUnlessDBFeature("supports_expression_indexes")
def test_table_with_func_unique_constraint(self): def test_table_with_func_unique_constraint(self):
out = StringIO() out = StringIO()