From 9c7e2f4cd35477261c4b809f455c3926a747562a Mon Sep 17 00:00:00 2001 From: Zhangyanbo Date: Wed, 28 Apr 2021 18:28:16 -0600 Subject: [PATCH] fix bugs --- src/INN/INN.py | 2 +- src/INN/utilities.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/INN/INN.py b/src/INN/INN.py index b228e0e..c174f6b 100644 --- a/src/INN/INN.py +++ b/src/INN/INN.py @@ -284,7 +284,7 @@ def forward(self, x, log_p0=0, log_det_J=0): log_det = self.logdet(x) if len(x.shape) == 2: # [batch, dim] - log_det = self.logdet(x).repeat(x.shape[0]) + log_det = self.logdet(x) x = super(Linear, self).forward(x) if self.compute_p: diff --git a/src/INN/utilities.py b/src/INN/utilities.py index 3deab85..cf9dd93 100644 --- a/src/INN/utilities.py +++ b/src/INN/utilities.py @@ -479,8 +479,9 @@ def init_weights(self, m): nonlinearity = 'sigmoid' if type(m) == nn.Linear: - # doing Kaiming initialization - torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity=nonlinearity) + # doing xavier initialization + # NOTE: Kaiming initialization will make the output too high, which leads to nan + torch.nn.init.xavier_uniform_(m.weight.data) torch.nn.init.zeros_(m.bias.data) def forward(self, x):