Source code for curver.kernel.structures


''' A module of data structures. '''

from collections import defaultdict, namedtuple
from itertools import chain, islice

import curver

[docs]class UnionFind(object): ''' A fast union--find data structure. Given items must be hashable. ''' def __init__(self, items): self.parent = dict((item, item) for item in items) self.rank = dict((item, 0) for item in items) def __iter__(self): ''' Iterate through the groups of self. ''' groups = defaultdict(list) for item in self.parent: groups[self(item)].append(item) return iter(groups.values()) def __len__(self): return sum(1 if self.parent[item] == item else 0 for item in self.parent) def __repr__(self): return str(self) def __str__(self): return ', '.join('{' + ', '.join(str(item) for item in g) + '}' for g in self) def __call__(self, x): ''' Find the root of x. Two items are in the same group iff they have the same root. ''' root = x while self.parent[root] != root: root = self.parent[root] while self.parent[x] != root: x, self.parent[x] = self.parent[x], root return root
[docs] def union2(self, x, y): ''' Combine the class containing x and the class containing y. ''' rx, ry = self(x), self(y) if self.rank[x] > self.rank[y]: self.parent[ry] = rx elif self.rank[x] < self.rank[y]: self.parent[rx] = ry elif rx != ry: self.parent[ry] = rx self.rank[rx] += 1
[docs] def union(self, *args): ''' Combine all of the classes containing the given items. ''' if len(args) == 1: args = args[0] for item in args: self.union2(args[0], item)
Terminal = namedtuple('Terminal', ['value'])
[docs]class StraightLineProgram(object): ''' This represents a straight line program. ''' def __init__(self, data): if isinstance(data, StraightLineProgram): # Copy. data = data.graph elif isinstance(data, (list, tuple)): # Wrap. if not data or any(not isinstance(children, (list, tuple)) for children in data) or any(not isinstance(child, (Terminal, curver.IntegerType)) for children in data for child in children): data = [[Terminal(child) for child in data]] else: raise ValueError('Unknown data.') self.graph = [tuple(children) for children in data] self.sinks = [item.value for lst in self.graph for item in lst if isinstance(item, Terminal)] # !?! # If w is a descendant of v then w appears before v in self.indices. self.indices = [] used = set() def dfs(v): ''' Perform a depth first traversal starting at the given index. ''' for child in self(v): if not isinstance(child, Terminal) and child not in used: dfs(child) used.add(v) self.indices.append(v) dfs(0) self.num_children = [None] * self.size() for index in self.indices: self.num_children[index] = sum(1 if isinstance(item, Terminal) else self.num_children[item] for item in self(index)) def __str__(self): if len(self) <= 7: return str(list(self)) else: return '[%s, %s, %s, ..., %s, %s, %s]' % tuple(chain(islice(self, 3), reversed(list(islice(reversed(self), 3))))) def __repr__(self): strn = [] for index, item in enumerate(self.graph): strn.append('%d --> %s' % (index, item)) return '\n'.join(strn) def __call__(self, item): return self.graph[item] def __len__(self): return self.num_children[0]
[docs] def size(self): ''' Return the number of nodes in this graph. ''' return len(self.graph)
def __getitem__(self, value): if isinstance(value, slice): # TODO: 2) Implement this. return NotImplemented else: # We are returning a single item. if value >= len(self) or value < -len(self): raise IndexError('index out of range') if value < 0: value = len(self) + value index = 0 while True: for image in self(index): if isinstance(image, Terminal): if value == 0: return image.value else: value -= 1 else: if self.num_children[image] > value: index = image break else: value -= self.num_children[image] raise RuntimeError('Should not be able to reach here.') def __iter__(self): todo = [0] while todo: v = todo.pop() if isinstance(v, Terminal): yield v.value else: # isinstance(v, curver.IntegerType): todo.extend(reversed(self(v))) return def __lshift__(self, index): return [[item if isinstance(item, Terminal) else item + index for item in lst] for lst in self.graph] def __rshift__(self, index): return [[item if isinstance(item, Terminal) else item - index for item in lst] for lst in self.graph] def __add__(self, other): return StraightLineProgram([[1, self.size()+1]] + (self << 1) + (other << self.size()+1)) def __radd__(self, other): return StraightLineProgram([[1, other.size()+1]] + (other << 1) + (self << other.size()+1)) def __mul__(self, other): if other == 0: return StraightLineProgram([]) binary = [bool(int(x)) for x in bin(other)[2:]] binary_graph = [[i+2, i+2] for i in range(len(binary)-1)] binary_expansion = [i+1 for i, b in enumerate(binary) if b] return StraightLineProgram([binary_expansion] + binary_graph + (self << len(binary_graph)+1)) def __rmul__(self, other): return self * other def __contains__(self, value): return Terminal(value) in self.sinks
[docs] def reverse(self): ''' Return the StraightLineProgram that returns self[::-1]. ''' return StraightLineProgram([lst[::-1] for lst in self.graph])
def __reversed__(self): return iter(self.reverse())
[docs] def map(self, function=lambda x: x): ''' Return the StraightLineProgram obtained by mapping the values of this one under the given function. ''' return StraightLineProgram([[Terminal(function(child.value)) if isinstance(child, Terminal) else child for child in children] for children in self.graph])