Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shape do not match in slowfast #230

Open
xwjBupt opened this issue Dec 30, 2022 · 0 comments
Open

shape do not match in slowfast #230

xwjBupt opened this issue Dec 30, 2022 · 0 comments

Comments

@xwjBupt
Copy link

xwjBupt commented Dec 30, 2022

If you do not know the root cause of the problem / bug, and wish someone to help you, please
post according to this template:

🐛 Bugs / Unexpected behaviors

Hello, I want to use slowfast to finetune my own data. I extract the code referenced by slowfast and put it in a separate file (play2. py,as shown in the code part). Then, according to the original paper, I input the size of the slow view as 1 * 3 * 4 * 224 * 224, and the size of the fast view as 1 * 3 * 32 * 224 * 224. When I run an error, I am prompted to translate the slow view and fast_view, the time depth of view fusion does not match. What is the matter? Thank you for your reply!

NOTE: Please look at the existing list of Issues tagged with the label 'bug`. Only open a new issue if this bug has not already been reported. If an issue already exists, please comment there instead..

Instructions To Reproduce the Issue:

Please include the following (depending on what the issue is):

  1. Any changes you made (git diff) or code you wrote
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill
import numpy as np


def set_attributes(self, params: List[object] = None) -> None:
    """
    An utility function used in classes to set attributes from the input list of parameters.
    Args:
        params (list): list of parameters.
    """
    if params:
        for k, v in params.items():
            if k != "self":
                setattr(self, k, v)


class ResNetBasicStem(nn.Module):
    """
    ResNet basic 3D stem module. Performs spatiotemporal Convolution, BN, and activation
    following by a spatiotemporal pooling.
    ::
                                        Conv3d
                                           ↓
                                     Normalization
                                           ↓
                                       Activation
                                           ↓
                                        Pool3d
    The builder can be found in `create_res_basic_stem`.
    """

    def __init__(
        self,
        *,
        conv: nn.Module = None,
        norm: nn.Module = None,
        activation: nn.Module = None,
        pool: nn.Module = None,
    ) -> None:
        """
        Args:
            conv (torch.nn.modules): convolutional module.
            norm (torch.nn.modules): normalization module.
            activation (torch.nn.modules): activation module.
            pool (torch.nn.modules): pooling module.
        """
        super().__init__()
        set_attributes(self, locals())
        assert self.conv is not None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.norm is not None:
            x = self.norm(x)
        if self.activation is not None:
            x = self.activation(x)
        if self.pool is not None:
            x = self.pool(x)
        return x


def create_res_basic_stem(
    *,
    # Conv configs.
    in_channels: int,
    out_channels: int,
    conv_kernel_size: Tuple[int] = (3, 7, 7),
    conv_stride: Tuple[int] = (1, 2, 2),
    conv_padding: Tuple[int] = (1, 3, 3),
    conv_bias: bool = False,
    conv: Callable = nn.Conv3d,
    # Pool configs.
    pool: Callable = nn.MaxPool3d,
    pool_kernel_size: Tuple[int] = (1, 3, 3),
    pool_stride: Tuple[int] = (1, 2, 2),
    pool_padding: Tuple[int] = (0, 1, 1),
    # BN configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation: Callable = nn.ReLU,
) -> nn.Module:
    """
    Creates the basic resnet stem layer. It performs spatiotemporal Convolution, BN, and
    Relu following by a spatiotemporal pooling.
    ::
                                        Conv3d
                                           ↓
                                     Normalization
                                           ↓
                                       Activation
                                           ↓
                                        Pool3d
    Normalization options include: BatchNorm3d and None (no normalization).
    Activation options include: ReLU, Softmax, Sigmoid, and None (no activation).
    Pool3d options include: AvgPool3d, MaxPool3d, and None (no pooling).
    Args:
        in_channels (int): input channel size of the convolution.
        out_channels (int): output channel size of the convolution.
        conv_kernel_size (tuple): convolutional kernel size(s).
        conv_stride (tuple): convolutional stride size(s).
        conv_padding (tuple): convolutional padding size(s).
        conv_bias (bool): convolutional bias. If true, adds a learnable bias to the
            output.
        conv (callable): Callable used to build the convolution layer.
        pool (callable): a callable that constructs pooling layer, options include:
            nn.AvgPool3d, nn.MaxPool3d, and None (not performing pooling).
        pool_kernel_size (tuple): pooling kernel size(s).
        pool_stride (tuple): pooling stride size(s).
        pool_padding (tuple): pooling padding size(s).
        norm (callable): a callable that constructs normalization layer, options
            include nn.BatchNorm3d, None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.
        activation (callable): a callable that constructs activation layer, options
            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
            activation).
    Returns:
        (nn.Module): resnet basic stem layer.
    """
    conv_module = conv(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=conv_kernel_size,
        stride=conv_stride,
        padding=conv_padding,
        bias=conv_bias,
    )
    norm_module = (
        None
        if norm is None
        else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum)
    )
    activation_module = None if activation is None else activation()
    pool_module = (
        None
        if pool is None
        else pool(
            kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
        )
    )

    return ResNetBasicStem(
        conv=conv_module,
        norm=norm_module,
        activation=activation_module,
        pool=pool_module,
    )


_MODEL_STAGE_DEPTH = {
    18: (1, 1, 1, 1),
    50: (3, 4, 6, 3),
    101: (3, 4, 23, 3),
    152: (3, 8, 36, 3),
}


class BottleneckBlock(nn.Module):
    """
    Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
    and Activations repeated in the following order:
    ::
                                    Conv3d (conv_a)
                                           ↓
                                 Normalization (norm_a)
                                           ↓
                                   Activation (act_a)
                                           ↓
                                    Conv3d (conv_b)
                                           ↓
                                 Normalization (norm_b)
                                           ↓
                                   Activation (act_b)
                                           ↓
                                    Conv3d (conv_c)
                                           ↓
                                 Normalization (norm_c)
    The builder can be found in `create_bottleneck_block`.
    """

    def __init__(
        self,
        *,
        conv_a: nn.Module = None,
        norm_a: nn.Module = None,
        act_a: nn.Module = None,
        conv_b: nn.Module = None,
        norm_b: nn.Module = None,
        act_b: nn.Module = None,
        conv_c: nn.Module = None,
        norm_c: nn.Module = None,
    ) -> None:
        """
        Args:
            conv_a (torch.nn.modules): convolutional module.
            norm_a (torch.nn.modules): normalization module.
            act_a (torch.nn.modules): activation module.
            conv_b (torch.nn.modules): convolutional module.
            norm_b (torch.nn.modules): normalization module.
            act_b (torch.nn.modules): activation module.
            conv_c (torch.nn.modules): convolutional module.
            norm_c (torch.nn.modules): normalization module.
        """
        super().__init__()
        set_attributes(self, locals())
        assert all(op is not None for op in (self.conv_a, self.conv_b, self.conv_c))
        if self.norm_c is not None:
            # This flag is used for weight initialization.
            self.norm_c.block_final_bn = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Explicitly forward every layer.
        # Branch2a, for example Tx1x1, BN, ReLU.
        x = self.conv_a(x)
        if self.norm_a is not None:
            x = self.norm_a(x)
        if self.act_a is not None:
            x = self.act_a(x)

        # Branch2b, for example 1xHxW, BN, ReLU.
        x = self.conv_b(x)
        if self.norm_b is not None:
            x = self.norm_b(x)
        if self.act_b is not None:
            x = self.act_b(x)

        # Branch2c, for example 1x1x1, BN.
        x = self.conv_c(x)
        if self.norm_c is not None:
            x = self.norm_c(x)
        return


def create_bottleneck_block(
    *,
    # Convolution configs.
    dim_in: int,
    dim_inner: int,
    dim_out: int,
    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
    conv_a_stride: Tuple[int] = (2, 1, 1),
    conv_a_padding: Tuple[int] = (1, 0, 0),
    conv_a: Callable = nn.Conv3d,
    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
    conv_b_stride: Tuple[int] = (1, 2, 2),
    conv_b_padding: Tuple[int] = (0, 1, 1),
    conv_b_num_groups: int = 1,
    conv_b_dilation: Tuple[int] = (1, 1, 1),
    conv_b: Callable = nn.Conv3d,
    conv_c: Callable = nn.Conv3d,
    # Norm configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation: Callable = nn.ReLU,
) -> nn.Module:
    """
    Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
    and Activations repeated in the following order:
    ::
                                    Conv3d (conv_a)
                                           ↓
                                 Normalization (norm_a)
                                           ↓
                                   Activation (act_a)
                                           ↓
                                    Conv3d (conv_b)
                                           ↓
                                 Normalization (norm_b)
                                           ↓
                                   Activation (act_b)
                                           ↓
                                    Conv3d (conv_c)
                                           ↓
                                 Normalization (norm_c)
    Normalization examples include: BatchNorm3d and None (no normalization).
    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
    Args:
        dim_in (int): input channel size to the bottleneck block.
        dim_inner (int): intermediate channel size of the bottleneck.
        dim_out (int): output channel size of the bottleneck.
        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
        conv_a_padding (tuple): convolutional padding(s) for conv_a.
        conv_a (callable): a callable that constructs the conv_a conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
        conv_b_padding (tuple): convolutional padding(s) for conv_b.
        conv_b_num_groups (int): number of groups for groupwise convolution for
            conv_b.
        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
        conv_b (callable): a callable that constructs the conv_b conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_c (callable): a callable that constructs the conv_c conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        norm (callable): a callable that constructs normalization layer, examples
            include nn.BatchNorm3d, None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.
        activation (callable): a callable that constructs activation layer, examples
            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
            activation).
    Returns:
        (nn.Module): resnet bottleneck block.
    """
    conv_a = conv_a(
        in_channels=dim_in,
        out_channels=dim_inner,
        kernel_size=conv_a_kernel_size,
        stride=conv_a_stride,
        padding=conv_a_padding,
        bias=False,
    )
    norm_a = (
        None
        if norm is None
        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
    )
    act_a = None if activation is None else activation()

    conv_b = conv_b(
        in_channels=dim_inner,
        out_channels=dim_inner,
        kernel_size=conv_b_kernel_size,
        stride=conv_b_stride,
        padding=conv_b_padding,
        bias=False,
        groups=conv_b_num_groups,
        dilation=conv_b_dilation,
    )
    norm_b = (
        None
        if norm is None
        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
    )
    act_b = None if activation is None else activation()

    conv_c = conv_c(
        in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
    )
    norm_c = (
        None
        if norm is None
        else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
    )

    return BottleneckBlock(
        conv_a=conv_a,
        norm_a=norm_a,
        act_a=act_a,
        conv_b=conv_b,
        norm_b=norm_b,
        act_b=act_b,
        conv_c=conv_c,
        norm_c=norm_c,
    )


class ResNetBasicHead(nn.Module):
    """
    ResNet basic head. This layer performs an optional pooling operation followed by an
    optional dropout, a fully-connected projection, an optional activation layer and a
    global spatiotemporal averaging.
    ::
                                        Pool3d
                                           ↓
                                        Dropout
                                           ↓
                                       Projection
                                           ↓
                                       Activation
                                           ↓
                                       Averaging
    The builder can be found in `create_res_basic_head`.
    """

    def __init__(
        self,
        pool: nn.Module = None,
        dropout: nn.Module = None,
        proj: nn.Module = None,
        activation: nn.Module = None,
        output_pool: nn.Module = None,
    ) -> None:
        """
        Args:
            pool (torch.nn.modules): pooling module.
            dropout(torch.nn.modules): dropout module.
            proj (torch.nn.modules): project module.
            activation (torch.nn.modules): activation module.
            output_pool (torch.nn.Module): pooling module for output.
        """
        super().__init__()
        set_attributes(self, locals())
        assert self.proj is not None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Performs pooling.
        if self.pool is not None:
            x = self.pool(x)
        # Performs dropout.
        if self.dropout is not None:
            x = self.dropout(x)
        # Performs projection.
        if self.proj is not None:
            x = x.permute((0, 2, 3, 4, 1))
            x = self.proj(x)
            x = x.permute((0, 4, 1, 2, 3))
        # Performs activation.
        if self.activation is not None:
            x = self.activation(x)

        if self.output_pool is not None:
            # Performs global averaging.
            x = self.output_pool(x)
            x = x.view(x.shape[0], -1)
        return


def create_res_basic_head(
    *,
    # Projection configs.
    in_features: int,
    out_features: int,
    # Pooling configs.
    pool: Callable = nn.AvgPool3d,
    output_size: Tuple[int] = (1, 1, 1),
    pool_kernel_size: Tuple[int] = (1, 7, 7),
    pool_stride: Tuple[int] = (1, 1, 1),
    pool_padding: Tuple[int] = (0, 0, 0),
    # Dropout configs.
    dropout_rate: float = 0.5,
    # Activation configs.
    activation: Callable = None,
    # Output configs.
    output_with_global_average: bool = True,
) -> nn.Module:
    """
    Creates ResNet basic head. This layer performs an optional pooling operation
    followed by an optional dropout, a fully-connected projection, an activation layer
    and a global spatiotemporal averaging.
    ::
                                        Pooling
                                           ↓
                                        Dropout
                                           ↓
                                       Projection
                                           ↓
                                       Activation
                                           ↓
                                       Averaging
    Activation examples include: ReLU, Softmax, Sigmoid, and None.
    Pool3d examples include: AvgPool3d, MaxPool3d, AdaptiveAvgPool3d, and None.
    Args:
        in_features: input channel size of the resnet head.
        out_features: output channel size of the resnet head.
        pool (callable): a callable that constructs resnet head pooling layer,
            examples include: nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, and
            None (not applying pooling).
        pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive
            pooling.
        pool_stride (tuple): pooling stride size(s) when not using adaptive pooling.
        pool_padding (tuple): pooling padding size(s) when not using adaptive
            pooling.
        output_size (tuple): spatial temporal output size when using adaptive
            pooling.
        activation (callable): a callable that constructs resnet head activation
            layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not
            applying activation).
        dropout_rate (float): dropout rate.
        output_with_global_average (bool): if True, perform global averaging on temporal
            and spatial dimensions and reshape output to batch_size x out_features.
    """

    if activation is None:
        activation_model = None
    elif activation == nn.Softmax:
        activation_model = activation(dim=1)
    else:
        activation_model = activation()

    if pool is None:
        pool_model = None
    elif pool == nn.AdaptiveAvgPool3d:
        pool_model = pool(output_size)
    else:
        pool_model = pool(
            kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
        )

    if output_with_global_average:
        output_pool = nn.AdaptiveAvgPool3d(1)
    else:
        output_pool = None

    return ResNetBasicHead(
        proj=nn.Linear(in_features, out_features),
        activation=activation_model,
        pool=pool_model,
        dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None,
        output_pool=output_pool,
    )


def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None:
    """
    Performs ResNet style weight initialization. That is, recursively initialize the
    given model in the following way for each type:
        Conv - Follow the initialization of kaiming_normal:
            https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
        BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck
            to 0.
        Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias
            to 0.
    Args:
        model (nn.Module): Model to be initialized.
        fc_init_std (float): the expected standard deviation for fully-connected layer.
    """
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Conv3d)):
            """
            Follow the initialization method proposed in:
            {He, Kaiming, et al.
            "Delving deep into rectifiers: Surpassing human-level
            performance on imagenet classification."
            arXiv preprint arXiv:1502.01852 (2015)}
            """
            c2_msra_fill(m)
        elif isinstance(m, nn.modules.batchnorm._NormBase):
            if m.weight is not None:
                if hasattr(m, "block_final_bn") and m.block_final_bn:
                    m.weight.data.fill_(0.0)
                else:
                    m.weight.data.fill_(1.0)
            if m.bias is not None:
                m.bias.data.zero_()
        if isinstance(m, nn.Linear):
            if hasattr(m, "xavier_init") and m.xavier_init:
                c2_xavier_fill(m)
            else:
                m.weight.data.normal_(mean=0.0, std=fc_init_std)
            if m.bias is not None:
                m.bias.data.zero_()
    return


def init_net_weights(
    model: nn.Module,
    init_std: float = 0.01,
    style: str = "resnet",
) -> None:
    """
    Performs weight initialization. Options include ResNet style weight initialization
    and transformer style weight initialization.
    Args:
        model (nn.Module): Model to be initialized.
        init_std (float): The expected standard deviation for initialization.
        style (str): Options include "resnet" and "vit".
    """
    assert style in ["resnet", "vit"]
    if style == "resnet":
        return _init_resnet_weights(model, init_std)
    else:
        raise


class Net(nn.Module):
    """
    Build a general Net models with a list of blocks for video recognition.
    ::
                                         Input
                                           ↓
                                         Block 1
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                         Block N
                                           ↓
    The ResNet builder can be found in `create_resnet`.
    """

    def __init__(self, *, blocks: nn.ModuleList) -> None:
        """
        Args:
            blocks (torch.nn.module_list): the list of block modules.
        """
        super().__init__()
        assert blocks is not None
        self.blocks = blocks
        init_net_weights(self)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for _, block in enumerate(self.blocks):
            x = block(x)
        return


def _trivial_sum(x, y):
    """
    Utility function used in lieu of lamda which are not picklable
    """
    return x + y


class ResBlock(nn.Module):
    """
    Residual block. Performs a summation between an identity shortcut in branch1 and a
    main block in branch2. When the input and output dimensions are different, a
    convolution followed by a normalization will be performed.
    ::
                                         Input
                                           |-------+
                                           ↓       |
                                         Block     |
                                           ↓       |
                                       Summation ←-+
                                           ↓
                                       Activation
    The builder can be found in `create_res_block`.
    """

    def __init__(
        self,
        branch1_conv: nn.Module = None,
        branch1_norm: nn.Module = None,
        branch2: nn.Module = None,
        activation: nn.Module = None,
        branch_fusion: Callable = None,
    ) -> nn.Module:
        """
        Args:
            branch1_conv (torch.nn.modules): convolutional module in branch1.
            branch1_norm (torch.nn.modules): normalization module in branch1.
            branch2 (torch.nn.modules): bottleneck block module in branch2.
            activation (torch.nn.modules): activation module.
            branch_fusion: (Callable): A callable or layer that combines branch1
                and branch2.
        """
        super().__init__()
        set_attributes(self, locals())
        assert self.branch2 is not None

    def forward(self, x) -> torch.Tensor:
        if self.branch1_conv is None:
            x = self.branch_fusion(x, self.branch2(x))
        else:
            shortcut = self.branch1_conv(x)
            if self.branch1_norm is not None:
                shortcut = self.branch1_norm(shortcut)
            x = self.branch_fusion(shortcut, self.branch2(x))
        if self.activation is not None:
            x = self.activation(x)
        return x


def create_res_block(
    *,
    # Bottleneck Block configs.
    dim_in: int,
    dim_inner: int,
    dim_out: int,
    bottleneck: Callable,
    use_shortcut: bool = False,
    branch_fusion: Callable = _trivial_sum,
    # Conv configs.
    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
    conv_a_stride: Tuple[int] = (2, 1, 1),
    conv_a_padding: Tuple[int] = (1, 0, 0),
    conv_a: Callable = nn.Conv3d,
    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
    conv_b_stride: Tuple[int] = (1, 2, 2),
    conv_b_padding: Tuple[int] = (0, 1, 1),
    conv_b_num_groups: int = 1,
    conv_b_dilation: Tuple[int] = (1, 1, 1),
    conv_b: Callable = nn.Conv3d,
    conv_c: Callable = nn.Conv3d,
    conv_skip: Callable = nn.Conv3d,
    # Norm configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation_bottleneck: Callable = nn.ReLU,
    activation_block: Callable = nn.ReLU,
) -> nn.Module:
    """
    Residual block. Performs a summation between an identity shortcut in branch1 and a
    main block in branch2. When the input and output dimensions are different, a
    convolution followed by a normalization will be performed.
    ::
                                         Input
                                           |-------+
                                           ↓       |
                                         Block     |
                                           ↓       |
                                       Summation ←-+
                                           ↓
                                       Activation
    Normalization examples include: BatchNorm3d and None (no normalization).
    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
    Transform examples include: BottleneckBlock.
    Args:
        dim_in (int): input channel size to the bottleneck block.
        dim_inner (int): intermediate channel size of the bottleneck.
        dim_out (int): output channel size of the bottleneck.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.
        use_shortcut (bool): If true, use conv and norm layers in skip connection.
        branch_fusion (callable): a callable that constructs summation layer.
            Examples include: lambda x, y: x + y, OctaveSum.
        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
        conv_a_padding (tuple): convolutional padding(s) for conv_a.
        conv_a (callable): a callable that constructs the conv_a conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
        conv_b_padding (tuple): convolutional padding(s) for conv_b.
        conv_b_num_groups (int): number of groups for groupwise convolution for
            conv_b.
        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
        conv_b (callable): a callable that constructs the conv_b conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_c (callable): a callable that constructs the conv_c conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_skip (callable): a callable that constructs the conv_skip conv layer,
        examples include nn.Conv3d, OctaveConv, etc
        norm (callable): a callable that constructs normalization layer. Examples
            include nn.BatchNorm3d, None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.
        activation_bottleneck (callable): a callable that constructs activation layer in
            bottleneck. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None
            (not performing activation).
        activation_block (callable): a callable that constructs activation layer used
            at the end of the block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid,
            and None (not performing activation).
    Returns:
        (nn.Module): resnet basic block layer.
    """
    branch1_conv_stride = tuple([x * y for x, y in zip(conv_a_stride, conv_b_stride)])
    norm_model = None
    if use_shortcut or (
        norm is not None and (dim_in != dim_out or np.prod(branch1_conv_stride) != 1)
    ):
        norm_model = norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)

    return ResBlock(
        branch1_conv=conv_skip(
            dim_in,
            dim_out,
            kernel_size=(1, 1, 1),
            stride=branch1_conv_stride,
            bias=False,
        )
        if (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) or use_shortcut
        else None,
        branch1_norm=norm_model,
        branch2=bottleneck(
            dim_in=dim_in,
            dim_inner=dim_inner,
            dim_out=dim_out,
            conv_a_kernel_size=conv_a_kernel_size,
            conv_a_stride=conv_a_stride,
            conv_a_padding=conv_a_padding,
            conv_a=conv_a,
            conv_b_kernel_size=conv_b_kernel_size,
            conv_b_stride=conv_b_stride,
            conv_b_padding=conv_b_padding,
            conv_b_num_groups=conv_b_num_groups,
            conv_b_dilation=conv_b_dilation,
            conv_b=conv_b,
            conv_c=conv_c,
            norm=norm,
            norm_eps=norm_eps,
            norm_momentum=norm_momentum,
            activation=activation_bottleneck,
        ),
        activation=None if activation_block is None else activation_block(),
        branch_fusion=branch_fusion,
    )


class ResStage(nn.Module):
    """
    ResStage composes sequential blocks that make up a ResNet. These blocks could be,
    for example, Residual blocks, Non-Local layers, or Squeeze-Excitation layers.
    ::
                                        Input
                                           ↓
                                       ResBlock
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                       ResBlock
    The builder can be found in `create_res_stage`.
    """

    def __init__(self, res_blocks: nn.ModuleList) -> nn.Module:
        """
        Args:
            res_blocks (torch.nn.module_list): ResBlock module(s).
        """
        super().__init__()
        self.res_blocks = res_blocks

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for _, res_block in enumerate(self.res_blocks):
            x = res_block(x)
        return


def create_res_stage(
    *,
    # Stage configs.
    depth: int,
    # Bottleneck Block configs.
    dim_in: int,
    dim_inner: int,
    dim_out: int,
    bottleneck: Callable,
    # Conv configs.
    conv_a_kernel_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1, 1),
    conv_a_stride: Tuple[int] = (2, 1, 1),
    conv_a_padding: Union[Tuple[int], List[Tuple[int]]] = (1, 0, 0),
    conv_a: Callable = nn.Conv3d,
    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
    conv_b_stride: Tuple[int] = (1, 2, 2),
    conv_b_padding: Tuple[int] = (0, 1, 1),
    conv_b_num_groups: int = 1,
    conv_b_dilation: Tuple[int] = (1, 1, 1),
    conv_b: Callable = nn.Conv3d,
    conv_c: Callable = nn.Conv3d,
    # Norm configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation: Callable = nn.ReLU,
) -> nn.Module:
    """
    Create Residual Stage, which composes sequential blocks that make up a ResNet. These
    blocks could be, for example, Residual blocks, Non-Local layers, or
    Squeeze-Excitation layers.
    ::
                                        Input
                                           ↓
                                       ResBlock
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                       ResBlock
    Normalization examples include: BatchNorm3d and None (no normalization).
    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
    Bottleneck examples include: create_bottleneck_block.
    Args:
        depth (init): number of blocks to create.
        dim_in (int): input channel size to the bottleneck block.
        dim_inner (int): intermediate channel size of the bottleneck.
        dim_out (int): output channel size of the bottleneck.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.
        conv_a_kernel_size (tuple or list of tuple): convolutional kernel size(s)
            for conv_a. If conv_a_kernel_size is a tuple, use it for all blocks in
            the stage. If conv_a_kernel_size is a list of tuple, the kernel sizes
            will be repeated until having same length of depth in the stage. For
            example, for conv_a_kernel_size = [(3, 1, 1), (1, 1, 1)], the kernel
            size for the first 6 blocks would be [(3, 1, 1), (1, 1, 1), (3, 1, 1),
            (1, 1, 1), (3, 1, 1)].
        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
        conv_a_padding (tuple or list of tuple): convolutional padding(s) for
            conv_a. If conv_a_padding is a tuple, use it for all blocks in
            the stage. If conv_a_padding is a list of tuple, the padding sizes
            will be repeated until having same length of depth in the stage.
        conv_a (callable): a callable that constructs the conv_a conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
        conv_b_padding (tuple): convolutional padding(s) for conv_b.
        conv_b_num_groups (int): number of groups for groupwise convolution for
            conv_b.
        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
        conv_b (callable): a callable that constructs the conv_b conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_c (callable): a callable that constructs the conv_c conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        norm (callable): a callable that constructs normalization layer. Examples
            include nn.BatchNorm3d, and None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.
        activation (callable): a callable that constructs activation layer. Examples
            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
            activation).
    Returns:
        (nn.Module): resnet basic stage layer.
    """
    res_blocks = []
    if isinstance(conv_a_kernel_size[0], int):
        conv_a_kernel_size = [conv_a_kernel_size]
    if isinstance(conv_a_padding[0], int):
        conv_a_padding = [conv_a_padding]
    # Repeat conv_a kernels until having same length of depth in the stage.
    conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth]
    conv_a_padding = (conv_a_padding * depth)[:depth]

    for ind in range(depth):
        block = create_res_block(
            dim_in=dim_in if ind == 0 else dim_out,
            dim_inner=dim_inner,
            dim_out=dim_out,
            bottleneck=bottleneck,
            conv_a_kernel_size=conv_a_kernel_size[ind],
            conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1),
            conv_a_padding=conv_a_padding[ind],
            conv_a=conv_a,
            conv_b_kernel_size=conv_b_kernel_size,
            conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1),
            conv_b_padding=conv_b_padding,
            conv_b_num_groups=conv_b_num_groups,
            conv_b_dilation=conv_b_dilation,
            conv_b=conv_b,
            conv_c=conv_c,
            norm=norm,
            norm_eps=norm_eps,
            norm_momentum=norm_momentum,
            activation_bottleneck=activation,
            activation_block=activation,
        )
        res_blocks.append(block)
    return ResStage(res_blocks=nn.ModuleList(res_blocks))


class MultiPathWayWithFuse(nn.Module):
    """
    Build multi-pathway block with fusion for video recognition, each of the pathway
    contains its own Blocks and Fusion layers across different pathways.
    ::
                            Pathway 1  ... Pathway N
                                ↓              ↓
                             Block 1        Block N
                                ↓⭠ --Fusion----↓
    """

    def __init__(
        self,
        *,
        multipathway_blocks: nn.ModuleList,
        multipathway_fusion: Optional[nn.Module],
        inplace: Optional[bool] = True,
    ) -> None:
        """
        Args:
            multipathway_blocks (nn.module_list): list of models from all pathways.
            multipathway_fusion (nn.module): fusion model.
            inplace (bool): If inplace, directly update the input list without making
                a copy.
        """
        super().__init__()
        set_attributes(self, locals())

    def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
        assert isinstance(
            x, list
        ), "input for MultiPathWayWithFuse needs to be a list of tensors"
        if self.inplace:
            x_out = x
        else:
            x_out = [None] * len(x)
        for pathway_idx in range(len(self.multipathway_blocks)):
            if self.multipathway_blocks[pathway_idx] is not None:
                x_out[pathway_idx] = self.multipathway_blocks[pathway_idx](
                    x[pathway_idx]
                )
        if self.multipathway_fusion is not None:
            x_out = self.multipathway_fusion(x_out)
        return


def create_slowfast(
    *,
    # SlowFast configs.
    slowfast_channel_reduction_ratio: Union[Tuple[int], int] = (8,),
    slowfast_conv_channel_fusion_ratio: int = 2,
    slowfast_fusion_conv_kernel_size: Tuple[int] = (
        7,
        1,
        1,
    ),  # deprecated, use fusion_builder
    slowfast_fusion_conv_stride: Tuple[int] = (
        4,
        1,
        1,
    ),  # deprecated, use fusion_builder
    fusion_builder: Callable[
        [int, int], nn.Module
    ] = None,  # Args: fusion_dim_in, stage_idx
    # Input clip configs.
    input_channels: Tuple[int] = (3, 3),
    # Model configs.
    model_depth: int = 18,
    model_num_class: int = 5,
    dropout_rate: float = 0.5,
    # Normalization configs.
    norm: Callable = nn.BatchNorm3d,
    # Activation configs.
    activation: Callable = nn.ReLU,
    # Stem configs.
    stem_function: Tuple[Callable] = (
        create_res_basic_stem,
        create_res_basic_stem,
    ),
    stem_dim_outs: Tuple[int] = (64, 8),
    stem_conv_kernel_sizes: Tuple[Tuple[int]] = ((1, 7, 7), (5, 7, 7)),
    stem_conv_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
    stem_pool: Union[Callable, Tuple[Callable]] = (nn.MaxPool3d, nn.MaxPool3d),
    stem_pool_kernel_sizes: Tuple[Tuple[int]] = ((1, 3, 3), (1, 3, 3)),
    stem_pool_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
    # Stage configs.
    stage_conv_a_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
        ((1, 1, 1), (1, 1, 1), (3, 1, 1), (3, 1, 1)),
        ((3, 1, 1), (3, 1, 1), (3, 1, 1), (3, 1, 1)),
    ),
    stage_conv_b_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
        ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
        ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
    ),
    stage_conv_b_num_groups: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
    stage_conv_b_dilations: Tuple[Tuple[Tuple[int]]] = (
        ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
        ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
    ),
    stage_spatial_strides: Tuple[Tuple[int]] = ((1, 2, 2, 2), (1, 2, 2, 2)),
    stage_temporal_strides: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
    bottleneck: Union[Callable, Tuple[Tuple[Callable]]] = (
        (
            create_bottleneck_block,
            create_bottleneck_block,
            create_bottleneck_block,
            create_bottleneck_block,
        ),
        (
            create_bottleneck_block,
            create_bottleneck_block,
            create_bottleneck_block,
            create_bottleneck_block,
        ),
    ),
    # Head configs.
    head: Callable = create_res_basic_head,
    head_pool: Callable = nn.AvgPool3d,
    head_pool_kernel_sizes: Tuple[Tuple[int]] = ((8, 7, 7), (32, 7, 7)),
    head_output_size: Tuple[int] = (1, 1, 1),
    head_activation: Callable = None,
    head_output_with_global_average: bool = True,
) -> nn.Module:
    """
    Build SlowFast model for video recognition, SlowFast model involves a Slow pathway,
    operating at low frame rate, to capture spatial semantics, and a Fast pathway,
    operating at high frame rate, to capture motion at fine temporal resolution. The
    Fast pathway can be made very lightweight by reducing its channel capacity, yet can
    learn useful temporal information for video recognition. Details can be found from
    the paper:

    Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
    "SlowFast networks for video recognition."
    https://arxiv.org/pdf/1812.03982.pdf

    ::

                             Slow Input  Fast Input
                                  ↓           ↓
                                 Stem       Stem
                                  ↓ ⭠ Fusion- ↓
                               Stage 1     Stage 1
                                  ↓ ⭠ Fusion- ↓
                                  .           .
                                  ↓           ↓
                               Stage N     Stage N
                                  ↓ ⭠ Fusion- ↓
                                         ↓
                                       Head

    Args:
        slowfast_channel_reduction_ratio (int): Corresponds to the inverse of the channel
            reduction ratio, $\beta$ between the Slow and Fast pathways.
        slowfast_conv_channel_fusion_ratio (int): Ratio of channel dimensions
            between the Slow and Fast pathways.
        DEPRECATED slowfast_fusion_conv_kernel_size (tuple): the convolutional kernel
            size used for fusion.
        DEPRECATED slowfast_fusion_conv_stride (tuple): the convolutional stride size
            used for fusion.
        fusion_builder (Callable[[int, int], nn.Module]): Builder function for generating
            the fusion modules based on stage dimension and index

        input_channels (tuple): number of channels for the input video clip.

        model_depth (int): the depth of the resnet.
        model_num_class (int): the number of classes for the video dataset.
        dropout_rate (float): dropout rate.

        norm (callable): a callable that constructs normalization layer.

        activation (callable): a callable that constructs activation layer.

        stem_function (Tuple[Callable]): a callable that constructs stem layer.
            Examples include create_res_basic_stem. Indexed by pathway
        stem_dim_outs (tuple): output channel size to stem.
        stem_conv_kernel_sizes (tuple): convolutional kernel size(s) of stem.
        stem_conv_strides (tuple): convolutional stride size(s) of stem.
        stem_pool (Tuple[Callable]): a callable that constructs resnet head pooling layer.
            Indexed by pathway
        stem_pool_kernel_sizes (tuple): pooling kernel size(s).
        stem_pool_strides (tuple): pooling stride size(s).

        stage_conv_a_kernel_sizes (tuple): convolutional kernel size(s) for conv_a.
        stage_conv_b_kernel_sizes (tuple): convolutional kernel size(s) for conv_b.
        stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
            for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
        stage_conv_b_dilations (tuple): dilation for 3D convolution for conv_b.
        stage_spatial_strides (tuple): the spatial stride for each stage.
        stage_temporal_strides (tuple): the temporal stride for each stage.
        bottleneck (Tuple[Tuple[Callable]]): a callable that constructs bottleneck
            block layer. Examples include: create_bottleneck_block.
            Indexed by pathway and stage index

        head (callable): a callable that constructs the resnet-style head.
            Ex: create_res_basic_head
        head_pool (callable): a callable that constructs resnet head pooling layer.
        head_output_sizes (tuple): the size of output tensor for head.
        head_activation (callable): a callable that constructs activation layer.
        head_output_with_global_average (bool): if True, perform global averaging on
            the head output.
    Returns:
        (nn.Module): SlowFast model.
    """

    torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_slowfast")

    # Number of blocks for different stages given the model depth.
    _num_pathway = len(input_channels)
    assert (
        model_depth in _MODEL_STAGE_DEPTH.keys()
    ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
    stage_depths = _MODEL_STAGE_DEPTH[model_depth]

    # Fix up inputs
    if isinstance(slowfast_channel_reduction_ratio, int):
        slowfast_channel_reduction_ratio = (slowfast_channel_reduction_ratio,)
    if isinstance(stem_pool, Callable):
        stem_pool = (stem_pool,) * _num_pathway
    if isinstance(bottleneck, Callable):
        bottleneck = (bottleneck,) * len(stage_depths)
        bottleneck = (bottleneck,) * _num_pathway
    if fusion_builder is None:
        fusion_builder = FastToSlowFusionBuilder(
            slowfast_channel_reduction_ratio=slowfast_channel_reduction_ratio[0],
            conv_fusion_channel_ratio=slowfast_conv_channel_fusion_ratio,
            conv_kernel_size=slowfast_fusion_conv_kernel_size,
            conv_stride=slowfast_fusion_conv_stride,
            norm=norm,
            activation=activation,
            max_stage_idx=len(stage_depths) - 1,
        ).create_module

    # Build stem blocks.
    stems = []
    for pathway_idx in range(_num_pathway):
        stems.append(
            stem_function[pathway_idx](
                in_channels=input_channels[pathway_idx],
                out_channels=stem_dim_outs[pathway_idx],
                conv_kernel_size=stem_conv_kernel_sizes[pathway_idx],
                conv_stride=stem_conv_strides[pathway_idx],
                conv_padding=[
                    size // 2 for size in stem_conv_kernel_sizes[pathway_idx]
                ],
                pool=stem_pool[pathway_idx],
                pool_kernel_size=stem_pool_kernel_sizes[pathway_idx],
                pool_stride=stem_pool_strides[pathway_idx],
                pool_padding=[
                    size // 2 for size in stem_pool_kernel_sizes[pathway_idx]
                ],
                norm=norm,
                activation=activation,
            )
        )

    stages = []
    stages.append(
        MultiPathWayWithFuse(
            multipathway_blocks=nn.ModuleList(stems),
            multipathway_fusion=fusion_builder(
                fusion_dim_in=stem_dim_outs[0],
                stage_idx=0,
            ),
        )
    )

    # Build stages blocks.
    stage_dim_in = stem_dim_outs[0]
    stage_dim_out = stage_dim_in * 4
    for idx in range(len(stage_depths)):
        pathway_stage_dim_in = [
            stage_dim_in
            + stage_dim_in
            * slowfast_conv_channel_fusion_ratio
            // slowfast_channel_reduction_ratio[0],
        ]
        pathway_stage_dim_inner = [
            stage_dim_out // 4,
        ]
        pathway_stage_dim_out = [
            stage_dim_out,
        ]
        for reduction_ratio in slowfast_channel_reduction_ratio:
            pathway_stage_dim_in = pathway_stage_dim_in + [
                stage_dim_in // reduction_ratio
            ]
            pathway_stage_dim_inner = pathway_stage_dim_inner + [
                stage_dim_out // 4 // reduction_ratio
            ]
            pathway_stage_dim_out = pathway_stage_dim_out + [
                stage_dim_out // reduction_ratio
            ]

        stage = []
        for pathway_idx in range(_num_pathway):
            depth = stage_depths[idx]

            stage_conv_a_kernel = stage_conv_a_kernel_sizes[pathway_idx][idx]
            stage_conv_a_stride = (stage_temporal_strides[pathway_idx][idx], 1, 1)
            stage_conv_a_padding = (
                [size // 2 for size in stage_conv_a_kernel]
                if isinstance(stage_conv_a_kernel[0], int)
                else [[size // 2 for size in sizes] for sizes in stage_conv_a_kernel]
            )

            stage_conv_b_stride = (
                1,
                stage_spatial_strides[pathway_idx][idx],
                stage_spatial_strides[pathway_idx][idx],
            )
            stage.append(
                create_res_stage(
                    depth=depth,
                    dim_in=pathway_stage_dim_in[pathway_idx],
                    dim_inner=pathway_stage_dim_inner[pathway_idx],
                    dim_out=pathway_stage_dim_out[pathway_idx],
                    bottleneck=bottleneck[pathway_idx][idx],
                    conv_a_kernel_size=stage_conv_a_kernel,
                    conv_a_stride=stage_conv_a_stride,
                    conv_a_padding=stage_conv_a_padding,
                    conv_b_kernel_size=stage_conv_b_kernel_sizes[pathway_idx][idx],
                    conv_b_stride=stage_conv_b_stride,
                    conv_b_padding=(
                        stage_conv_b_kernel_sizes[pathway_idx][idx][0] // 2,
                        stage_conv_b_dilations[pathway_idx][idx][1]
                        if stage_conv_b_dilations[pathway_idx][idx][1] > 1
                        else stage_conv_b_kernel_sizes[pathway_idx][idx][1] // 2,
                        stage_conv_b_dilations[pathway_idx][idx][2]
                        if stage_conv_b_dilations[pathway_idx][idx][2] > 1
                        else stage_conv_b_kernel_sizes[pathway_idx][idx][2] // 2,
                    ),
                    conv_b_num_groups=stage_conv_b_num_groups[pathway_idx][idx],
                    conv_b_dilation=stage_conv_b_dilations[pathway_idx][idx],
                    norm=norm,
                    activation=activation,
                )
            )
        stages.append(
            MultiPathWayWithFuse(
                multipathway_blocks=nn.ModuleList(stage),
                multipathway_fusion=fusion_builder(
                    fusion_dim_in=stage_dim_out,
                    stage_idx=idx + 1,
                ),
            )
        )
        stage_dim_in = stage_dim_out
        stage_dim_out = stage_dim_out * 2

    if head_pool is None:
        pool_model = None
    elif head_pool == nn.AdaptiveAvgPool3d:
        pool_model = [head_pool(head_output_size[idx]) for idx in range(_num_pathway)]
    elif head_pool == nn.AvgPool3d:
        pool_model = [
            head_pool(
                kernel_size=head_pool_kernel_sizes[idx],
                stride=(1, 1, 1),
                padding=(0, 0, 0),
            )
            for idx in range(_num_pathway)
        ]
    else:
        raise NotImplementedError(f"Unsupported pool_model type {pool_model}")

    stages.append(PoolConcatPathway(retain_list=False, pool=nn.ModuleList(pool_model)))
    head_in_features = stage_dim_in
    for reduction_ratio in slowfast_channel_reduction_ratio:
        head_in_features = head_in_features + stage_dim_in // reduction_ratio
    if head is not None:
        stages.append(
            head(
                in_features=head_in_features,
                out_features=model_num_class,
                pool=None,
                output_size=head_output_size,
                dropout_rate=dropout_rate,
                activation=head_activation,
                output_with_global_average=head_output_with_global_average,
            )
        )
    return Net(blocks=nn.ModuleList(stages))


# TODO: move to pytorchvideo/layer once we have a common.py
class PoolConcatPathway(nn.Module):
    """
    Given a list of tensors, perform optional spatio-temporal pool and concatenate the
        tensors along the channel dimension.
    """

    def __init__(
        self,
        retain_list: bool = False,
        pool: Optional[nn.ModuleList] = None,
        dim: int = 1,
    ) -> None:
        """
        Args:
            retain_list (bool): if True, return the concatenated tensor in a list.
            pool (nn.module_list): if not None, list of pooling models for different
                pathway before performing concatenation.
            dim (int): dimension to performance concatenation.
        """
        super().__init__()
        set_attributes(self, locals())

    def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
        if self.pool is not None:
            assert len(x) == len(self.pool)
        output = []
        for ind in range(len(x)):
            if x[ind] is not None:
                if self.pool is not None and self.pool[ind] is not None:
                    x[ind] = self.pool[ind](x[ind])
                output.append(x[ind])
        if self.retain_list:
            return [torch.cat(output, 1)]
        else:
            return torch.cat(output, 1)


class FastToSlowFusionBuilder:
    def __init__(
        self,
        slowfast_channel_reduction_ratio: int,
        conv_fusion_channel_ratio: float,
        conv_kernel_size: Tuple[int],
        conv_stride: Tuple[int],
        norm: Callable = nn.BatchNorm3d,
        norm_eps: float = 1e-5,
        norm_momentum: float = 0.1,
        activation: Callable = nn.ReLU,
        max_stage_idx: int = 3,
    ) -> None:
        """
        Given a list of two tensors from Slow pathway and Fast pathway, fusion information
        from the Fast pathway to the Slow on through a convolution followed by a
        concatenation, then return the fused list of tensors from Slow and Fast pathway in
        order.
        Args:
            slowfast_channel_reduction_ratio (int): Reduction ratio from the stage dimension.
                Used to compute conv_dim_in = fusion_dim_in // slowfast_channel_reduction_ratio
            conv_fusion_channel_ratio (int): channel ratio for the convolution used to fuse
                from Fast pathway to Slow pathway.
            conv_kernel_size (int): kernel size of the convolution used to fuse from Fast
                pathway to Slow pathway.
            conv_stride (int): stride size of the convolution used to fuse from Fast pathway
                to Slow pathway.
            norm (callable): a callable that constructs normalization layer, examples
                include nn.BatchNorm3d, None (not performing normalization).
            norm_eps (float): normalization epsilon.
            norm_momentum (float): normalization momentum.
            activation (callable): a callable that constructs activation layer, examples
                include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
                activation).
            max_stage_idx (int): Returns identity module if we exceed this
        """
        set_attributes(self, locals())

    def create_module(self, fusion_dim_in: int, stage_idx: int) -> nn.Module:
        """
        Creates the module for the given stage
        Args:
            fusion_dim_in (int): input stage dimension
            stage_idx (int): which stage this is
        """
        if stage_idx > self.max_stage_idx:
            return nn.Identity()

        conv_dim_in = fusion_dim_in // self.slowfast_channel_reduction_ratio
        conv_fast_to_slow = nn.Conv3d(
            conv_dim_in,
            int(conv_dim_in * self.conv_fusion_channel_ratio),
            kernel_size=self.conv_kernel_size,
            stride=self.conv_stride,
            padding=[k_size // 2 for k_size in self.conv_kernel_size],
            bias=False,
        )
        norm_module = (
            None
            if self.norm is None
            else self.norm(
                num_features=conv_dim_in * self.conv_fusion_channel_ratio,
                eps=self.norm_eps,
                momentum=self.norm_momentum,
            )
        )
        activation_module = None if self.activation is None else self.activation()
        return FuseFastToSlow(
            conv_fast_to_slow=conv_fast_to_slow,
            norm=norm_module,
            activation=activation_module,
        )


class FuseFastToSlow(nn.Module):
    """
    Given a list of two tensors from Slow pathway and Fast pathway, fusion information
    from the Fast pathway to the Slow on through a convolution followed by a
    concatenation, then return the fused list of tensors from Slow and Fast pathway in
    order.
    """

    def __init__(
        self,
        conv_fast_to_slow: nn.Module,
        norm: Optional[nn.Module] = None,
        activation: Optional[nn.Module] = None,
    ) -> None:
        """
        Args:
            conv_fast_to_slow (nn.module): convolution to perform fusion.
            norm (nn.module): normalization module.
            activation (torch.nn.modules): activation module.
        """
        super().__init__()
        set_attributes(self, locals())

    def forward(self, x):
        x_s = x[0]
        x_f = x[1]
        fuse = self.conv_fast_to_slow(x_f)
        if self.norm is not None:
            fuse = self.norm(fuse)
        if self.activation is not None:
            fuse = self.activation(fuse)
        x_s_fuse = torch.cat([x_s, fuse], 1)
        return [x_s_fuse, x_f]


if __name__ == "__main__":
    num_classes = 5
    slow_view = torch.autograd.Variable(torch.rand(1, 3, 4, 224, 224))
    fast_view = torch.autograd.Variable(torch.rand(1, 3, 32, 224, 224))

    model = create_slowfast(model_depth=18)
    print(model)
    output = model([slow_view, fast_view])
    print(output.size())

  1. The exact command(s) you ran:
  2. What you observed (including the full logs):
  File "/home/user/skip/code/DSFNet_MTICI/play2.py", line 1476, in forward
    x_s_fuse = torch.cat([x_s, fuse], 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 8 for tensor number 1 in the list.

Please also simplify the steps as much as possible so they do not require additional resources to
run, such as a private dataset, models, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant