generic_utils: add OrderedDefaultDict
diff --git a/generic_utils.py b/generic_utils.py
index eeb26d6..0bf289e 100644
--- a/generic_utils.py
+++ b/generic_utils.py
@@ -31,6 +31,7 @@
Intended to be imported into another namespace
"""
+import collections
import functools
import sys
import of_g
@@ -99,8 +100,6 @@
#
################################################################
-import collections
-
class OrderedSet(collections.MutableSet):
"""
A set implementations that retains insertion order. From the receipe
@@ -164,6 +163,59 @@
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
+################################################################
+#
+# OrderedDefaultDict
+#
+################################################################
+
+class OrderedDefaultDict(collections.OrderedDict):
+ """
+ A Dictionary that maintains insertion order where missing values
+ are provided by a factory function, i.e., a combination of
+ the semantics of collections.defaultdict and collections.OrderedDict.
+ """
+ def __init__(self, default_factory=None, *a, **kw):
+ if (default_factory is not None and
+ not callable(default_factory)):
+ raise TypeError('first argument must be callable')
+ collections.OrderedDict.__init__(self, *a, **kw)
+ self.default_factory = default_factory
+
+ def __getitem__(self, key):
+ try:
+ return collections.OrderedDict.__getitem__(self, key)
+ except KeyError:
+ return self.__missing__(key)
+
+ def __missing__(self, key):
+ if self.default_factory is None:
+ raise KeyError(key)
+ self[key] = value = self.default_factory()
+ return value
+
+ def __reduce__(self):
+ if self.default_factory is None:
+ args = tuple()
+ else:
+ args = self.default_factory,
+ return type(self), args, None, None, self.items()
+
+ def copy(self):
+ return self.__copy__()
+
+ def __copy__(self):
+ return type(self)(self.default_factory, self)
+
+ def __deepcopy__(self, memo):
+ import copy
+ return type(self)(self.default_factory,
+ copy.deepcopy(self.items()))
+ def __repr__(self):
+ return 'OrderedDefaultDict(%s, %s)' % (self.default_factory,
+ collections.OrderedDict.__repr__(self))
+
+
def find(iterable, func):
"""
find the first item in iterable for which func returns something true'ish.