From c0132e88f6cdb6dbaa843e9075ae97419b8709d9 Mon Sep 17 00:00:00 2001 From: Jason Pellerin Date: Tue, 11 Jul 2006 03:16:28 +0000 Subject: [PATCH] [multi-db] Added preliminary drop-table generation to django.db.backends.ansi.sql.SchemaBuilder. git-svn-id: http://code.djangoproject.com/svn/django/branches/multiple-db-support@3320 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/backends/ansi/sql.py | 77 ++++++++++++++++++++--- django/db/models/manager.py | 5 +- tests/othertests/ansi_sql.py | 109 ++++++++++++++++++++++----------- 3 files changed, 143 insertions(+), 48 deletions(-) diff --git a/django/db/backends/ansi/sql.py b/django/db/backends/ansi/sql.py index 6268515831..9f13b7761f 100644 --- a/django/db/backends/ansi/sql.py +++ b/django/db/backends/ansi/sql.py @@ -43,8 +43,13 @@ class SchemaBuilder(object): or other constraints. """ def __init__(self): + # models that I have created self.models_already_seen = set() - + # model references, keyed by the referrent model + self.references = {} + # table cache; set to short-circuit table lookups + self.tables = None + def get_create_table(self, model, style=None): """Construct and return the SQL expression(s) needed to create the table for the given model, and any constraints on that @@ -218,8 +223,8 @@ class SchemaBuilder(object): def get_drop_table(self, model, cascade=False, style=None): """Construct and return the SQL statment(s) needed to drop a model's table. If cascade is true, then output additional statments to drop any - dependant man-many tables and drop any foreign keys that reference - this table. + many-to-many tables that this table created and any foreign keys that + reference this table. """ if style is None: style = default_style @@ -227,16 +232,45 @@ class SchemaBuilder(object): info = opts.connection_info db_table = opts.db_table backend = info.backend + qn = backend.quote_name output = [] output.append(BoundStatement( '%s %s;' % (style.SQL_KEYWORD('DROP TABLE'), - style.SQL_TABLE(backend.quote_name(db_table))), + style.SQL_TABLE(qn(db_table))), info.connection)) if cascade: - # FIXME deal with my foreign keys, others that might have a foreign - # key TO me, and many-many - pass + # deal with others that might have a foreign key TO me: alter + # their tables to drop the constraint + if backend.supports_constraints: + references_to_delete = self.get_references() + if model in references_to_delete: + for rel_class, f in references_to_delete[model]: + table = rel_class._meta.db_table + if not self.table_exists(info, table): + continue + col = f.column + r_table = opts.db_table + r_col = opts.get_field(f.rel.field_name).column + output.append(BoundStatement( + '%s %s %s %s;' % + (style.SQL_KEYWORD('ALTER TABLE'), + style.SQL_TABLE(qn(table)), + style.SQL_KEYWORD( + backend.get_drop_foreignkey_sql()), + style.SQL_FIELD(qn("%s_referencing_%s_%s" % + (col, r_table, r_col)))), + info.connection)) + del references_to_delete[model] + # many to many: drop any many-many tables that are my + # responsiblity + for f in opts.many_to_many: + if not isinstance(f.rel, models.GenericRel): + output.append(BoundStatement( + '%s %s;' % + (style.SQL_KEYWORD('DROP TABLE'), + style.SQL_TABLE(qn(f.m2m_db_table()))), + info.connection)) # Reverse it, to deal with table dependencies. output.reverse() return output @@ -273,11 +307,36 @@ class SchemaBuilder(object): def get_initialdata_path(self, model): """Get the path from which to load sql initial data files for a model. """ - return os.path.normpath(os.path.join(os.path.dirname(models.get_app(model._meta.app_label).__file__), 'sql')) + return os.path.normpath(os.path.join(os.path.dirname( + models.get_app(model._meta.app_label).__file__), 'sql')) def get_rel_data_type(self, f): return (f.get_internal_type() in ('AutoField', 'PositiveIntegerField', 'PositiveSmallIntegerField')) \ and 'IntegerField' \ or f.get_internal_type() - + + def get_references(self): + """Fill (if needed) and return the reference cache. + """ + if self.references: + return self.references + for klass in models.get_models(): + for f in klass._meta.fields: + if f.rel: + self.references.setdefault(f.rel.to, []).append((klass, f)) + return self.references + + def get_table_list(self, connection_info): + """Get list of tables accessible via the connection described by + connection_info. + """ + if self.tables is not None: + return self.tables + cursor = info.connection.cursor() + introspection = connection_info.get_introspection_module() + return introspection.get_table_list(cursor) + + def table_exists(self, connection_info, table): + tables = self.get_table_list(connection_info) + return table in tables diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 04debd7cc6..450218650c 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -159,9 +159,8 @@ class Manager(object): """Get list of tables accessible via my model's connection. """ info = self.model._meta.connection_info - cursor = info.connection.cursor() - introspect = info.get_introspection_module() - return introspect.get_table_list(cursor) + builder = info.get_creation_module.builder() + return builder.get_table_list(info) class ManagerDescriptor(object): # This class ensures managers aren't accessible via model instances. diff --git a/tests/othertests/ansi_sql.py b/tests/othertests/ansi_sql.py index de12eaf590..f51ba7d0b2 100644 --- a/tests/othertests/ansi_sql.py +++ b/tests/othertests/ansi_sql.py @@ -1,52 +1,24 @@ -# For Python 2.3 -if not hasattr(__builtins__, 'set'): - from sets import Set as set - """ ->>> from django.db import models >>> from django.db.backends.ansi import sql -# test models ->>> class Car(models.Model): -... make = models.CharField(maxlength=32) -... model = models.CharField(maxlength=32) -... year = models.IntegerField() -... condition = models.CharField(maxlength=32) -... -... class Meta: -... app_label = 'ansi_sql' - ->>> class Collector(models.Model): -... name = models.CharField(maxlength=32) -... cars = models.ManyToManyField(Car) -... -... class Meta: -... app_label = 'ansi_sql' - ->>> class Mod(models.Model): -... car = models.ForeignKey(Car) -... part = models.CharField(maxlength=32, db_index=True) -... description = models.TextField() -... -... class Meta: -... app_label = 'ansi_sql' +# so we can test with a predicatable constraint setting +>>> real_cnst = Mod._meta.connection_info.backend.supports_constraints +>>> Mod._meta.connection_info.backend.supports_constraints = True # generate create sql >>> builder = sql.SchemaBuilder() >>> builder.get_create_table(Car) -([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], []) +([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], {}) >>> builder.models_already_seen -[] +Set([]) >>> builder.models_already_seen = set() # test that styles are used >>> builder.get_create_table(Car, style=mockstyle()) -([BoundStatement('SQL_KEYWORD(CREATE TABLE) SQL_TABLE("ansi_sql_car") (...SQL_FIELD("id")...);')], []) +([BoundStatement('SQL_KEYWORD(CREATE TABLE) SQL_TABLE("ansi_sql_car") (...SQL_FIELD("id")...);')], {}) # test pending relationships >>> builder.models_already_seen = set() ->>> real_cnst = Mod._meta.connection_info.backend.supports_constraints ->>> Mod._meta.connection_info.backend.supports_constraints = True >>> builder.get_create_table(Mod) ([BoundStatement('CREATE TABLE "ansi_sql_mod" (..."car_id" integer NOT NULL,...);')], {: [BoundStatement('ALTER TABLE "ansi_sql_mod" ADD CONSTRAINT ... FOREIGN KEY ("car_id") REFERENCES "ansi_sql_car" ("id");')]}) >>> builder.models_already_seen = set() @@ -54,7 +26,6 @@ if not hasattr(__builtins__, 'set'): ([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], {}) >>> builder.get_create_table(Mod) ([BoundStatement('CREATE TABLE "ansi_sql_mod" (..."car_id" integer NOT NULL REFERENCES "ansi_sql_car" ("id"),...);')], {}) ->>> Mod._meta.connection_info.backend.supports_constraints = real_cnst # test many-many >>> builder.get_create_table(Collector) @@ -75,16 +46,82 @@ if not hasattr(__builtins__, 'set'): >>> builder.get_initialdata_path = othertests_sql >>> builder.get_initialdata(Car) [BoundStatement('insert into ansi_sql_car (...)...values (...);')] + +# test drop +>>> builder.get_drop_table(Mod) +[BoundStatement('DROP TABLE "ansi_sql_mod";')] +>>> builder.get_drop_table(Mod, cascade=True) +[BoundStatement('DROP TABLE "ansi_sql_mod";')] +>>> builder.get_drop_table(Car) +[BoundStatement('DROP TABLE "ansi_sql_car";')] +>>> builder.get_drop_table(Car, cascade=True) +[BoundStatement('DROP TABLE "ansi_sql_car";')] + +>>> builder.tables = ['ansi_sql_car', 'ansi_sql_mod', 'ansi_sql_collector'] +>>> Mod._meta.connection_info.backend.supports_constraints = False +>>> builder.get_drop_table(Car, cascade=True) +[BoundStatement('DROP TABLE "ansi_sql_car";')] +>>> Mod._meta.connection_info.backend.supports_constraints = True +>>> builder.get_drop_table(Car, cascade=True) +[BoundStatement('ALTER TABLE "ansi_sql_mod" ...'), BoundStatement('DROP TABLE "ansi_sql_car";')] +>>> builder.get_drop_table(Collector) +[BoundStatement('DROP TABLE "ansi_sql_collector";')] +>>> builder.get_drop_table(Collector, cascade=True) +[BoundStatement('DROP TABLE "ansi_sql_collector_cars";'), BoundStatement('DROP TABLE "ansi_sql_collector";')] +>>> Mod._meta.connection_info.backend.supports_constraints = real_cnst + """ import os +from django.db import models +from django.core.management import install + +# For Python 2.3 +if not hasattr(__builtins__, 'set'): + from sets import Set as set + + +# test models +class Car(models.Model): + make = models.CharField(maxlength=32) + model = models.CharField(maxlength=32) + year = models.IntegerField() + condition = models.CharField(maxlength=32) + + class Meta: + app_label = 'ansi_sql' + + +class Collector(models.Model): + name = models.CharField(maxlength=32) + cars = models.ManyToManyField(Car) + + class Meta: + app_label = 'ansi_sql' + + +class Mod(models.Model): + car = models.ForeignKey(Car) + part = models.CharField(maxlength=32, db_index=True) + description = models.TextField() + + class Meta: + app_label = 'ansi_sql' + -# mock style that wraps text in STYLE(text), for testing class mockstyle: + """mock style that wraps text in STYLE(text), for testing""" def __getattr__(self, attr): if attr in ('ERROR', 'ERROR_OUTPUT', 'SQL_FIELD', 'SQL_COLTYPE', 'SQL_KEYWORD', 'SQL_TABLE'): return lambda text: "%s(%s)" % (attr, text) + def othertests_sql(mod): """Look in othertests/sql for sql initialdata""" return os.path.normpath(os.path.join(os.path.dirname(__file__), 'sql')) + + +# install my stuff +Car.objects.install() +Collector.objects.install() +Mod.objects.install()