diff --git a/django/db/models/functions/text.py b/django/db/models/functions/text.py index 2b49f54328..500fbea194 100644 --- a/django/db/models/functions/text.py +++ b/django/db/models/functions/text.py @@ -73,7 +73,7 @@ class ConcatPair(Func): function = "CONCAT" - def as_sqlite(self, compiler, connection, **extra_context): + def pipes_concat_sql(self, compiler, connection, **extra_context): coalesced = self.coalesce() return super(ConcatPair, coalesced).as_sql( compiler, @@ -83,19 +83,19 @@ class ConcatPair(Func): **extra_context, ) + as_sqlite = pipes_concat_sql + def as_postgresql(self, compiler, connection, **extra_context): - copy = self.copy() - copy.set_source_expressions( + c = self.copy() + c.set_source_expressions( [ - Cast(expression, TextField()) - for expression in copy.get_source_expressions() + expression + if isinstance(expression.output_field, (CharField, TextField)) + else Cast(expression, TextField()) + for expression in c.get_source_expressions() ] ) - return super(ConcatPair, copy).as_sql( - compiler, - connection, - **extra_context, - ) + return c.pipes_concat_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection, **extra_context): # Use CONCAT_WS with an empty separator so that NULLs are ignored. @@ -132,16 +132,20 @@ class Concat(Func): def __init__(self, *expressions, **extra): if len(expressions) < 2: raise ValueError("Concat must take at least two expressions") - paired = self._paired(expressions) + paired = self._paired(expressions, output_field=extra.get("output_field")) super().__init__(paired, **extra) - def _paired(self, expressions): + def _paired(self, expressions, output_field): # wrap pairs of expressions in successive concat functions # exp = [a, b, c, d] # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) if len(expressions) == 2: - return ConcatPair(*expressions) - return ConcatPair(expressions[0], self._paired(expressions[1:])) + return ConcatPair(*expressions, output_field=output_field) + return ConcatPair( + expressions[0], + self._paired(expressions[1:], output_field=output_field), + output_field=output_field, + ) class Left(Func): diff --git a/tests/db_functions/text/test_concat.py b/tests/db_functions/text/test_concat.py index 1441e31e97..6e4cb91d3a 100644 --- a/tests/db_functions/text/test_concat.py +++ b/tests/db_functions/text/test_concat.py @@ -75,7 +75,10 @@ class ConcatTests(TestCase): expected = article.title + " - " + article.text self.assertEqual(expected.upper(), article.title_text) - @skipUnless(connection.vendor == "sqlite", "sqlite specific implementation detail.") + @skipUnless( + connection.vendor in ("sqlite", "postgresql"), + "SQLite and PostgreSQL specific implementation detail.", + ) def test_coalesce_idempotent(self): pair = ConcatPair(V("a"), V("b")) # Check nodes counts @@ -89,3 +92,18 @@ class ConcatTests(TestCase): qs = Article.objects.annotate(description=Concat("title", V(": "), "summary")) # Multiple compilations should not alter the generated query. self.assertEqual(str(qs.query), str(qs.all().query)) + + def test_concat_non_str(self): + Author.objects.create(name="The Name", age=42) + with self.assertNumQueries(1) as ctx: + author = Author.objects.annotate( + name_text=Concat( + "name", V(":"), "alias", V(":"), "age", output_field=TextField() + ), + ).get() + self.assertEqual(author.name_text, "The Name::42") + # Only non-string columns are casted on PostgreSQL. + self.assertEqual( + ctx.captured_queries[0]["sql"].count("::text"), + 1 if connection.vendor == "postgresql" else 0, + )