From 782d85b6dfa191e67c0f1d572641d8236c79174c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pave=C5=82=20Ty=C5=9Blacki?= <pavel.tyslacki@gmail.com>
Date: Thu, 28 Feb 2019 00:47:29 +0300
Subject: [PATCH] Fixed #30183 -- Added introspection of inline SQLite
 constraints.

---
 django/db/backends/sqlite3/introspection.py | 179 ++++++++++++++------
 tests/backends/sqlite/test_introspection.py | 115 +++++++++++++
 tests/introspection/models.py               |  18 ++
 tests/introspection/tests.py                |  59 ++++++-
 tests/schema/tests.py                       |  12 +-
 5 files changed, 332 insertions(+), 51 deletions(-)

diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py
index 1104571bb2..faf9328415 100644
--- a/django/db/backends/sqlite3/introspection.py
+++ b/django/db/backends/sqlite3/introspection.py
@@ -217,50 +217,124 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             }
         return constraints
 
-    def _parse_table_constraints(self, sql):
+    def _parse_column_or_constraint_definition(self, tokens, columns):
+        token = None
+        is_constraint_definition = None
+        field_name = None
+        constraint_name = None
+        unique = False
+        unique_columns = []
+        check = False
+        check_columns = []
+        braces_deep = 0
+        for token in tokens:
+            if token.match(sqlparse.tokens.Punctuation, '('):
+                braces_deep += 1
+            elif token.match(sqlparse.tokens.Punctuation, ')'):
+                braces_deep -= 1
+                if braces_deep < 0:
+                    # End of columns and constraints for table definition.
+                    break
+            elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
+                # End of current column or constraint definition.
+                break
+            # Detect column or constraint definition by first token.
+            if is_constraint_definition is None:
+                is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
+                if is_constraint_definition:
+                    continue
+            if is_constraint_definition:
+                # Detect constraint name by second token.
+                if constraint_name is None:
+                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
+                        constraint_name = token.value
+                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
+                        constraint_name = token.value[1:-1]
+                # Start constraint columns parsing after UNIQUE keyword.
+                if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
+                    unique = True
+                    unique_braces_deep = braces_deep
+                elif unique:
+                    if unique_braces_deep == braces_deep:
+                        if unique_columns:
+                            # Stop constraint parsing.
+                            unique = False
+                        continue
+                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
+                        unique_columns.append(token.value)
+                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
+                        unique_columns.append(token.value[1:-1])
+            else:
+                # Detect field name by first token.
+                if field_name is None:
+                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
+                        field_name = token.value
+                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
+                        field_name = token.value[1:-1]
+                if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
+                    unique_columns = [field_name]
+            # Start constraint columns parsing after CHECK keyword.
+            if token.match(sqlparse.tokens.Keyword, 'CHECK'):
+                check = True
+                check_braces_deep = braces_deep
+            elif check:
+                if check_braces_deep == braces_deep:
+                    if check_columns:
+                        # Stop constraint parsing.
+                        check = False
+                    continue
+                if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
+                    if token.value in columns:
+                        check_columns.append(token.value)
+                elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
+                    if token.value[1:-1] in columns:
+                        check_columns.append(token.value[1:-1])
+        unique_constraint = {
+            'unique': True,
+            'columns': unique_columns,
+            'primary_key': False,
+            'foreign_key': None,
+            'check': False,
+            'index': False,
+        } if unique_columns else None
+        check_constraint = {
+            'check': True,
+            'columns': check_columns,
+            'primary_key': False,
+            'unique': False,
+            'foreign_key': None,
+            'index': False,
+        } if check_columns else None
+        return constraint_name, unique_constraint, check_constraint, token
+
+    def _parse_table_constraints(self, sql, columns):
         # Check constraint parsing is based of SQLite syntax diagram.
         # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
-        def next_ttype(ttype):
-            for token in tokens:
-                if token.ttype == ttype:
-                    return token
-
         statement = sqlparse.parse(sql)[0]
         constraints = {}
-        tokens = statement.flatten()
+        unnamed_constrains_index = 0
+        tokens = (token for token in statement.flatten() if not token.is_whitespace)
+        # Go to columns and constraint definition
         for token in tokens:
-            name = None
-            if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
-                # Table constraint
-                name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
-                name = name_token.value[1:-1]
-                token = next_ttype(sqlparse.tokens.Keyword)
-            if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
-                constraints[name] = {
-                    'unique': True,
-                    'columns': [],
-                    'primary_key': False,
-                    'foreign_key': None,
-                    'check': False,
-                    'index': False,
-                }
-            if token.match(sqlparse.tokens.Keyword, 'CHECK'):
-                # Column check constraint
-                if name is None:
-                    column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
-                    column = column_token.value[1:-1]
-                    name = '__check__%s' % column
-                    columns = [column]
+            if token.match(sqlparse.tokens.Punctuation, '('):
+                break
+        # Parse columns and constraint definition
+        while True:
+            constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
+            if unique:
+                if constraint_name:
+                    constraints[constraint_name] = unique
                 else:
-                    columns = []
-                constraints[name] = {
-                    'check': True,
-                    'columns': columns,
-                    'primary_key': False,
-                    'unique': False,
-                    'foreign_key': None,
-                    'index': False,
-                }
+                    unnamed_constrains_index += 1
+                    constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
+            if check:
+                if constraint_name:
+                    constraints[constraint_name] = check
+                else:
+                    unnamed_constrains_index += 1
+                    constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
+            if end_token.match(sqlparse.tokens.Punctuation, ')'):
+                break
         return constraints
 
     def get_constraints(self, cursor, table_name):
@@ -280,7 +354,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             # table_name is a view.
             pass
         else:
-            constraints.update(self._parse_table_constraints(table_schema))
+            columns = {info.name for info in self.get_table_description(cursor, table_name)}
+            constraints.update(self._parse_table_constraints(table_schema, columns))
 
         # Get the index info
         cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
@@ -288,6 +363,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             # SQLite 3.8.9+ has 5 columns, however older versions only give 3
             # columns. Discard last 2 columns if there.
             number, index, unique = row[:3]
+            cursor.execute(
+                "SELECT sql FROM sqlite_master "
+                "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
+            )
+            # There's at most one row.
+            sql, = cursor.fetchone() or (None,)
+            # Inline constraints are already detected in
+            # _parse_table_constraints(). The reasons to avoid fetching inline
+            # constraints from `PRAGMA index_list` are:
+            # - Inline constraints can have a different name and information
+            #   than what `PRAGMA index_list` gives.
+            # - Not all inline constraints may appear in `PRAGMA index_list`.
+            if not sql:
+                # An inline constraint
+                continue
             # Get the index info for that index
             cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
             for index_rank, column_rank, column in cursor.fetchall():
@@ -305,15 +395,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             if constraints[index]['index'] and not constraints[index]['unique']:
                 # SQLite doesn't support any index type other than b-tree
                 constraints[index]['type'] = Index.suffix
-                cursor.execute(
-                    "SELECT sql FROM sqlite_master "
-                    "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
-                )
-                orders = []
-                # There would be only 1 row to loop over
-                for sql, in cursor.fetchall():
-                    order_info = sql.split('(')[-1].split(')')[0].split(',')
-                    orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info]
+                order_info = sql.split('(')[-1].split(')')[0].split(',')
+                orders = ['DESC' if info.endswith('DESC') else 'ASC' for info in order_info]
                 constraints[index]['orders'] = orders
         # Get the PK
         pk_column = self.get_primary_key_column(cursor, table_name)
diff --git a/tests/backends/sqlite/test_introspection.py b/tests/backends/sqlite/test_introspection.py
index 1695ee549e..e378e0ee56 100644
--- a/tests/backends/sqlite/test_introspection.py
+++ b/tests/backends/sqlite/test_introspection.py
@@ -1,5 +1,7 @@
 import unittest
 
+import sqlparse
+
 from django.db import connection
 from django.test import TestCase
 
@@ -25,3 +27,116 @@ class IntrospectionTests(TestCase):
                         self.assertEqual(field, expected_string)
                     finally:
                         cursor.execute('DROP TABLE test_primary')
+
+
+@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
+class ParsingTests(TestCase):
+    def parse_definition(self, sql, columns):
+        """Parse a column or constraint definition."""
+        statement = sqlparse.parse(sql)[0]
+        tokens = (token for token in statement.flatten() if not token.is_whitespace)
+        with connection.cursor():
+            return connection.introspection._parse_column_or_constraint_definition(tokens, set(columns))
+
+    def assertConstraint(self, constraint_details, cols, unique=False, check=False):
+        self.assertEqual(constraint_details, {
+            'unique': unique,
+            'columns': cols,
+            'primary_key': False,
+            'foreign_key': None,
+            'check': check,
+            'index': False,
+        })
+
+    def test_unique_column(self):
+        tests = (
+            ('"ref" integer UNIQUE,', ['ref']),
+            ('ref integer UNIQUE,', ['ref']),
+            ('"customname" integer UNIQUE,', ['customname']),
+            ('customname integer UNIQUE,', ['customname']),
+        )
+        for sql, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertIsNone(constraint)
+                self.assertConstraint(details, columns, unique=True)
+                self.assertIsNone(check)
+
+    def test_unique_constraint(self):
+        tests = (
+            ('CONSTRAINT "ref" UNIQUE ("ref"),', 'ref', ['ref']),
+            ('CONSTRAINT ref UNIQUE (ref),', 'ref', ['ref']),
+            ('CONSTRAINT "customname1" UNIQUE ("customname2"),', 'customname1', ['customname2']),
+            ('CONSTRAINT customname1 UNIQUE (customname2),', 'customname1', ['customname2']),
+        )
+        for sql, constraint_name, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertEqual(constraint, constraint_name)
+                self.assertConstraint(details, columns, unique=True)
+                self.assertIsNone(check)
+
+    def test_unique_constraint_multicolumn(self):
+        tests = (
+            ('CONSTRAINT "ref" UNIQUE ("ref", "customname"),', 'ref', ['ref', 'customname']),
+            ('CONSTRAINT ref UNIQUE (ref, customname),', 'ref', ['ref', 'customname']),
+        )
+        for sql, constraint_name, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertEqual(constraint, constraint_name)
+                self.assertConstraint(details, columns, unique=True)
+                self.assertIsNone(check)
+
+    def test_check_column(self):
+        tests = (
+            ('"ref" varchar(255) CHECK ("ref" != \'test\'),', ['ref']),
+            ('ref varchar(255) CHECK (ref != \'test\'),', ['ref']),
+            ('"customname1" varchar(255) CHECK ("customname2" != \'test\'),', ['customname2']),
+            ('customname1 varchar(255) CHECK (customname2 != \'test\'),', ['customname2']),
+        )
+        for sql, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertIsNone(constraint)
+                self.assertIsNone(details)
+                self.assertConstraint(check, columns, check=True)
+
+    def test_check_constraint(self):
+        tests = (
+            ('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', 'ref', ['ref']),
+            ('CONSTRAINT ref CHECK (ref != \'test\'),', 'ref', ['ref']),
+            ('CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),', 'customname1', ['customname2']),
+            ('CONSTRAINT customname1 CHECK (customname2 != \'test\'),', 'customname1', ['customname2']),
+        )
+        for sql, constraint_name, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertEqual(constraint, constraint_name)
+                self.assertIsNone(details)
+                self.assertConstraint(check, columns, check=True)
+
+    def test_check_column_with_operators_and_functions(self):
+        tests = (
+            ('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ['ref']),
+            ('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ['ref']),
+            ('"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),', ['ref', 'max_length']),
+        )
+        for sql, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertIsNone(constraint)
+                self.assertIsNone(details)
+                self.assertConstraint(check, columns, check=True)
+
+    def test_check_and_unique_column(self):
+        tests = (
+            ('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ['ref']),
+            ('ref varchar(255) UNIQUE CHECK (ref != \'test\'),', ['ref']),
+        )
+        for sql, columns in tests:
+            with self.subTest(sql=sql):
+                constraint, details, check, _ = self.parse_definition(sql, columns)
+                self.assertIsNone(constraint)
+                self.assertConstraint(details, columns, unique=True)
+                self.assertConstraint(check, columns, check=True)
diff --git a/tests/introspection/models.py b/tests/introspection/models.py
index 32acc323bd..fa663de2fd 100644
--- a/tests/introspection/models.py
+++ b/tests/introspection/models.py
@@ -58,3 +58,21 @@ class ArticleReporter(models.Model):
 
     class Meta:
         managed = False
+
+
+class Comment(models.Model):
+    ref = models.UUIDField(unique=True)
+    article = models.ForeignKey(Article, models.CASCADE, db_index=True)
+    email = models.EmailField()
+    pub_date = models.DateTimeField()
+    up_votes = models.PositiveIntegerField()
+    body = models.TextField()
+
+    class Meta:
+        constraints = [
+            models.CheckConstraint(name='up_votes_gte_0_check', check=models.Q(up_votes__gte=0)),
+            models.UniqueConstraint(fields=['article', 'email', 'pub_date'], name='article_email_pub_date_uniq'),
+        ]
+        indexes = [
+            models.Index(fields=['email', 'pub_date'], name='email_pub_date_idx'),
+        ]
diff --git a/tests/introspection/tests.py b/tests/introspection/tests.py
index d851352cae..10524cdacb 100644
--- a/tests/introspection/tests.py
+++ b/tests/introspection/tests.py
@@ -5,7 +5,7 @@ from django.db.models import Index
 from django.db.utils import DatabaseError
 from django.test import TransactionTestCase, skipUnlessDBFeature
 
-from .models import Article, ArticleReporter, City, District, Reporter
+from .models import Article, ArticleReporter, City, Comment, District, Reporter
 
 
 class IntrospectionTests(TransactionTestCase):
@@ -211,3 +211,60 @@ class IntrospectionTests(TransactionTestCase):
                 self.assertEqual(val['orders'], ['ASC'] * len(val['columns']))
                 indexes_verified += 1
         self.assertEqual(indexes_verified, 4)
+
+    def test_get_constraints(self):
+        def assertDetails(details, cols, primary_key=False, unique=False, index=False, check=False, foreign_key=None):
+            # Different backends have different values for same constraints:
+            #               PRIMARY KEY     UNIQUE CONSTRAINT    UNIQUE INDEX
+            # MySQL      pk=1 uniq=1 idx=1  pk=0 uniq=1 idx=1  pk=0 uniq=1 idx=1
+            # PostgreSQL pk=1 uniq=1 idx=0  pk=0 uniq=1 idx=0  pk=0 uniq=1 idx=1
+            # SQLite     pk=1 uniq=0 idx=0  pk=0 uniq=1 idx=0  pk=0 uniq=1 idx=1
+            if details['primary_key']:
+                details['unique'] = True
+            if details['unique']:
+                details['index'] = False
+            self.assertEqual(details['columns'], cols)
+            self.assertEqual(details['primary_key'], primary_key)
+            self.assertEqual(details['unique'], unique)
+            self.assertEqual(details['index'], index)
+            self.assertEqual(details['check'], check)
+            self.assertEqual(details['foreign_key'], foreign_key)
+
+        with connection.cursor() as cursor:
+            constraints = connection.introspection.get_constraints(cursor, Comment._meta.db_table)
+        # Test custom constraints
+        custom_constraints = {
+            'article_email_pub_date_uniq',
+            'email_pub_date_idx',
+        }
+        if connection.features.supports_column_check_constraints:
+            custom_constraints.add('up_votes_gte_0_check')
+            assertDetails(constraints['up_votes_gte_0_check'], ['up_votes'], check=True)
+        assertDetails(constraints['article_email_pub_date_uniq'], ['article_id', 'email', 'pub_date'], unique=True)
+        assertDetails(constraints['email_pub_date_idx'], ['email', 'pub_date'], index=True)
+        # Test field constraints
+        field_constraints = set()
+        for name, details in constraints.items():
+            if name in custom_constraints:
+                continue
+            elif details['columns'] == ['up_votes'] and details['check']:
+                assertDetails(details, ['up_votes'], check=True)
+                field_constraints.add(name)
+            elif details['columns'] == ['ref'] and details['unique']:
+                assertDetails(details, ['ref'], unique=True)
+                field_constraints.add(name)
+            elif details['columns'] == ['article_id'] and details['index']:
+                assertDetails(details, ['article_id'], index=True)
+                field_constraints.add(name)
+            elif details['columns'] == ['id'] and details['primary_key']:
+                assertDetails(details, ['id'], primary_key=True, unique=True)
+                field_constraints.add(name)
+            elif details['columns'] == ['article_id'] and details['foreign_key']:
+                assertDetails(details, ['article_id'], foreign_key=('introspection_article', 'id'))
+                field_constraints.add(name)
+            elif details['check']:
+                # Some databases (e.g. Oracle) include additional check
+                # constraints.
+                field_constraints.add(name)
+        # All constraints are accounted for.
+        self.assertEqual(constraints.keys() ^ (custom_constraints | field_constraints), set())
diff --git a/tests/schema/tests.py b/tests/schema/tests.py
index 9b40a43523..00ce2e494e 100644
--- a/tests/schema/tests.py
+++ b/tests/schema/tests.py
@@ -129,6 +129,14 @@ class SchemaTests(TransactionTestCase):
                 if c['index'] and len(c['columns']) == 1
             ]
 
+    def get_uniques(self, table):
+        with connection.cursor() as cursor:
+            return [
+                c['columns'][0]
+                for c in connection.introspection.get_constraints(cursor, table).values()
+                if c['unique'] and len(c['columns']) == 1
+            ]
+
     def get_constraints(self, table):
         """
         Get the constraints on a table using a new cursor.
@@ -1971,7 +1979,7 @@ class SchemaTests(TransactionTestCase):
             editor.add_field(Book, new_field3)
         self.assertIn(
             "slug",
-            self.get_indexes(Book._meta.db_table),
+            self.get_uniques(Book._meta.db_table),
         )
         # Remove the unique, check the index goes with it
         new_field4 = CharField(max_length=20, unique=False)
@@ -1980,7 +1988,7 @@ class SchemaTests(TransactionTestCase):
             editor.alter_field(BookWithSlug, new_field3, new_field4, strict=True)
         self.assertNotIn(
             "slug",
-            self.get_indexes(Book._meta.db_table),
+            self.get_uniques(Book._meta.db_table),
         )
 
     def test_text_field_with_db_index(self):