-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprosess.py
56 lines (43 loc) · 1.71 KB
/
preprosess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding:utf-8 -*-
import numpy as np
import jieba
import codecs
vocab = {}
def get_vocab(input_file):
jieba_load_cache = list(jieba.cut("I love you!"))
global vocab
with codecs.open(input_file, 'r', encoding='utf-8_sig') as rfile:
for line in rfile.readlines():
line = line.strip()
data = line.split('\t')
vocab[data[0]] = int(data[1])
def padding_sentence(inputs, max_length):
result = []
for data in inputs:
sentence = [vocab[word] if word in vocab else vocab["UNK"] for word in data]
if len(sentence) < max_length:
sentence = sentence + [vocab['PAD']]*(max_length-len(sentence))
elif len(sentence) > max_length:
sentence = sentence[:max_length]
result.append(sentence)
return result
def seg_sent(sents):
punc = "!?。,.!。??"#©@$%&'()()*+,-//::-;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
seg_result = []
for sent in sents:
seg = list(jieba.cut(sent))
seg = [token for token in seg if token not in punc]
seg_result.append(seg)
return seg_result
def preprocess(query, docs, max_len=32):
global vocab
if not len(vocab):
get_vocab('./resources/vocab/vocab_seg_with_sw')
query_seg = seg_sent([query])
docs_seg = seg_sent(docs)
left_data = padding_sentence(query_seg, max_len)
left_data = left_data * len(docs)
right_data = padding_sentence(docs_seg, max_len)
x_left_data = np.array(left_data)
x_right_data = np.array(right_data)
return x_left_data, x_right_data