开发者

Pytorch CNN, incompatible tensor shapes

Here is my pytorch CNN net. The input tensor is torch.Size([4, 1, 128, 128]) that represent images in batch size of 4:

class My_Net(nn.Module):

    def __init__(self, image_length):

        self.image_length = image_length

        # Creating the layers here (convolutional, pooling, and linear layers)

        super(My_Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(5, 5), padding='same')
        self.pool1 = nn.MaxPool2d(kernel_size=(10, 10))

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), padding='same')
        self.pool2 = nn.MaxPool2d(kernel_size=(8, 8))

        self.lin1 = nn.Linear(128, 50)

        self.lin2 = nn.Linear(50, 9)

    def forward(self, x):

        x = self.pool1(F.relu(self.conv1(x)))  # first convolution and pooling step with relu activation

        x = self.pool2(F.relu(self.conv2(x)))  # second convolution and pooling step with relu activation

        print(x.size()) #added this to see the tensor demensions before passing into the view and linear layers

        x = x.view((128 * 1 * 1, 4))  # second reshape

        x = F.relu(self.lin1(x))  # relu activation function on the first linear layer

        x = F.relu(self.lin2(x))  # we want only positive values so relu works best here

        return x

I'm getting an error in the forward pass that I am having a hard time fixing. I think it comes from a lack of understanding of how the dimensions are changing. The error is as follows:

line 51, in forward x = F.rel开发者_Go百科u(self.lin1(x))

line 1102, in _call_impl return forward_call(*input, **kwargs)

line 103, in forward return F.linear(input, self.weight, self.bias)

line 1848, in linear return torch._C._nn.linear(input, weight, bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x4 and 128x50)

My main objective here is to send the images through 2 convolutional and pooling layers, then 2 linear layers ultimately reducing down to a 9 neuron linear layer that will be the output and compared via MSE to a set of 9 growth conditions for a given image. The output of the x.size() call in forward is torch.Size([4, 128, 1, 1])


Pytorch linear layers work with single instances and batched data very easily. If you are passing batched data, ensure that your data is shaped as (batch_size, *) where * is the dimension of your data. Your data should thus be of shape (4, 128) before passing it into your first linear layer. This you can do by reshaping, using x = x.view((4, 128)).

Additionally, it looks like the My_Net class is missing a call to the super() method in its init method. You should add this call so that the parent nn.Module class's init method is also called. This will properly initialize the nn.Module class and allow your network to function properly. Your init method should look like this:

def __init__(self, image_length):
    super(My_Net, self).__init__()  # Add this line
   
    # The rest of your net


class My_Net(nn.Module):

def __init__(self, image_length):

    self.image_length = image_length

    # Creating the layers here (convolutional, pooling, and linear layers)

    super(My_Net, self).__init__()

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(5, 5), padding='same')
    self.pool1 = nn.MaxPool2d(kernel_size=(10, 10))

    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), padding='same')
    self.pool2 = nn.MaxPool2d(kernel_size=(8, 8))

    self.lin1 = nn.Linear(128, 50)

    self.lin2 = nn.Linear(50, 9)

def forward(self, x):

    x = self.pool1(F.relu(self.conv1(x)))  # first convolution and pooling step with relu activation

    x = self.pool2(F.relu(self.conv2(x)))  # second convolution and pooling step with relu activation

    print(x.size()) #added this to see the tensor demensions before passing into the view and linear layers

    x = x.view((128 * 1 * 1, 4))  # second reshape

    x = F.relu(self.lin1(x))  # relu activation function on the first linear layer

    x = F.relu(self.lin2(x))  # we want only positive values so relu works best here
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜