diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index ad1f101..f998774 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -41,6 +41,8 @@ def main(): configs = yaml.load(fin, Loader=yaml.FullLoader) feature_dim = configs['model']['input_dim'] model = init_model(configs['model']) + is_fsmn = configs['model']['backbone']['type'] == 'fsmn' + num_layers = configs['model']['backbone']['num_layers'] if configs['training_config'].get('criterion', 'max_pooling') == 'ctc': # if we use ctc_loss, the logits need to be convert into probs model.forward = model.forward_softmax @@ -54,6 +56,8 @@ def main(): model.hdim, model.backbone.padding, dtype=torch.float) + if is_fsmn: + cache = cache.unsqueeze(-1).expand(-1, -1, -1, num_layers) torch.onnx.export(model, (dummy_input, cache), args.onnx_model, input_names=['input', 'cache'], diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index 430ed0f..2d6a581 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -477,6 +477,8 @@ def forward( if in_cache is None or len(in_cache) == 0 : in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))] + else: + in_cache = [in_cache[:, :, :, i:i+1] for i in range(in_cache.size(-1))] input = (input, in_cache) x1 = self.in_linear1(input) x2 = self.in_linear2(x1) @@ -490,7 +492,7 @@ def forward( # x7 = self.softmax(x6) x7, _ = x6 # return x7, None - return x7, in_cache + return x7, torch.cat(in_cache, dim=-1) def to_kaldi_net(self): re_str = ''