From f4e4ae96cb69d616af6f7009512cdcd2cb7c1e39 Mon Sep 17 00:00:00 2001
From: Adrian Holovaty <adrian@holovaty.com>
Date: Tue, 2 Aug 2005 22:33:39 +0000
Subject: [PATCH] Improved 'django-admin inspectdb' so that it detects
 ForeignKey relationships -- PostgreSQL only

git-svn-id: http://code.djangoproject.com/svn/django/trunk@395 bcc190cf-cafb-0310-a4f2-bffc1f526a37
---
 django/bin/django-admin.py            |  2 +-
 django/core/db/__init__.py            |  1 +
 django/core/db/backends/mysql.py      |  3 +++
 django/core/db/backends/postgresql.py | 21 +++++++++++++++++
 django/core/db/backends/sqlite3.py    |  3 +++
 django/core/management.py             | 34 ++++++++++++++++++++-------
 6 files changed, 55 insertions(+), 9 deletions(-)

diff --git a/django/bin/django-admin.py b/django/bin/django-admin.py
index 8c4370f427..fc9ffec373 100755
--- a/django/bin/django-admin.py
+++ b/django/bin/django-admin.py
@@ -76,7 +76,7 @@ def main():
             for line in ACTION_MAPPING[action](param):
                 print line
         except NotImplementedError:
-            sys.stderr.write("Error: %r isn't supported for the currently selected database backend." % action)
+            sys.stderr.write("Error: %r isn't supported for the currently selected database backend.\n" % action)
             sys.exit(1)
     elif action in ('startapp', 'startproject'):
         try:
diff --git a/django/core/db/__init__.py b/django/core/db/__init__.py
index 4ecf33770c..43ca2b2619 100644
--- a/django/core/db/__init__.py
+++ b/django/core/db/__init__.py
@@ -38,6 +38,7 @@ get_last_insert_id = dbmod.get_last_insert_id
 get_date_extract_sql = dbmod.get_date_extract_sql
 get_date_trunc_sql = dbmod.get_date_trunc_sql
 get_table_list = dbmod.get_table_list
+get_relations = dbmod.get_relations
 OPERATOR_MAPPING = dbmod.OPERATOR_MAPPING
 DATA_TYPES = dbmod.DATA_TYPES
 DATA_TYPES_REVERSE = dbmod.DATA_TYPES_REVERSE
diff --git a/django/core/db/backends/mysql.py b/django/core/db/backends/mysql.py
index aa99516209..886a0a38b8 100644
--- a/django/core/db/backends/mysql.py
+++ b/django/core/db/backends/mysql.py
@@ -73,6 +73,9 @@ def get_table_list(cursor):
     cursor.execute("SHOW TABLES")
     return [row[0] for row in cursor.fetchall()]
 
+def get_relations(cursor, table_name):
+    raise NotImplementedError
+
 OPERATOR_MAPPING = {
     'exact': '=',
     'iexact': 'LIKE',
diff --git a/django/core/db/backends/postgresql.py b/django/core/db/backends/postgresql.py
index cd8d9a064e..94a6ed35c3 100644
--- a/django/core/db/backends/postgresql.py
+++ b/django/core/db/backends/postgresql.py
@@ -82,6 +82,27 @@ def get_table_list(cursor):
             AND pg_catalog.pg_table_is_visible(c.oid)""")
     return [row[0] for row in cursor.fetchall()]
 
+def get_relations(cursor, table_name):
+    """
+    Returns a dictionary of {field_index: (field_index_other_table, other_table)}
+    representing all relationships to the given table. Indexes are 0-based.
+    """
+    cursor.execute("""
+        SELECT con.conkey, con.confkey, c2.relname
+        FROM pg_constraint con, pg_class c1, pg_class c2
+        WHERE c1.oid = con.conrelid
+            AND c2.oid = con.confrelid
+            AND c1.relname = %s
+            AND con.contype = 'f'""", [table_name])
+    relations = {}
+    for row in cursor.fetchall():
+        try:
+            # row[0] and row[1] are like "{2}", so strip the curly braces.
+            relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2])
+        except ValueError:
+            continue
+    return relations
+
 # Register these custom typecasts, because Django expects dates/times to be
 # in Python's native (standard-library) datetime/time format, whereas psycopg
 # use mx.DateTime by default.
diff --git a/django/core/db/backends/sqlite3.py b/django/core/db/backends/sqlite3.py
index b719b8f8cd..38ce767dfb 100644
--- a/django/core/db/backends/sqlite3.py
+++ b/django/core/db/backends/sqlite3.py
@@ -112,6 +112,9 @@ def _sqlite_date_trunc(lookup_type, dt):
 def get_table_list(cursor):
     raise NotImplementedError
 
+def get_relations(cursor, table_name):
+    raise NotImplementedError
+
 # Operators and fields ########################################################
 
 OPERATOR_MAPPING = {
diff --git a/django/core/management.py b/django/core/management.py
index b0ba55275a..4def37e0b9 100644
--- a/django/core/management.py
+++ b/django/core/management.py
@@ -434,22 +434,40 @@ def inspectdb(db_name):
     "Generator that introspects the tables in the given database name and returns a Django model, one line at a time."
     from django.core import db
     from django.conf import settings
+
+    def table2model(table_name):
+        object_name = table_name.title().replace('_', '')
+        return object_name.endswith('s') and object_name[:-1] or object_name
+
     settings.DATABASE_NAME = db_name
     cursor = db.db.cursor()
+    yield "# This is an auto-generated Django model module."
+    yield "# You'll have to do the following manually to clean this up:"
+    yield "#     * Rearrange models' order"
+    yield "#     * Add primary_key=True to one field in each model."
+    yield "# Feel free to rename the models, but don't rename db_table values or field names."
+    yield ''
     yield 'from django.core import meta'
     yield ''
     for table_name in db.get_table_list(cursor):
-        object_name = table_name.title().replace('_', '')
-        object_name = object_name.endswith('s') and object_name[:-1] or object_name
-        yield 'class %s(meta.Model):' % object_name
+        yield 'class %s(meta.Model):' % table2model(table_name)
         yield '    db_table = %r' % table_name
         yield '    fields = ('
+        try:
+            relations = db.get_relations(cursor, table_name)
+        except NotImplementedError:
+            relations = {}
         cursor.execute("SELECT * FROM %s LIMIT 1" % table_name)
-        for row in cursor.description:
-            field_type = db.DATA_TYPES_REVERSE[row[1]]
-            field_desc = 'meta.%s(%r' % (field_type, row[0])
-            if field_type == 'CharField':
-                field_desc += ', maxlength=%s' % (row[3])
+        for i, row in enumerate(cursor.description):
+            if relations.has_key(i):
+                rel = relations[i]
+                rel_to = rel[1] == table_name and "'self'" or table2model(rel[1])
+                field_desc = 'meta.ForeignKey(%s, name=%r' % (rel_to, row[0])
+            else:
+                field_type = db.DATA_TYPES_REVERSE[row[1]]
+                field_desc = 'meta.%s(%r' % (field_type, row[0])
+                if field_type == 'CharField':
+                    field_desc += ', maxlength=%s' % (row[3])
             yield '        %s),' % field_desc
         yield '    )'
         yield ''