diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index dbbc139e48..2bad8db221 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -3,22 +3,12 @@ from django.db.models.sql.query import Query __all__ = ['CheckConstraint'] -class CheckConstraint: - def __init__(self, *, check, name): - self.check = check +class BaseConstraint: + def __init__(self, name): self.name = name def constraint_sql(self, model, schema_editor): - query = Query(model) - where = query.build_where(self.check) - connection = schema_editor.connection - compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') - sql, params = where.as_sql(compiler, connection) - params = tuple(schema_editor.quote_value(p) for p in params) - return schema_editor.sql_check % { - 'name': schema_editor.quote_name(self.name), - 'check': sql % params, - } + raise NotImplementedError('This method must be implemented by a subclass.') def create_sql(self, model, schema_editor): sql = self.constraint_sql(model, schema_editor) @@ -34,6 +24,33 @@ class CheckConstraint: 'name': quote_name(self.name), } + def deconstruct(self): + path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) + path = path.replace('django.db.models.constraints', 'django.db.models') + return (path, (), {'name': self.name}) + + def clone(self): + _, args, kwargs = self.deconstruct() + return self.__class__(*args, **kwargs) + + +class CheckConstraint(BaseConstraint): + def __init__(self, *, check, name): + self.check = check + super().__init__(name) + + def constraint_sql(self, model, schema_editor): + query = Query(model) + where = query.build_where(self.check) + connection = schema_editor.connection + compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') + sql, params = where.as_sql(compiler, connection) + params = tuple(schema_editor.quote_value(p) for p in params) + return schema_editor.sql_check % { + 'name': schema_editor.quote_name(self.name), + 'check': sql % params, + } + def __repr__(self): return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) @@ -45,10 +62,6 @@ class CheckConstraint: ) def deconstruct(self): - path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) - path = path.replace('django.db.models.constraints', 'django.db.models') - return (path, (), {'check': self.check, 'name': self.name}) - - def clone(self): - _, args, kwargs = self.deconstruct() - return self.__class__(*args, **kwargs) + path, args, kwargs = super().deconstruct() + kwargs['check'] = self.check + return path, args, kwargs diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 6bff8ce042..28a5c4ba34 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -1,9 +1,18 @@ from django.db import IntegrityError, models -from django.test import TestCase, skipUnlessDBFeature +from django.db.models.constraints import BaseConstraint +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from .models import Product +class BaseConstraintTests(SimpleTestCase): + def test_constraint_sql(self): + c = BaseConstraint('name') + msg = 'This method must be implemented by a subclass.' + with self.assertRaisesMessage(NotImplementedError, msg): + c.constraint_sql(None, None) + + class CheckConstraintTests(TestCase): def test_repr(self): check = models.Q(price__gt=models.F('discounted_price'))