diff --git a/common/lib/capa/capa/safe_exec/safe_exec.py b/common/lib/capa/capa/safe_exec/safe_exec.py index 3ae9567fa9..da74f3aaf5 100644 --- a/common/lib/capa/capa/safe_exec/safe_exec.py +++ b/common/lib/capa/capa/safe_exec/safe_exec.py @@ -1,6 +1,7 @@ """Capa's specialized use of codejail.safe_exec.""" -import codejail.safe_exec +from codejail.safe_exec import safe_exec as codejail_safe_exec +from codejail.safe_exec import json_safe from . import lazymod # Establish the Python environment for Capa. @@ -43,13 +44,32 @@ for name, modname in ASSUMED_IMPORTS: LAZY_IMPORTS = "".join(LAZY_IMPORTS) -def safe_exec(code, globals_dict, random_seed=None, python_path=None): +def safe_exec(code, globals_dict, random_seed=None, python_path=None, cache=None): """Exec python code safely. + `cache` is an object with .get(key) and .set(key, value) methods. + """ + # Check the cache for a previous result. + if cache: + canonical_globals = sorted(json_safe(globals_dict).iteritems()) + key = "safe_exec %r %s %r" % (random_seed, code, canonical_globals) + cached = cache.get(key) + if cached is not None: + globals_dict.update(cached) + return + + # Create the complete code we'll run. code_prolog = CODE_PROLOG % random_seed - codejail.safe_exec.safe_exec( + # Run the code! Results are side effects in globals_dict. + codejail_safe_exec( code_prolog + LAZY_IMPORTS + code, globals_dict, python_path=python_path, ) + + # Put the result back in the cache. This is complicated by the fact that + # the globals dict might not be entirely serializable. + if cache: + cleaned_results = json_safe(globals_dict) + cache.set(key, cleaned_results) diff --git a/common/lib/capa/capa/safe_exec/tests/test_safe_exec.py b/common/lib/capa/capa/safe_exec/tests/test_safe_exec.py index 7ed44a69a1..37f86383c2 100644 --- a/common/lib/capa/capa/safe_exec/tests/test_safe_exec.py +++ b/common/lib/capa/capa/safe_exec/tests/test_safe_exec.py @@ -56,3 +56,37 @@ class TestSafeExec(unittest.TestCase): "import constant; a = constant.THE_CONST", g, python_path=[pylib] ) + + +class DictCache(object): + """A cache implementation over a simple dict, for testing.""" + + def __init__(self, d): + self.cache = d + + def get(self, key): + return self.cache.get(key) + + def set(self, key, value): + self.cache[key] = value + + +class TestSafeExecCaching(unittest.TestCase): + """Test that caching works on safe_exec.""" + + def test_cache_miss_then_hit(self): + g = {} + cache = {} + + # Cache miss + safe_exec("a = int(math.pi)", g, cache=DictCache(cache)) + self.assertEqual(g['a'], 3) + # A result has been cached + self.assertEqual(cache.values(), [{'a': 3}]) + + # Fiddle with the cache, then try it again. + cache[cache.keys()[0]] = {'a': 17} + + g = {} + safe_exec("a = int(math.pi)", g, cache=DictCache(cache)) + self.assertEqual(g['a'], 17)