Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner All Pairs Difference #106

Open
0xc1c4da opened this issue Sep 3, 2024 · 3 comments
Open

Cleaner All Pairs Difference #106

0xc1c4da opened this issue Sep 3, 2024 · 3 comments

Comments

@0xc1c4da
Copy link

0xc1c4da commented Sep 3, 2024

In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows:
dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]

This is performed in the reference implementation of Mamba 2

While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:

def test_all_pairs_difference():
    H = Axis("H", 7)
    W = Axis("W", 8)
    D = Axis("D", 9)
    T = Axis("T", 11)

    named1 = hax.random.uniform(PRNGKey(0), (H, W, D, T))
    # making sure this analogue works:
    #dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
    named1_diff = named1.broadcast_axis(hax.Axis("T2", 11)) - named1.rename({"T": "T2"})
    named1_diff = named1_diff.rearrange((..., "T", "T2"))
    assert named1_diff.axes == (H, W, D, T, Axis("T2", 11))

    vanilla_diff = named1.array[:, :, :, :, None] - named1.array[:, :, :, None, :]

    assert jnp.all(named1_diff.array == vanilla_diff)

This issue exists provide better support for this kind of operation.

@dlwh
Copy link
Member

dlwh commented Sep 5, 2024

what do you think about:

with hax.auto_broadcast():
    named1_diff = named1 - named1.rename({"T": "T2"})

Basically the only thing stopping this from working is an explicit check I do to avoid accidentally combining arrays where one isn't a subset of the other.

The other thing I could do is relax the check to be "at least one overlapping axis"

@0xc1c4da
Copy link
Author

I think it is certainly cleaner, but I wouldn't remove the explicit check, wouldn't it be better to explicitly disable the check?

@dlwh
Copy link
Member

dlwh commented Sep 15, 2024

meaning you like with hax.auto_broadcast?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants