1
0
mirror of https://github.com/django/django.git synced 2025-01-22 00:02:15 +00:00

Fixed #24343 -- Ensure db converters are used for foreign keys.

Joint effort between myself, Josh, Anssi and Shai.
This commit is contained in:
Marc Tamlyn 2015-02-14 19:37:12 +00:00
parent dbacbc729a
commit 4755f8fc25
7 changed files with 44 additions and 15 deletions

View File

@ -585,10 +585,10 @@ class Random(ExpressionNode):
class Col(ExpressionNode):
def __init__(self, alias, target, source=None):
if source is None:
source = target
super(Col, self).__init__(output_field=source)
def __init__(self, alias, target, output_field=None):
if output_field is None:
output_field = target
super(Col, self).__init__(output_field=output_field)
self.alias, self.target = alias, target
def __repr__(self):
@ -606,7 +606,10 @@ class Col(ExpressionNode):
return [self]
def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection)
if self.target == self.output_field:
return self.output_field.get_db_converters(connection)
return (self.output_field.get_db_converters(connection) +
self.target.get_db_converters(connection))
class Ref(ExpressionNode):

View File

@ -330,12 +330,12 @@ class Field(RegisterLookupMixin):
]
return []
def get_col(self, alias, source=None):
if source is None:
source = self
if alias != self.model._meta.db_table or source != self:
def get_col(self, alias, output_field=None):
if output_field is None:
output_field = self
if alias != self.model._meta.db_table or output_field != self:
from django.db.models.expressions import Col
return Col(alias, self, source)
return Col(alias, self, output_field)
else:
return self.cached_col

View File

@ -2064,6 +2064,20 @@ class ForeignKey(ForeignObject):
def db_parameters(self, connection):
return {"type": self.db_type(connection), "check": []}
def convert_empty_strings(self, value, connection, context):
if (not value) and isinstance(value, six.string_types):
return None
return value
def get_db_converters(self, connection):
converters = super(ForeignKey, self).get_db_converters(connection)
if connection.features.interprets_empty_strings_as_nulls:
converters += [self.convert_empty_strings]
return converters
def get_col(self, alias, output_field=None):
return super(ForeignKey, self).get_col(alias, output_field or self.related_field)
class OneToOneField(ForeignKey):
"""

View File

@ -57,7 +57,7 @@ class ModelIterator(BaseIterator):
model_cls = klass_info['model']
select_fields = klass_info['select_fields']
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
init_list = [f[0].output_field.attname
init_list = [f[0].target.attname
for f in select[model_fields_start:model_fields_end]]
if len(init_list) != len(model_cls._meta.concrete_fields):
init_set = set(init_list)
@ -1618,7 +1618,7 @@ class RelatedPopulator(object):
self.cols_start = select_fields[0]
self.cols_end = select_fields[-1] + 1
self.init_list = [
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
f[0].target.attname for f in select[self.cols_start:self.cols_end]
]
self.reorder_for_init = None
else:
@ -1627,7 +1627,7 @@ class RelatedPopulator(object):
]
reorder_map = []
for idx in select_fields:
field = select[idx][0].output_field
field = select[idx][0].target
init_pos = model_init_attnames.index(field.attname)
reorder_map.append((init_pos, field.attname, idx))
reorder_map.sort()

View File

@ -1458,7 +1458,7 @@ class Query(object):
# database from tripping over IN (...,NULL,...) selects and returning
# nothing
col = query.select[0]
select_field = col.field
select_field = col.target
alias = col.alias
if self.is_nullable(select_field):
lookup_class = select_field.get_lookup('isnull')

View File

@ -369,6 +369,10 @@ class PrimaryKeyUUIDModel(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
class RelatedToUUIDModel(models.Model):
uuid_fk = models.ForeignKey('PrimaryKeyUUIDModel')
###############################################################################
# See ticket #24215.

View File

@ -5,7 +5,9 @@ from django.core import exceptions, serializers
from django.db import models
from django.test import TestCase
from .models import NullableUUIDModel, PrimaryKeyUUIDModel, UUIDModel
from .models import (
NullableUUIDModel, PrimaryKeyUUIDModel, RelatedToUUIDModel, UUIDModel,
)
class TestSaveLoad(TestCase):
@ -121,3 +123,9 @@ class TestAsPrimaryKey(TestCase):
self.assertTrue(u1_found)
self.assertTrue(u2_found)
self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2)
def test_underlying_field(self):
pk_model = PrimaryKeyUUIDModel.objects.create()
RelatedToUUIDModel.objects.create(uuid_fk=pk_model)
related = RelatedToUUIDModel.objects.get()
self.assertEqual(related.uuid_fk.pk, related.uuid_fk_id)