由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:

[作者环境: torch 1.2.0 torchvision 0.4.0 ]

pip install pytorch2caffe

import torch
import torchvision
from pytorch2caffe import pytorch2caffe
def SaveDemo():
    from torchvision.models import resnet

    name = 'resnet18'
    resnet18 = resnet.resnet18()
    resnet18.eval()
    dummy_input = torch.ones([1, 3, 224, 224])
    pytorch2caffe.trans_net(resnet18, dummy_input, name)
    pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))


if __name__ == '__main__':
    SaveDemo()

如果你的模型中使用了avg_pool 使用这种写法:

x = F.avg_pool2d(x,7)

 

 

 

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐