diff --git a/trak/projectors.py b/trak/projectors.py index a833da9..18ec010 100644 --- a/trak/projectors.py +++ b/trak/projectors.py @@ -254,7 +254,7 @@ def free_memory(self): def get_generator_states(self): self.generator_states = [] self.seeds = [] - self.jl_size = self.proj_matrix.numel() + self.jl_size = self.grad_dim * self.block_size for i in range(self.num_blocks): s = self.seed + int(1e3) * i + int(1e5) * self.model_id