开发者

PyTorch中nn.Module使用示例指南

目录
  • 一、什么是nn.Module
  • 二、基础用法
    • 2.1 自定义模型类
    • 2.2 实例化与调用
      • 3.1__init__()
      • 3.2forward()
  • 三、构造方法详解
    • 四、常见模块层
      • 五、模型嵌套结构(子模块)
        • 六、内置方法和属性
          • 七、使用nn.Sequential
            • 八、实战完整示例:MNIST 分类网络
              • 九、常见陷阱和建议
                • 十、总结
                  • 十一、综合示例
                    • 1、ResNet18 简洁实现(适合图像分类)
                      • 2、UNet(适合图像分割)
                        • 3、简化版 Transformer 编码器(适合序列建模)
                          • 4、 总结对比

                          在 PyTorch 中,nn.Module 是神经网络中最核心的基类,用于构建所有模型。理解并熟练使用 nn.Module 是掌握 PyTorch 的关键。

                          一、什么是nn.Module

                          nn.Module 是 PyTorch 中所有神经网络模块的基类。可以把它看作是“神经网络的容器”,它封装了以下几件事:

                          1. 网络层(如 Linear、Conv2d 等)
                          2. 前向传播逻辑(forward 函数)
                          3. 模型参数(自动注册并可训练)
                          4. 可嵌套(可以包含多个子模块)
                          5. 便捷的模型保存 / 加载等工具函数

                          二、基础用法

                          2.1 自定义模型类

                          import torch
                          import torch.nn as nn
                          class MyNet(nn.Module):
                              def __init__(self):
                                  super().__init__()
                                  self.fc1 = nn.Linear(784, 128)
                                  self.relu = nn.ReLU()
                                  self.fc2 = nn.Linear(128, 10)
                              def forward(self, x):
                                  x = self.fc1(x)
                                  x = self.relu(x)
                                  x = self.fc2(x)
                                  return x

                          2.2 实例化与调用

                          model = MyNet()
                          x = torch.randn(32, 784)     # BATch_size = 32
                          output = model(x)            # 自动调用 forward

                          三、构造方法详解

                          3.1__init__()

                          • 定义子模块、层等结构。
                          • 例如 self.conv1 = nn.Conv2d(...) 会被自动注册为模型参数。

                          3.2forward()

                          • 定义前向传播逻辑。
                          • 不能手动调用,应使用 model(x) 形式。

                          四、常见模块层

                          模块名作用示例
                          nn.Linear全连接层nn.Linear(128, 64)
                          nn.Conv2d卷积层nn.Conv2d(3, 16, 3)
                          nn.ReLU激活函数nn.ReLU()
                          nn.Sigmoid激活函数nn.Sigmoid()
                          nn.BatchNorm2d批归一化nn.BatchNorm2d(16)
                          nn.DropoutDropout 层nn.Dropout(0.5)
                          nn.LSTMLSTM 层nn.LSTM(10, 20)
                          nn.Sequential层的顺序容器见下文说明

                          五、模型嵌套结构(子模块)

                          你可以将一个 nn.Module 作为另一个模块的子模块嵌套:

                          class block(nn.Module):
                              def __init__(self):
                                  super().__init__()
                                  self.layer = nn.Sequential(
                                      nn.Linear(64, 64),
                                      nn.ReLU()
                                  )
                              def forward(self, x):
                                  return self.layer(x)
                          class Net(nn.Module):
                              def __init__(self):
                                  super().__init__()
                                  self.block1 = Block()
                                  self.block2 = Block()
                                  self.output = nn.Linear(64, 10)
                              def forward(self, x):
                                  x = self.block1(x)
                                  x = self.block2(x)
                                  return self.output(x)

                          六、内置方法和属性

                          方法 / 属性说明
                          model.parameters()返回所有可训练参数(用于优化器)
                          model.named_parameters()返回带名字的参数迭代器
                          model.children()返回子模块迭代器
                          model.eval()设置为评估模式(Dropout、BN失效)
                          model.train()设置为训练模式
                          model.to(device)将模型转移到 GPU/CPU
                          model.state_dict()获取模型参数字典(保存)
                          model.load_state_dict()加载模型参数字典

                          七、使用nn.Sequential

                          nn.Sequential 是一个顺序容器,可以用来简化网络结构定义:

                          model = nn.Sequential(
                              nn.Linear(784, 128),
                              nn.ReLU(),
                              nn.Linear(128, 10)
                          )

                          等价于手写的自定义 nn.Module。适合前向传播是线性“流动”的结构。

                          八、实战完整示例:MNIST 分类网络

                          class MNISTNet(nn.Module):
                              def __init__(self):
                                  super().__init__()
                                  self.net = nn.Sequential(
                                      nn.Flatten(),
                                      nn.Linear(28*28, 256),
                                      nn.ReLU(),
                                      nn.Linear(256, 10)
                                  )
                              def forward(self, x):
                                  return self.net(x)
                          # 实例化模型
                          model = MNISTNet()
                          print(model)
                          # 配置训练
                          criterion = nn.CrossEntropyLoss()
                          optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
                          # 示例训练循环
                          for epoch in range(10):
                              for images, labels in train_loader:
                                  output = model(images)
                                  loss = criterion(output, labels)
                                  optimizer.zero_grad()
                                  loss.backward()
                                  optimizer.step()

                          九、常见陷阱和建议

                          问题说明
                          forward() 不起作用应该使用 model(x),而不是手动调用 model.forward(x)
                          忘记 super().__init__()子模块将不会被注册
                          参数未注册层/模块必须赋值为 self.xxx = ...
                          训练/测试模式混淆注意 model.eval()model.train()

                          十、总结

                          项目说明
                          __init__()定义模型结构(子模块、层)
                          forward()定义前向传播
                          自动注册参数所有 self.xxx = nn.XXX(...) 都会被追踪
                          嵌套模块支持递归子模块调用
                          便捷方法.parameters().to().eval()

                          十一、综合示例

                          以下是基于 PyTorch nn.Module 封装的三种经典深度学习架构(ResNet18UNetTransformer)的简洁而完整的实现,适合初学者快速上手。

                          1、ResNet18 简洁实现(适合图像分类)

                          import torch
                          import torch.nn as nn
                          import torch.nn.functional as F
                          class BasicBlock(nn.Module):
                              expansion = 1
                              def __init__(self, in_planes, planes, stride=1, downsample=None):
                                  super().__init__()
                                  self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
                                  self.bn1   = nn.BatchNorm2d(planes)
                                  self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
                                  self.bn2   = nn.BatchNorm2d(planes)
                                  self.downsample = downsample
                              def forward(self, x):
                                  identity = x
                                  if self.downsample:
                                      identity = self.downsample(x)
                                  out = F.relu(self.bn1(self.conv1(x)))
                                  out = self.bn2(self.conv2(out))
                                  out += identity
                                  return F.relu(out)
                          class ResNet(nn.Module):
                              def __init__(self, block, layers, num_classes=1000):
                                  super().__init__()
                                  self.in_planes = 64
                                  self.conv1 = nn.Conv2d(3, 64, kernel_s编程客栈ize=7, stride=2, padding=3, bias=False)
                                  self.bn1   = nn.BatchNorm2d(64)
                                  self.pool  = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                  self.layer1 = self._make_layer(block, 64,  layers[0])
                                  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
                                  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
                                  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
                                  self.avgpool = nnphp.AdaptiveAvgPool2d((1, 1))
                                  self.fc      = nn.Linear(512 * block.expansion, num_classes)
                              def _make_layer(self, block, planes, blocks, stride=1):
                                  downsample = None
                                  if stride != 1 or self.in_planes != planes * block.expansion:
                                      downsample = nn.Sequential(
                                    python      nn.Conv2d(self.in_planes, planes * block.expansion,
                                                    kernel_size=1, stride=stride, bias=False),
                                          nn.BatchNorm2d(planes * block.expansion)
                                      )
                                  layers = [block(self.in_planes, planes, stride, downsample)]
                                  spythonelf.in_planes = planes * block.expansion
                                  for _ in range(1, blocks):
                                      layers.append(block(self.in_planes, planes))
                                  return nn.Sequential(*layers)
                              def forward(self, x):
                                  x = self.pool(F.relu(self.bn1(self.conv1(x))))
                                  x = self.layer1(x)
                                  x = self.layer2(x)
                                  x = self.layer3(x)
                                  x = self.layer4(x)
                                  x = self.avgpool(x).flatten(1)
                                  return self.fc(x)
                          def ResNet18(num_classes=1000):
                              return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

                          2、UNet(适合图像分割)

                          class UNetBlock(nn.Module):
                              def __init__(self, in_ch, out_ch):
                                  super().__init__()
                                  self.block = nn.Sequential(
                                      nn.Conv2d(in_ch, out_ch, 3, padding=1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(out_ch, out_ch, 3, padding=1),
                                      nn.ReLU(inplace=True)
                                  )
                              def forward(self, x):
                                  return self.block(x)
                          class UNet(nn.Module):
                              def __init__(self, in_channels=1, out_channels=1):
                                  super().__init__()
                                  self.enc1 = UNetBlock(in_channels, 64)
                                  self.enc2 = UNetBlock(64, 128)
                                  self.enc3 = UNetBlock(128, 256)
                                  self.enc4 = UNetBlock(256, 512)
                                  self.pool = nn.MaxPool2d(2)
                                  self.bottleneck = UNetBlock(512, 1024)
                                  self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
                                  self.dec4 = UNetBlock(1024, 512)
                                  self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
                                  self.dec3 = UNetBlock(512, 256)
                                  self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
                                  self.dec2 = UNetBlock(256, 128)
                                  self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
                                  self.dec1 = UNetBlock(128, 64)
                                  self.final = nn.Conv2d(64, out_channels, kernel_size=1)
                              def forward(self, x):
                                  e1 = self.enc1(x)
                                  e2 = self.enc2(self.pool(e1))
                                  e3 = self.enc3(self.pool(e2))
                                  e4 = self.enc4(self.pool(e3))
                                  b  = self.bottleneck(self.pool(e4))
                                  d4 = self.upconv4(b)
                                  d4 = self.dec4(torch.cat([d4, e4], dim=1))
                                  d3 = self.upconv3(d4)
                                  d3 = self.dec3(torch.cat([d3, e3], dim=1))
                                  d2 = self.upconv2(d3)
                                  d2 = self.dec2(torch.cat([d2, e2], dim=1))
                                  d1 = self.upconv1(d2)
                                  d1 = self.dec1(torch.cat([d1, e1], dim=1))
                                  return self.final(d1)

                          3、简化版 Transformer 编码器(适合序列建模)

                          class TransformerBlock(nn.Module):
                              def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1):
                                  super().__init__()
                                  self.attn = nn.MultiheadAttehttp://www.devze.comntion(embed_dim, heads, dropout=dropout, batch_first=True)
                                  self.ff = nn.Sequential(
                                      nn.Linear(embed_dim, ff_hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(ff_hidden_dim, embed_dim)
                                  )
                                  self.norm1 = nn.LayerNorm(embed_dim)
                                  self.norm2 = nn.LayerNorm(embed_dim)
                                  self.dropout = nn.Dropout(dropout)
                              def forward(self, x, mask=None):
                                  attn_out, _ = self.attn(x, x, x, attn_mask=mask)
                                  x = self.norm1(x + self.dropout(attn_out))
                                  ff_out = self.ff(x)
                                  x = self.norm2(x + self.dropout(ff_out))
                                  return x
                          class TransformerEncoder(nn.Module):
                              def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512):
                                  super().__init__()
                                  self.embedding = nn.Embedding(vocab_size, embed_dim)
                                  self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim)
                                  self.layers = nn.ModuleList([
                                      TransformerBlock(embed_dim, n_heads, ff_dim)
                                      for _ in range(num_layers)
                                  ])
                                  self.dropout = nn.Dropout(0.1)
                              def _generate_positional_encoding(self, max_len, d_model):
                                  pos = torch.arange(0, max_len).unsqueeze(1)
                                  i = torch.arange(0, d_model, 2)
                                  angle_rates = 1 / torch.pow(10000, (i / d_model))
                                  pos_enc = torch.zeros(max_len, d_model)
                                  pos_enc[:, 0::2] = torch.sin(pos * angle_rates)
                                  pos_enc[:, 1::2] = torch.cos(pos * angle_rates)
                                  return pos_enc.unsqueeze(0)
                              def forward(self, x):
                                  B, T = x.shape
                                  x = self.embedding(x) + self.pos_encoding[:, :T].to(x.device)
                                  x = self.dropout(x)
                                  for layer in self.layers:
                                      x = layer(x)
                                  return x

                          4、 总结对比

                          模型类型场景特点
                          ResNet18图像分类深残差网络结构,适合迁移学习
                          UNet图像分割对称结构,编码 + 解码 + skip
                          TransformerNLP / 序列建模全注意力机制,无卷积无循环

                          到此这篇关于PyTorch中nn.Module详解和综合代码示例的文章就介绍到这了,更多相关PyTorch nn.Module内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!

                          0

                          上一篇:

                          下一篇:

                          精彩评论

                          暂无评论...
                          验证码 换一张
                          取 消

                          最新开发

                          开发排行榜