From c76200a0bfbbec747fa420cc6ef21868dcc3bf69 Mon Sep 17 00:00:00 2001
From: Jannis Leidel <jannis@leidel.info>
Date: Sat, 3 Mar 2012 01:06:37 +0000
Subject: [PATCH] Fixed #17819 -- Convinced the NamedUrlWizardView to stop
 dropping files when stepping through the forms.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17634 bcc190cf-cafb-0310-a4f2-bffc1f526a37
---
 .../tests/wizard/namedwizardtests/forms.py    | 10 ++++++
 .../tests/wizard/namedwizardtests/tests.py    | 34 +++++++++++++++----
 django/contrib/formtools/wizard/views.py      |  4 +--
 3 files changed, 40 insertions(+), 8 deletions(-)

diff --git a/django/contrib/formtools/tests/wizard/namedwizardtests/forms.py b/django/contrib/formtools/tests/wizard/namedwizardtests/forms.py
index ae981269f8..39e914d05d 100644
--- a/django/contrib/formtools/tests/wizard/namedwizardtests/forms.py
+++ b/django/contrib/formtools/tests/wizard/namedwizardtests/forms.py
@@ -1,4 +1,8 @@
+import os
+import tempfile
+
 from django import forms
+from django.core.files.storage import FileSystemStorage
 from django.forms.formsets import formset_factory
 from django.http import HttpResponse
 from django.template import Template, Context
@@ -7,6 +11,9 @@ from django.contrib.auth.models import User
 
 from django.contrib.formtools.wizard.views import NamedUrlWizardView
 
+temp_storage_location = tempfile.mkdtemp(dir=os.environ.get('DJANGO_TEST_TEMP_DIR'))
+temp_storage = FileSystemStorage(location=temp_storage_location)
+
 class Page1(forms.Form):
     name = forms.CharField(max_length=100)
     user = forms.ModelChoiceField(queryset=User.objects.all())
@@ -15,6 +22,7 @@ class Page1(forms.Form):
 class Page2(forms.Form):
     address1 = forms.CharField(max_length=100)
     address2 = forms.CharField(max_length=100)
+    file1 = forms.FileField()
 
 class Page3(forms.Form):
     random_crap = forms.CharField(max_length=100)
@@ -22,6 +30,8 @@ class Page3(forms.Form):
 Page4 = formset_factory(Page3, extra=2)
 
 class ContactWizard(NamedUrlWizardView):
+    file_storage = temp_storage
+
     def done(self, form_list, **kwargs):
         c = Context({
             'form_list': [x.cleaned_data for x in form_list],
diff --git a/django/contrib/formtools/tests/wizard/namedwizardtests/tests.py b/django/contrib/formtools/tests/wizard/namedwizardtests/tests.py
index 1238c3e618..550622180f 100644
--- a/django/contrib/formtools/tests/wizard/namedwizardtests/tests.py
+++ b/django/contrib/formtools/tests/wizard/namedwizardtests/tests.py
@@ -8,6 +8,7 @@ from django.contrib.formtools.wizard.views import (NamedUrlSessionWizardView,
                                                    NamedUrlCookieWizardView)
 from django.contrib.formtools.tests.wizard.forms import get_request, Step1, Step2
 
+
 class NamedWizardTests(object):
     urls = 'django.contrib.formtools.tests.wizard.namedwizardtests.urls'
 
@@ -30,7 +31,6 @@ class NamedWizardTests(object):
         self.assertEqual(wizard['steps'].count, 4)
         self.assertEqual(wizard['url_name'], self.wizard_urlname)
 
-
     def test_initial_call_with_params(self):
         get_params = {'getvar1': 'getval1', 'getvar2': 'getval2'}
         response = self.client.get(reverse('%s_start' % self.wizard_urlname),
@@ -119,10 +119,12 @@ class NamedWizardTests(object):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.context['wizard']['steps'].current, 'form2')
 
+        post_data = self.wizard_step_data[1]
+        post_data['form2-file1'] = open(__file__)
         response = self.client.post(
             reverse(self.wizard_urlname,
                     kwargs={'step': response.context['wizard']['steps'].current}),
-            self.wizard_step_data[1])
+            post_data)
         response = self.client.get(response['Location'])
 
         self.assertEqual(response.status_code, 200)
@@ -144,7 +146,10 @@ class NamedWizardTests(object):
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
-        self.assertEqual(response.context['form_list'], [
+        all_data = response.context['form_list']
+        self.assertEqual(all_data[1]['file1'].read(), open(__file__).read())
+        del all_data[1]['file1']
+        self.assertEqual(all_data, [
             {'name': u'Pony', 'thirsty': True, 'user': self.testuser},
             {'address1': u'123 Main St', 'address2': u'Djangoland'},
             {'random_crap': u'blah blah'},
@@ -162,13 +167,21 @@ class NamedWizardTests(object):
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
+        post_data = self.wizard_step_data[1]
+        post_data['form2-file1'] = open(__file__)
         response = self.client.post(
             reverse(self.wizard_urlname,
                     kwargs={'step': response.context['wizard']['steps'].current}),
-            self.wizard_step_data[1])
+            post_data)
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
+        step2_url = reverse(self.wizard_urlname, kwargs={'step': 'form2'})
+        response = self.client.get(step2_url)
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.context['wizard']['steps'].current, 'form2')
+        self.assertEqual(response.context['wizard']['form'].files['form2-file1'].read(), open(__file__).read())
+
         response = self.client.post(
             reverse(self.wizard_urlname,
                     kwargs={'step': response.context['wizard']['steps'].current}),
@@ -183,8 +196,11 @@ class NamedWizardTests(object):
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
+        all_data = response.context['all_cleaned_data']
+        self.assertEqual(all_data['file1'].read(), open(__file__).read())
+        del all_data['file1']
         self.assertEqual(
-            response.context['all_cleaned_data'],
+            all_data,
             {'name': u'Pony', 'thirsty': True, 'user': self.testuser,
              'address1': u'123 Main St', 'address2': u'Djangoland',
              'random_crap': u'blah blah', 'formset-form4': [
@@ -204,10 +220,12 @@ class NamedWizardTests(object):
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
+        post_data = self.wizard_step_data[1]
+        post_data['form2-file1'] = open(__file__)
         response = self.client.post(
             reverse(self.wizard_urlname,
                     kwargs={'step': response.context['wizard']['steps'].current}),
-            self.wizard_step_data[1])
+            post_data)
         response = self.client.get(response['Location'])
         self.assertEqual(response.status_code, 200)
 
@@ -246,6 +264,7 @@ class NamedWizardTests(object):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.context['wizard']['steps'].current, 'form1')
 
+
 class NamedSessionWizardTests(NamedWizardTests, TestCase):
     wizard_urlname = 'nwiz_session'
     wizard_step_1_data = {
@@ -276,6 +295,7 @@ class NamedSessionWizardTests(NamedWizardTests, TestCase):
         }
     )
 
+
 class NamedCookieWizardTests(NamedWizardTests, TestCase):
     wizard_urlname = 'nwiz_cookie'
     wizard_step_1_data = {
@@ -321,12 +341,14 @@ class NamedFormTests(object):
         instance.render_done(None)
         self.assertEqual(instance.storage.current_step, 'start')
 
+
 class TestNamedUrlSessionWizardView(NamedUrlSessionWizardView):
 
     def dispatch(self, request, *args, **kwargs):
         response = super(TestNamedUrlSessionWizardView, self).dispatch(request, *args, **kwargs)
         return response, self
 
+
 class TestNamedUrlCookieWizardView(NamedUrlCookieWizardView):
 
     def dispatch(self, request, *args, **kwargs):
diff --git a/django/contrib/formtools/wizard/views.py b/django/contrib/formtools/wizard/views.py
index 4104eaf50b..06a03984a7 100644
--- a/django/contrib/formtools/wizard/views.py
+++ b/django/contrib/formtools/wizard/views.py
@@ -624,14 +624,14 @@ class NamedUrlWizardView(WizardView):
             # URL step name and storage step name are equal, render!
             return self.render(self.get_form(
                 data=self.storage.current_step_data,
-                files=self.storage.current_step_data,
+                files=self.storage.current_step_files,
             ), **kwargs)
 
         elif step_url in self.get_form_list():
             self.storage.current_step = step_url
             return self.render(self.get_form(
                 data=self.storage.current_step_data,
-                files=self.storage.current_step_data,
+                files=self.storage.current_step_files,
             ), **kwargs)
 
         # invalid step name, reset to first and redirect.