开发者

pytorch之torch.flatten()和torch.nn.Flatten()的用法

目录
  • torch.flatten()和torch.nn.Flatten()的用法
  • 下面举例说明
  • 总结

torch.flatten()和torch.nn.Flatten()的用法

flatten()函数的作用是将tensor铺平成一维

torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor
  • input (Tensor) – the input tensor.
  • start_dim (int) – the first dim to flatten
  • end_dim (int) – the last dim to flatten

start_dim和end_dim构成了整个你要选择铺平的维度范围

下面举例说明

x = torch.tensor([[1,2], [3,4], [5,6]])
x = x.flatten(0)
x
------------------------
tensor([1, 2, 3, 4, 5, 6])

对于图片数据,我们往往期望进入fc层的维度为(channels, N)这样

x = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
x = x.flatten(1)
x
-------------------------
tensor([[1, 2],
        [3, 4],
        [5, 6]])

注:

torch.nn.Flatten(start_dim=1, end_dim=- 1)

start_dim 默认为 1

所以在构建网络时,下面两种是等价的

class Classifier(nn.Module):
    def __ipythonnit__(self):
        super(Classifier, self).__init__()
        # The arguments for commonly used modules:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
        # torch.nn.MaxPool2d(kernel_size, stride=None, padding=0)

        # input image size: [3, 128, 128]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BATchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
编程客栈            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),
        )
        self.fc_layers = nn.Sequential(
         编程   nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            phpnn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 128, 128]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

 编程       # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),

            nn.Flatten(),

            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
       
        x = self.layers(x)

        return x

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程客栈(www.devze.com)。

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新开发

开发排行榜