博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch中实现深度学习网络的训练和推断——以yolov3为例
阅读量:2810 次
发布时间:2019-05-13

本文共 4850 字,大约阅读时间需要 16 分钟。

文章目录

一、简介

Pytorch是目前非常流行的大规模矩阵计算框架,上手简易,文档详尽,最新发表的深度学习领域的论文中有多半是以pytorch框架来实现的,足以看出其易用性和流行度。

这篇文章将以yolov3为例,介绍pytorch中如何实现一个网络的训练和推断。

二、Pytorch构建深度学习网络

这一部分主要讲解一下,在pytorch中构建一个深度学习网络,需要包含哪些部分,各部分都起了什么作用。不同的框架的实现方式会有许多不同,但基本都包含这些部分。在以下的讲解中我隐去了一些具体的实现细节,如果想详细了解,可以前往这个github了解,我的讲解代码也是以它为基础改编的,两个版本配合着看能更好地了解和上手。

1.datasets

数据集在网络的训练过程是必须的。通常在训练脚本中,会看到类似下面的这样一行代码。

dataloader = torch.utils.data.DataLoader(Dataset(train_path))

其中的Dataset是自定义的一个类,train_path是训练数据集的路径。Dataset通常定义在命名为datasets的文件内,当然也有以VOCDataset、COCODataset来命名的,其作用都是相同的,定义一个数据集类,以便pytorch调用。下面给出一个Dataset的类定义模板,该模板为yolov3的框架所使用。

class Dataset(Dataset):    def __init__(self, img_dir, label_dir):        self.img_files = glob.glob(os.path.join(img_dir, '*.*'))        self.label_files = glob.glob(os.path.join(label_dir, '*.*'))    def __getitem__(self, index):        # === 图片 ===        # 读取图片        img_path = self.img_files[index % len(self.img_files)].rstrip()        img = np.array(Image.open(img_path))        img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255)   # 将numpy.array的格式转为torch.Tensor格式,并转换通道        # 图像预处理(可选)        # 做一些诸如pad、resize之类的操作                # === 标签 ===        # 获取标签文件路径        label_path = self.label_files[index % len(self.img_files)].rstrip()        # 解析标签文件(可选)        # 读取label_path的文件然后解析,也可直接返回label_path                return img, label_path    def __len__(self):        return len(self.img_files)

基本上所有的Dataset类都会包含init、getitem、len这三个函数,在getitem函数中,一般会包含图像预处理和标签预处理,也有些是把这两部分放在外部处理,getitem只获取图像和标签文件路径,值得注意的是,有不少的框架对getitem进行了重载,所以你可能没有找到getitem函数,但是有其他函数能代替getitem的作用。

2.models

在DL框架中models是一个最为重要的部分,它实现了整个网络的整体结构和具体细节,在一些通用型的大型项目框架内,通常会把这部分拆分成多个modules进行实现,而在一些小项目里,models也可能仅仅用一个文件来实现。这里我还是以yolov3的models来举例介绍。

# === 读取cfg配置文件 ===def create_modules(cfg):	# 根据配置文件进行解析    return module_list# === yolo层定义 ===class YOLOLayer(nn.Module):    def __init__(self, cfg):        super(YOLOLayer, self).__init__()    def forward(self, x, targets=None):        if targets is not None:            # === 训练阶段 ===        	# 计算loss,根据输入的x的结果与targets进行计算,最后得到loss            return x, loss        else:            # === 推断阶段 ===        	# 根据输入的x计算出预测结果            return x# === darknet网络结构定义 ===class Darknet(nn.Module):    def __init__(self, cfg):        super(Darknet, self).__init__()        self.module_list = create_modules(cfg)    def forward(self, x, targets=None):    	losses= []        for module in self.module_list:			if module is not 'YOLO':                x = module(x)            else:                # === 训练阶段 ===                if is_training:                    x, loss = module(x, targets)                    losses.append(loss)                # === 推断阶段 ===                else:                    x = module(x)        return x, losses

在网络结构和yolo层定义中,init和forward这两个函数是必须的,事实上这两个函数也是torch内置已经定义过了的,这里这样写实际是重载了这两个函数。有些项目中可能会把训练和推断的forward函数拆分成两个函数,函数名字也改变了,实际运用时要注意。

3.train

训练脚本可以说是网络中最为关键的部分,它直接影响了模型的性能和鲁棒性。基本上不同网络的训练脚本均有不同之处,但是均可以达到一定的效果。一个训练脚本一般包含dataloader、optimizer、model三个部分,运用这三个部分构成train迭代循环过程。

# 构建model,模型结构model = Darknet(model_config_path)model.apply(weights_init_normal)model.train()if cuda:	model = model.cuda()# 设置dataloader ,数据集加载器# batch_size根据显存大小调整,shuffle是指是否打乱数据集的读取顺序,num_workers是指用多少个线程读取数据集dataloader = torch.utils.data.DataLoader(    Dataset(img_dir, label_dir), batch_size=16, shuffle=True, num_workers=4)# 设置optimizer,优化器# 优化器的种类有非常多,建议新手使用Adam,因为这是一个自适应调整学习率的优化器,不需要设置很多参数# 如果需要精调模型,或者对这方面比较熟练,可以使用SGD+Momentum优化器optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))# 主循环train过程total_epoch = 10for epoch in range(total_epoch):    for batch_i, (imgs, targets) in enumerate(dataloader):        # 注意,输入的图像必须进行通道转换,我这里忽略了这个步骤,因为我之前已经在Dataset部分实现了        # 这里的imgs的shape应该为(B, C, H,  W),B为batch_size,C为通道,H为高,W为宽        imgs.requires_grad = True   # imgs的requires_grad属性必须为True,而targets的requires_grad属性为False(默认为False)        if cuda:        	imgs = imgs.cuda()        	targets = targets.cuda()        optimizer.zero_grad()        _, loss = model(imgs, targets)        loss.backward()        optimizer.step()        print('epoch:', epoch, 'batch:', batch_i, 'loss:', loss.detach().cpu().numpy())       if epoch % 1 == 0:        torch.save(model.state_dict(), 'backup.pth')

以上就是训练脚本中所包含的基本部分,关于loss的计算,有些project把它放在了forward函数里面,也是没问题的,只要注意进行计算imgs的requires_grad必须为True就可以了。

4.inference

推断脚本相对于训练来说比较简单,基本上大同小异,只要模型结构没错基本上输出结果都是相同的。

# 加载模型model = Darknet(model_config_path)params = torch.load('backup.pth')model.load_state_dict(params)model.eval()# 读取图片和推断img = cv2.imread(path)img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255).unsqueeze(0)with torch.no_grad():	out, _ = model(img)# 处理out,例如进行nms和结果显示,该部分省略

三、总结

以上就是在pytorch中构建模型和训练、推断的主要过程,这个博客主要目的是帮大家理解这个过程,所以对于一些具体实现细节我没有给出,想详细了解的可以去这个github上进行了解,后续我有时间也会公布一个我个人的yolov3的pytorch版本。如有疑问也可以在下面评论,我有空会回复,谢谢。

转载地址:http://mjqqd.baihongyu.com/

你可能感兴趣的文章
spring cloud 实战(干货)
查看>>
docker简介
查看>>
docker镜像和仓库
查看>>
3_小米监控Open-Falcon 后端服务安装并启动
查看>>
4_小米监控Open-Falcon 前端安装
查看>>
5_小米监控Open-Falcon 安装-Agent
查看>>
8_小米监控Open-Falcon安装查询组件-API
查看>>
2_RabbitMQ-3.7.2安装手册
查看>>
3_rabbitmq后台管理界面
查看>>
4_rabbitmq java操作简单队列
查看>>
5_rabbitmq work queues 工作队列
查看>>
7_rabbitmq订阅模式 PublishSubscribe
查看>>
8_rabbitmq路由模式
查看>>
11_RabbitMQ之消息确认机制
查看>>
4_ElaticSearch 使用terms搜索多个值
查看>>
5_ElaticSearch 基于range filter来进行范围过滤
查看>>
6_ElatisSearch 控制全文检索结果的精准度
查看>>
7_ElaticSearch term+bool实现的multiword搜索原理
查看>>
10_ElasticSearch dis_max实现best fields策略进行多字段搜索
查看>>
13_ElasticSearch multi_match+most fiels策略进行multi-field搜索
查看>>