Skip to content

Commit

Permalink
Merge pull request #35 from OxWearables/hot-fix
Browse files Browse the repository at this point in the history
Handles edge case during data transformation when the remainder is 0.
  • Loading branch information
angerhang authored Sep 28, 2023
2 parents 8363ddb + 86ae741 commit 8839562
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "asleep" # Required
#
# For a discussion on single-sourcing the version, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version = "0.4.3" # Required
version = "0.4.4" # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
2 changes: 1 addition & 1 deletion src/asleep/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def predict(self, X, groups=None):
model.to(self.device)

_, y_pred, _ = sslmodel.predict(
model, dataloader, self.device, output_logits=False, name='prediction')
model, dataloader, self.device, output_logits=False)

y_pred = self.hmms.predict(y_pred, groups=groups)

Expand Down
5 changes: 1 addition & 4 deletions src/asleep/sslmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ def predict(model, data_loader, device,
pid_list.extend(pid)

predictions_list = torch.cat(predictions_list)
if name == 'prediction':
true_list = predictions_list
else:
true_list = torch.Tensor([1, 2, 3])
true_list = torch.cat(true_list)

if output_logits:
return (
Expand Down
7 changes: 4 additions & 3 deletions src/asleep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,10 @@ def data_long2wide(X, times, non_wear):
"""
# get multiple of 900
remainder = X.shape[0] % 900
X = X[:-remainder]
times = times[:-remainder]
non_wear = non_wear[:-remainder]
if remainder != 0:
X = X[:-remainder]
times = times[:-remainder]
non_wear = non_wear[:-remainder]

x = X[:, 0]
y = X[:, 1]
Expand Down

0 comments on commit 8839562

Please sign in to comment.