mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Refs #29641 -- Extracted reusable CheckConstraint logic into a base class.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							9142bebff2
						
					
				
				
					commit
					24dc7d8940
				
			| @@ -3,22 +3,12 @@ from django.db.models.sql.query import Query | |||||||
| __all__ = ['CheckConstraint'] | __all__ = ['CheckConstraint'] | ||||||
|  |  | ||||||
|  |  | ||||||
| class CheckConstraint: | class BaseConstraint: | ||||||
|     def __init__(self, *, check, name): |     def __init__(self, name): | ||||||
|         self.check = check |  | ||||||
|         self.name = name |         self.name = name | ||||||
|  |  | ||||||
|     def constraint_sql(self, model, schema_editor): |     def constraint_sql(self, model, schema_editor): | ||||||
|         query = Query(model) |         raise NotImplementedError('This method must be implemented by a subclass.') | ||||||
|         where = query.build_where(self.check) |  | ||||||
|         connection = schema_editor.connection |  | ||||||
|         compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') |  | ||||||
|         sql, params = where.as_sql(compiler, connection) |  | ||||||
|         params = tuple(schema_editor.quote_value(p) for p in params) |  | ||||||
|         return schema_editor.sql_check % { |  | ||||||
|             'name': schema_editor.quote_name(self.name), |  | ||||||
|             'check': sql % params, |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|     def create_sql(self, model, schema_editor): |     def create_sql(self, model, schema_editor): | ||||||
|         sql = self.constraint_sql(model, schema_editor) |         sql = self.constraint_sql(model, schema_editor) | ||||||
| @@ -34,6 +24,33 @@ class CheckConstraint: | |||||||
|             'name': quote_name(self.name), |             'name': quote_name(self.name), | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |     def deconstruct(self): | ||||||
|  |         path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) | ||||||
|  |         path = path.replace('django.db.models.constraints', 'django.db.models') | ||||||
|  |         return (path, (), {'name': self.name}) | ||||||
|  |  | ||||||
|  |     def clone(self): | ||||||
|  |         _, args, kwargs = self.deconstruct() | ||||||
|  |         return self.__class__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CheckConstraint(BaseConstraint): | ||||||
|  |     def __init__(self, *, check, name): | ||||||
|  |         self.check = check | ||||||
|  |         super().__init__(name) | ||||||
|  |  | ||||||
|  |     def constraint_sql(self, model, schema_editor): | ||||||
|  |         query = Query(model) | ||||||
|  |         where = query.build_where(self.check) | ||||||
|  |         connection = schema_editor.connection | ||||||
|  |         compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') | ||||||
|  |         sql, params = where.as_sql(compiler, connection) | ||||||
|  |         params = tuple(schema_editor.quote_value(p) for p in params) | ||||||
|  |         return schema_editor.sql_check % { | ||||||
|  |             'name': schema_editor.quote_name(self.name), | ||||||
|  |             'check': sql % params, | ||||||
|  |         } | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) |         return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) | ||||||
|  |  | ||||||
| @@ -45,10 +62,6 @@ class CheckConstraint: | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def deconstruct(self): |     def deconstruct(self): | ||||||
|         path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) |         path, args, kwargs = super().deconstruct() | ||||||
|         path = path.replace('django.db.models.constraints', 'django.db.models') |         kwargs['check'] = self.check | ||||||
|         return (path, (), {'check': self.check, 'name': self.name}) |         return path, args, kwargs | ||||||
|  |  | ||||||
|     def clone(self): |  | ||||||
|         _, args, kwargs = self.deconstruct() |  | ||||||
|         return self.__class__(*args, **kwargs) |  | ||||||
|   | |||||||
| @@ -1,9 +1,18 @@ | |||||||
| from django.db import IntegrityError, models | from django.db import IntegrityError, models | ||||||
| from django.test import TestCase, skipUnlessDBFeature | from django.db.models.constraints import BaseConstraint | ||||||
|  | from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature | ||||||
|  |  | ||||||
| from .models import Product | from .models import Product | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseConstraintTests(SimpleTestCase): | ||||||
|  |     def test_constraint_sql(self): | ||||||
|  |         c = BaseConstraint('name') | ||||||
|  |         msg = 'This method must be implemented by a subclass.' | ||||||
|  |         with self.assertRaisesMessage(NotImplementedError, msg): | ||||||
|  |             c.constraint_sql(None, None) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CheckConstraintTests(TestCase): | class CheckConstraintTests(TestCase): | ||||||
|     def test_repr(self): |     def test_repr(self): | ||||||
|         check = models.Q(price__gt=models.F('discounted_price')) |         check = models.Q(price__gt=models.F('discounted_price')) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user