aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Harder <radhermit@gmail.com>2017-12-21 05:14:09 -0500
committerTim Harder <radhermit@gmail.com>2018-04-11 14:39:12 -0400
commit3db0d6447d8bd7e25bfdd97a31414ea44ef610bf (patch)
treeef7d140ad188d4a5307502dcbb4d50f077805ede /tests/test_klass.py
parentremove unsupported py2 and deprecated py3 code and fallbacks (diff)
downloadsnakeoil-3db0d6447d8bd7e25bfdd97a31414ea44ef610bf.tar.gz
snakeoil-3db0d6447d8bd7e25bfdd97a31414ea44ef610bf.tar.bz2
snakeoil-3db0d6447d8bd7e25bfdd97a31414ea44ef610bf.zip
move tests to root level dir
Diffstat (limited to 'tests/test_klass.py')
-rw-r--r--tests/test_klass.py608
1 files changed, 608 insertions, 0 deletions
diff --git a/tests/test_klass.py b/tests/test_klass.py
new file mode 100644
index 00000000..56e0d49e
--- /dev/null
+++ b/tests/test_klass.py
@@ -0,0 +1,608 @@
+# Copyright: 2006-2007 Brian Harring <ferringb@gmail.com>
+# License: BSD/GPL2
+
+from functools import partial
+import math
+import re
+from time import time
+
+from snakeoil import klass
+from snakeoil.test import TestCase, mk_cpy_loadable_testcase, not_a_test
+
+
+class Test_native_GetAttrProxy(TestCase):
+ kls = staticmethod(klass.native_GetAttrProxy)
+
+ def test_it(self):
+ class foo1(object):
+ def __init__(self, obj):
+ self.obj = obj
+ __getattr__ = self.kls('obj')
+
+ class foo2(object):
+ pass
+
+ o2 = foo2()
+ o = foo1(o2)
+ self.assertRaises(AttributeError, getattr, o, "blah")
+ self.assertEqual(o.obj, o2)
+ o2.foon = "dar"
+ self.assertEqual(o.foon, "dar")
+ o.foon = "foo"
+ self.assertEqual(o.foon, 'foo')
+
+ def test_attrlist(self):
+ def make_class(attr_list=None):
+ class foo(object, metaclass=self.kls):
+ if attr_list is not None:
+ locals()['__attr_comparison__'] = attr_list
+
+ self.assertRaises(TypeError, make_class)
+ self.assertRaises(TypeError, make_class, ['foon'])
+ self.assertRaises(TypeError, make_class, [None])
+
+ def test_instancemethod(self):
+ class foo(object):
+ bar = "baz"
+
+ class Test(object):
+ method = self.kls('test')
+ test = foo()
+
+ test = Test()
+ self.assertEqual(test.method('bar'), foo.bar)
+
+
+class Test_CPY_GetAttrProxy(Test_native_GetAttrProxy):
+
+ kls = staticmethod(klass.GetAttrProxy)
+ if klass.GetAttrProxy is klass.native_GetAttrProxy:
+ skip = "cpython extension isn't available"
+
+ def test_sane_recursion_bail(self):
+ # people are stupid; if protection isn't in place, we wind up blowing
+ # the c stack, which doesn't result in a friendly Exception being
+ # thrown.
+ # results in a segfault.. so if it's horked, this will bail the test
+ # runner.
+
+ class c(object):
+ __getattr__ = self.kls("obj")
+
+ o = c()
+ o.obj = o
+ # now it's cyclical.
+ self.assertRaises((AttributeError, RuntimeError), getattr, o, "hooey")
+
+
+class TestDirProxy(TestCase):
+
+ @staticmethod
+ def noninternal_attrs(obj):
+ return sorted(x for x in dir(obj) if not re.match(r'__\w+__', x))
+
+ def test_combined(self):
+ class foo1(object):
+ def __init__(self, obj):
+ self.obj = obj
+ __dir__ = klass.DirProxy('obj')
+
+ class foo2(object):
+ def __init__(self):
+ self.attr = 'foo'
+
+ o2 = foo2()
+ o = foo1(o2)
+ self.assertEqual(self.noninternal_attrs(o), ['attr', 'obj'])
+
+ def test_empty(self):
+ class foo1(object):
+ def __init__(self, obj):
+ self.obj = obj
+ __dir__ = klass.DirProxy('obj')
+
+ class foo2(object):
+ pass
+
+ o2 = foo2()
+ o = foo1(o2)
+ self.assertEqual(self.noninternal_attrs(o2), [])
+ self.assertEqual(self.noninternal_attrs(o), ['obj'])
+
+
+class Test_native_contains(TestCase):
+ func = staticmethod(klass.native_contains)
+
+ def test_it(self):
+ class c(dict):
+ __contains__ = self.func
+ d = c({"1": 2})
+ self.assertIn("1", d)
+ self.assertNotIn(1, d)
+
+
+class Test_CPY_contains(Test_native_contains):
+ func = staticmethod(klass.contains)
+
+ if klass.contains is klass.native_contains:
+ skip = "cpython extension isn't available"
+
+
+class Test_native_get(TestCase):
+ func = staticmethod(klass.native_get)
+
+ def test_it(self):
+ class c(dict):
+ get = self.func
+ d = c({"1": 2})
+ self.assertEqual(d.get("1"), 2)
+ self.assertEqual(d.get("1", 3), 2)
+ self.assertEqual(d.get(1), None)
+ self.assertEqual(d.get(1, 3), 3)
+
+class Test_CPY_get(Test_native_get):
+ func = staticmethod(klass.get)
+
+ if klass.get is klass.native_get:
+ skip = "cpython extension isn't available"
+
+
+class Test_chained_getter(TestCase):
+
+ kls = klass.chained_getter
+
+ def test_hash(self):
+ self.assertEqual(hash(self.kls("foon")), hash("foon"))
+ self.assertEqual(hash(self.kls("foon.dar")), hash("foon.dar"))
+
+ def test_caching(self):
+ l = [id(self.kls("fa2341f%s" % x)) for x in "abcdefghij"]
+ self.assertEqual(id(self.kls("fa2341fa")), l[0])
+
+ def test_eq(self):
+ self.assertEqual(self.kls("asdf", disable_inst_caching=True),
+ self.kls("asdf", disable_inst_caching=True))
+
+ self.assertNotEqual(self.kls("asdf2", disable_inst_caching=True),
+ self.kls("asdf", disable_inst_caching=True))
+
+ def test_it(self):
+ class maze(object):
+ def __init__(self, kwargs):
+ self.__data__ = kwargs
+
+ def __getattr__(self, attr):
+ return self.__data__.get(attr, self)
+
+ d = {}
+ m = maze(d)
+ f = self.kls
+ self.assertEqual(f('foon')(m), m)
+ d["foon"] = 1
+ self.assertEqual(f('foon')(m), 1)
+ self.assertEqual(f('dar.foon')(m), 1)
+ self.assertEqual(f('.'.join(['blah']*10))(m), m)
+ self.assertRaises(AttributeError, f('foon.dar'), m)
+
+
+class Test_native_jit_attr(TestCase):
+
+ kls = staticmethod(klass._native_internal_jit_attr)
+
+ @property
+ def jit_attr(self):
+ return partial(klass.jit_attr, kls=self.kls)
+
+ @property
+ def jit_attr_named(self):
+ return partial(klass.jit_attr_named, kls=self.kls)
+
+ @property
+ def jit_attr_ext_method(self):
+ return partial(klass.jit_attr_ext_method, kls=self.kls)
+
+ def mk_inst(self, attrname='_attr', method_lookup=False,
+ use_cls_setattr=False, func=None,
+ singleton=klass._uncached_singleton):
+
+ f = func
+ if not func:
+ def f(self):
+ self._invokes.append(self)
+ return 54321
+
+ class cls(object):
+
+ def __init__(self):
+ sf = partial(object.__setattr__, self)
+ sf('_sets', [])
+ sf('_reflects', [])
+ sf('_invokes', [])
+
+ attr = self.kls(f, attrname, singleton, use_cls_setattr)
+
+ def __setattr__(self, attr, value):
+ self._sets.append(self)
+ object.__setattr__(self, attr, value)
+
+ def reflect(self):
+ self._reflects.append(self)
+ return 12345
+
+ return cls()
+
+ def assertState(self, instance, sets=0, reflects=0, invokes=0, value=54321):
+ self.assertEqual(instance.attr, value)
+ sets = [instance] * sets
+ reflects = [instance] * reflects
+ invokes = [instance] * invokes
+ msg = ("checking %s: got(%r), expected(%r); state was sets=%r, "
+ "reflects=%r, invokes=%r" % (
+ "%s", "%s", "%s", instance._sets, instance._reflects,
+ instance._invokes))
+ self.assertEqual(instance._sets, sets,
+ msg=(msg % ("sets", instance._sets, sets,)))
+ self.assertEqual(instance._reflects, reflects,
+ msg=(msg % ("reflects", instance._reflects,
+ reflects,)))
+ self.assertEqual(instance._invokes, invokes,
+ msg=(msg % ("invokes", instance._invokes, invokes,)))
+
+ def test_implementation(self):
+ obj = self.mk_inst()
+
+ # default state is use_cls_setattr = False
+ self.assertState(obj, invokes=1)
+ self.assertState(obj, invokes=1)
+ del obj._attr
+ self.assertState(obj, invokes=2)
+ self.assertState(obj, invokes=2)
+
+ # basic caching is now verified.
+ obj = self.mk_inst(use_cls_setattr=True)
+ self.assertState(obj, sets=1, invokes=1)
+ self.assertState(obj, sets=1, invokes=1)
+ del obj._attr
+ self.assertState(obj, sets=2, invokes=2)
+ self.assertState(obj, sets=2, invokes=2)
+
+ def test_jit_attr(self):
+ now = time()
+
+ class cls(object):
+ @self.jit_attr
+ def my_attr(self):
+ return now
+
+ o = cls()
+ self.assertEqual(o.my_attr, now)
+ self.assertEqual(o._my_attr, now)
+
+ class cls(object):
+ @self.jit_attr
+ def attr2(self):
+ return now
+
+ def __setattr__(self, attr, value):
+ raise AssertionError("setattr was invoked")
+
+ o = cls()
+ self.assertEqual(o.attr2, now)
+ self.assertEqual(o._attr2, now)
+ del o._attr2
+ self.assertEqual(o.attr2, now)
+ self.assertEqual(o._attr2, now)
+
+ def test_jit_attr_named(self):
+ now = time()
+
+ # check attrname control and default object.__setattr__ avoidance
+ class cls(object):
+ @self.jit_attr_named("_blah")
+ def my_attr(self):
+ return now
+
+ def __setattr__(self, attr, value):
+ raise AssertionError("setattr was invoked")
+
+ o = cls()
+ self.assertEqual(o.my_attr, now)
+ self.assertEqual(o._blah, now)
+
+ class cls(object):
+ @self.jit_attr_named("_blah2", use_cls_setattr=True)
+ def my_attr(self):
+ return now
+
+ def __setattr__(self, attr, value):
+ object.__setattr__(self, "invoked", True)
+ object.__setattr__(self, attr, value)
+
+ o = cls()
+ self.assertFalse(hasattr(o, 'invoked'))
+ self.assertEqual(o.my_attr, now)
+ self.assertEqual(o._blah2, now)
+ self.assertTrue(o.invoked)
+
+ def test_jit_attr_ext_method(self):
+ now = time()
+ now2 = now + 100
+
+ class base(object):
+ def f1(self):
+ return now
+
+ def f2(self):
+ return now2
+
+ def __setattr__(self, attr, value):
+ if not getattr(self, '_setattr_allowed', False):
+ raise TypeError("setattr isn't allowed for %s" % attr)
+ object.__setattr__(self, attr, value)
+
+ base.attr = self.jit_attr_ext_method('f1', '_attr')
+ o = base()
+ self.assertEqual(o.attr, now)
+ self.assertEqual(o._attr, now)
+ self.assertEqual(o.attr, now)
+
+ base.attr = self.jit_attr_ext_method('f1', '_attr', use_cls_setattr=True)
+ o = base()
+ self.assertRaises(TypeError, getattr, o, 'attr')
+ base._setattr_allowed = True
+ self.assertEqual(o.attr, now)
+
+ base.attr = self.jit_attr_ext_method('f2', '_attr2')
+ o = base()
+ self.assertEqual(o.attr, now2)
+ self.assertEqual(o._attr2, now2)
+
+ # finally, check that it's doing lookups rather then storing the func.
+ base.attr = self.jit_attr_ext_method('func', '_attr2')
+ o = base()
+ # no func...
+ self.assertRaises(AttributeError, getattr, o, 'attr')
+ base.func = base.f1
+ self.assertEqual(o.attr, now)
+ self.assertEqual(o._attr2, now)
+ # check caching...
+ base.func = base.f2
+ self.assertEqual(o.attr, now)
+ del o._attr2
+ self.assertEqual(o.attr, now2)
+
+ def test_check_singleton_is_compare(self):
+ def throw_assert(*args, **kwds):
+ raise AssertionError("I shouldn't be invoked: %s, %s" % (args, kwds,))
+
+ class puker(object):
+ __eq__ = throw_assert
+
+ puker_singleton = puker()
+
+ obj = self.mk_inst(singleton=puker_singleton)
+ obj._attr = puker_singleton
+ # force attr access. if it's done wrong, it'll puke.
+ # pylint: disable=pointless-statement
+ obj.attr
+
+ def test_cached_property(self):
+ l = []
+ class foo(object):
+ @klass.cached_property
+ def blah(self, l=l, i=iter(range(5))):
+ l.append(None)
+ return next(i)
+ f = foo()
+ self.assertEqual(f.blah, 0)
+ self.assertEqual(len(l), 1)
+ self.assertEqual(f.blah, 0)
+ self.assertEqual(len(l), 1)
+ del f.blah
+ self.assertEqual(f.blah, 1)
+ self.assertEqual(len(l), 2)
+
+ def test_cached_property(self):
+ l = []
+
+ def named(self, l=l, i=iter(range(5))):
+ l.append(None)
+ return next(i)
+
+ class foo(object):
+ blah = klass.cached_property_named("blah")(named)
+
+ f = foo()
+ self.assertEqual(f.blah, 0)
+ self.assertEqual(len(l), 1)
+ self.assertEqual(f.blah, 0)
+ self.assertEqual(len(l), 1)
+ del f.blah
+ self.assertEqual(f.blah, 1)
+ self.assertEqual(len(l), 2)
+
+
+class Test_cpy_jit_attr(Test_native_jit_attr):
+
+ kls = staticmethod(klass._internal_jit_attr)
+ if klass._internal_jit_attr is klass._native_internal_jit_attr:
+ skip = "extension is missing"
+
+
+class test_aliased_attr(TestCase):
+
+ func = staticmethod(klass.alias_attr)
+
+ def test_it(self):
+ class cls(object):
+ attr = self.func("dar.blah")
+
+ o = cls()
+ self.assertRaises(AttributeError, getattr, o, 'attr')
+ o.dar = "foon"
+
+ self.assertRaises(AttributeError, getattr, o, 'attr')
+ o.dar = o
+ o.blah = "monkey"
+
+ self.assertEqual(o.attr, 'monkey')
+
+ # verify it'll cross properties...
+ class blah(object):
+ target = object()
+
+ class cls(object):
+ @property
+ def foon(self):
+ return blah()
+
+ alias = self.func("foon.target")
+ o = cls()
+ self.assertIdentical(o.alias, blah.target)
+
+
+class test_cached_hash(TestCase):
+ func = staticmethod(klass.cached_hash)
+
+ def test_it(self):
+ now = int(time())
+ class cls(object):
+ invoked = []
+ @self.func
+ def __hash__(self):
+ self.invoked.append(self)
+ return now
+ o = cls()
+ self.assertEqual(hash(o), now)
+ self.assertEqual(o.invoked, [o])
+ # ensure it cached...
+ self.assertEqual(hash(o), now)
+ self.assertEqual(o.invoked, [o])
+ self.assertEqual(o._hash, now)
+
+
+class test_native_reflective_hash(TestCase):
+ func = staticmethod(klass.native_reflective_hash)
+
+ def test_it(self):
+ class cls(object):
+ __hash__ = self.func('_hash')
+
+ obj = cls()
+ self.assertRaises(AttributeError, hash, obj)
+ obj._hash = 1
+ self.assertEqual(hash(obj), 1)
+ obj._hash = 123123123
+ self.assertEqual(hash(obj), 123123123)
+ # verify it's not caching in any form
+ del obj._hash
+ self.assertRaises(AttributeError, hash, obj)
+
+ class cls2(object):
+ __hash__ = self.func('_dar')
+ obj = cls2()
+ self.assertRaises(AttributeError, hash, obj)
+ obj._dar = 4
+ self.assertEqual(hash(obj), 4)
+
+
+class test_cpy_reflective_hash(test_native_reflective_hash):
+
+ kls = staticmethod(klass.reflective_hash)
+ if klass.reflective_hash is klass.native_reflective_hash:
+ skip = "cpython extension isn't available"
+
+
+cpy_loaded_Test = mk_cpy_loadable_testcase(
+ "snakeoil._klass", "snakeoil.klass", "reflective_hash", "reflective_hash")
+
+
+class TestImmutableInstance(TestCase):
+
+ def test_metaclass(self):
+ def f(scope):
+ scope["__metaclass__"] = klass.immutable_instance
+
+ self.common_test(f)
+
+ def test_injection(self):
+
+ def f(scope):
+ klass.inject_immutable_instance(scope)
+
+ self.common_test(f)
+
+ @not_a_test
+ def common_test(self, modify_kls):
+ class kls(object):
+ modify_kls(locals())
+
+ o = kls()
+ self.assertRaises(AttributeError, setattr, o, "dar", "foon")
+ self.assertRaises(AttributeError, delattr, o, "dar")
+
+ object.__setattr__(o, 'dar', 'foon')
+ self.assertRaises(AttributeError, delattr, o, "dar")
+
+ # ensure it only sets it if nothing is in place already.
+
+ class kls(object):
+ def __setattr__(self, attr, value):
+ raise TypeError(self)
+
+ modify_kls(locals())
+
+ o = kls()
+ self.assertRaises(TypeError, setattr, o, "dar", "foon")
+ self.assertRaises(AttributeError, delattr, o, "dar")
+
+
+class TestAliasMethod(TestCase):
+
+ func = staticmethod(klass.alias_method)
+
+ def test_alias_method(self):
+ class kls(object):
+ __len__ = lambda s: 3
+ lfunc = self.func("__len__")
+
+ c = kls()
+ self.assertEqual(c.__len__(), c.lfunc())
+ c.__len__ = lambda: 4
+ self.assertEqual(c.__len__(), c.lfunc())
+
+
+class TestPatch(TestCase):
+
+ def setUp(self):
+ # cache original methods
+ self._math_ceil = math.ceil
+ self._math_floor = math.floor
+
+ def tearDown(self):
+ # restore original methods
+ math.ceil = self._math_ceil
+ math.floor = self._math_floor
+
+ def test_patch(self):
+ n = 0.1
+ self.assertEqual(math.ceil(n), 1)
+
+ @klass.patch('math.ceil')
+ def ceil(orig_ceil, n):
+ return math.floor(n)
+
+ self.assertEqual(math.ceil(n), 0)
+
+ def test_multiple_patches(self):
+ n = 1.1
+ self.assertEqual(math.ceil(n), 2)
+ self.assertEqual(math.floor(n), 1)
+
+ @klass.patch('math.ceil')
+ @klass.patch('math.floor')
+ def zero(orig_func, n):
+ return 0
+
+ self.assertEqual(math.ceil(n), 0)
+ self.assertEqual(math.floor(n), 0)