Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create model.py2 #7

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

elonmasai7
Copy link

another view
import math
import torch
from torch import nn
from torch.nn.utils import weight_norm

class Mamba(nn.Module):
def init(self, d_model, d_state, n_layers, d_inner, dropout=0.1):
super().init()

    # We don't want to learn position embeddings.
    # We'll do a simple positional encoding.
    # Note that we divide by sqrt(d_model), which you'll find across other Transformer implementations,
    # and serve the same purpose as with standard attention.
    # `to(torch.float32)` is there only because this code is intended to be seamlessly used with mixed precision.
    self.pos_enc = torch.arange(0, 64, dtype=torch.float32).view(1, -1).to(torch.float32) / math.sqrt(d_model)

    layers = []
    for _ in range(n_layers):
        layers.append(MambaLayer(d_model, d_state, d_inner, dropout=dropout))
    self.layers = nn.ModuleList(layers)

    # Final dense layers.
    self.fc = nn.Linear(d_model, 50257)

def forward(self, x, state_init=None):
    # The input has `l` sequences of length `L` and `b` batch size.
    # `x` has shape: `(l, b, L, d_model)`.
    # We assume the first dimension is the `l` sequence one.
    l, b, L, d = x.shape

    if state_init is None:
        state_init = torch.zeros(l, b, 1, d // 2, dtype=x.dtype, device=x.device)

    x = x + self.pos_enc[:L, None]
    states, outs = [], []

    for layer in self.layers:
        x, state = layer(x, state_init)
        states.append(state)
        # `outs` will eventually have shape `(l, b, L, d)`.
        outs.append(x)

    return self.fc(torch.cat(outs, dim=-1)), torch.cat(states, dim=-2)

class MambaLayer(nn.Module):
def init(self, d_model, d_state, d_inner, dropout=0.1):
super().init()
d_model_half = d_model // 2

    self.lin_A = nn.Linear(d_model, d_model_half)
    self.lin_D = nn.Linear(d_model, d_model_half)

    self.lin_in = nn.Linear(d_model, d_inner)
    self.lin_B1 = nn.Linear(d_inner, d_model_half)
    self.lin_B2 = nn.Linear(d_state, d_model_half)
    self.lin_C = weight_norm(nn.Linear(d_model_half, d_model_half))

    self.dropout = nn.Dropout(dropout)

def forward(self, x, state_init):
    # We output both the state AND the transformed sequence (`x`).
    # The `x` shape is expected to be `(l, b, L, d)`.
    # The `state_init` shape is expected to be `(l, b, 1, n)`.

    l, b, L, d = x.shape
    d_model_half = d // 2

    # We learned to use tanh activation for A and D.
    A = torch.tanh(self.lin_A(x))
    D = torch.tanh(self.lin_D(x))

    a = self.dropout(self.lin_in(x))
    b1 = self.lin_B1(a)
    b2 = self.dropout(self.lin_B2(state_init))
    B = b1 + b2
    c = self.lin_C(self.dropout(A * B))
    state = D * state_init + c[:, :, :, None]

    # It looks like state_init might be off by one timestep from A, B, C, D, but this is
    # not the case because we will start the loop on the 2nd timestep. It is perfectly
    # consistent with the equations of Mamba (see [1] Algorithm 2).
    # Intuitively, we also need to use `state_init` at time `t - 1` rather than `t` to compute
    # `x_t`. Indeed, `state_t - 1` is a consequence of `x_t - 1` and `u_t - 1`.
    # If we were to use `state_t`, this would be equivalent to having `δ_t = 1` instead of
    # `δ_t = 0`, which is the case under the "zero-input" assumption made by the authors
    # (see Equation (7) in [1]).
    x = A * B + C

    # We obtain a new state and a new output sequence `x`.
    return x, state

another view
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant