diff --git a/django/utils/feedgenerator.py b/django/utils/feedgenerator.py index da6e0a8dc8..705dd84502 100644 --- a/django/utils/feedgenerator.py +++ b/django/utils/feedgenerator.py @@ -178,13 +178,16 @@ class RssFeed(SyndicationFeed): def write(self, outfile, encoding): handler = SimplerXMLGenerator(outfile, encoding) handler.startDocument() - handler.startElement(u"rss", {u"version": self._version}) + handler.startElement(u"rss", self.rss_attributes()) handler.startElement(u"channel", self.root_attributes()) self.add_root_elements(handler) self.write_items(handler) self.endChannelElement(handler) handler.endElement(u"rss") + def rss_attributes(self): + return {u"version": self._version} + def write_items(self, handler): for item in self.items: handler.startElement(u'item', self.item_attributes(item)) @@ -266,7 +269,7 @@ class Atom1Feed(SyndicationFeed): self.write_items(handler) handler.endElement(u"feed") - def root_element_attributes(self): + def root_attributes(self): if self.feed['language'] is not None: return {u"xmlns": self.ns, u"xml:lang": self.feed['language']} else: diff --git a/tests/regressiontests/syndication/tests.py b/tests/regressiontests/syndication/tests.py index 0938f69e5b..caf5e4f04d 100644 --- a/tests/regressiontests/syndication/tests.py +++ b/tests/regressiontests/syndication/tests.py @@ -20,9 +20,19 @@ class SyndicationFeedTest(TestCase): def test_rss_feed(self): response = self.client.get('/syndication/feeds/rss/') doc = minidom.parseString(response.content) - self.assertEqual(len(doc.getElementsByTagName('channel')), 1) - - chan = doc.getElementsByTagName('channel')[0] + + # Making sure there's only 1 `rss` element and that the correct + # RSS version was specified. + feed_elem = doc.getElementsByTagName('rss') + self.assertEqual(len(feed_elem), 1) + feed = feed_elem[0] + self.assertEqual(feed.getAttribute('version'), '2.0') + + # Making sure there's only one `channel` element w/in the + # `rss` element. + chan_elem = feed.getElementsByTagName('channel') + self.assertEqual(len(chan_elem), 1) + chan = chan_elem[0] self.assertChildNodes(chan, ['title', 'link', 'description', 'language', 'lastBuildDate', 'item']) items = chan.getElementsByTagName('item') @@ -36,6 +46,7 @@ class SyndicationFeedTest(TestCase): feed = doc.firstChild self.assertEqual(feed.nodeName, 'feed') + self.assertEqual(feed.getAttribute('xmlns'), 'http://www.w3.org/2005/Atom') self.assertChildNodes(feed, ['title', 'link', 'id', 'updated', 'entry']) entries = feed.getElementsByTagName('entry')