Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhangyanbo committed Apr 29, 2021
1 parent 84a1277 commit 9c7e2f4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/INN/INN.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def forward(self, x, log_p0=0, log_det_J=0):
log_det = self.logdet(x)
if len(x.shape) == 2:
# [batch, dim]
log_det = self.logdet(x).repeat(x.shape[0])
log_det = self.logdet(x)
x = super(Linear, self).forward(x)

if self.compute_p:
Expand Down
5 changes: 3 additions & 2 deletions src/INN/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,9 @@ def init_weights(self, m):
nonlinearity = 'sigmoid'

if type(m) == nn.Linear:
# doing Kaiming initialization
torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity=nonlinearity)
# doing xavier initialization
# NOTE: Kaiming initialization will make the output too high, which leads to nan
torch.nn.init.xavier_uniform_(m.weight.data)
torch.nn.init.zeros_(m.bias.data)

def forward(self, x):
Expand Down

0 comments on commit 9c7e2f4

Please sign in to comment.