from __future__ import with_statement

import mdp
import sys
import py.test

def teardown_function(function):
    """Deactivate all extensions and remove testing extensions."""
    mdp.deactivate_extensions(mdp.get_active_extensions())
    for key in mdp.get_extensions().copy():
        if key.startswith("__test"):
            del mdp.get_extensions()[key]

def testSimpleExtension():
    """Test for a single new extension."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
        _testtest_attr = 1337
    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
        def _testtest(self):
            return 42
        _testtest_attr = 1338
    sfa_node = mdp.nodes.SFANode()
    mdp.activate_extension("__test")
    assert sfa_node._testtest() == 42
    assert sfa_node._testtest_attr == 1338
    mdp.deactivate_extension("__test")
    assert not hasattr(mdp.nodes.SFANode, "_testtest")

def testContextDecorator():
    """Test the with_extension function decorator."""

    class Test1ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test1"
        def _testtest(self):
            pass

    @mdp.with_extension("__test1")
    def test():
        return mdp.get_active_extensions()

    # check that the extension is activated
    assert mdp.get_active_extensions() == []
    active = test()
    assert active == ["__test1"]
    assert mdp.get_active_extensions() == []

    # check that it is only deactiveted if it was activated there
    mdp.activate_extension("__test1")
    active = test()
    assert active == ["__test1"]
    assert mdp.get_active_extensions() == ["__test1"]

def testContextManager1():
    """Test that the context manager activates extensions."""

    class Test1ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test1"
        def _testtest(self):
            pass
    class Test2ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test2"
        def _testtest(self):
            pass

    assert mdp.get_active_extensions() == []
    with mdp.extension('__test1'):
        assert mdp.get_active_extensions() == ['__test1']
    assert mdp.get_active_extensions() == []
    # with multiple extensions
    with mdp.extension(['__test1', '__test2']):
        active = mdp.get_active_extensions()
        assert '__test1' in active
        assert '__test2' in active
    assert mdp.get_active_extensions() == []
    mdp.activate_extension("__test1")
    # Test that only activated extensions are deactiveted.
    with mdp.extension(['__test1', '__test2']):
        active = mdp.get_active_extensions()
        assert '__test1' in active
        assert '__test2' in active
    assert mdp.get_active_extensions() == ["__test1"]

def testDecoratorExtension():
    """Test extension decorator with a single new extension."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
    def _sfa_testtest(self):
        return 42
    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
    def _testtest(self):
        return 42 + _sfa_testtest(self)
    sfa_node = mdp.nodes.SFANode()
    sfa2_node = mdp.nodes.SFA2Node()
    mdp.activate_extension("__test")
    assert sfa_node._testtest() == 42
    assert sfa2_node._testtest() == 84
    mdp.deactivate_extension("__test")
    assert not hasattr(mdp.nodes.SFANode, "_testtest")
    assert not hasattr(mdp.nodes.SFA2Node, "_testtest")

def testDecoratorInheritance():
    """Test inhertiance with decorators for a single new extension."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
    def _sfa_testtest(self):
        return 42
    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
    def _testtest(self):
        return 42 + super(mdp.nodes.SFA2Node, self)._testtest()
    sfa_node = mdp.nodes.SFANode()
    sfa2_node = mdp.nodes.SFA2Node()
    mdp.activate_extension("__test")
    assert sfa_node._testtest() == 42
    assert sfa2_node._testtest() == 84

def testExtensionInheritance():
    """Test inheritance of extension nodes."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
        def _testtest(self):
            return 42
        _testtest_attr = 1337
    class TestSFA2Node(TestSFANode, mdp.nodes.SFA2Node):
        def _testtest(self):
            if sys.version_info[0] < 3:
                return TestSFANode._testtest.im_func(self)
            else:
                return TestSFANode._testtest(self)
    sfa2_node = mdp.nodes.SFA2Node()
    mdp.activate_extension("__test")
    assert sfa2_node._testtest() == 42
    assert sfa2_node._testtest_attr == 1337

def testExtensionInheritance2():
    """Test inheritance of extension nodes, using super."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
        def _testtest(self):
            return 42
    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
        def _testtest(self):
            return super(mdp.nodes.SFA2Node, self)._testtest()
    sfa2_node = mdp.nodes.SFA2Node()
    mdp.activate_extension("__test")
    assert sfa2_node._testtest() == 42

def testExtensionInheritance3():
    """Test explicit use of extension nodes and inheritance."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _testtest(self):
            pass
    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
        def _testtest(self):
            return 42
    # Note the inheritance order, otherwise this would not work.
    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
        def _testtest(self):
            return super(mdp.nodes.SFA2Node, self)._testtest()
    sfa2_node = TestSFA2Node()
    assert sfa2_node._testtest() == 42

def testMultipleExtensions():
    """Test behavior of multiple extensions."""
    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
        extension_name = "__test1"
        def _testtest1(self):
            pass
    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
        extension_name = "__test2"
        def _testtest2(self):
            pass
    mdp.activate_extension("__test1")
    node = mdp.Node()
    node._testtest1()
    mdp.activate_extension("__test2")
    node._testtest2()
    mdp.deactivate_extension("__test1")
    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
    mdp.activate_extension("__test1")
    node._testtest1()
    mdp.deactivate_extensions(["__test1", "__test2"])
    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
    assert not hasattr(mdp.nodes.SFANode, "_testtest2")

def testExtCollision():
    """Test the check for method name collision."""
    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
        extension_name = "__test1"
        def _testtest(self):
            pass
    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
        extension_name = "__test2"
        def _testtest(self):
            pass
    py.test.raises(mdp.ExtensionException,
                   mdp.activate_extensions, ["__test1", "__test2"])
    # none of the extension should be active after the exception
    assert not hasattr(mdp.Node, "_testtest")

def testExtensionInheritanceInjection():
    """Test the injection of inherited methods"""
    class TestNode(object):
        def _test1(self):
            return 0
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _test1(self):
            return 1
        def _test2(self):
            return 2
        def _test3(self):
            return 3
    class TestNodeExt(TestExtensionNode, TestNode):
        def _test2(self):
            return "2b"
    @mdp.extension_method("__test", TestNode)
    def _test4(self):
        return 4
    test_node = TestNode()
    mdp.activate_extension("__test")
    assert test_node._test1() == 1
    assert test_node._test2() == "2b"
    assert test_node._test3() == 3
    assert test_node._test4() == 4
    mdp.deactivate_extension("__test")
    assert test_node._test1() == 0
    assert not hasattr(test_node, "_test2")
    assert not hasattr(test_node, "_test3")
    assert not hasattr(test_node, "_test4")

def testExtensionInheritanceInjectionNonExtension():
    """Test non_extension method injection."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _execute(self):
            return 0
    class TestNode(mdp.Node):
        # no _execute method
        pass
    class ExtendedTestNode(TestExtensionNode, TestNode):
        pass
    test_node = TestNode()
    mdp.activate_extension('__test')
    assert hasattr(test_node, "_non_extension__execute")
    mdp.deactivate_extension('__test')
    assert not hasattr(test_node, "_non_extension__execute")
    assert not hasattr(test_node, "_extension_for__execute")
    # test that the non-native _execute has been completely removed
    assert "_execute" not in test_node.__class__.__dict__

def testExtensionInheritanceInjectionNonExtension2():
    """Test non_extension method injection."""
    class TestExtensionNode(mdp.ExtensionNode):
        extension_name = "__test"
        def _execute(self):
            return 0
    class TestNode(mdp.Node):
        def _execute(self):
            return 1
    class ExtendedTestNode(TestExtensionNode, TestNode):
        pass
    test_node = TestNode()
    mdp.activate_extension('__test')
    # test that non-extended attribute has been added as well
    assert hasattr(test_node, "_non_extension__execute")
    mdp.deactivate_extension('__test')
    assert not hasattr(test_node, "_non_extension__execute")
    assert not hasattr(test_node, "_extension_for__execute")
    # test that the native _execute has been preserved
    assert "_execute" in test_node.__class__.__dict__

def testExtensionInheritanceTwoExtensions():
    """Test non_extension injection for multiple extensions."""
    class Test1ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test1"
        def _execute(self):
            return 1
    class Test2ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test2"
    class Test3ExtensionNode(mdp.ExtensionNode):
        extension_name = "__test3"
        def _execute(self):
            return "3a"
    class TestNode1(mdp.Node):
        pass
    class TestNode2(TestNode1):
        pass
    class ExtendedTest1Node2(Test1ExtensionNode, TestNode2):
        pass
    class ExtendedTest2Node1(Test2ExtensionNode, TestNode1):
        def _execute(self):
            return 2
    class ExtendedTest3Node1(Test3ExtensionNode, TestNode1):
        def _execute(self):
            return "3b"
    test_node = TestNode2()
    mdp.activate_extension('__test2')
    assert test_node._execute() == 2
    mdp.deactivate_extension('__test2')
    # in this order TestNode2 should get execute from __test1,
    # the later addition by __test1 to TestNode1 doesn't matter
    mdp.activate_extensions(['__test1', '__test2'])
    assert test_node._execute() == 1
    mdp.deactivate_extensions(['__test2', '__test1'])
    # now activate in inverse order
    # TestNode2 already gets _execute from __test2, but that is still
    # overriden by __test1, thats how its registered in _extensions
    mdp.activate_extensions(['__test2', '__test1'])
    assert test_node._execute() == 1
    mdp.deactivate_extensions(['__test2', '__test1'])
    ## now the same with extension 3
    mdp.activate_extension('__test3')
    assert test_node._execute() == "3b"
    mdp.deactivate_extension('__test3')
    # __test3 does not override, since the _execute slot for Node2
    # was first filled by __test1
    mdp.activate_extensions(['__test3', '__test1'])
    assert test_node._execute() == 1
    mdp.deactivate_extensions(['__test3', '__test1'])
    # inverse order
    mdp.activate_extensions(['__test1', '__test3'])
    assert test_node._execute() == 1
    mdp.deactivate_extensions(['__test2', '__test1'])