diff --git a/tests/backend/pytorch_backend/test_torch_ops.py b/tests/backend/pytorch_backend/test_torch_ops.py index 68c4657..e22f8a6 100644 --- a/tests/backend/pytorch_backend/test_torch_ops.py +++ b/tests/backend/pytorch_backend/test_torch_ops.py @@ -52,7 +52,7 @@ def __init__(self): super().__init__() def forward(self, batch): - output_size = len(batch) + output_size = len(batch["input_ids"]) return { "a": torch.ones(output_size, device="cuda") * 1, "b": torch.ones(output_size, device="cuda") * 2,