diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py
index be98ff2d48..7d88579ecd 100644
--- a/django/contrib/postgres/fields/jsonb.py
+++ b/django/contrib/postgres/fields/jsonb.py
@@ -107,7 +107,7 @@ class KeyTransform(Transform):
             previous = previous.lhs
         lhs, params = compiler.compile(previous)
         if len(key_transforms) > 1:
-            return "(%s %s %%s)" % (lhs, self.nested_operator), [key_transforms] + params
+            return '(%s %s %%s)' % (lhs, self.nested_operator), params + [key_transforms]
         try:
             lookup = int(self.key_name)
         except ValueError:
diff --git a/tests/postgres_tests/test_json.py b/tests/postgres_tests/test_json.py
index 1360bc85dc..d00f97c36e 100644
--- a/tests/postgres_tests/test_json.py
+++ b/tests/postgres_tests/test_json.py
@@ -6,7 +6,9 @@ from decimal import Decimal
 from django.core import checks, exceptions, serializers
 from django.core.serializers.json import DjangoJSONEncoder
 from django.db import connection
-from django.db.models import Count, Q
+from django.db.models import Count, F, Q
+from django.db.models.expressions import RawSQL
+from django.db.models.functions import Cast
 from django.forms import CharField, Form, widgets
 from django.test.utils import CaptureQueriesContext, isolate_apps
 from django.utils.html import escape
@@ -186,6 +188,23 @@ class TestQuerying(PostgreSQLTestCase):
             operator.itemgetter('key', 'count'),
         )
 
+    def test_nested_key_transform_raw_expression(self):
+        expr = RawSQL('%s::jsonb', ['{"x": {"y": "bar"}}'])
+        self.assertSequenceEqual(
+            JSONModel.objects.filter(field__foo=KeyTransform('y', KeyTransform('x', expr))),
+            [self.objs[-1]],
+        )
+
+    def test_nested_key_transform_expression(self):
+        self.assertSequenceEqual(
+            JSONModel.objects.filter(field__d__0__isnull=False).annotate(
+                key=KeyTransform('d', 'field'),
+                chain=KeyTransform('f', KeyTransform('1', 'key')),
+                expr=KeyTransform('f', KeyTransform('1', Cast('key', JSONField()))),
+            ).filter(chain=F('expr')),
+            [self.objs[8]],
+        )
+
     def test_deep_values(self):
         query = JSONModel.objects.values_list('field__k__l')
         self.assertSequenceEqual(