diff --git a/feedgen/entry.py b/feedgen/entry.py index 18d63ca..d1286c1 100644 --- a/feedgen/entry.py +++ b/feedgen/entry.py @@ -641,33 +641,22 @@ class FeedEntry(object): # Load extension extname = name[0].upper() + name[1:] + 'EntryExtension' - - # Try to import extension from dedicated module for entry: try: supmod = __import__('feedgen.ext.%s_entry' % name) extmod = getattr(supmod.ext, name + '_entry') except ImportError: - # Try the FeedExtension module instead + # Use FeedExtension module instead supmod = __import__('feedgen.ext.%s' % name) extmod = getattr(supmod.ext, name) + ext = getattr(extmod, extname) + self.register_extension(name, ext, atom, rss) - ext = getattr(extmod, extname) - extinst = ext() - setattr(self, name, extinst) - self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss} - def register_extension( - self, - namespace, - extension_class_feed=None, - extension_class_entry=None, - atom=True, - rss=True - ): + def register_extension(self, namespace, extension_class_entry=None, + atom=True, rss=True): '''Register a specific extension by classes to a namespace. :param namespace: namespace for the extension - :param extension_class_feed: Class of the feed extension to load. :param extension_class_entry: Class of the entry extension to load. :param atom: If the extension should be used for ATOM feeds. :param rss: If the extension should be used for RSS feeds. @@ -678,20 +667,16 @@ class FeedEntry(object): self.__extensions = {} if namespace in self.__extensions.keys(): raise ImportError('Extension already loaded') + if not extension_class_entry: + raise ImportError('No extension class') extinst = extension_class_entry() setattr(self, namespace, extinst) # `load_extension` registry - self.__extensions[namespace] = {'inst': extinst, - 'atom': atom, - 'rss': rss - } - - # `register_extension` registry - self.__extensions_register[namespace] = { - 'extension_class_feed': extension_class_feed, - 'extension_class_entry': extension_class_entry, - 'atom': atom, - 'rss': rss, - } \ No newline at end of file + self.__extensions[namespace] = { + 'inst':extinst, + 'extension_class_entry': extension_class_entry, + 'atom':atom, + 'rss':rss + } diff --git a/feedgen/feed.py b/feedgen/feed.py index e79e7d9..98da277 100644 --- a/feedgen/feed.py +++ b/feedgen/feed.py @@ -82,7 +82,6 @@ class FeedGenerator(object): # Extension list: self.__extensions = {} - self.__extensions_register = {} def _create_atom(self, extensions=True): @@ -1024,14 +1023,9 @@ class FeedGenerator(object): # Try to load extensions: for extname,ext in items: try: - if extname in self.__extensions_register: - ext_reg = self.__extensions_register[extname] - feedEntry.register_extension(extname, - ext_reg['extension_class_feed'], - ext_reg['extension_class_entry'], - ext_reg['atom'], ext_reg['rss'] ) - else: - feedEntry.load_extension( extname, ext['atom'], ext['rss'] ) + feedEntry.register_extension(extname, + ext['extension_class_entry'], + ext['atom'], ext['rss'] ) except ImportError: pass @@ -1073,14 +1067,9 @@ class FeedGenerator(object): for e in entry: for extname,ext in items: try: - if extname in self.__extensions_register: - ext_reg = self.__extensions_register[extname] - e.register_extension(extname, - ext_reg['extension_class_feed'], - ext_reg['extension_class_entry'], - ext_reg['atom'], ext_reg['rss'] ) - else: - e.load_extension( extname, ext['atom'], ext['rss'] ) + e.register_extension(extname, + ext['extension_class_entry'], + ext['atom'], ext['rss'] ) except ImportError: pass @@ -1127,29 +1116,26 @@ class FeedGenerator(object): raise ImportError('Extension already loaded') # Load extension - extname = name[0].upper() + name[1:] + 'Extension' - supmod = __import__('feedgen.ext.%s' % name) - extmod = getattr(supmod.ext, name) - ext = getattr(extmod, extname) - extinst = ext() - setattr(self, name, extinst) - self.__extensions[name] = {'inst':extinst,'atom':atom,'rss':rss} + extname = name[0].upper() + name[1:] + feedsupmod = __import__('feedgen.ext.%s' % name) + feedextmod = getattr(feedsupmod.ext, name) + try: + entrysupmod = __import__('feedgen.ext.%s_entry' % name) + entryextmod = getattr(entrysupmod.ext, name + '_entry') + except ImportError: + # Use FeedExtension module instead + entrysupmod = feedsupmod + entryextmod = feedextmod + feedext = getattr(feedextmod, extname + 'Extension') + try: + entryext = getattr(entryextmod, extname + 'EntryExtension') + except AttributeError: + entryext = None + self.register_extension(name, feedext, entryext, atom, rss) - # Try to load the extension for already existing entries: - for entry in self.__feed_entries: - try: - entry.load_extension( name, atom, rss ) - except ImportError: - pass - def register_extension( - self, - namespace, - extension_class_feed = None, - extension_class_entry = None, - atom=True, - rss=True - ): + def register_extension(self, namespace, extension_class_feed = None, + extension_class_entry = None, atom=True, rss=True): '''Registers an extension by class. :param namespace: namespace for the extension @@ -1166,29 +1152,22 @@ class FeedGenerator(object): raise ImportError('Extension already loaded') # Load extension - extinst = extension_class_entry() + extinst = extension_class_feed() setattr(self, namespace, extinst) # `load_extension` registry - self.__extensions[namespace] = {'inst':extinst, - 'atom':atom, - 'rss':rss - } - - # `register_extension` registry - self.__extensions_register[namespace] = { - 'extension_class_feed': extension_class_feed, - 'extension_class_entry': extension_class_entry, - 'atom': atom, - 'rss': rss, - } + self.__extensions[namespace] = { + 'inst':extinst, + 'extension_class_feed': extension_class_feed, + 'extension_class_entry': extension_class_entry, + 'atom':atom, + 'rss':rss + } # Try to load the extension for already existing entries: for entry in self.__feed_entries: try: entry.register_extension(namespace, - extension_class_entry, - extension_class_feed, - atom, rss) + extension_class_entry, atom, rss) except ImportError: - raise \ No newline at end of file + pass diff --git a/feedgen/tests/test_feed.py b/feedgen/tests/test_feed.py index bcfe506..83245d3 100644 --- a/feedgen/tests/test_feed.py +++ b/feedgen/tests/test_feed.py @@ -10,6 +10,7 @@ A basic feed does not contain entries so far. import unittest from lxml import etree from ..feed import FeedGenerator +from ..ext.dc import DcExtension, DcEntryExtension class TestSequenceFunctions(unittest.TestCase): @@ -234,11 +235,27 @@ class TestSequenceFunctions(unittest.TestCase): def test_loadPodcastExtension(self): fg = self.fg + fg.add_entry() fg.load_extension('podcast', atom=True, rss=True) + fg.add_entry() def test_loadDcExtension(self): fg = self.fg + fg.add_entry() fg.load_extension('dc', atom=True, rss=True) + fg.add_entry() + + def test_extensionAlreadyLoaded(self): + fg = self.fg + fg.load_extension('dc', atom=True, rss=True) + with self.assertRaises(ImportError) as context: + fg.load_extension('dc') + + def test_registerCustomExtension(self): + fg = self.fg + fg.add_entry() + fg.register_extension('dc', DcExtension, DcEntryExtension) + fg.add_entry() def checkRssString(self, rssString):