-
Notifications
You must be signed in to change notification settings - Fork 209
/
bert_embedding.py
70 lines (62 loc) · 2.74 KB
/
bert_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding:utf-8 -*-
'''
-------------------------------------------------
Description : bert_embedding实现
Author : machinelp
Date : 2020-06-03
-------------------------------------------------
'''
import numpy as np
import tensorflow as tf
from textmatch.config.constant import Constant as const
from textmatch.models.model_base.model_base import ModelBase
from bert4keras.backend import keras, set_gelu
from bert4keras.bert import build_bert_model
from bert4keras.tokenizer import Tokenizer
from bert4keras.snippets import sequence_padding
set_gelu('tanh')
class BertEmbedding(ModelBase):
'''
'''
def __init__(self,
config_path=const.BERT_CONFIG_PATH,
checkpoint_path = const.BERT_CHECKPOINT_PATH,
dict_path = const.BERT_DICT_PATH,
train_mode=False ) :
self.session = tf.Session()
keras.backend.set_session(self.session)
self.bert = build_bert_model(
config_path,
checkpoint_path,
with_pool='linear',
# application='seq2seq',
return_keras_model=False, )
self.encoder = keras.models.Model(self.bert.model.inputs, self.bert.model.outputs[0])
self.tokenizer = Tokenizer(dict_path, do_lower_case=True)
def init(self, words_list=None, update=True):
if words_list!=None:
token_ids_list, segment_ids_list = [], []
for words in words_list:
token_ids, segment_ids = self.tokenizer.encode(words)
token_ids_list.append(token_ids)
segment_ids_list.append(segment_ids)
token_ids_list = sequence_padding(token_ids_list)
segment_ids_list = sequence_padding(segment_ids_list)
self.words_list_pre = self.encoder.predict([token_ids_list, segment_ids_list])
self.words_list_pre = self._normalize(self.words_list_pre)
return self
def _predict(self, words):
with self.session.as_default():
with self.session.graph.as_default():
token_ids, segment_ids = self.tokenizer.encode( words )
pre = self.encoder.predict([np.array([token_ids]), np.array([segment_ids])])
pre = self._normalize(pre)
return pre
# 句向量
def predict(self, words):
with self.session.as_default():
with self.session.graph.as_default():
token_ids, segment_ids = self.tokenizer.encode( words )
pre = self.encoder.predict([np.array([token_ids]), np.array([segment_ids])])
pre = self._normalize(pre)
return np.dot( self.words_list_pre[:], pre[0] )