PyTorch中torch.nn模块的实现
目录
- 1. nn.Module
- 2. Layers(层)
- 2.1 nn.Linear(全连接层)
- 2.2 nn.Conv2d(二维卷积层)
- 2.3 nn.MaxPool2d(二维最大池化层)
- 3. Loss Functions(损失函数)
- 3.1 nn.MSELoss(均方误差损失)
- 3.2 nn.CrossEntropyLoss(交叉熵损失)
- 4. Optimizers(优化器)
- 4.1 torch.optim.SGD(随机梯度下降)
- 4.2 torch.optim.Adam(自适应矩估计)
- 5. Activation Functions(激活函数)
- 5.1 nn.ReLU(修正线性单元)
- 6. Normalization Layers(归一化层)
- 6.1 nn.BATchNorm2d(二维批量归一化)
- 7. Dropout Layers(丢弃层)
- 7.1 nn.Dropout
- 8. Container Modules(容器模块)
- 8.1 nn.Sequential(顺序容器)
- 8.2 nn.ModuleList(模块列表)
- 9. Functional API (torch.nn.functional)
- 9.1 F.relu(ReLU 激活函数)
- 9.2 F.cross_entropy(交叉熵损失函数)
- 9.3 F.conv2d(二维卷积)
- 10. Parameter (torch.nn.Parameter)
- 示例代码:
- 综合示例
torch.nn
是 PyTorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:
1. nn.Module
原理和要点:nn.Module
是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.Module
,并实现其 forward
方法。
使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。
主要 API 和使用场景:
__init__
: 初始化模型参数。forward
: 定义前向传播逻辑。parameters
: 返回模型的所有参数。
import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) model = MyModel() print(model)
2. Layers(层)
- 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
- 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。
2.1 nn.Linear(全连接层)
linear = nn.Linear(10, 5) input = torch.randn(1, 10) output = linear(input) print(output)
2.2 nn.Conv2d(二维卷积层)
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3) input = torch.randn(1, 1, 5, 5) output = conv(input) print(output)
2.3 nn.MaxPool2d(二维最大池化层)
maxpool = nn.MaxPool2d(kernel_size=2) input = torch.randn(1, 1, 4, 4) output = maxpool(input) print(output)
3. Loss Functions(损失函数)
- 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
- 使用场景:用于计算训练过程中需要最小化的误差。
3.1 nn.MSELoss(均方误差损失)
mse_loss = nn.MSELoss() input = torch.randn(3, 5) target = torch.randn(3, 5) loss = mse_loss(input, target) print(loss)
3.2 nn.CrossEntropyLoss(交叉熵损失)
cross_entropy_loss = nn.CrossEntropyLoss() input = torch.randn(3, 5) target = torch.tensor([1, 0, 4]) loss = cross_entropy_loss(input, target) print(loss)
4. Optimizers(优化器)
- 原理和要点:优化器用于调整模型参数,以最小化损失函数。
- 使用场景:用于训练模型,通过反向传播更新参数。
4.1 torch.optim.SGD(随机梯度下降)
import torch.optixGcaccXEDm as optim model = MyModel() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.MSELoss() # Trandroidaining loop for epoch in range(100): optimizer.zero_grad() output = model(torch.randn(1, 10)) loss = criterion(output, torch.randn(1, 1)) loss.backward() optimizer.step()
4.2 torch.optim.Adam(自适应矩估计)
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop for epoch in range(100): optimizer.zero_grad() output = model(torch.randn(1, 10)) loss = criterion(output, torch.randn(1, 1)) loss.backward() optimizer.step()
5. Activation Functions(激活函数)
- 原理和要点:激活函数引入非线性,使模型能够拟合复杂的函数。
- 使用场景:用于激活输入,增加模型表达能力。
5.1 nn.ReLU(修正线性单元)
relu = nn.ReLU() input = torch.randn(2) output = relu(input) print(output)
6. Normalization Layers(归一化层)
- 原理和要点:归一化层用于标准化输入,改善训练的稳定性和速度。
- 使用场景:用于标准化激活值,防止梯度爆炸或消失。
6.1 nn.BatchNorm2d(二维批量归一化)
batch_norm = nn.BatchNorm2d(3) input = torch.randn(1, 3, 5, 5) output = batch_norm(input) print(output)
7. Dropout Layers(丢弃层)
- 原理和要点:Dropout 层通过在训练过程中随机丢弃一部分神经元来防止过拟合。
- 使用场景:用于防止模型过拟合,增加模型的泛化能力。
7.1 nn.Dropout
dropout = nn.Dropout(p=0.5) input = torch.randn(2, 3) output = dropout(input) print(output)
8. Container Modules(容器模块)
- 原理和要点:容器模块用于组合多个层,构建复杂的神经网络结构。
- 使用场景:用于组合多个层,形成更复杂的网络结构。
8.1 nn.Sequential(顺序android容器)
model = nn.jsSequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5) ) input = torch.randn(1, 10) output = model(input) print(output)
8.2 nn.ModuleList(模块列表)
layers = nn.ModuleList([ nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5) ]) input = torch.randn(1, 10) for layer in layers: input = layer(input) print(input)
9. Functional API (torch.nn.functional)
- 原理和要点:包含大量用于深度学习的无状态函数,这些函数通常是操作层的底层实现。
- 使用场景:用于在前向传播中灵活调用函数。
9.1 F.relu(ReLU 激活函数)
import torch.nn.functional as F input = torch.randn(2) output = F.relu(input) print(output)
9.2 F.cross_entropy(交叉熵损失函数)
input = torch.randn(3, 5) target = torch.tensor([1, 0, 4]) loss = F.cross_entropy(input, target) print(loss)
9.3 F.conv2d(二维卷积)
input = torch.randn(1, 1, 5, 5) weight = torch.randn(3, 1, 3, 3) # Manually defined weights output = F.conv2d(input, weight) print(output)
10. Parameter (torch.nn.Parameter)
- 原理和要点:
torch.nn.Parameter
&编程客栈nbsp;是torch.Tensor
的一种特殊子类,用于表示模型的可学习参数。它们在nn.Module
中会自动注册为参数。 - 使用场景:用于定义模型中的可学习参数。
示例代码:
class MyModelWithParam(nn.Module): def __init__(self): super(MyModelWithParam, self).__init__() self.my_param = nn.Parameter(torch.randn(10, 10)) def forward(self, x): return x @ self.my_param model = MyModelWithParam() input = torch.randn(1, 10) output = model(input) print(output) # 查看模型参数 for name, param in model.named_parameters(): print(name, param.size())
综合示例
下面是一个结合上述各个部分的综合示例:
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class MyComplexModel(nn.Module): def __init__(self): super(MyComplexModel, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, kernel_size=3) self.bn2 = nn.BatchNorm2d(64) self.dropout = nn.Dropout(0.25) self.fc1 = nn.Linear(64*12*12, 128) self.fc2 = nn.Linear(128, 10) self.custom_param = nn.Parameter(torch.randn(128, 128)) def forward(self, x): x = F.relu(self .bn1(self.conv1(x))) x = F.max_pool2d(x, 2) x = F.relu(self.bn2(self.conv2(x))) x = F.max_pool2d(x, 2) x = self.dropout(x) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = x @ self.custom_param x = self.fc2(x) return F.log_softmax(x, dim=1) model = MyComplexModel() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): optimizer.zero_grad() input = torch.randn(64, 1, 28, 28) target = torch.randint(0, 10, (64,)) output = model(input) loss = criterion(output, target) loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item()}')
通过以上示例,可以更清晰地理解 torch.nn
模块的整体架构、原理、要点及其具体使用场景。
到此这篇关于PyTorch中torch.nn模块的实现的文章就介绍到这了,更多相关PyTorch torch.nn模块内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!
精彩评论