From bd7de3cb87b40b264ef9c4ecb8af59501525d9a6 Mon Sep 17 00:00:00 2001 From: Brett Haydon Date: Tue, 7 Jun 2016 10:26:24 +1000 Subject: [PATCH] [1.10.x] Fixed #26716 -- Made CurrentSiteMiddleware compatible with new-style middleware. Backport of 5e3f4c2e53d9dde0fcf9f5f33f63c15c4750019f from master --- django/contrib/sites/middleware.py | 4 +++- tests/sites_tests/tests.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/django/contrib/sites/middleware.py b/django/contrib/sites/middleware.py index fdc7923581..bc3bf20c48 100644 --- a/django/contrib/sites/middleware.py +++ b/django/contrib/sites/middleware.py @@ -1,7 +1,9 @@ +from django.utils.deprecation import MiddlewareMixin + from .shortcuts import get_current_site -class CurrentSiteMiddleware(object): +class CurrentSiteMiddleware(MiddlewareMixin): """ Middleware that sets `site` attribute to request object. """ diff --git a/tests/sites_tests/tests.py b/tests/sites_tests/tests.py index fa618b4acc..7a4215cc4a 100644 --- a/tests/sites_tests/tests.py +++ b/tests/sites_tests/tests.py @@ -11,7 +11,7 @@ from django.contrib.sites.requests import RequestSite from django.contrib.sites.shortcuts import get_current_site from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db.models.signals import post_migrate -from django.http import HttpRequest +from django.http import HttpRequest, HttpResponse from django.test import TestCase, modify_settings, override_settings from django.test.utils import captured_stdout @@ -305,9 +305,15 @@ class CreateDefaultSiteTests(TestCase): class MiddlewareTest(TestCase): - def test_request(self): + def test_old_style_request(self): """ Makes sure that the request has correct `site` attribute. """ middleware = CurrentSiteMiddleware() request = HttpRequest() middleware.process_request(request) self.assertEqual(request.site.id, settings.SITE_ID) + + def test_request(self): + def get_response(request): + return HttpResponse(str(request.site.id)) + response = CurrentSiteMiddleware(get_response)(HttpRequest()) + self.assertContains(response, settings.SITE_ID)