diff --git a/git_common.py b/git_common.py index 5663a1c17..3d90d9ffd 100644 --- a/git_common.py +++ b/git_common.py @@ -41,7 +41,10 @@ import tempfile import textwrap import time import typing +from typing import Any from typing import AnyStr +from typing import Callable +from typing import ContextManager from typing import Optional from typing import Tuple @@ -154,6 +157,40 @@ class BadCommitRefException(Exception): super(BadCommitRefException, self).__init__(msg) +class _MemoizeWrapper(object): + + def __init__(self, f: Callable[[Any], Any], *, threadsafe: bool): + self._f: Callable[[Any], Any] = f + self._cache: dict[Any, Any] = {} + self._lock: ContextManager = contextlib.nullcontext() + if threadsafe: + self._lock = threading.Lock() + + def __call__(self, arg: Any) -> Any: + ret = self.get(arg) + if ret is None: + ret = self._f(arg) + if ret is not None: + self.set(arg, ret) + return ret + + def get(self, key: Any, default: Any = None) -> Any: + with self._lock: + return self._cache.get(key, default) + + def set(self, key: Any, value: Any) -> None: + with self._lock: + self._cache[key] = value + + def clear(self) -> None: + with self._lock: + self._cache.clear() + + def update(self, other: dict[Any, Any]) -> None: + with self._lock: + self._cache.update(other) + + def memoize_one(*, threadsafe: bool): """Memoizes a single-argument pure function. @@ -172,19 +209,6 @@ def memoize_one(*, threadsafe: bool): unittests. * update(other) - Updates the contents of the cache from another dict. """ - if threadsafe: - - def withlock(lock, f): - def inner(*args, **kwargs): - with lock: - return f(*args, **kwargs) - - return inner - else: - - def withlock(_lock, f): - return f - def decorator(f): # Instantiate the lock in decorator, in case users of memoize_one do: # @@ -195,26 +219,8 @@ def memoize_one(*, threadsafe: bool): # # @memoizer # def fn2(val): ... - - lock = threading.Lock() if threadsafe else None - cache = {} - _get = withlock(lock, cache.get) - _set = withlock(lock, cache.__setitem__) - - @functools.wraps(f) - def inner(arg): - ret = _get(arg) - if ret is None: - ret = f(arg) - if ret is not None: - _set(arg, ret) - return ret - - inner.get = _get - inner.set = _set - inner.clear = withlock(lock, cache.clear) - inner.update = withlock(lock, cache.update) - return inner + wrapped = _MemoizeWrapper(f, threadsafe=threadsafe) + return functools.wraps(f)(wrapped) return decorator