昇腾:单机单卡训练->单机多卡训练

分布式训练

(1)单机单卡的训练流程在这里插入图片描述

  • 硬盘读取数据
  • CPU处理数据,将数据组成一个batch
  • 传入GPU
  • 网络前向传播计算loss
  • 网络反向传播计算梯度

(2)PyTorch中最早的数据并行框架:Data Parallel (DP)——单进程、多线程

  • 从硬盘读取数据,通过一个CPU进程将数据分成多份
  • 给每个GPU一份
  • 每个GPU独立进行网络的前向传播、后向传播,计算出各自的梯度
  • 所有其他的GPU都将自己计算出的梯度传递到GPU0上进行平均
  • GPU0通过全局平均的梯度更新自己的网络参数
  • GPU0将更新后的参数广播到其他GPU上
    在这里插入图片描述
    在这里插入图片描述

分布式训练中最关键的问题是如何减少多卡之间的通信量,以提高训练效率。

  • 下面分析一下DP这种模式的通信量:
    假设参数量为 ψ \psi ψ,节点数为 N N N
    对于GPU0,传入梯度为 ( N − 1 ) ψ (N-1)\psi (N1)ψ;传出参数为 ( N − 1 ) ψ (N-1)\psi (N1)ψ
    对于其他GPU,传入梯度为 ψ \psi ψ;传出参数为 ψ \psi ψ
  • DP这种模式存在的问题:
    单进程,多线程,Python GIL只能利用一个CPU核;
    GPU0负责收集梯度,更新参数,广播参数。通信计算压力大。

(3)PyTorch中替代DP的分布式训练框架:Distributed Data Parallel (DDP)——多进程,分布式数据并行

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在环形连接的Ring-AllReduce里,每个GPU的负载都是一样的。每个GPU同时发送和接收,可以最大限度利用每个显卡的上下行带宽。

在这里插入图片描述

  • DDP模式是多进程的,每个进程为自己的GPU准备数据并和其他GPU通信,每个进程用各自的数据进行神经网络的前向和反向传播计算自己的梯度。
  • 由于每个进程的数据不同,所以每个batch数据计算出来的梯度也不同,这个时候就需要用Ring-AllReduce的方式同步每个GPU上的梯度,同步后各个GPU上的梯度都相同。
  • 他们用各自的优化器来更新各自的神经网络,网络状态和优化器状态始终保持同步。优化器一样,梯度也一样,保证了优化结果的一致性。

下面分析一下DDP这种模式的通信量:
假设参数量为 ψ \psi ψ,进程数为 N N N
对于每一个GPU进程,
Scatter-Reduce阶段传入/传出: ( N − 1 ) ψ N (N-1) \frac{\psi}{N} (N1)Nψ
AllGather阶段传入/传出: ( N − 1 ) ψ N (N-1) \frac{\psi}{N} (N1)Nψ
总传入传出: 2 ψ 2\psi 2ψ,与集群大小无关

昇腾DDP参考代码——shell版本

(1)华为官方单机多卡训练脚本demo:ddp_test_shell.py

# 导入依赖和库 
import torch 
from torch import nn 
import torch_npu 
import torch.distributed as dist 
from torch.utils.data import DataLoader 
from torchvision import datasets 
from torchvision.transforms import ToTensor 
import time 
import torch.multiprocessing as mp 
import os 
 
torch.manual_seed(0) 
# 下载训练数据 
training_data = datasets.FashionMNIST( 
    root="./data", 
    train=True, 
    download=True, 
    transform=ToTensor(), 
) 
 
# 下载测试数据 
test_data = datasets.FashionMNIST( 
    root="./data", 
    train=False, 
    download=True, 
    transform=ToTensor(), 
) 
 
# 构建模型 
class NeuralNetwork(nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.flatten = nn.Flatten() 
        self.linear_relu_stack = nn.Sequential( 
            nn.Linear(28*28, 512), 
            nn.ReLU(), 
            nn.Linear(512, 512), 
            nn.ReLU(), 
            nn.Linear(512, 10) 
        ) 
 
    def forward(self, x): 
        x = self.flatten(x) 
        logits = self.linear_relu_stack(x) 
        return logits 
 
def test(dataloader, model, loss_fn): 
    size = len(dataloader.dataset) 
    num_batches = len(dataloader) 
    model.eval() 
    test_loss, correct = 0, 0 
    with torch.no_grad(): 
        for X, y in dataloader: 
            X, y = X.to(device), y.to(device) 
            pred = model(X) 
            test_loss += loss_fn(pred, y).item() 
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() 
    test_loss /= num_batches 
    correct /= size 
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
    
def main(world_size: int,  batch_size = 64, total_epochs = 5,):    # 用户可自行设置
    ngpus_per_node = world_size
    main_worker(args.gpu, ngpus_per_node, args)
    
def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"    # 用户需根据自己实际情况设置
    os.environ["MASTER_PORT"] = "29500"    # 用户需根据自己实际情况设置
    dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
    
def main_worker(gpu, ngpus_per_node, args):

    start_epoch = 0
    end_epoch = 5
    args.gpu = int(os.environ['LOCAL_RANK'])    # 在shell脚本中循环传入local_rank变量作为指定的device
    ddp_setup(args.gpu, args.world_size)

    torch_npu.npu.set_device(args.gpu)
    total_batch_size = args.batch_size
    total_workers = ngpus_per_node

    batch_size = int(total_batch_size / ngpus_per_node)    
    workers = int((total_workers + ngpus_per_node - 1) / ngpus_per_node)

    model = NeuralNetwork()

    device = torch.device("npu")

    train_sampler = torch.utils.data.distributed.DistributedSampler(training_data)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)

    train_loader = torch.utils.data.DataLoader(
        training_data, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=False, sampler=train_sampler, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size, shuffle=(test_sampler is None),
        num_workers=workers, pin_memory=False, sampler=test_sampler, drop_last=True)

    loc = 'npu:{}'.format(args.gpu)
    model = model.to(loc)
    criterion = nn.CrossEntropyLoss().to(loc)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    for epoch in range(start_epoch, end_epoch):
        train_sampler.set_epoch(epoch)
        train(train_loader, model, criterion, optimizer, epoch, args.gpu)


def train(train_loader, model, criterion, optimizer, epoch, gpu):
    size = len(train_loader.dataset)
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time

        loc = 'npu:{}'.format(gpu)
        target = target.to(torch.int32)        
        images, target = images.to(loc, non_blocking=False), target.to(loc, non_blocking=False)

        # compute output
        output = model(images)
        loss = criterion(output, target)


        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        end = time.time()
        if i % 100 == 0:
            loss, current = loss.item(), i * len(target)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
            
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--batch_size', default=512, type=int, help='Input batch size on each device (default: 32)')
    parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
    args = parser.parse_args()
    world_size = torch.npu.device_count()
    args.world_size = world_size
    ngpus_per_node = world_size

    start_epoch = 0
    end_epoch = 5
    args.gpu = int(os.environ['LOCAL_RANK'])    # 在shell脚本中循环传入local_rank变量作为指定的device
    print('args.gpu:',args.gpu)
    ddp_setup(args.gpu, args.world_size)

    torch_npu.npu.set_device(args.gpu)
    total_batch_size = args.batch_size
    total_workers = ngpus_per_node

    batch_size = int(total_batch_size / ngpus_per_node)    
    workers = int((total_workers + ngpus_per_node - 1) / ngpus_per_node)

    model = NeuralNetwork()

    device = torch.device("npu")

    train_sampler = torch.utils.data.distributed.DistributedSampler(training_data)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)

    train_loader = torch.utils.data.DataLoader(
        training_data, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=False, sampler=train_sampler, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size, shuffle=(test_sampler is None),
        num_workers=workers, pin_memory=False, sampler=test_sampler, drop_last=True)

    loc = 'npu:{}'.format(args.gpu)
    model = model.to(loc)
    criterion = nn.CrossEntropyLoss().to(loc)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    for epoch in range(start_epoch, end_epoch):
        train_sampler.set_epoch(epoch)
        train(train_loader, model, criterion, optimizer, epoch, args.gpu)

(2)华为官方单机多卡启动脚本demo:run.sh
这里是单机双卡示例:

#!/bin/bash
export HCCL_WHITELIST_DISABLE=1
RANK_ID_START=0
WORLD_SIZE=2
for((RANK_ID=$RANK_ID_START;RANK_ID<$((WORLD_SIZE+RANK_ID_START));RANK_ID++));
do
    echo "Device ID: $RANK_ID"
    export LOCAL_RANK=$RANK_ID
    python3 /home/work/user-job-dir/app/notebook/RTDosePrediction-main/RTDosePrediction/Src/DCNN/train-2p.py &
done
wait

详细解释

参考代码:https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu.py
参考视频:https://www.youtube.com/watch?v=-LAtx9Q6DA8&list=PL_lsbAsL_o2CSuhUhJIiW0IkdT5C2wGWj&index=3

01 导入一些分布式相关的包

(1)torch.multiprocessing
对 Python 标准库 multiprocessing 的扩展,用于在多个进程之间共享和传输 PyTorch 对象,尤其是张量(Tensors)和模型(Models)。在多进程环境中,数据共享是一个挑战。torch.multiprocessing 允许在进程之间共享 PyTorch 张量,这些张量存储在共享内存中,而不是在进程之间复制数据。
(2)import torch.distributed as dist
train_sampler = torch.utils.data.distributed.DistributedSampler(training_data)
在分布式训练环境中对数据集进行采样。这个采样器的设计目的是确保在分布式训练过程中,每个进程只处理数据集的一个子集,这样可以有效地利用多个进程和GPU来加速训练。

(3)torch.nn.parallel
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
DDP通过在每个进程中创建模型的一个副本,并在每个副本上独立地进行前向和反向传播,从而实现并行计算。在每个训练步骤后,DDP自动同步各个进程计算出的梯度,确保所有进程的模型参数保持一致。DDP使用高效的通信后端(如NCCL)来在进程间同步梯度,这对于GPU之间的通信尤其重要。
(4)torch.distributed
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
用于初始化分布式训练的环境。这个函数会设置一个进程组,该进程组是一组进程,它们可以相互通信以进行分布式训练。

02 初始化分布式训练进程组(intitialize the distributed process group)

def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"    # 用户需根据自己实际情况设置
    os.environ["MASTER_PORT"] = "29500"    # 用户需根据自己实际情况设置
    dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)

MASTER_ADDR:the machine that running the rank 0 process
MASTER_PORT:any free port on this machine
(master: the machine coordinates the communication across all of our processes)

dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)中,

  • backend指定了用于进程间通信的后端。华为Ascend NPU通信后端:hccl,用于在NPU上实现高效的分布式通信;NVIDIA通信后端:nccl(nvidia collective communications library)
  • rank参数是当前进程的唯一标识符,通常是一个从0开始的整数。用于确定每个进程应该使用的GPU设备,以及在进行集合通信时如何定位其他进程。
  • world_size指定了进程组中的进程总数,即参与分布式训练的进程数量。如果你有4个NPU,并且你想要在所有4个NPU上进行分布式训练,那么 world_size 应该设置为4。

03 在进行模型训练之前对模型用DDP进行wrap

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

04 在模型保存时要记得unwrap

因为我们在上面将模型包装成了一个DDP Object,所以在模型保存时不能直接ckp = model.state_dict()
而要:ckp = model.module.state_dict()
示例代码:

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

05 确保在分布式训练过程中,每个进程只处理数据集的一个子集:DistributedSampler

train_sampler = torch.utils.data.distributed.DistributedSampler(training_data)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)

train_loader = torch.utils.data.DataLoader(
    training_data, batch_size=batch_size, shuffle=(train_sampler is None),
    num_workers=workers, pin_memory=False, sampler=train_sampler, drop_last=True)

val_loader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=(test_sampler is None),
    num_workers=workers, pin_memory=False, sampler=test_sampler, drop_last=True)

在Dataloader函数中传入采样器sampler,注意分布式训练时,传入sampler参数后,shuffle需要设为false

06 主函数修改

示例代码:多传入两个参数:rank,world_size

def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()

注意事项

(1)shell脚本中的python脚本要使用绝对路径,否则会报错:No such file or directory
注:获取绝对路径的方法:直接copy path的不全

import os
# 获取当前文件的绝对路径
current_path = os.path.abspath(__file__)
print(current_path)

(2)在modelmate平台上提交训练任务、启动训练时,可能会报错:Bus error


notebook/RTDosePrediction-main/RTDosePrediction/Src/DCNN/run_8p.sh: line 11:   738 Bus error               (core dumped) python3 /home/work/user-job-dir/app/notebook/RTDosePrediction-main/RTDosePrediction/Src/DCNN/train-8p.py

这个原因未知,多试几次可能又行了。
(3)加载训练好的模型,然后用于测试时出错。
报错如下:

Traceback (most recent call last):
  File "/home/work/user-job-dir/app/notebook/RTDosePrediction-main/RTDosePrediction/Src/DCNN/test.py", line 115, in <module>
    trainer.init_trainer(ckpt_file=args.model_path,
  File "/home/work/user-job-dir/app/notebook/RTDosePrediction-main/RTDosePrediction/Src/NetworkTrainer/network_trainer.py", line 359, in init_trainer
    self.setting.network.load_state_dict(ckpt['network_state_dict'])
  File "/home/naie/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Model:
        Missing key(s) in state_dict: "encoder.encoder_1.0.single_conv.0.weight", "encoder.encoder_1.0.single_conv.1.weight", "encoder.encoder_1.0.single_conv.1.bias", "encoder.encoder_1.0.single_conv.1.running_mean", "encoder.encoder_1.0.single_conv.1.running_var", "encoder.encoder_1.1.single_conv.0.weight", "encoder.encoder_1.1.single_conv.1.weight", "encoder.encoder_1.1.single_conv.1.bias", "encoder.encoder_1.1.single_conv.1.running_mean", "encoder.encoder_1.1.single_conv.1.running_var", "encoder.encoder_2.1.single_conv.0.weight", "encoder.encoder_2.1.single_conv.1.weight", "encoder.encoder_2.1.single_conv.1.bias", "encoder.encoder_2.1.single_conv.1.running_mean", "encoder.encoder_2.1.single_conv.1.running_var", "encoder.encoder_2.2.single_conv.0.weight", "encoder.encoder_2.2.single_conv.1.weight", "encoder.encoder_2.2.single_conv.1.bias", "encoder.encoder_2.2.single_conv.1.running_mean", "encoder.encoder_2.2.single_conv.1.running_var", "encoder.encoder_3.1.single_conv.0.weight", "encoder.encoder_3.1.single_conv.1.weight", "encoder.encoder_3.1.single_conv.1.bias", "encoder.encoder_3.1.single_conv.1.running_mean", "encoder.encoder_3.1.single_conv.1.running_var", "encoder.encoder_3.2.single_conv.0.weight", "encoder.encoder_3.2.single_conv.1.weight", "encoder.encoder_3.2.single_conv.1.bias", "encoder.encoder_3.2.single_conv.1.running_mean", "encoder.encoder_3.2.single_conv.1.running_var", "encoder.encoder_4.1.single_conv.0.weight", "encoder.encoder_4.1.single_conv.1.weight", "encoder.encoder_4.1.single_conv.1.bias", "encoder.encoder_4.1.single_conv.1.running_mean", "encoder.encoder_4.1.single_conv.1.running_var", "encoder.encoder_4.2.single_conv.0.weight", "encoder.encoder_4.2.single_conv.1.weight", "encoder.encoder_4.2.single_conv.1.bias", "encoder.encoder_4.2.single_conv.1.running_mean", "encoder.encoder_4.2.single_conv.1.running_var", "decoder.upconv_3_1.weight", "decoder.upconv_3_1.bias", "decoder.decoder_conv_3_1.0.single_conv.0.weight", "decoder.decoder_conv_3_1.0.single_conv.1.weight", "decoder.decoder_conv_3_1.0.single_conv.1.bias", "decoder.decoder_conv_3_1.0.single_conv.1.running_mean", "decoder.decoder_conv_3_1.0.single_conv.1.running_var", "decoder.decoder_conv_3_1.1.single_conv.0.weight", "decoder.decoder_conv_3_1.1.single_conv.1.weight", "decoder.decoder_conv_3_1.1.single_conv.1.bias", "decoder.decoder_conv_3_1.1.single_conv.1.running_mean", "decoder.decoder_conv_3_1.1.single_conv.1.running_var", "decoder.upconv_2_1.weight", "decoder.upconv_2_1.bias", "decoder.decoder_conv_2_1.0.single_conv.0.weight", "decoder.decoder_conv_2_1.0.single_conv.1.weight", "decoder.decoder_conv_2_1.0.single_conv.1.bias", "decoder.decoder_conv_2_1.0.single_conv.1.running_mean", "decoder.decoder_conv_2_1.0.single_conv.1.running_var", "decoder.decoder_conv_2_1.1.single_conv.0.weight", "decoder.decoder_conv_2_1.1.single_conv.1.weight", "decoder.decoder_conv_2_1.1.single_conv.1.bias", "decoder.decoder_conv_2_1.1.single_conv.1.running_mean", "decoder.decoder_conv_2_1.1.single_conv.1.running_var", "decoder.upconv_1_1.weight", "decoder.upconv_1_1.bias", "decoder.decoder_conv_1_1.0.single_conv.0.weight", "decoder.decoder_conv_1_1.0.single_conv.1.weight", "decoder.decoder_conv_1_1.0.single_conv.1.bias", "decoder.decoder_conv_1_1.0.single_conv.1.running_mean", "decoder.decoder_conv_1_1.0.single_conv.1.running_var", "decoder.decoder_conv_1_1.1.single_conv.0.weight", "decoder.decoder_conv_1_1.1.single_conv.1.weight", "decoder.decoder_conv_1_1.1.single_conv.1.bias", "decoder.decoder_conv_1_1.1.single_conv.1.running_mean", "decoder.decoder_conv_1_1.1.single_conv.1.running_var", "decoder.conv_out.0.weight", "decoder.conv_out.0.bias". 
        Unexpected key(s) in state_dict: "module.encoder.encoder_1.0.single_conv.0.weight", "module.encoder.encoder_1.0.single_conv.1.weight", "module.encoder.encoder_1.0.single_conv.1.bias", "module.encoder.encoder_1.0.single_conv.1.running_mean", "module.encoder.encoder_1.0.single_conv.1.running_var", "module.encoder.encoder_1.0.single_conv.1.num_batches_tracked", "module.encoder.encoder_1.1.single_conv.0.weight", "module.encoder.encoder_1.1.single_conv.1.weight", "module.encoder.encoder_1.1.single_conv.1.bias", "module.encoder.encoder_1.1.single_conv.1.running_mean", "module.encoder.encoder_1.1.single_conv.1.running_var", "module.encoder.encoder_1.1.single_conv.1.num_batches_tracked", "module.encoder.encoder_2.1.single_conv.0.weight", "module.encoder.encoder_2.1.single_conv.1.weight", "module.encoder.encoder_2.1.single_conv.1.bias", "module.encoder.encoder_2.1.single_conv.1.running_mean", "module.encoder.encoder_2.1.single_conv.1.running_var", "module.encoder.encoder_2.1.single_conv.1.num_batches_tracked", "module.encoder.encoder_2.2.single_conv.0.weight", "module.encoder.encoder_2.2.single_conv.1.weight", "module.encoder.encoder_2.2.single_conv.1.bias", "module.encoder.encoder_2.2.single_conv.1.running_mean", "module.encoder.encoder_2.2.single_conv.1.running_var", "module.encoder.encoder_2.2.single_conv.1.num_batches_tracked", "module.encoder.encoder_3.1.single_conv.0.weight", "module.encoder.encoder_3.1.single_conv.1.weight", "module.encoder.encoder_3.1.single_conv.1.bias", "module.encoder.encoder_3.1.single_conv.1.running_mean", "module.encoder.encoder_3.1.single_conv.1.running_var", "module.encoder.encoder_3.1.single_conv.1.num_batches_tracked", "module.encoder.encoder_3.2.single_conv.0.weight", "module.encoder.encoder_3.2.single_conv.1.weight", "module.encoder.encoder_3.2.single_conv.1.bias", "module.encoder.encoder_3.2.single_conv.1.running_mean", "module.encoder.encoder_3.2.single_conv.1.running_var", "module.encoder.encoder_3.2.single_conv.1.num_batches_tracked", "module.encoder.encoder_4.1.single_conv.0.weight", "module.encoder.encoder_4.1.single_conv.1.weight", "module.encoder.encoder_4.1.single_conv.1.bias", "module.encoder.encoder_4.1.single_conv.1.running_mean", "module.encoder.encoder_4.1.single_conv.1.running_var", "module.encoder.encoder_4.1.single_conv.1.num_batches_tracked", "module.encoder.encoder_4.2.single_conv.0.weight", "module.encoder.encoder_4.2.single_conv.1.weight", "module.encoder.encoder_4.2.single_conv.1.bias", "module.encoder.encoder_4.2.single_conv.1.running_mean", "module.encoder.encoder_4.2.single_conv.1.running_var", "module.encoder.encoder_4.2.single_conv.1.num_batches_tracked", "module.encoder.DFA.conv1.0.weight", "module.encoder.DFA.conv1.0.bias", "module.encoder.DFA.conv1.0.running_mean", "module.encoder.DFA.conv1.0.running_var", "module.encoder.DFA.conv1.0.num_batches_tracked", "module.encoder.DFA.conv1.2.weight", "module.encoder.DFA.conv1.2.bias", "module.encoder.DFA.conv2.0.weight", "module.encoder.DFA.conv2.0.bias", "module.encoder.DFA.conv2.0.running_mean", "module.encoder.DFA.conv2.0.running_var", "module.encoder.DFA.conv2.0.num_batches_tracked", "module.encoder.DFA.conv2.2.weight", "module.encoder.DFA.conv2.2.bias", "module.encoder.DFA.conv3.0.weight", "module.encoder.DFA.conv3.0.bias", "module.encoder.DFA.conv3.0.running_mean", "module.encoder.DFA.conv3.0.running_var", "module.encoder.DFA.conv3.0.num_batches_tracked", "module.encoder.DFA.conv3.2.weight", "module.encoder.DFA.conv3.2.bias", "module.encoder.DFA.conv4.0.weight", "module.encoder.DFA.conv4.0.bias", "module.encoder.DFA.conv4.0.running_mean", "module.encoder.DFA.conv4.0.running_var", "module.encoder.DFA.conv4.0.num_batches_tracked", "module.encoder.DFA.conv4.2.weight", "module.encoder.DFA.conv4.2.bias", "module.encoder.DFA.conv5.0.weight", "module.encoder.DFA.conv5.0.bias", "module.encoder.DFA.conv5.0.running_mean", "module.encoder.DFA.conv5.0.running_var", "module.encoder.DFA.conv5.0.num_batches_tracked", "module.encoder.DFA.conv5.2.weight", "module.encoder.DFA.conv5.2.bias", "module.encoder.DFA.conv_out.0.weight", "module.encoder.DFA.conv_out.0.bias", "module.encoder.DFA.conv_out.0.running_mean", "module.encoder.DFA.conv_out.0.running_var", "module.encoder.DFA.conv_out.0.num_batches_tracked", "module.encoder.DFA.conv_out.2.weight", "module.encoder.DFA.conv_out.2.bias", "module.decoder.upconv_3_1.weight", "module.decoder.upconv_3_1.bias", "module.decoder.decoder_conv_3_1.0.single_conv.0.weight", "module.decoder.decoder_conv_3_1.0.single_conv.1.weight", "module.decoder.decoder_conv_3_1.0.single_conv.1.bias", "module.decoder.decoder_conv_3_1.0.single_conv.1.running_mean", "module.decoder.decoder_conv_3_1.0.single_conv.1.running_var", "module.decoder.decoder_conv_3_1.0.single_conv.1.num_batches_tracked", "module.decoder.decoder_conv_3_1.1.single_conv.0.weight", "module.decoder.decoder_conv_3_1.1.single_conv.1.weight", "module.decoder.decoder_conv_3_1.1.single_conv.1.bias", "module.decoder.decoder_conv_3_1.1.single_conv.1.running_mean", "module.decoder.decoder_conv_3_1.1.single_conv.1.running_var", "module.decoder.decoder_conv_3_1.1.single_conv.1.num_batches_tracked", "module.decoder.upconv_2_1.weight", "module.decoder.upconv_2_1.bias", "module.decoder.decoder_conv_2_1.0.single_conv.0.weight", "module.decoder.decoder_conv_2_1.0.single_conv.1.weight", "module.decoder.decoder_conv_2_1.0.single_conv.1.bias", "module.decoder.decoder_conv_2_1.0.single_conv.1.running_mean", "module.decoder.decoder_conv_2_1.0.single_conv.1.running_var", "module.decoder.decoder_conv_2_1.0.single_conv.1.num_batches_tracked", "module.decoder.decoder_conv_2_1.1.single_conv.0.weight", "module.decoder.decoder_conv_2_1.1.single_conv.1.weight", "module.decoder.decoder_conv_2_1.1.single_conv.1.bias", "module.decoder.decoder_conv_2_1.1.single_conv.1.running_mean", "module.decoder.decoder_conv_2_1.1.single_conv.1.running_var", "module.decoder.decoder_conv_2_1.1.single_conv.1.num_batches_tracked", "module.decoder.upconv_1_1.weight", "module.decoder.upconv_1_1.bias", "module.decoder.decoder_conv_1_1.0.single_conv.0.weight", "module.decoder.decoder_conv_1_1.0.single_conv.1.weight", "module.decoder.decoder_conv_1_1.0.single_conv.1.bias", "module.decoder.decoder_conv_1_1.0.single_conv.1.running_mean", "module.decoder.decoder_conv_1_1.0.single_conv.1.running_var", "module.decoder.decoder_conv_1_1.0.single_conv.1.num_batches_tracked", "module.decoder.decoder_conv_1_1.1.single_conv.0.weight", "module.decoder.decoder_conv_1_1.1.single_conv.1.weight", "module.decoder.decoder_conv_1_1.1.single_conv.1.bias", "module.decoder.decoder_conv_1_1.1.single_conv.1.running_mean", "module.decoder.decoder_conv_1_1.1.single_conv.1.running_var", "module.decoder.decoder_conv_1_1.1.single_conv.1.num_batches_tracked", "module.decoder.conv_out.0.weight", "module.decoder.conv_out.0.bias". 

报错原因:如果在训练时使用了多卡(例如使用DataParallel或DistributedDataParallel),在单卡环境中加载模型时可能会出现这种问题。因为在多卡训练中,状态字典中的键会被添加module.前缀。
解决办法:

from collections import OrderedDict

    def init_trainer(self, ckpt_file, list_GPU_ids, only_network=True):
        ckpt = torch.load(ckpt_file, map_location='cpu')
        # 移除module前缀
        if 'network_state_dict' in ckpt:
            new_state_dict = OrderedDict()
            for k, v in ckpt['network_state_dict'].items():
                name = k[7:] if k.startswith('module.') else k
                new_state_dict[name] = v
            ckpt['network_state_dict'] = new_state_dict

        # self.setting.network.load_state_dict(ckpt['network_state_dict'], strict=False)

        self.setting.network.load_state_dict(ckpt['network_state_dict'])

        if not only_network:
            self.setting.lr_scheduler.load_state_dict(ckpt['lr_scheduler_state_dict'])
            self.setting.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            self.log = ckpt['log']

        self.set_GPU_device(list_GPU_ids)

        # If do not do so, the states of optimizer will always in cpu
        # This for Adam
        if type(self.setting.optimizer).__name__ == 'Adam':
            for key in self.setting.optimizer.state.items():
                key[1]['exp_avg'] = key[1]['exp_avg'].to(self.setting.device)
                key[1]['exp_avg_sq'] = key[1]['exp_avg_sq'].to(self.setting.device)
                key[1]['max_exp_avg_sq'] = key[1]['max_exp_avg_sq'].to(self.setting.device)

        self.print_log_to_file('==> Init trainer from ' + ckpt_file + ' successfully! \n', 'a')

附件:yotube视频中的代码

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datautils import MyTrainDataset

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os


def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)
Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐