diff --git a/django/template/context.py b/django/template/context.py
index 1ef7e889bc..3830c1c660 100644
--- a/django/template/context.py
+++ b/django/template/context.py
@@ -12,6 +12,21 @@ class ContextPopException(Exception):
     "pop() has been called more times than push()"
     pass
 
+
+class ContextDict(dict):
+    def __init__(self, context, *args, **kwargs):
+        super(ContextDict, self).__init__(*args, **kwargs)
+
+        context.dicts.append(self)
+        self.context = context
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args, **kwargs):
+        self.context.pop()
+
+
 class BaseContext(object):
     def __init__(self, dict_=None):
         self._reset_dicts(dict_)
@@ -34,10 +49,8 @@ class BaseContext(object):
         for d in reversed(self.dicts):
             yield d
 
-    def push(self):
-        d = {}
-        self.dicts.append(d)
-        return d
+    def push(self, *args, **kwargs):
+        return ContextDict(self, *args, **kwargs)
 
     def pop(self):
         if len(self.dicts) == 1:
@@ -83,6 +96,7 @@ class BaseContext(object):
         new_context._reset_dicts(values)
         return new_context
 
+
 class Context(BaseContext):
     "A stack container for variable context"
     def __init__(self, dict_=None, autoescape=True, current_app=None,
@@ -106,6 +120,7 @@ class Context(BaseContext):
         self.dicts.append(other_dict)
         return other_dict
 
+
 class RenderContext(BaseContext):
     """
     A stack container for storing Template state.
diff --git a/django/template/defaulttags.py b/django/template/defaulttags.py
index 5b2a1b9501..5c9490f749 100644
--- a/django/template/defaulttags.py
+++ b/django/template/defaulttags.py
@@ -95,10 +95,9 @@ class FilterNode(Node):
     def render(self, context):
         output = self.nodelist.render(context)
         # Apply filters.
-        context.update({'var': output})
-        filtered = self.filter_expr.resolve(context)
-        context.pop()
-        return filtered
+        with context.push(var=output):
+            return self.filter_expr.resolve(context)
+
 
 class FirstOfNode(Node):
     def __init__(self, variables, escape=False):
@@ -143,71 +142,69 @@ class ForNode(Node):
             parentloop = context['forloop']
         else:
             parentloop = {}
-        context.push()
-        try:
-            values = self.sequence.resolve(context, True)
-        except VariableDoesNotExist:
-            values = []
-        if values is None:
-            values = []
-        if not hasattr(values, '__len__'):
-            values = list(values)
-        len_values = len(values)
-        if len_values < 1:
-            context.pop()
-            return self.nodelist_empty.render(context)
-        nodelist = NodeList()
-        if self.is_reversed:
-            values = reversed(values)
-        unpack = len(self.loopvars) > 1
-        # Create a forloop value in the context.  We'll update counters on each
-        # iteration just below.
-        loop_dict = context['forloop'] = {'parentloop': parentloop}
-        for i, item in enumerate(values):
-            # Shortcuts for current loop iteration number.
-            loop_dict['counter0'] = i
-            loop_dict['counter'] = i+1
-            # Reverse counter iteration numbers.
-            loop_dict['revcounter'] = len_values - i
-            loop_dict['revcounter0'] = len_values - i - 1
-            # Boolean values designating first and last times through loop.
-            loop_dict['first'] = (i == 0)
-            loop_dict['last'] = (i == len_values - 1)
+        with context.push():
+            try:
+                values = self.sequence.resolve(context, True)
+            except VariableDoesNotExist:
+                values = []
+            if values is None:
+                values = []
+            if not hasattr(values, '__len__'):
+                values = list(values)
+            len_values = len(values)
+            if len_values < 1:
+                return self.nodelist_empty.render(context)
+            nodelist = NodeList()
+            if self.is_reversed:
+                values = reversed(values)
+            unpack = len(self.loopvars) > 1
+            # Create a forloop value in the context.  We'll update counters on each
+            # iteration just below.
+            loop_dict = context['forloop'] = {'parentloop': parentloop}
+            for i, item in enumerate(values):
+                # Shortcuts for current loop iteration number.
+                loop_dict['counter0'] = i
+                loop_dict['counter'] = i+1
+                # Reverse counter iteration numbers.
+                loop_dict['revcounter'] = len_values - i
+                loop_dict['revcounter0'] = len_values - i - 1
+                # Boolean values designating first and last times through loop.
+                loop_dict['first'] = (i == 0)
+                loop_dict['last'] = (i == len_values - 1)
 
-            pop_context = False
-            if unpack:
-                # If there are multiple loop variables, unpack the item into
-                # them.
-                try:
-                    unpacked_vars = dict(zip(self.loopvars, item))
-                except TypeError:
-                    pass
-                else:
-                    pop_context = True
-                    context.update(unpacked_vars)
-            else:
-                context[self.loopvars[0]] = item
-            # In TEMPLATE_DEBUG mode provide source of the node which
-            # actually raised the exception
-            if settings.TEMPLATE_DEBUG:
-                for node in self.nodelist_loop:
+                pop_context = False
+                if unpack:
+                    # If there are multiple loop variables, unpack the item into
+                    # them.
                     try:
+                        unpacked_vars = dict(zip(self.loopvars, item))
+                    except TypeError:
+                        pass
+                    else:
+                        pop_context = True
+                        context.update(unpacked_vars)
+                else:
+                    context[self.loopvars[0]] = item
+                # In TEMPLATE_DEBUG mode provide source of the node which
+                # actually raised the exception
+                if settings.TEMPLATE_DEBUG:
+                    for node in self.nodelist_loop:
+                        try:
+                            nodelist.append(node.render(context))
+                        except Exception as e:
+                            if not hasattr(e, 'django_template_source'):
+                                e.django_template_source = node.source
+                            raise
+                else:
+                    for node in self.nodelist_loop:
                         nodelist.append(node.render(context))
-                    except Exception as e:
-                        if not hasattr(e, 'django_template_source'):
-                            e.django_template_source = node.source
-                        raise
-            else:
-                for node in self.nodelist_loop:
-                    nodelist.append(node.render(context))
-            if pop_context:
-                # The loop variables were pushed on to the context so pop them
-                # off again. This is necessary because the tag lets the length
-                # of loopvars differ to the length of each set of items and we
-                # don't want to leave any vars from the previous loop on the
-                # context.
-                context.pop()
-        context.pop()
+                if pop_context:
+                    # The loop variables were pushed on to the context so pop them
+                    # off again. This is necessary because the tag lets the length
+                    # of loopvars differ to the length of each set of items and we
+                    # don't want to leave any vars from the previous loop on the
+                    # context.
+                    context.pop()
         return nodelist.render(context)
 
 class IfChangedNode(Node):
@@ -500,10 +497,9 @@ class WithNode(Node):
     def render(self, context):
         values = dict([(key, val.resolve(context)) for key, val in
                        six.iteritems(self.extra_context)])
-        context.update(values)
-        output = self.nodelist.render(context)
-        context.pop()
-        return output
+        with context.push(**values):
+            return self.nodelist.render(context)
+
 
 @register.tag
 def autoescape(parser, token):
diff --git a/django/template/loader.py b/django/template/loader.py
index 6df4e43c4f..44b8f600fb 100644
--- a/django/template/loader.py
+++ b/django/template/loader.py
@@ -164,11 +164,8 @@ def render_to_string(template_name, dictionary=None, context_instance=None):
         return t.render(Context(dictionary))
     # Add the dictionary to the context stack, ensuring it gets removed again
     # to keep the context_instance in the same state it started in.
-    context_instance.update(dictionary)
-    try:
+    with context_instance.push(dictionary):
         return t.render(context_instance)
-    finally:
-        context_instance.pop()
 
 def select_template(template_name_list):
     "Given a list of template names, returns the first that can be loaded."
diff --git a/django/template/loader_tags.py b/django/template/loader_tags.py
index 767f0e5ff8..406775da9d 100644
--- a/django/template/loader_tags.py
+++ b/django/template/loader_tags.py
@@ -47,22 +47,21 @@ class BlockNode(Node):
 
     def render(self, context):
         block_context = context.render_context.get(BLOCK_CONTEXT_KEY)
-        context.push()
-        if block_context is None:
-            context['block'] = self
-            result = self.nodelist.render(context)
-        else:
-            push = block = block_context.pop(self.name)
-            if block is None:
-                block = self
-            # Create new block so we can store context without thread-safety issues.
-            block = BlockNode(block.name, block.nodelist)
-            block.context = context
-            context['block'] = block
-            result = block.nodelist.render(context)
-            if push is not None:
-                block_context.push(self.name, push)
-        context.pop()
+        with context.push():
+            if block_context is None:
+                context['block'] = self
+                result = self.nodelist.render(context)
+            else:
+                push = block = block_context.pop(self.name)
+                if block is None:
+                    block = self
+                # Create new block so we can store context without thread-safety issues.
+                block = BlockNode(block.name, block.nodelist)
+                block.context = context
+                context['block'] = block
+                result = block.nodelist.render(context)
+                if push is not None:
+                    block_context.push(self.name, push)
         return result
 
     def super(self):
@@ -133,10 +132,9 @@ class BaseIncludeNode(Node):
                        in six.iteritems(self.extra_context)])
         if self.isolated_context:
             return template.render(context.new(values))
-        context.update(values)
-        output = template.render(context)
-        context.pop()
-        return output
+        with context.push(**values):
+            return template.render(context)
+
 
 class ConstantIncludeNode(BaseIncludeNode):
     def __init__(self, template_path, *args, **kwargs):
diff --git a/docs/ref/templates/api.txt b/docs/ref/templates/api.txt
index 6a9efc0811..f7dd0121d1 100644
--- a/docs/ref/templates/api.txt
+++ b/docs/ref/templates/api.txt
@@ -325,6 +325,31 @@ If you ``pop()`` too much, it'll raise
     ...
     django.template.ContextPopException
 
+.. versionadded:: 1.7
+
+You can also use ``push()`` as a context manager to ensure a matching ``pop()``
+is called.
+
+    >>> c = Context()
+    >>> c['foo'] = 'first level'
+    >>> with c.push():
+    >>>     c['foo'] = 'second level'
+    >>>     c['foo']
+    'second level'
+    >>> c['foo']
+    'first level'
+
+All arguments passed to ``push()`` will be passed to the ``dict`` constructor
+used to build the new context level.
+
+    >>> c = Context()
+    >>> c['foo'] = 'first level'
+    >>> with c.push(foo='second level'):
+    >>>     c['foo']
+    'second level'
+    >>> c['foo']
+    'first level'
+
 .. method:: update(other_dict)
 
 In addition to ``push()`` and ``pop()``, the ``Context``
diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt
index ae2a80d7a3..f551828455 100644
--- a/docs/releases/1.7.txt
+++ b/docs/releases/1.7.txt
@@ -60,6 +60,13 @@ Minor features
 * :attr:`~django.db.models.Options.app_label` is no longer required for models
   that are defined in a ``models`` package within an app.
 
+* The :meth:`Context.push() <django.template.Context.push>` method now returns
+  a context manager which automatically calls :meth:`pop()
+  <django.template.Context.pop>` upon exiting the ``with`` statement.
+  Additionally, :meth:`push() <django.template.Context.push>` now accepts
+  parameters that are passed to the ``dict`` constructor used to build the new
+  context level.
+
 Backwards incompatible changes in 1.7
 =====================================
 
diff --git a/tests/template_tests/test_context.py b/tests/template_tests/test_context.py
index 224b94d060..ca167a73f3 100644
--- a/tests/template_tests/test_context.py
+++ b/tests/template_tests/test_context.py
@@ -16,3 +16,12 @@ class ContextTests(TestCase):
         self.assertEqual(c.pop(), {"a": 2})
         self.assertEqual(c["a"], 1)
         self.assertEqual(c.get("foo", 42), 42)
+
+        with c.push():
+            c['a'] = 2
+            self.assertEqual(c['a'], 2)
+        self.assertEqual(c['a'], 1)
+
+        with c.push(a=3):
+            self.assertEqual(c['a'], 3)
+        self.assertEqual(c['a'], 1)