diff --git a/feedgen/__main__.py b/feedgen/__main__.py index f072895..2cf1346 100644 --- a/feedgen/__main__.py +++ b/feedgen/__main__.py @@ -44,7 +44,7 @@ def print_enc(s): print(s) -if __name__ == '__main__': +def main(): if len(sys.argv) != 2 or not ( sys.argv[1].endswith('rss') or sys.argv[1].endswith('atom') or @@ -138,3 +138,7 @@ if __name__ == '__main__': elif arg.endswith('rss'): fg.rss_file(arg) + + +if __name__ == '__main__': + main() diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..7f528ef --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +''' +Tests for feedgen main +''' + +import os +import sys +import tempfile +import unittest +from feedgen import __main__ + + +class TestSequenceFunctions(unittest.TestCase): + + def test_usage(self): + sys.argv = ['feedgen'] + try: + __main__.main() + except BaseException as e: + assert e.code is None + + def test_feed(self): + for ftype in 'rss', 'atom', 'podcast', 'torrent', 'dc.rss', 'dc.atom',\ + 'syndication.rss', 'syndication.atom': + sys.argv = ['feedgen', ftype] + try: + __main__.main() + except: + assert False + + def test_file(self): + for extemsion in '.atom', '.rss': + _, filename = tempfile.mkstemp(extemsion) + sys.argv = ['feedgen', filename] + try: + __main__.main() + except: + assert False + os.remove(filename)