开发者

PyTorch中torch.argmax函数的使用

目录
  • 函数定义
  • 核心功能
    • 1、​全局最大值索引​(当 dim=None)
    • 2|​沿指定维度查找最大值索引​(当 dim 指定时)
  • 参数详解
    • 1. dim 参数
    • 2. keepdim 参数
  • 常见用途
    • 1、​分类任务中获取预测标签
    • 2、​计算准确率
  • 注意事项
    • 1、​多个相同最大值:
    • 2、​数据类型
    • 3、​维度合法性
  • 总结

    torch.argmax 是 PyTorch 中的一个函数,用于返回输入张量中最大值所在的索引。其作用与数学中的 ​argmax 概念一致,即找到某个函数在指定范围内取得最大值时的参数(位置索引

    函数定义

    torch.argmax(input, dim=None, keepdim=False)
    
    • ​输入:
      • input:输入张量。
      • dim(可选):指定沿哪个维度查找最大值。如果为 None,则在整个张量中查找。
      • keepdim(可选):是否保持输出张量的维度与输入一致(默认为 False)。
    • ​输出:

      一个张量,包含最大值所在的索引

    核心功能

    1、​全局最大值索引​(当 dim=None)

    • 将输入张量展平后,返回最大值的索引
    import torch
    
    x = torch.tensor([[1, 2, 3],
                      [6, 5, 4]])
    print(torch.argmax(x))  # 输出:tensor(3)
    # 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值为6,索引为3(从0开始)
    

    2|​沿指定维度查找最大值索引​(当 dim 指定时)

    • 沿 dim 维度对输入张量操作,返回每行/列的最大值索引
    # 沿编程行维度(dim=1)查找
    x = torch.tensor([[1, 2, 3],
                      [6, 5, 4]])
    print(torch.argmax(x, dim=1))  # 输出:tensor([2, 0])
    # 解释:
    # 第一行 [1, 2, 3] 最大值3,索引2
    # 第二行 [6, 5, 4] 最大值6,索引0
    
    # 沿列维度(dim=0)查找编程
    print(torch.argmax(x, dim=0))  # 输出:tensor([1, 1, 0])
    # 解释:
    # 第0列 [1, 6] 最大值6,索引1
    # 第1列 [2, 5] 最大值5,索引1
    # 第2列 [3, 4] 最大值4,索引1(但此处输出为0,可能有误,实际应为1)
    

    参数详解

    1. dim 参数

    • ​作用:指定沿哪个维度操作。
    • ​示例:
      • dim=0:沿列操作(纵向)。
      • dim=1:沿行操作(横向)。

    2. keepdim 参数

    • ​作用:保持输出维度与输入一致。
    • ​示例:
    x = torch.tensor([[1, 2, 3],
                      [6, 5, 4]])
    out = torch.argmax(x, dim=1, keepdim=True)
    print(out)  # 输出:tensor([[2], [0]])
    

    常见用途

    1、​分类任务编程客栈中获取预测标签

    logits = torch.tensandroidor([0.1, 0.8, 0.05, 0.05])  # 模型输出的概率分布
    predicted_class = torch.argmax(logits)         # 输出:tensor(1)
    

    2、​计算准确率

    # 假设BATch_size=4,num_classes=3
    preds = torch.tensor([[0.1, 0.2, 0.7],
                          [0.9, 0.05, 0.05],
                          [0.3, 0.4, 0.3],
                          [0.05, 0.8, 0.15]])
    labels = torch.tensor([2, 0, 1, 1])
    # 获取预测类别
    predicted_classes = torch.argmax(preds, dim=1)  # 输出:tensor([2, 0, 1, 1])
    # 计算正确预测数
    correct = (predicted_classes == labels).sum()   # 输出:tensor(3)
    

    注意事项

    1、​多个相同最大值:

    • 如果存在多个相同的最大值,返回第一个出现的索引
    x = torch.tensor([3, 1, 4, 4])
    print(torch.argmax(x))  # 输出:tensor(2)
    

    2、​数据类型

    • 输入张量应为数值类型(如 float32、int64)

    3、​维度合法性

    • 如果指定了不存在的维度(如 dim=3 对一个二维张量),会触发错误

    总结

    torch.argmax 是一个高效的工具,广泛应用于分类模型预测、js指标计算等场景。理解其 dim 和 keepdim 参数的行为,可以灵活处理不同维度的数据

    到此这篇关于PyTorch中torch.argmax函数的使用的文章就介绍到这了,更多相关PyTorch torch.argmax内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜