Fixed registration of custom extensions

This fixes some minor problems with the pull request from jvanasco so
that it may be included into the next release:

- Only one extension registry is used internally and custom extensions
  do not have to be stored two times.
- Fix the bug that entry extensions were loaded for feeds.
- Do not fail if there is only a feed extension.
- Extensions for entries do not need a feed extension class.

Signed-off-by: Lars Kiesow <lkiesow@uos.de>
This commit is contained in:
Lars Kiesow 2016-08-28 20:54:11 +02:00
parent 1a2c032654
commit 303e74dc7a
No known key found for this signature in database
GPG key ID: 5DAFE8D9C823CE73
3 changed files with 64 additions and 83 deletions

View file

@ -641,33 +641,22 @@ class FeedEntry(object):
# Load extension
extname = name[0].upper() + name[1:] + 'EntryExtension'
# Try to import extension from dedicated module for entry:
try:
supmod = __import__('feedgen.ext.%s_entry' % name)
extmod = getattr(supmod.ext, name + '_entry')
except ImportError:
# Try the FeedExtension module instead
# Use FeedExtension module instead
supmod = __import__('feedgen.ext.%s' % name)
extmod = getattr(supmod.ext, name)
ext = getattr(extmod, extname)
extinst = ext()
setattr(self, name, extinst)
self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss}
self.register_extension(name, ext, atom, rss)
def register_extension(
self,
namespace,
extension_class_feed=None,
extension_class_entry=None,
atom=True,
rss=True
):
def register_extension(self, namespace, extension_class_entry=None,
atom=True, rss=True):
'''Register a specific extension by classes to a namespace.
:param namespace: namespace for the extension
:param extension_class_feed: Class of the feed extension to load.
:param extension_class_entry: Class of the entry extension to load.
:param atom: If the extension should be used for ATOM feeds.
:param rss: If the extension should be used for RSS feeds.
@ -678,20 +667,16 @@ class FeedEntry(object):
self.__extensions = {}
if namespace in self.__extensions.keys():
raise ImportError('Extension already loaded')
if not extension_class_entry:
raise ImportError('No extension class')
extinst = extension_class_entry()
setattr(self, namespace, extinst)
# `load_extension` registry
self.__extensions[namespace] = {'inst': extinst,
self.__extensions[namespace] = {
'inst':extinst,
'extension_class_entry': extension_class_entry,
'atom':atom,
'rss':rss
}
# `register_extension` registry
self.__extensions_register[namespace] = {
'extension_class_feed': extension_class_feed,
'extension_class_entry': extension_class_entry,
'atom': atom,
'rss': rss,
}

View file

@ -82,7 +82,6 @@ class FeedGenerator(object):
# Extension list:
self.__extensions = {}
self.__extensions_register = {}
def _create_atom(self, extensions=True):
@ -1024,14 +1023,9 @@ class FeedGenerator(object):
# Try to load extensions:
for extname,ext in items:
try:
if extname in self.__extensions_register:
ext_reg = self.__extensions_register[extname]
feedEntry.register_extension(extname,
ext_reg['extension_class_feed'],
ext_reg['extension_class_entry'],
ext_reg['atom'], ext_reg['rss'] )
else:
feedEntry.load_extension( extname, ext['atom'], ext['rss'] )
ext['extension_class_entry'],
ext['atom'], ext['rss'] )
except ImportError:
pass
@ -1073,14 +1067,9 @@ class FeedGenerator(object):
for e in entry:
for extname,ext in items:
try:
if extname in self.__extensions_register:
ext_reg = self.__extensions_register[extname]
e.register_extension(extname,
ext_reg['extension_class_feed'],
ext_reg['extension_class_entry'],
ext_reg['atom'], ext_reg['rss'] )
else:
e.load_extension( extname, ext['atom'], ext['rss'] )
ext['extension_class_entry'],
ext['atom'], ext['rss'] )
except ImportError:
pass
@ -1127,29 +1116,26 @@ class FeedGenerator(object):
raise ImportError('Extension already loaded')
# Load extension
extname = name[0].upper() + name[1:] + 'Extension'
supmod = __import__('feedgen.ext.%s' % name)
extmod = getattr(supmod.ext, name)
ext = getattr(extmod, extname)
extinst = ext()
setattr(self, name, extinst)
self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss}
# Try to load the extension for already existing entries:
for entry in self.__feed_entries:
extname = name[0].upper() + name[1:]
feedsupmod = __import__('feedgen.ext.%s' % name)
feedextmod = getattr(feedsupmod.ext, name)
try:
entry.load_extension( name, atom, rss )
entrysupmod = __import__('feedgen.ext.%s_entry' % name)
entryextmod = getattr(entrysupmod.ext, name + '_entry')
except ImportError:
pass
# Use FeedExtension module instead
entrysupmod = feedsupmod
entryextmod = feedextmod
feedext = getattr(feedextmod, extname + 'Extension')
try:
entryext = getattr(entryextmod, extname + 'EntryExtension')
except AttributeError:
entryext = None
self.register_extension(name, feedext, entryext, atom, rss)
def register_extension(
self,
namespace,
extension_class_feed = None,
extension_class_entry = None,
atom=True,
rss=True
):
def register_extension(self, namespace, extension_class_feed = None,
extension_class_entry = None, atom=True, rss=True):
'''Registers an extension by class.
:param namespace: namespace for the extension
@ -1166,29 +1152,22 @@ class FeedGenerator(object):
raise ImportError('Extension already loaded')
# Load extension
extinst = extension_class_entry()
extinst = extension_class_feed()
setattr(self, namespace, extinst)
# `load_extension` registry
self.__extensions[namespace] = {'inst':extinst,
'atom':atom,
'rss':rss
}
# `register_extension` registry
self.__extensions_register[namespace] = {
self.__extensions[namespace] = {
'inst':extinst,
'extension_class_feed': extension_class_feed,
'extension_class_entry': extension_class_entry,
'atom':atom,
'rss': rss,
'rss':rss
}
# Try to load the extension for already existing entries:
for entry in self.__feed_entries:
try:
entry.register_extension(namespace,
extension_class_entry,
extension_class_feed,
atom, rss)
extension_class_entry, atom, rss)
except ImportError:
raise
pass

View file

@ -10,6 +10,7 @@ A basic feed does not contain entries so far.
import unittest
from lxml import etree
from ..feed import FeedGenerator
from ..ext.dc import DcExtension, DcEntryExtension
class TestSequenceFunctions(unittest.TestCase):
@ -234,11 +235,27 @@ class TestSequenceFunctions(unittest.TestCase):
def test_loadPodcastExtension(self):
fg = self.fg
fg.add_entry()
fg.load_extension('podcast', atom=True, rss=True)
fg.add_entry()
def test_loadDcExtension(self):
fg = self.fg
fg.add_entry()
fg.load_extension('dc', atom=True, rss=True)
fg.add_entry()
def test_extensionAlreadyLoaded(self):
fg = self.fg
fg.load_extension('dc', atom=True, rss=True)
with self.assertRaises(ImportError) as context:
fg.load_extension('dc')
def test_registerCustomExtension(self):
fg = self.fg
fg.add_entry()
fg.register_extension('dc', DcExtension, DcEntryExtension)
fg.add_entry()
def checkRssString(self, rssString):