Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change #8

Open
cuteboyqq opened this issue Aug 4, 2024 · 1 comment
Open

change #8

cuteboyqq opened this issue Aug 4, 2024 · 1 comment

Comments

@cuteboyqq
Copy link

when l set --sample_duration 32 , it will have below error, do you know how to solve ?
image

@cuteboyqq
Copy link
Author

cuteboyqq commented Aug 4, 2024

l remove the validation when training model, so it can train model , but just no vaidate...
l modify your model network, and it also can get high accuracy, because l can not understand the logic of forward the resnet in CNNLSTM model network

import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torchvision.models import resnet101

class CNNLSTM(nn.Module):
    def __init__(self, num_classes=2):
        super(CNNLSTM, self).__init__()
        self.resnet = resnet101(pretrained=True)
        self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
        self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3, batch_first=True)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)
       
    def forward(self, x_3d):
        batch_size, seq_len, c, h, w = x_3d.size()
        
        # Reshape the input to process all frames at once
        x = x_3d.view(-1, c, h, w)  # (batch_size * seq_len, c, h, w)
        
        # Forward pass through ResNet
        features = self.resnet(x)  # (batch_size * seq_len, 300)
        
        # Reshape features back to (batch_size, seq_len, 300)
        features = features.view(batch_size, seq_len, -1)  
        
        # LSTM forward pass
        out, _ = self.lstm(features)  # (batch_size, seq_len, 256)
        
        # Use the output from the last time step
        x = out[:, -1, :]  # (batch_size, 256)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant