Skip to content

Commit

Permalink
Merge pull request #2 from recombee/remove_compile
Browse files Browse the repository at this point in the history
Remove compile method, make all attributes private, merge two imports of torch
  • Loading branch information
Kasape authored Jul 30, 2023
2 parents ef16d25 + 9f19360 commit b663e79
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 29 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ num_epochs = 5
batch_size = 128

model = ELSA(n_items=items_cnt, device=device, n_dims=factors)
model.compile()

model.fit(X_csr, batch_size=batch_size, epochs=num_epochs)

Expand Down
2 changes: 1 addition & 1 deletion elsa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .elsa import ELSA

__version__ = "0.1.0"
__version__ = "0.1.4"

__all__ = ["ELSA", "__version__"]
45 changes: 19 additions & 26 deletions elsa/elsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,17 @@

import numpy as np
import torch
import torch.nn as nn
import scipy

logger = logging.getLogger(__name__)

# TODO:
# Implement partial_fit_items method


class ELSA(nn.Module):
class ELSA(torch.nn.Module):
"""
Scalable Linear Shallow Autoencoder for Collaborative Filtering
"""

def __init__(self, n_items: int, n_dims: int, device: torch.device = None):
def __init__(self, n_items: int, n_dims: int, device: torch.device = None, lr: float = 0.1):
"""
Train model with given training data
Expand All @@ -32,25 +28,22 @@ def __init__(self, n_items: int, n_dims: int, device: torch.device = None):
ELSA's weights will allocated on this device
"""
super(ELSA, self).__init__()
W = nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty([n_items, n_dims])).detach().clone())
W.requires_grad = True
self.W_list = nn.ParameterList([W])
W = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty([n_items, n_dims])).detach().clone())
self.__W_list = torch.nn.ParameterList([W])
self.__device = device or torch.device("cuda")
self.__items_cnt = n_items

def compile(self, lr: float = 0.1):
self.to(self.__device)
self.optimizer = torch.optim.NAdam(self.parameters(), lr=lr)
self.__optimizer = torch.optim.NAdam(self.parameters(), lr=lr)
# Loss function of ELSA is NMSE, but PyTorch implements only MSE. Normalization is done manually in train_step function
self.criterion = nn.MSELoss()
self.cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
self.__criterion = torch.nn.MSELoss()
self.__cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
self.to(self.__device)

def train_step(self, x, y):
self.zero_grad()
output = self(x)
loss = self.criterion(nn.functional.normalize(output, dim=-1), nn.functional.normalize(y, dim=-1))
loss = self.__criterion(torch.nn.functional.normalize(output, dim=-1), torch.nn.functional.normalize(y, dim=-1))
loss.backward()
self.optimizer.step()
self.__optimizer.step()
return loss, output

def fit(
Expand Down Expand Up @@ -124,7 +117,7 @@ def fit(
)
loss, predictions = self.train_step(io_batch, io_batch) # Input is also output, since ELSA is an autoencoder
nmse_losses_per_epoch.append(loss.item())
cosine_losses_per_epoch.append(1 - torch.mean(self.cosine(io_batch, predictions), dim=-1).item())
cosine_losses_per_epoch.append(1 - torch.mean(self.__cosine(io_batch, predictions), dim=-1).item())
if verbose:
log_dict = {
"Epoch": f"{epoch_index}/{epochs}",
Expand Down Expand Up @@ -154,12 +147,12 @@ def fit(
for step, io_batch in enumerate(validation_dataloader, start=1):
output = self(io_batch)
nmse_losses_per_epoch.append(
self.criterion(
nn.functional.normalize(output, dim=-1),
nn.functional.normalize(io_batch, dim=-1),
self.__criterion(
torch.nn.functional.normalize(output, dim=-1),
torch.nn.functional.normalize(io_batch, dim=-1),
).item()
)
cosine_losses_per_epoch.append(1 - torch.mean(self.cosine(io_batch, output), dim=-1).item())
cosine_losses_per_epoch.append(1 - torch.mean(self.__cosine(io_batch, output), dim=-1).item())

losses["nmse_val"].append(np.mean(nmse_losses_per_epoch))
losses["cosine_val"].append(np.mean(cosine_losses_per_epoch))
Expand Down Expand Up @@ -193,9 +186,9 @@ def set_device(self, device: typing.Union[str, torch.device]) -> None:
elif not isinstance(device, torch.device):
raise ValueError(f"Device must be specified by string or by torch.device instance, but '{type(device)}' was given.")
self.__device = device
for i, W in enumerate(self.W_list):
for i, W in enumerate(self.__W_list):
if W.device != self.__device:
self.W_list[i] = W.to(self.__device)
self.__W_list[i] = W.to(self.__device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -211,7 +204,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor
Predicted tensor with the same shape as input tensor 'x'
"""
A = nn.functional.normalize(self.__get_weights(), dim=-1)
A = torch.nn.functional.normalize(self.__get_weights(), dim=-1)
xA = torch.matmul(x, A)
xAAT = torch.matmul(xA, A.T)
return xAAT - x
Expand Down Expand Up @@ -391,7 +384,7 @@ def predict_generator(
yield self.forward(input_batch).detach()

def __get_weights(self):
return torch.vstack([param.to(self.__device) for param in self.W_list])
return torch.vstack([param.to(self.__device) for param in self.__W_list])

@staticmethod
def __convert_data_to_dataloader(
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="elsarec",
version="0.1.3",
version="0.1.4",
description="Scalable Linear Shallow Autoencoder for Collaborative Filtering",
author="Recombee",
author_email="vojtech.vancura@recombee.com",
Expand All @@ -24,6 +24,7 @@
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
packages=find_packages(),
# PyTorch needs to be installed manually since it is not on pypi
Expand Down

0 comments on commit b663e79

Please sign in to comment.