开发者

pytorch通过自己的数据集训练Unet网络架构

目录
  • 一、Unet网络介绍
  • 二、VOC训练Unet
    • 2.1 Unet代码实现
    • 2.2 数据集处理
    • 2.3 训练过程

在图像分割这个问题上,主要有两个流派:Encoder-Decoder和Dialated Conv。本文介绍的是编解码网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++

一、Unet网络介绍

论文:https://arxiv.org/abs/1505.04597v1(2015)

UNet的设计就是应用与医学图像的分割。由于医学影像处理中,数据量较少,本文提出的方法有效提升了使用少量数据集训练检测的效果,提出了处理大尺寸图像的有效方法。

UNet的网络架构继承自FCN,并在此基础上做了些改变。提出了Encoder-Decoder概念,实际上就是FCN那个先卷积再上采样的思想。

pytorch通过自己的数据集训练Unet网络架构

上图是Unet的网络结构,从图中可以看出,

结构左边为Encoder,即下采样提取特征的过程。Encoder基本模块为双卷积形式,即输入经过两个

conu 3x3,使用的valid卷积,在代码实现时我们可以增加padding使用same卷积,来适应Skip Architecture。下采样采用的池化层直接缩小2倍。

结构右边是Decoder,即上采样恢复图像尺寸并预测的过程。Decoder一样采用双卷积的形式,其中上采样使用转置卷积实现,每次转置卷积放大2倍。

结构中间copy and crop是一个cat操作,即feature map的通道叠加。

二、VOC训练Unet

2.1 Unet代码实现

根据上面对于Unet网络结构的介绍,可见其结构非常对称简单,代码Unet.py实现如下:

from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BATchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.output = nn.Conv2d(64, out_ch, 1)
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(开发者_自学开发conv4)
        conv5 = self.conv5(pool4)
        up6 = self.up6(conv5)
        meger6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(meger6)
        up7 = self.up7(conv6)
        meger7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(meger7)
        up8 = self.up8(conv7)
        meger8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(meger8)
        up9 = self.up9(conv8)
        meger9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(meger9)
        out = self.output(conv9)
        return out
if __name__=="__main__":
    model = Unet(3, 21)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(model)

2.2 数据集处理

pytorch通过自己的数据集训练Unet网络架构

pytorch通过自己的数据集训练Unet网络架构

数据来源于kaggle,下载地址我忘了。包含2个类别,1个车,还有1个背景类,共有5k+的数据,按照比例分为训练集和验证集即可。具体见carnava.py

from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
    def __init_python_(self, root, train=True):
        self.root = root
        self.crop_size = (256, 256)
        self.img_path = os.path.join(root, "train_hq")
        self.label_path = os.path.join(root, "train_masks")
        img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
        train_ppythonath_list, val_path_list = self._split_data_set(img_path_list)
        if train:
            self.imgs_list = train_path_list
        else:
            self.imgs_list = val_path_list
        normalizhttp://www.devze.come = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transforms = T.Compose([
                T.Resize(256),
                T.CenterCrop(256),
                T.ToTensor(),
                normalize
            ])
        self.transforms_val = T.Compose([
            T.Resize(256),
            T.CenterCrop(256)
        ])
        self.color_map = [[0, 0, 0], [255, 255, 255]]
    def __getitem__(self, index: int):
        im_path = self.imgs_list[index]
        image = Image.open(im_path).convert("RGB")
        data = self.transforms(image)
        (filepath, filename) = os.path.split(im_path)
        filename = filename.split('.')[0]
        label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
        label = self.transforms_val(label)
        cm2lb=np.zeros(256**3)
        for i,cm in enumerate(self.color_map):
            cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
        image=np.array(label,dtype=np.int64)
        idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
        label=np.array(cm2lb[idx],dtype=np.int64)
        label=torch.from_numpy(label).long()
        return data, label
    def label2img(self, label):
        cmap = self.color_map
        cmap = np.array(cmap).astype(np.uint8)
        pred = cmap[label]
        return pred
    def __len__(self):
        return http://www.devze.comlen(self.imgs_list)
    def _split_data_set(self, img_path_list):
        val_path_list = img_path_list[::8]
        train_path_list = []
        for item in img_path_list:
            if item not in val_path_list:
                train_path_list.append(item)
        return train_path_list, val_path_list
if __name__=="__main__":
    root = "../dataset/carvana"
    car_train = Car(root,train=True)
    train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
    print(len(car_train))
    print(len(train_dataloader))
    # for data, label in car_train:
    #     print(data.shape)
    #     print(label.shape)
    #     break
    (data, label) = car_train[190]
    label_np = label.data.numpy()
    label_im = car_train.label2img(label_np)
    plt.figure()
    plt.imshow(label_im)
    plt.show()

2.3 训练过程

分割其实就是给每个像素分类而已,所以损失函数依旧是交叉熵函数,正确率为分类正确的像素点个数/全部的像素点个数

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaLuation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
    os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
    os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
    best_score = 0.0
    for e in range(epochs):
        net.train()
        train_loss = 0.0
        label_true = torch.LongTensor()
        label_pred = torch.LongTensor()
        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            output = net(data)
            loss = criterion(output, label)
            pred = output.argmax(dim=1).squeeze().data.cpu()
            real = label.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.cpu().item()
            label_true = torch.cat((label_true,real),dim=0)
            label_pred = torch.cat((label_pred,pred),dim=0)
        train_loss /= len(train_dataloader)
        acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
        print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
            e+1, train_loss, acc, acc_cls, mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, train_loss:{:.4f}, acc:www.devze.com{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
                e+1,train_loss,acc, acc_cls, mean_iu))
        net.eval()
        val_loss = 0.0
        val_label_true = torch.LongTensor()
        val_label_pred = torch.LongTensor()
        with torch.no_grad():
            for batch_id, (data, label) in enumerate(val_dataloader):
                data, label = data.to(device), label.to(device)
                output = net(data)
                loss = criterion(output, label)
                pred = output.argmax(dim=1).squeeze().data.cpu()
                real = label.data.cpu()
                val_loss += loss.cpu().item()
                val_label_true = torch.cat((val_label_true, real), dim=0)
                val_label_pred = torch.cat((val_label_pred, pred), dim=0)
            val_loss/=len(val_dataloader)
            val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
                                                                    val_label_pred.numpy(),numclasses)
        print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
            e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
        score = (val_acc_cls+val_mean_iu)/2
        if score > best_score:
            best_score = score
            torch.save(net.state_dict(), model_path)
def evaluate():
    import util
    import random
    import matplotlib.pyplot as plt
    net.load_state_dict(torch.load(model_path))
    index = random.randint(0, len(val_data)-1)
    val_image, val_label = val_data[index]
    out = net(val_image.unsqueeze(0).to(device))
    pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
    label = val_label.data.numpy()
    img_pred = val_data.label2img(pred)
    img_label = val_data.label2img(label)
    temp = val_image.numpy()
    temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
    ax[1].imshow(img_label)
    ax[2].imshow(img_pred)
    plt.show()
if __name__=="__main__":
    # train_model()
    evaluate()

最终训练结果是:

pytorch通过自己的数据集训练Unet网络架构

由于数据比较简单,训练到epoch为5时,mIOU就已经达到0.97了。

最后测试一下效果:

pytorch通过自己的数据集训练Unet网络架构

从左到右分别是:原图、真实label、预测label

备注:

其实最开始使用voc数据集训练的,但效果极差,也没发现哪里有问题。换个数据集效果就好了,可能有两个原因:

1. voc数据我在处理数据时出错了,没检查出来

2. 这个数据集比较简单,容易学习,所以效果差不多。

到此这篇关于pytorch通过自己的数据集训练Unet网络架构的文章就介绍到这了,更多相关pytorch Unet内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新开发

开发排行榜