-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
166 lines (126 loc) · 4.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#
# Source: https://github.com/aimacode/aima-python/blob/master/utils.py
#
# Functions on Sequences and Iterables
import collections
import functools
import heapq
import random
def remove_all(item, seq):
"""Return a copy of seq (or string) with all occurrences of item removed."""
if isinstance(seq, str):
return seq.replace(item, '')
elif isinstance(seq, set):
rest = seq.copy()
rest.remove(item)
return rest
else:
return [x for x in seq if x != item]
def unique(seq):
"""Remove duplicate elements from seq. Assumes hashable elements."""
return list(set(seq))
def count(seq):
"""Count the number of items in sequence that are interpreted as true."""
return sum(map(bool, seq))
def multimap(items):
"""Given (key, val) pairs, return {key: [val, ....], ...}."""
result = collections.defaultdict(list)
for (key, val) in items:
result[key].append(val)
return dict(result)
def multimap_items(mmap):
"""Yield all (key, val) pairs stored in the multimap."""
for (key, vals) in mmap.items():
for val in vals:
yield key, val
def product(numbers):
"""Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
result = 1
for x in numbers:
result *= x
return result
def first(iterable, default=None):
"""Return the first element of an iterable; or default."""
return next(iter(iterable), default)
def is_in(elt, seq):
"""Similar to (elt in seq), but compares with 'is', not '=='."""
return any(x is elt for x in seq)
def extend(s, var, val):
"""Copy dict s and extend it by setting var to val; return copy."""
return {**s, var: val}
def flatten(seqs):
return sum(seqs, [])
def shuffled(iterable):
"""Randomly shuffle a copy of iterable."""
items = list(iterable)
random.shuffle(items)
return items
identity = lambda x: x
def argmax_random_tie(seq, key=identity):
"""Return an element with the highest fn(seq[i]) score; break ties at random."""
return max(shuffled(seq), key=key)
class PriorityQueue:
"""A Queue in which the minimum (or maximum) element (as determined by f and
order) is returned first.
If order is 'min', the item with minimum f(x) is
returned first; if order is 'max', then it is the item with maximum f(x).
Also supports dict-like lookup."""
def __init__(self, order='min', f=lambda x: x):
self.heap = []
if order == 'min':
self.f = f
elif order == 'max': # now item with max f(x)
self.f = lambda x: -f(x) # will be popped first
else:
raise ValueError("Order must be either 'min' or 'max'.")
def append(self, item):
"""Insert item at its correct position."""
heapq.heappush(self.heap, (self.f(item), item))
def extend(self, items):
"""Insert each item in items at its correct position."""
for item in items:
self.append(item)
def pop(self):
"""Pop and return the item (with min or max f(x) value)
depending on the order."""
if self.heap:
return heapq.heappop(self.heap)[1]
else:
raise Exception('Trying to pop from empty PriorityQueue.')
def __len__(self):
"""Return current capacity of PriorityQueue."""
return len(self.heap)
def __contains__(self, key):
"""Return True if the key is in PriorityQueue."""
return any([item == key for _, item in self.heap])
def __getitem__(self, key):
"""Returns the first value associated with key in PriorityQueue.
Raises KeyError if key is not present."""
for value, item in self.heap:
if item == key:
return value
raise KeyError(str(key) + " is not in the priority queue")
def __delitem__(self, key):
"""Delete the first occurrence of key."""
try:
del self.heap[[item == key for _, item in self.heap].index(True)]
except ValueError:
raise KeyError(str(key) + " is not in the priority queue")
heapq.heapify(self.heap)
def memoize(fn, slot=None, maxsize=32):
"""Memoize fn: make it remember the computed value for any argument list.
If slot is specified, store result in that slot of first argument.
If slot is false, use lru_cache for caching the values."""
if slot:
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
else:
val = fn(obj, *args)
setattr(obj, slot, val)
return val
else:
@functools.lru_cache(maxsize=maxsize)
def memoized_fn(*args):
return fn(*args)
return memoized_fn