forked from mozilla/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_tflite.py
108 lines (86 loc) · 3.87 KB
/
evaluate_tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import argparse
import numpy as np
import wave
import csv
import sys
import os
from six.moves import zip, range
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
from deepspeech import Model
from util.text import levenshtein
from util.evaluate_tools import process_decode_result, calculate_report
r'''
This module should be self-contained:
- build libdeepspeech.so with TFLite:
- bazel build [...] --define=runtime=tflite [...] //native_client:libdeepspeech.so
- make -C native_client/python/ TFDIR=... bindings
- setup a virtualenv
- pip install native_client/python/dist/deepspeech*.whl
- pip install -r requirements_eval_tflite.txt
Then run with a TF Lite model, alphabet, LM/trie and a CSV test file
'''
BEAM_WIDTH = 500
LM_ALPHA = 0.75
LM_BETA = 1.85
N_FEATURES = 26
N_CONTEXT = 9
def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out, gpu_mask):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
ds = Model(model, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
while True:
msg = queue_in.get()
fin = wave.open(msg['filename'], 'rb')
fs = fin.getframerate()
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/16000)
fin.close()
decoded = ds.stt(audio, fs)
queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
queue_in.task_done()
def main():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--alphabet', required=True,
help='Path to the configuration file specifying the alphabet used by the network')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
help='Number of processes to spawn, defaulting to number of CPUs')
args = parser.parse_args()
work_todo = JoinableQueue() # this is where we are going to store input data
work_done = Queue() # this where we are gonna push them out
processes = []
for i in range(args.proc):
worker_process = Process(target=tflite_worker, args=(args.model, args.alphabet, args.lm, args.trie, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i))
worker_process.start() # Launch reader() as a separate python process
processes.append(worker_process)
print([x.name for x in processes])
ground_truths = []
predictions = []
losses = []
with open(args.csv, 'r') as csvfile:
csvreader = csv.DictReader(csvfile)
for row in csvreader:
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
work_todo.join()
while (not work_done.empty()):
msg = work_done.get()
losses.append(0.0)
ground_truths.append(msg['ground_truth'])
predictions.append(msg['prediction'])
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
mean_loss = np.mean(losses)
print('Test - WER: %f, CER: %f, loss: %f' %
(wer, cer, mean_loss))
if __name__ == '__main__':
main()