-
Notifications
You must be signed in to change notification settings - Fork 1
/
encoder.py
121 lines (100 loc) · 4.61 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import math
import logging
from collections import defaultdict
from automaton import Automaton
from corpus import read_corpus, normalize_corpus
from quantizer import LogLinQuantizer
def compute_state_entropy(automaton, which):
counts = defaultdict(int)
for src in automaton.m:
for tgt in automaton.m[src]:
if which == "src":
counts[src] += 1
elif which == "tgt":
counts[tgt] += 1
total = float(sum(counts.itervalues()))
return sum(-c/total * math.log(c/total) for c in counts.itervalues())
class Encoder(object):
"""Class that encodes automata with given coding"""
def __init__(self, entropy, state_bits="u"):
self.entropy = entropy
self.state_bits = state_bits
def automaton_bits(self, automaton):
automaton.round_and_normalize()
automaton_onestate_bits = math.log(len(automaton.m), 2)
if self.state_bits == "u":
source_onestate_bits = automaton_onestate_bits
target_onestate_bits = automaton_onestate_bits
elif self.state_bits == "e":
source_onestate_bits = compute_state_entropy(automaton, "src")
target_onestate_bits = compute_state_entropy(automaton, "tgt")
automaton_emission_bits = 0.0
automaton_trans_bits = 0.0
q = automaton.quantizer
edge_bits = math.log(q.levels, 2)
for state in automaton.m:
logging.debug("State {0}".format(state))
# bits for emission
s_len = (len(automaton.emissions[state])
if not (state.startswith("EPSILON") or state == "^")
else 0)
s_bits = self.entropy * s_len
automaton_emission_bits += s_bits
logging.debug("Emit bits={0}".format(s_bits))
# bits for transition
source_bits = (source_onestate_bits if state != "^" else 0.0)
if len(automaton.m[state].items()) == 1:
target_bits = (target_onestate_bits
if automaton.m[state].items()[0][0]!= "$"
else 0.0)
# if target is $, and only one transition, we won't encode
# the source, because we assume that there are no trimmed
# states
source_bits = (source_bits
if automaton.m[state].items()[0][0]!= "$"
else 0.0)
automaton_trans_bits += (source_bits + target_bits)
logging.debug("Only one transition from here, bits={0}".format(
source_bits + target_bits))
continue
for target, prob in automaton.m[state].iteritems():
# we only need target state and string and probs
target_bits = (target_onestate_bits if target != "$" else 0.0)
if q.representer(prob) != q.representer(q.neg_cutoff):
automaton_trans_bits += (source_bits + edge_bits + target_bits)
logging.debug("transition is encoded in {0} bits({1}-{2}-{3})".format(
source_bits + edge_bits + target_bits, source_bits, edge_bits, target_bits))
else:
# we don't wanna encode those transitions at all
pass
return automaton_emission_bits, automaton_trans_bits
def encode(self, automaton, corpus, reverse=False):
emit_bits, trans_bits = self.automaton_bits(automaton)
automaton_bits = emit_bits + trans_bits
err_bits = automaton.distance_from_corpus(corpus,
Automaton.kullback, reverse)
# computing entropy of generated language
lang = automaton.language()
di = automaton.state_indices["$"]
gen_lang_entropy = sum([-math.exp(l[di]) * l[di]
for l in lang.itervalues() if l[di] < 0.0])
gen_lang_entropy /= math.log(2)
total_cost = (automaton_bits +
len(corpus) * (gen_lang_entropy + err_bits))
return (automaton_bits, emit_bits, trans_bits, err_bits,
gen_lang_entropy, total_cost)
def main():
automaton = Automaton.create_from_dump(open(sys.argv[1]))
corpus = read_corpus(open(sys.argv[2]))
normalize_corpus(corpus)
entropy = float(sys.argv[3])
string_bits = "u"
if len(sys.argv) > 4:
string_bits = sys.argv[4]
q = LogLinQuantizer(10, -20)
automaton.quantizer = q
encoder = Encoder(entropy, string_bits)
print encoder.encode(automaton, corpus)
if __name__ == "__main__":
main()