mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Refs #29641 -- Extracted reusable CheckConstraint logic into a base class.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							9142bebff2
						
					
				
				
					commit
					24dc7d8940
				
			| @@ -3,22 +3,12 @@ from django.db.models.sql.query import Query | ||||
| __all__ = ['CheckConstraint'] | ||||
|  | ||||
|  | ||||
| class CheckConstraint: | ||||
|     def __init__(self, *, check, name): | ||||
|         self.check = check | ||||
| class BaseConstraint: | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|  | ||||
|     def constraint_sql(self, model, schema_editor): | ||||
|         query = Query(model) | ||||
|         where = query.build_where(self.check) | ||||
|         connection = schema_editor.connection | ||||
|         compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') | ||||
|         sql, params = where.as_sql(compiler, connection) | ||||
|         params = tuple(schema_editor.quote_value(p) for p in params) | ||||
|         return schema_editor.sql_check % { | ||||
|             'name': schema_editor.quote_name(self.name), | ||||
|             'check': sql % params, | ||||
|         } | ||||
|         raise NotImplementedError('This method must be implemented by a subclass.') | ||||
|  | ||||
|     def create_sql(self, model, schema_editor): | ||||
|         sql = self.constraint_sql(model, schema_editor) | ||||
| @@ -34,6 +24,33 @@ class CheckConstraint: | ||||
|             'name': quote_name(self.name), | ||||
|         } | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) | ||||
|         path = path.replace('django.db.models.constraints', 'django.db.models') | ||||
|         return (path, (), {'name': self.name}) | ||||
|  | ||||
|     def clone(self): | ||||
|         _, args, kwargs = self.deconstruct() | ||||
|         return self.__class__(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class CheckConstraint(BaseConstraint): | ||||
|     def __init__(self, *, check, name): | ||||
|         self.check = check | ||||
|         super().__init__(name) | ||||
|  | ||||
|     def constraint_sql(self, model, schema_editor): | ||||
|         query = Query(model) | ||||
|         where = query.build_where(self.check) | ||||
|         connection = schema_editor.connection | ||||
|         compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') | ||||
|         sql, params = where.as_sql(compiler, connection) | ||||
|         params = tuple(schema_editor.quote_value(p) for p in params) | ||||
|         return schema_editor.sql_check % { | ||||
|             'name': schema_editor.quote_name(self.name), | ||||
|             'check': sql % params, | ||||
|         } | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) | ||||
|  | ||||
| @@ -45,10 +62,6 @@ class CheckConstraint: | ||||
|         ) | ||||
|  | ||||
|     def deconstruct(self): | ||||
|         path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) | ||||
|         path = path.replace('django.db.models.constraints', 'django.db.models') | ||||
|         return (path, (), {'check': self.check, 'name': self.name}) | ||||
|  | ||||
|     def clone(self): | ||||
|         _, args, kwargs = self.deconstruct() | ||||
|         return self.__class__(*args, **kwargs) | ||||
|         path, args, kwargs = super().deconstruct() | ||||
|         kwargs['check'] = self.check | ||||
|         return path, args, kwargs | ||||
|   | ||||
| @@ -1,9 +1,18 @@ | ||||
| from django.db import IntegrityError, models | ||||
| from django.test import TestCase, skipUnlessDBFeature | ||||
| from django.db.models.constraints import BaseConstraint | ||||
| from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature | ||||
|  | ||||
| from .models import Product | ||||
|  | ||||
|  | ||||
| class BaseConstraintTests(SimpleTestCase): | ||||
|     def test_constraint_sql(self): | ||||
|         c = BaseConstraint('name') | ||||
|         msg = 'This method must be implemented by a subclass.' | ||||
|         with self.assertRaisesMessage(NotImplementedError, msg): | ||||
|             c.constraint_sql(None, None) | ||||
|  | ||||
|  | ||||
| class CheckConstraintTests(TestCase): | ||||
|     def test_repr(self): | ||||
|         check = models.Q(price__gt=models.F('discounted_price')) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user