mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #24343 -- Ensure db converters are used for foreign keys.
Joint effort between myself, Josh, Anssi and Shai.
This commit is contained in:
		| @@ -585,10 +585,10 @@ class Random(ExpressionNode): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Col(ExpressionNode): | class Col(ExpressionNode): | ||||||
|     def __init__(self, alias, target, source=None): |     def __init__(self, alias, target, output_field=None): | ||||||
|         if source is None: |         if output_field is None: | ||||||
|             source = target |             output_field = target | ||||||
|         super(Col, self).__init__(output_field=source) |         super(Col, self).__init__(output_field=output_field) | ||||||
|         self.alias, self.target = alias, target |         self.alias, self.target = alias, target | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
| @@ -606,7 +606,10 @@ class Col(ExpressionNode): | |||||||
|         return [self] |         return [self] | ||||||
|  |  | ||||||
|     def get_db_converters(self, connection): |     def get_db_converters(self, connection): | ||||||
|         return self.output_field.get_db_converters(connection) |         if self.target == self.output_field: | ||||||
|  |             return self.output_field.get_db_converters(connection) | ||||||
|  |         return (self.output_field.get_db_converters(connection) + | ||||||
|  |                 self.target.get_db_converters(connection)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Ref(ExpressionNode): | class Ref(ExpressionNode): | ||||||
|   | |||||||
| @@ -330,12 +330,12 @@ class Field(RegisterLookupMixin): | |||||||
|             ] |             ] | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|     def get_col(self, alias, source=None): |     def get_col(self, alias, output_field=None): | ||||||
|         if source is None: |         if output_field is None: | ||||||
|             source = self |             output_field = self | ||||||
|         if alias != self.model._meta.db_table or source != self: |         if alias != self.model._meta.db_table or output_field != self: | ||||||
|             from django.db.models.expressions import Col |             from django.db.models.expressions import Col | ||||||
|             return Col(alias, self, source) |             return Col(alias, self, output_field) | ||||||
|         else: |         else: | ||||||
|             return self.cached_col |             return self.cached_col | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2064,6 +2064,20 @@ class ForeignKey(ForeignObject): | |||||||
|     def db_parameters(self, connection): |     def db_parameters(self, connection): | ||||||
|         return {"type": self.db_type(connection), "check": []} |         return {"type": self.db_type(connection), "check": []} | ||||||
|  |  | ||||||
|  |     def convert_empty_strings(self, value, connection, context): | ||||||
|  |         if (not value) and isinstance(value, six.string_types): | ||||||
|  |             return None | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  |     def get_db_converters(self, connection): | ||||||
|  |         converters = super(ForeignKey, self).get_db_converters(connection) | ||||||
|  |         if connection.features.interprets_empty_strings_as_nulls: | ||||||
|  |             converters += [self.convert_empty_strings] | ||||||
|  |         return converters | ||||||
|  |  | ||||||
|  |     def get_col(self, alias, output_field=None): | ||||||
|  |         return super(ForeignKey, self).get_col(alias, output_field or self.related_field) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OneToOneField(ForeignKey): | class OneToOneField(ForeignKey): | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ class ModelIterator(BaseIterator): | |||||||
|         model_cls = klass_info['model'] |         model_cls = klass_info['model'] | ||||||
|         select_fields = klass_info['select_fields'] |         select_fields = klass_info['select_fields'] | ||||||
|         model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 |         model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 | ||||||
|         init_list = [f[0].output_field.attname |         init_list = [f[0].target.attname | ||||||
|                      for f in select[model_fields_start:model_fields_end]] |                      for f in select[model_fields_start:model_fields_end]] | ||||||
|         if len(init_list) != len(model_cls._meta.concrete_fields): |         if len(init_list) != len(model_cls._meta.concrete_fields): | ||||||
|             init_set = set(init_list) |             init_set = set(init_list) | ||||||
| @@ -1618,7 +1618,7 @@ class RelatedPopulator(object): | |||||||
|             self.cols_start = select_fields[0] |             self.cols_start = select_fields[0] | ||||||
|             self.cols_end = select_fields[-1] + 1 |             self.cols_end = select_fields[-1] + 1 | ||||||
|             self.init_list = [ |             self.init_list = [ | ||||||
|                 f[0].output_field.attname for f in select[self.cols_start:self.cols_end] |                 f[0].target.attname for f in select[self.cols_start:self.cols_end] | ||||||
|             ] |             ] | ||||||
|             self.reorder_for_init = None |             self.reorder_for_init = None | ||||||
|         else: |         else: | ||||||
| @@ -1627,7 +1627,7 @@ class RelatedPopulator(object): | |||||||
|             ] |             ] | ||||||
|             reorder_map = [] |             reorder_map = [] | ||||||
|             for idx in select_fields: |             for idx in select_fields: | ||||||
|                 field = select[idx][0].output_field |                 field = select[idx][0].target | ||||||
|                 init_pos = model_init_attnames.index(field.attname) |                 init_pos = model_init_attnames.index(field.attname) | ||||||
|                 reorder_map.append((init_pos, field.attname, idx)) |                 reorder_map.append((init_pos, field.attname, idx)) | ||||||
|             reorder_map.sort() |             reorder_map.sort() | ||||||
|   | |||||||
| @@ -1458,7 +1458,7 @@ class Query(object): | |||||||
|         # database from tripping over IN (...,NULL,...) selects and returning |         # database from tripping over IN (...,NULL,...) selects and returning | ||||||
|         # nothing |         # nothing | ||||||
|         col = query.select[0] |         col = query.select[0] | ||||||
|         select_field = col.field |         select_field = col.target | ||||||
|         alias = col.alias |         alias = col.alias | ||||||
|         if self.is_nullable(select_field): |         if self.is_nullable(select_field): | ||||||
|             lookup_class = select_field.get_lookup('isnull') |             lookup_class = select_field.get_lookup('isnull') | ||||||
|   | |||||||
| @@ -369,6 +369,10 @@ class PrimaryKeyUUIDModel(models.Model): | |||||||
|     id = models.UUIDField(primary_key=True, default=uuid.uuid4) |     id = models.UUIDField(primary_key=True, default=uuid.uuid4) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RelatedToUUIDModel(models.Model): | ||||||
|  |     uuid_fk = models.ForeignKey('PrimaryKeyUUIDModel') | ||||||
|  |  | ||||||
|  |  | ||||||
| ############################################################################### | ############################################################################### | ||||||
|  |  | ||||||
| # See ticket #24215. | # See ticket #24215. | ||||||
|   | |||||||
| @@ -5,7 +5,9 @@ from django.core import exceptions, serializers | |||||||
| from django.db import models | from django.db import models | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from .models import NullableUUIDModel, PrimaryKeyUUIDModel, UUIDModel | from .models import ( | ||||||
|  |     NullableUUIDModel, PrimaryKeyUUIDModel, RelatedToUUIDModel, UUIDModel, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSaveLoad(TestCase): | class TestSaveLoad(TestCase): | ||||||
| @@ -121,3 +123,9 @@ class TestAsPrimaryKey(TestCase): | |||||||
|         self.assertTrue(u1_found) |         self.assertTrue(u1_found) | ||||||
|         self.assertTrue(u2_found) |         self.assertTrue(u2_found) | ||||||
|         self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2) |         self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2) | ||||||
|  |  | ||||||
|  |     def test_underlying_field(self): | ||||||
|  |         pk_model = PrimaryKeyUUIDModel.objects.create() | ||||||
|  |         RelatedToUUIDModel.objects.create(uuid_fk=pk_model) | ||||||
|  |         related = RelatedToUUIDModel.objects.get() | ||||||
|  |         self.assertEqual(related.uuid_fk.pk, related.uuid_fk_id) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user