Fixed registration of custom extensions

This fixes some minor problems with the pull request from jvanasco so
that it may be included into the next release:

- Only one extension registry is used internally and custom extensions
  do not have to be stored two times.
- Fix the bug that entry extensions were loaded for feeds.
- Do not fail if there is only a feed extension.
- Extensions for entries do not need a feed extension class.

Signed-off-by: Lars Kiesow <lkiesow@uos.de>
This commit is contained in:
Lars Kiesow 2016-08-28 20:54:11 +02:00
parent 1a2c032654
commit 303e74dc7a
No known key found for this signature in database
GPG key ID: 5DAFE8D9C823CE73
3 changed files with 64 additions and 83 deletions

View file

@ -641,33 +641,22 @@ class FeedEntry(object):
# Load extension # Load extension
extname = name[0].upper() + name[1:] + 'EntryExtension' extname = name[0].upper() + name[1:] + 'EntryExtension'
# Try to import extension from dedicated module for entry:
try: try:
supmod = __import__('feedgen.ext.%s_entry' % name) supmod = __import__('feedgen.ext.%s_entry' % name)
extmod = getattr(supmod.ext, name + '_entry') extmod = getattr(supmod.ext, name + '_entry')
except ImportError: except ImportError:
# Try the FeedExtension module instead # Use FeedExtension module instead
supmod = __import__('feedgen.ext.%s' % name) supmod = __import__('feedgen.ext.%s' % name)
extmod = getattr(supmod.ext, name) extmod = getattr(supmod.ext, name)
ext = getattr(extmod, extname)
self.register_extension(name, ext, atom, rss)
ext = getattr(extmod, extname)
extinst = ext()
setattr(self, name, extinst)
self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss}
def register_extension( def register_extension(self, namespace, extension_class_entry=None,
self, atom=True, rss=True):
namespace,
extension_class_feed=None,
extension_class_entry=None,
atom=True,
rss=True
):
'''Register a specific extension by classes to a namespace. '''Register a specific extension by classes to a namespace.
:param namespace: namespace for the extension :param namespace: namespace for the extension
:param extension_class_feed: Class of the feed extension to load.
:param extension_class_entry: Class of the entry extension to load. :param extension_class_entry: Class of the entry extension to load.
:param atom: If the extension should be used for ATOM feeds. :param atom: If the extension should be used for ATOM feeds.
:param rss: If the extension should be used for RSS feeds. :param rss: If the extension should be used for RSS feeds.
@ -678,20 +667,16 @@ class FeedEntry(object):
self.__extensions = {} self.__extensions = {}
if namespace in self.__extensions.keys(): if namespace in self.__extensions.keys():
raise ImportError('Extension already loaded') raise ImportError('Extension already loaded')
if not extension_class_entry:
raise ImportError('No extension class')
extinst = extension_class_entry() extinst = extension_class_entry()
setattr(self, namespace, extinst) setattr(self, namespace, extinst)
# `load_extension` registry # `load_extension` registry
self.__extensions[namespace] = {'inst': extinst, self.__extensions[namespace] = {
'atom': atom, 'inst':extinst,
'rss': rss 'extension_class_entry': extension_class_entry,
} 'atom':atom,
'rss':rss
# `register_extension` registry }
self.__extensions_register[namespace] = {
'extension_class_feed': extension_class_feed,
'extension_class_entry': extension_class_entry,
'atom': atom,
'rss': rss,
}

View file

@ -82,7 +82,6 @@ class FeedGenerator(object):
# Extension list: # Extension list:
self.__extensions = {} self.__extensions = {}
self.__extensions_register = {}
def _create_atom(self, extensions=True): def _create_atom(self, extensions=True):
@ -1024,14 +1023,9 @@ class FeedGenerator(object):
# Try to load extensions: # Try to load extensions:
for extname,ext in items: for extname,ext in items:
try: try:
if extname in self.__extensions_register: feedEntry.register_extension(extname,
ext_reg = self.__extensions_register[extname] ext['extension_class_entry'],
feedEntry.register_extension(extname, ext['atom'], ext['rss'] )
ext_reg['extension_class_feed'],
ext_reg['extension_class_entry'],
ext_reg['atom'], ext_reg['rss'] )
else:
feedEntry.load_extension( extname, ext['atom'], ext['rss'] )
except ImportError: except ImportError:
pass pass
@ -1073,14 +1067,9 @@ class FeedGenerator(object):
for e in entry: for e in entry:
for extname,ext in items: for extname,ext in items:
try: try:
if extname in self.__extensions_register: e.register_extension(extname,
ext_reg = self.__extensions_register[extname] ext['extension_class_entry'],
e.register_extension(extname, ext['atom'], ext['rss'] )
ext_reg['extension_class_feed'],
ext_reg['extension_class_entry'],
ext_reg['atom'], ext_reg['rss'] )
else:
e.load_extension( extname, ext['atom'], ext['rss'] )
except ImportError: except ImportError:
pass pass
@ -1127,29 +1116,26 @@ class FeedGenerator(object):
raise ImportError('Extension already loaded') raise ImportError('Extension already loaded')
# Load extension # Load extension
extname = name[0].upper() + name[1:] + 'Extension' extname = name[0].upper() + name[1:]
supmod = __import__('feedgen.ext.%s' % name) feedsupmod = __import__('feedgen.ext.%s' % name)
extmod = getattr(supmod.ext, name) feedextmod = getattr(feedsupmod.ext, name)
ext = getattr(extmod, extname) try:
extinst = ext() entrysupmod = __import__('feedgen.ext.%s_entry' % name)
setattr(self, name, extinst) entryextmod = getattr(entrysupmod.ext, name + '_entry')
self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss} except ImportError:
# Use FeedExtension module instead
entrysupmod = feedsupmod
entryextmod = feedextmod
feedext = getattr(feedextmod, extname + 'Extension')
try:
entryext = getattr(entryextmod, extname + 'EntryExtension')
except AttributeError:
entryext = None
self.register_extension(name, feedext, entryext, atom, rss)
# Try to load the extension for already existing entries:
for entry in self.__feed_entries:
try:
entry.load_extension( name, atom, rss )
except ImportError:
pass
def register_extension( def register_extension(self, namespace, extension_class_feed = None,
self, extension_class_entry = None, atom=True, rss=True):
namespace,
extension_class_feed = None,
extension_class_entry = None,
atom=True,
rss=True
):
'''Registers an extension by class. '''Registers an extension by class.
:param namespace: namespace for the extension :param namespace: namespace for the extension
@ -1166,29 +1152,22 @@ class FeedGenerator(object):
raise ImportError('Extension already loaded') raise ImportError('Extension already loaded')
# Load extension # Load extension
extinst = extension_class_entry() extinst = extension_class_feed()
setattr(self, namespace, extinst) setattr(self, namespace, extinst)
# `load_extension` registry # `load_extension` registry
self.__extensions[namespace] = {'inst':extinst, self.__extensions[namespace] = {
'atom':atom, 'inst':extinst,
'rss':rss 'extension_class_feed': extension_class_feed,
} 'extension_class_entry': extension_class_entry,
'atom':atom,
# `register_extension` registry 'rss':rss
self.__extensions_register[namespace] = { }
'extension_class_feed': extension_class_feed,
'extension_class_entry': extension_class_entry,
'atom': atom,
'rss': rss,
}
# Try to load the extension for already existing entries: # Try to load the extension for already existing entries:
for entry in self.__feed_entries: for entry in self.__feed_entries:
try: try:
entry.register_extension(namespace, entry.register_extension(namespace,
extension_class_entry, extension_class_entry, atom, rss)
extension_class_feed,
atom, rss)
except ImportError: except ImportError:
raise pass

View file

@ -10,6 +10,7 @@ A basic feed does not contain entries so far.
import unittest import unittest
from lxml import etree from lxml import etree
from ..feed import FeedGenerator from ..feed import FeedGenerator
from ..ext.dc import DcExtension, DcEntryExtension
class TestSequenceFunctions(unittest.TestCase): class TestSequenceFunctions(unittest.TestCase):
@ -234,11 +235,27 @@ class TestSequenceFunctions(unittest.TestCase):
def test_loadPodcastExtension(self): def test_loadPodcastExtension(self):
fg = self.fg fg = self.fg
fg.add_entry()
fg.load_extension('podcast', atom=True, rss=True) fg.load_extension('podcast', atom=True, rss=True)
fg.add_entry()
def test_loadDcExtension(self): def test_loadDcExtension(self):
fg = self.fg fg = self.fg
fg.add_entry()
fg.load_extension('dc', atom=True, rss=True) fg.load_extension('dc', atom=True, rss=True)
fg.add_entry()
def test_extensionAlreadyLoaded(self):
fg = self.fg
fg.load_extension('dc', atom=True, rss=True)
with self.assertRaises(ImportError) as context:
fg.load_extension('dc')
def test_registerCustomExtension(self):
fg = self.fg
fg.add_entry()
fg.register_extension('dc', DcExtension, DcEntryExtension)
fg.add_entry()
def checkRssString(self, rssString): def checkRssString(self, rssString):