Pytorch转Caffe最简单方法
由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:import torchimport torchvisionfrom pytorch2caffe import pytorch2caffedef SaveDemo():from torchvision.m
·
由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:
[作者环境: torch 1.2.0 torchvision 0.4.0 ]
pip install pytorch2caffe
import torch
import torchvision
from pytorch2caffe import pytorch2caffe
def SaveDemo():
from torchvision.models import resnet
name = 'resnet18'
resnet18 = resnet.resnet18()
resnet18.eval()
dummy_input = torch.ones([1, 3, 224, 224])
pytorch2caffe.trans_net(resnet18, dummy_input, name)
pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))
if __name__ == '__main__':
SaveDemo()
如果你的模型中使用了avg_pool 使用这种写法:
x = F.avg_pool2d(x,7)

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐
所有评论(0)