1
0
mirror of https://github.com/django/django.git synced 2025-10-25 14:46:09 +00:00

Fixed #34838 -- Corrected output_field of resolved columns for GeneratedFields.

Thanks Simon Charette for the implementation idea.
This commit is contained in:
Paolo Melchiorre
2023-09-13 22:11:08 +02:00
committed by Mariusz Felisiak
parent 969ecb8236
commit 68d769e691
2 changed files with 47 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
from django.core import checks from django.core import checks
from django.db import connections, router from django.db import connections, router
from django.db.models.sql import Query from django.db.models.sql import Query
from django.utils.functional import cached_property
from . import NOT_PROVIDED, Field from . import NOT_PROVIDED, Field
@@ -32,6 +33,17 @@ class GeneratedField(Field):
self.db_persist = db_persist self.db_persist = db_persist
super().__init__(**kwargs) super().__init__(**kwargs)
@cached_property
def cached_col(self):
from django.db.models.expressions import Col
return Col(self.model._meta.db_table, self, self.output_field)
def get_col(self, alias, output_field=None):
if alias != self.model._meta.db_table and output_field is None:
output_field = self.output_field
return super().get_col(alias, output_field)
def contribute_to_class(self, *args, **kwargs): def contribute_to_class(self, *args, **kwargs):
super().contribute_to_class(*args, **kwargs) super().contribute_to_class(*args, **kwargs)

View File

@@ -1,6 +1,6 @@
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import IntegrityError, connection from django.db import IntegrityError, connection
from django.db.models import F, GeneratedField, IntegerField from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
@@ -49,6 +49,40 @@ class BaseGeneratedFieldTests(SimpleTestCase):
self.assertEqual(args, []) self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")}) self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
def test_get_col(self):
class Square(Model):
side = IntegerField()
area = GeneratedField(expression=F("side") * F("side"), db_persist=True)
col = Square._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, IntegerField)
class FloatSquare(Model):
side = IntegerField()
area = GeneratedField(
expression=F("side") * F("side"),
db_persist=True,
output_field=FloatField(),
)
col = FloatSquare._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, FloatField)
def test_cached_col(self):
class Sum(Model):
a = IntegerField()
b = IntegerField()
total = GeneratedField(expression=F("a") + F("b"), db_persist=True)
field = Sum._meta.get_field("total")
cached_col = field.cached_col
self.assertIs(field.get_col(Sum._meta.db_table), cached_col)
self.assertIs(field.get_col(Sum._meta.db_table, field), cached_col)
self.assertIsNot(field.get_col("alias"), cached_col)
self.assertIsNot(field.get_col(Sum._meta.db_table, IntegerField()), cached_col)
self.assertIs(cached_col.target, field)
self.assertIsInstance(cached_col.output_field, IntegerField)
class GeneratedFieldTestMixin: class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m): def _refresh_if_needed(self, m):