昇思MindSpore 应用学习-ShuffleNet图像分类-CSDN

昇思MindSpore 应用学习-ShuffleNet图像分类(AI 代码解析)

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。
了解ShuffleNet更多详细内容,详见论文ShuffleNet
如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.

模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/gkk,一共有g组,所有组共有(in_channels/gkk)out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

图片来源:Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.
Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1
kk,则卷积核参数量为:in_channelskk,得到的feature maps通道数与输入通道数相等
Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×1,卷积核参数量为(in_channels/g
1*1)*out_channels。

from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor

class GroupConv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
        # 初始化父类
        super(GroupConv, self).__init__()
        # 设置组数
        self.groups = groups
        # 创建卷积层列表
        self.convs = nn.CellList()
        # 循环根据组数添加卷积层
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))

    def construct(self, x):
        # 将输入张量在指定轴上分割,并根据组数划分
        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
        outputs = ()
        # 遍历每一组,进行卷积操作
        for i in range(self.groups):
            outputs = outputs + (self.convs[i](features[i].astype("float32")),)
        # 将所有输出在指定轴上拼接
        out = ops.cat(outputs, axis=1)
        return out

解析:

  1. 导入必要的模块和库
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor
  • mindspore.nn 用于构建神经网络层。
  • mindspore.ops 包含各种操作函数。
  • mindspore.Tensor 用于创建张量。
  1. 定义 GroupConv
class GroupConv(nn.Cell):
  • GroupConv 继承自 nn.Cell,是一个自定义的神经网络层。
  1. **初始化方法 **__init__
def __init__(self, in_channels, out_channels, kernel_size,
             stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
    super(GroupConv, self).__init__()
    self.groups = groups
    self.convs = nn.CellList()
    for _ in range(groups):
        self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                    kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                    padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
  • 参数解析:
    • in_channels:输入通道数。
    • out_channels:输出通道数。
    • kernel_size:卷积核大小。
    • stride:步幅。
    • pad_mode:填充模式。
    • pad:填充值。
    • groups:组数。
    • has_bias:是否有偏置。
  • 初始化卷积层列表 self.convs,根据组数创建多个卷积层。
  1. **构建方法 **construct
def construct(self, x):
    features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
    outputs = ()
    for i in range(self.groups):
        outputs = outputs + (self.convs[i](features[i].astype("float32")),)
    out = ops.cat(outputs, axis=1)
    return out
  • ops.split:将输入 x 在指定轴上分割,根据组数划分。
  • outputs:存储每组卷积后的输出。
  • 遍历每一组,进行卷积操作,并将结果拼接到 outputs 中。
  • 最后使用 ops.cat 在指定轴上拼接所有卷积输出,返回最终结果。

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。
为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

为了阅读方便,将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。

ShuffleNet模块

如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:

  1. 将开始和最后的1×1卷积模块(降维、升维)改成Point Wise Group Convolution;
  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;
  3. 降采样模块中,3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的3×3平均池化,并把相加改成拼接。

class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

解析:

  1. **初始化方法 **__init__
def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
    super(ShuffleV1Block, self).__init__()
    self.stride = stride
    pad = ksize // 2
    self.group = group
    if stride == 2:
        outputs = oup - inp
    else:
        outputs = oup
    self.relu = nn.ReLU()
  • 参数解析:
    • inp:输入通道数。
    • oup:输出通道数。
    • group:分组数。
    • first_group:是否是第一个分组。
    • mid_channels:中间通道数。
    • ksize:卷积核大小。
    • stride:步幅。
  • 根据步幅设置输出通道数。
  • 初始化 ReLU 激活函数。
  1. 构建分支
branch_main_1 = [
    GroupConv(in_channels=inp, out_channels=mid_channels,
              kernel_size=1, stride=1, pad_mode="pad", pad=0,
              groups=1 if first_group else group),
    nn.BatchNorm2d(mid_channels),
    nn.ReLU(),
]
branch_main_2 = [
    nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
              pad_mode='pad', padding=pad, group=mid_channels,
              weight_init='xavier_uniform', has_bias=False),
    nn.BatchNorm2d(mid_channels),
    GroupConv(in_channels=mid_channels, out_channels=outputs,
              kernel_size=1, stride=1, pad_mode="pad", pad=0,
              groups=group),
    nn.BatchNorm2d(outputs),
]
self.branch_main_1 = nn.SequentialCell(branch_main_1)
self.branch_main_2 = nn.SequentialCell(branch_main_2)
  • 定义两个分支 branch_main_1branch_main_2,分别包含卷积、批归一化和 ReLU 激活函数。
  • branch_main_1 使用 GroupConv 进行分组卷积。
  • branch_main_2 使用 nn.Conv2d 进行深度可分离卷积,然后再次使用 GroupConv 进行分组卷积。
  1. **构建方法 **construct
def construct(self, old_x):
    left = old_x
    right = old_x
    out = old_x
    right = self.branch_main_1(right)
    if self.group > 1:
        right = self.channel_shuffle(right)
    right = self.branch_main_2(right)
    if self.stride == 1:
        out = self.relu(left + right)
    elif self.stride == 2:
        left = self.branch_proj(left)
        out = ops.cat((left, right), 1)
        out = self.relu(out)
    return out
  • 根据步幅选择不同的操作:
    • 如果步幅为 1,则进行残差连接。
    • 如果步幅为 2,则进行平均池化和拼接操作。
  1. **通道混洗方法 **channel_shuffle
def channel_shuffle(self, x):
    batchsize, num_channels, height, width = ops.shape(x)
    group_channels = num_channels // self.group
    x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
    x = ops.transpose(x, (0, 2, 1, 3, 4))
    x = ops.reshape(x, (batchsize, num_channels, height, width))
    return x
  • 对输入张量进行通道混洗操作,以增加通道间的信息流动。
  • 通过重塑和转置操作实现通道混洗。

构建ShuffleNet网络

ShuffleNet网络结构如下图所示,以输入图像224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为3×3,stride为2的卷积层,输出特征图大小为112×112,channel为24;然后通过stride为2的最大池化层,输出特征图大小为56×56,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为1×1×960,再经过全连接层和softmax,得到分类概率。

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        
        # 定义网络的每个阶段的重复次数
        self.stage_repeats = [4, 8, 4]
        
        # 设置模型大小
        self.model_size = model_size
        
        # 根据group和model_size设置输出通道数
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        
        input_channel = self.stage_out_channels[1]
        
        # 定义第一层卷积和池化层
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        
        # 构建ShuffleNetV1的中间层
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        
        self.features = nn.SequentialCell(features)
        
        # 定义全局平均池化层和分类器
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        # 构建前向传播
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

解析:

  1. **初始化方法 **__init__
def __init__(self, n_class=1000, model_size='2.0x', group=3):
    super(ShuffleNetV1, self).__init__()
    print('model size is ', model_size)
    self.stage_repeats = [4, 8, 4]
    self.model_size = model_size
  • 参数解析:
    • n_class:分类数。
    • model_size:模型大小。
    • group:分组数。
  • 打印模型大小并初始化各个阶段的重复次数。
  1. 设置输出通道数
if group == 3:
    if model_size == '0.5x':
        self.stage_out_channels = [-1, 12, 120, 240, 480]
    elif model_size == '1.0x':
        self.stage_out_channels = [-1, 24, 240, 480, 960]
    elif model_size == '1.5x':
        self.stage_out_channels = [-1, 24, 360, 720, 1440]
    elif model_size == '2.0x':
        self.stage_out_channels = [-1, 48, 480, 960, 1920]
    else:
        raise NotImplementedError
elif group == 8:
    if model_size == '0.5x':
        self.stage_out_channels = [-1, 16, 192, 384, 768]
    elif model_size == '1.0x':
        self.stage_out_channels = [-1, 24, 384, 768, 1536]
    elif model_size == '1.5x':
        self.stage_out_channels = [-1, 24, 576, 1152, 2304]
    elif model_size == '2.0x':
        self.stage_out_channels = [-1, 48, 768, 1536, 3072]
    else:
        raise NotImplementedError
input_channel = self.stage_out_channels[1]
  • 根据 group 和 model_size 设置 stage_out_channels
  1. 定义第一层卷积和池化层
self.first_conv = nn.SequentialCell(
    nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
    nn.BatchNorm2d(input_channel),
    nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  • first_conv 包含一个卷积层、一个批归一化层和一个 ReLU 激活层。
  • maxpool 是一个最大池化层。
  1. 构建 ShuffleNetV1 的中间层
features = []
for idxstage in range(len(self.stage_repeats)):
    numrepeat = self.stage_repeats[idxstage]
    output_channel = self.stage_out_channels[idxstage + 2]
    for i in range(numrepeat):
        stride = 2 if i == 0 else 1
        first_group = idxstage == 0 and i == 0
        features.append(ShuffleV1Block(input_channel, output_channel,
                                       group=group, first_group=first_group,
                                       mid_channels=output_channel // 4, ksize=3, stride=stride))
        input_channel = output_channel
self.features = nn.SequentialCell(features)
  • 根据 stage_repeats 构建 ShuffleNetV1 的中间层。
  • 每个阶段中的模块数由 stage_repeats 决定。
  • 第一个模块的步幅为 2,其余模块的步幅为 1。
  • 第一个模块是否为第一个分组由 first_group 决定。
  1. 定义全局平均池化层和分类器
self.globalpool = nn.AvgPool2d(7)
self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
  1. **构建前向传播方法 **construct
def construct(self, x):
    x = self.first_conv(x)
    x = self.maxpool(x)
    x = self.features(x)
    x = self.globalpool(x)
    x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
    x = self.classifier(x)
    return x
  • 前向传播中依次通过初始卷积层、最大池化层、中间层、全局平均池化层,最后进行分类。

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

训练集准备与加载

采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。

from download import download  # 从download模块导入download函数

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
# 定义一个字符串变量url,用于存储要下载的文件的链接

download(url, "./dataset", kind="tar.gz", replace=True)
# 调用download函数,传入以下参数:
# url: 文件下载的URL地址
# "./dataset": 文件下载后的存储路径
# kind="tar.gz": 指定下载文件的类型为tar.gz格式
# replace=True: 如果指定路径下已存在同名文件,则替换已存在的文件

解析:

  1. from download import download
    这是导入语句,从名为 download 的模块中导入 download 函数。
  2. url** 变量**
    定义了一个字符串变量 url,指向要下载的文件的网络地址,这里是一个 .tar.gz 格式的压缩文件。
  3. download** 函数**
    调用 download 函数执行下载操作,有四个参数:
    • url: 需要下载的文件的 URL。
    • "./dataset": 下载后的文件存储路径。
    • kind="tar.gz": 文件类型,指定下载的是一个 tar.gz 格式的压缩包。
    • replace=True: 如果目标路径下已有同名文件,设置为 True 表示将其替换。

API:

  • download(url, path, kind, replace):
    这是一个下载函数,通常用于从指定的 URL 地址下载文件到本地指定路径,并根据参数处理文件类型和是否替换已存在的文件。
import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms

def get_dataset(train_dataset_path, batch_size, usage):
    image_trans = []
    if usage == "train":
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),  # 随机裁剪图像到32x32大小,裁剪偏移量为(4, 4, 4, 4)
            vision.RandomHorizontalFlip(prob=0.5),  # 随机水平翻转图像,概率为0.5
            vision.Resize((224, 224)),  # 调整图像大小到224x224
            vision.Rescale(1.0 / 255.0, 0.0),  # 将图像像素值从[0, 255]缩放到[0, 1]
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),  # 标准化图像
            vision.HWC2CHW()  # 将图像从HWC格式转换为CHW格式
        ]
    elif usage == "test":
        image_trans = [
            vision.Resize((224, 224)),  # 调整图像大小到224x224
            vision.Rescale(1.0 / 255.0, 0.0),  # 将图像像素值从[0, 255]缩放到[0, 1]
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),  # 标准化图像
            vision.HWC2CHW()  # 将图像从HWC格式转换为CHW格式
        ]
    label_trans = transforms.TypeCast(ms.int32)  # 将标签数据类型转换为int32
    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)  # 加载Cifar10数据集
    dataset = dataset.map(image_trans, 'image')  # 对图像应用变换
    dataset = dataset.map(label_trans, 'label')  # 对标签应用变换
    dataset = dataset.batch(batch_size, drop_remainder=True)  # 将数据集分批处理,丢弃不足一个batch的数据
    return dataset

dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")  # 获取训练数据集
batches_per_epoch = dataset.get_dataset_size()  # 获取每个epoch的batch数量

解析:

  1. 导入模块
    • mindsporemindspore.dataset 是 MindSpore 框架的核心模块。
    • Cifar10Dataset 用于加载 CIFAR-10 数据集。
    • visiontransforms 模块包含图像处理和数据转换的函数。
  2. get_dataset** 函数**:
    • 根据 usage 参数(“train” 或 “test”)定义不同的图像变换。
    • 使用 Cifar10Dataset 加载数据集,并根据路径、用途和是否打乱数据进行初始化。
    • 使用 map 方法对图像和标签应用相应的变换。
    • 使用 batch 方法将数据集分批处理,并丢弃不足一个batch的数据。
  3. 数据集变换
    • RandomCrop:随机裁剪图像。
    • RandomHorizontalFlip:随机水平翻转图像。
    • Resize:调整图像大小。
    • Rescale:缩放图像像素值。
    • Normalize:标准化图像。
    • HWC2CHW:将图像从 HWC 格式转换为 CHW 格式。
    • TypeCast:将标签数据类型转换为 int32
  4. 获取数据集
    • 调用 get_dataset 函数获取训练数据集。
    • 使用 get_dataset_size 方法获取每个epoch的batch数量。

API:

  • Cifar10Dataset:加载 CIFAR-10 数据集。
  • vision.RandomCrop:随机裁剪图像。
  • vision.RandomHorizontalFlip:随机水平翻转图像。
  • vision.Resize:调整图像大小。
  • vision.Rescale:缩放图像像素值。
  • vision.Normalize:标准化图像。
  • vision.HWC2CHW:将图像从 HWC 格式转换为 CHW 格式。
  • transforms.TypeCast:将数据类型转换为指定类型。
  • dataset.map:对数据集应用变换。
  • dataset.batch:将数据集分批处理。
  • dataset.get_dataset_size:获取数据集的batch数量。

模型训练

本节用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。

import time  # 导入时间模块,用于记录训练时间
import mindspore  # 导入MindSpore框架
import numpy as np  # 导入NumPy库,用于数值计算
from mindspore import Tensor, nn  # 导入Tensor和神经网络模块
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy  # 导入训练相关的模块

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="GPU")  # 设置MindSpore上下文,使用PYNATIVE模式和GPU设备
    net = ShuffleNetV1(model_size="2.0x", n_class=10)  # 初始化ShuffleNetV1模型,指定模型大小和类别数
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)  # 定义交叉熵损失函数,使用标签平滑
    min_lr = 0.0005  # 最小学习率
    base_lr = 0.05  # 基础学习率
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr, base_lr, batches_per_epoch*250, batches_per_epoch, decay_epoch=250)  # 余弦退火学习率调度器
    lr = Tensor(lr_scheduler[-1])  # 获取学习率的最后一值,并转换为Tensor
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)  # 定义优化器,使用动量优化
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)  # 定义固定损失缩放管理器
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)  # 初始化模型,指定AMP级别和损失缩放管理器
    callback = [TimeMonitor(), LossMonitor()]  # 定义回调函数列表,包括时间监控和损失监控
    save_ckpt_path = "./"  # 设置保存检查点的路径
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)  # 配置检查点保存策略,每个epoch保存一次,最多保留5个检查点
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)  # 定义检查点回调函数
    callback += [ckpt_callback]  # 将检查点回调函数添加到回调函数列表中

    print("============== Starting Training ==============")  # 打印训练开始提示
    start_time = time.time()  # 记录开始时间
    model.train(250, dataset, callbacks=callback)  # 开始模型训练,训练250个epoch,使用定义好的回调函数
    use_time = time.time() - start_time  # 计算训练总时间
    hour = str(int(use_time // 60 // 60))  # 计算训练时间的小时部分
    minute = str(int(use_time // 60 % 60))  # 计算训练时间的分钟部分
    second = str(int(use_time % 60))  # 计算训练时间的秒部分
    print("total time:" + hour + "h " + minute + "m " + second + "s")  # 打印训练总时间
    print("============== Train Success ==============")  # 打印训练成功提示

if __name__ == '__main__':
    train()  # 如果脚本作为主程序运行,则调用train函数

解析:

  1. 导入模块
    • time:用于记录和计算训练时间。
    • mindspore:MindSpore框架的核心模块。
    • numpy:用于数值计算。
    • mindspore.Tensormindspore.nn:Tensor和神经网络模块。
    • mindspore.train:训练相关的模块,包括模型检查点、回调函数等。
  2. set_context
    • mindspore.set_context(mode, device_target):设置MindSpore的运行模式和设备,使用PYNATIVE模式和GPU设备。
  3. 模型和损失函数
    • ShuffleNetV1:初始化ShuffleNetV1模型,指定模型大小为"2.0x",类别数为10。
    • nn.CrossEntropyLoss:定义交叉熵损失函数,使用标签平滑(label_smoothing)。
  4. 学习率调度和优化器
    • mindspore.nn.cosine_decay_lr:定义余弦退火学习率调度器,设置最小和基础学习率,计算每个epoch的学习率。
    • nn.Momentum:定义动量优化器,设置学习率、动量、权重衰减和损失缩放。
  5. 模型和损失缩放管理器
    • ms.amp.FixedLossScaleManager:定义固定损失缩放管理器,以避免数值问题。
    • Model:初始化模型,指定损失函数、优化器、AMP级别和损失缩放管理器。
  6. 回调函数
    • TimeMonitorLossMonitor:定义时间和损失监控回调函数。
    • CheckpointConfigModelCheckpoint:定义模型检查点配置和回调函数,每个epoch保存一次检查点,最多保留5个检查点。
    • 将检查点回调函数添加到回调函数列表中。
  7. 训练过程
    • 打印“Starting Training”提示。
    • 记录开始时间。
    • 调用model.train方法,开始训练250个epoch,并使用定义好的回调函数。
    • 计算和打印训练总时间。
    • 打印“Train Success”提示。

API:

  • mindspore.set_context:设置MindSpore的运行模式和设备。
  • ShuffleNetV1:初始化ShuffleNetV1模型。
  • nn.CrossEntropyLoss:定义交叉熵损失函数。
  • mindspore.nn.cosine_decay_lr:余弦退火学习率调度器。
  • nn.Momentum:动量优化器。
  • ms.amp.FixedLossScaleManager:固定损失缩放管理器。
  • Model:初始化模型。
  • TimeMonitor** 和 **LossMonitor:时间和损失监控回调函数。
  • CheckpointConfig** 和 **ModelCheckpoint:模型检查点配置和回调函数。
  • model.train:开始模型训练。

训练好的模型保存在当前目录的shufflenetv1-250_391.ckpt中,用作评估。

模型评估

在CIFAR-10的测试集上对模型进行评估。
设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()接口对模型进行评估。

from mindspore import load_checkpoint, load_param_into_net  # 导入MindSpore的检查点加载函数

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="GPU")  # 设置MindSpore上下文,使用GRAPH模式和GPU设备
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")  # 获取测试数据集
    net = ShuffleNetV1(model_size="2.0x", n_class=10)  # 初始化ShuffleNetV1模型,指定模型大小和类别数
    param_dict = load_checkpoint("shufflenetv1-250_391.ckpt")  # 加载训练好的模型参数
    load_param_into_net(net, param_dict)  # 将参数加载到网络中
    net.set_train(False)  # 设置网络为评估模式
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)  # 定义交叉熵损失函数,使用标签平滑
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(), 'Top_5_Acc': Top5CategoricalAccuracy()}  # 定义评估指标
    model = Model(net, loss_fn=loss, metrics=eval_metrics)  # 初始化模型,指定损失函数和评估指标

    start_time = time.time()  # 记录开始时间
    res = model.eval(dataset, dataset_sink_mode=False)  # 评估模型,使用测试数据集
    use_time = time.time() - start_time  # 计算评估时间
    hour = str(int(use_time // 60 // 60))  # 计算时间的小时部分
    minute = str(int(use_time // 60 % 60))  # 计算时间的分钟部分
    second = str(int(use_time % 60))  # 计算时间的秒部分
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-250_391.ckpt" + "', time: " + hour + "h " + minute + "m " + second + "s"  # 构建日志信息
    print(log)  # 打印日志信息

    filename = './eval_log.txt'  # 日志文件名
    with open(filename, 'a') as file_object:  # 以追加模式打开日志文件
        file_object.write(log + '\n')  # 写入日志信息

if __name__ == '__main__':
    test()  # 如果脚本作为主程序运行,则调用test函数

解析:

  1. 导入模块
    • load_checkpointload_param_into_net 用于加载已保存的模型参数。
  2. test** 函数**:
    • 设置上下文mindspore.set_context 设置为 GRAPH 模式并使用 GPU 设备。
    • 获取数据集:调用 get_dataset 函数获取测试数据集。
    • 初始化模型ShuffleNetV1 初始化模型,指定模型大小为"2.0x",类别数为10。
    • 加载模型参数:使用 load_checkpoint 加载模型参数,并使用 load_param_into_net 将参数加载到网络中。
    • 设置评估模式net.set_train(False) 设置网络为评估模式。
    • 定义损失函数nn.CrossEntropyLoss 定义交叉熵损失函数,使用标签平滑。
    • 定义评估指标eval_metrics 包含损失、Top-1 准确率和 Top-5 准确率。
    • 初始化模型Model 初始化模型,指定损失函数和评估指标。
  3. 评估过程
    • 记录开始时间:使用 time.time() 记录评估开始时间。
    • 评估模型:调用 model.eval 方法评估模型,并使用测试数据集。
    • 计算评估时间:计算评估所用时间,将其转换为小时、分钟和秒。
    • 记录和打印日志:构建日志信息,打印日志,并将日志写入文件 eval_log.txt 中。

API:

  • mindspore.set_context:设置MindSpore的运行模式和设备。
  • load_checkpoint:加载已保存的模型参数。
  • load_param_into_net:将加载的参数导入到模型中。
  • model.eval:评估模型性能。
  • time.time:记录时间,用于计算评估时长。

模型预测

在CIFAR-10的测试集上对模型进行预测,并将预测结果可视化。

import mindspore  # 导入MindSpore框架
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘制图像
import mindspore.dataset as ds  # 导入MindSpore的数据集模块

# 初始化ShuffleNetV1模型
net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []  # 用于存放显示图像的列表
param_dict = load_checkpoint("shufflenetv1-250_391.ckpt")  # 加载训练好的模型参数
load_param_into_net(net, param_dict)  # 将模型参数加载到网络中
model = Model(net)  # 初始化模型

# 加载CIFAR-10数据集,用于预测
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")

# 加载CIFAR-10数据集,用于显示图像
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)  # 按批次加载数据
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()  # 获取一个批次的图像并转换为NumPy数组

# 定义图像预处理步骤
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),  # 随机裁剪图像到32x32大小,裁剪偏移量为(4, 4, 4, 4)
    vision.RandomHorizontalFlip(prob=0.5),  # 随机水平翻转图像,概率为0.5
    vision.Resize((224, 224)),  # 调整图像大小到224x224
    vision.Rescale(1.0 / 255.0, 0.0),  # 将图像像素值从[0, 255]缩放到[0, 1]
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),  # 标准化图像
    vision.HWC2CHW()  # 将图像从HWC格式转换为CHW格式
]

# 应用预处理步骤并按批次加载预测数据
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)

# 类别标签字典
class_dict = {
    0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer",
    5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"
}

# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())  # 获取一个批次的预测数据
output = model.predict(ms.Tensor(predict_data['image']))  # 进行模型推理
pred = np.argmax(output.asnumpy(), axis=1)  # 获取预测结果的类别索引
index = 0

# 遍历并显示图像和预测结果
for image in show_images_lst:
    plt.subplot(2, 8, index + 1)
    plt.title('{}'.format(class_dict[pred[index]]))  # 显示预测的类别标签
    index += 1
    plt.imshow(image)  # 显示图像
    plt.axis("off")  # 隐藏坐标轴
plt.show()  # 显示所有图像

解析:

  1. 导入模块
    • mindspore:MindSpore框架的核心模块。
    • matplotlib.pyplot:用于绘制图像。
    • mindspore.dataset:MindSpore的数据集模块。
  2. 模型初始化
    • ShuffleNetV1:初始化ShuffleNetV1模型,指定模型大小为"2.0x",类别数为10。
    • load_checkpointload_param_into_net:加载训练好的模型参数,并将其加载到网络中。
    • Model:初始化模型。
  3. 加载数据集
    • Cifar10Dataset:加载CIFAR-10数据集。
    • dataset_show:用于显示图像的数据集,按批次加载数据。
    • show_images_lst:获取一个批次的图像并转换为NumPy数组。
  4. 图像预处理
    • vision.RandomCropvision.RandomHorizontalFlipvision.Resizevision.Rescalevision.Normalizevision.HWC2CHW:定义图像预处理步骤。
    • dataset_predict.map(image_trans, 'image'):应用预处理步骤。
    • 按批次加载预测数据。
  5. 类别标签字典
    • class_dict:定义类别标签字典。
  6. 推理和展示
    • 使用Matplotlib绘制图像,显示预测结果和对应图像。
    • model.predict:进行模型推理。
    • np.argmax:获取预测结果的类别索引。
    • 遍历并显示图像和预测结果,隐藏坐标轴。

API:

  • load_checkpoint:加载已保存的模型参数。
  • load_param_into_net:将加载的参数导入到模型中。
  • Cifar10Dataset:加载CIFAR-10数据集。
  • model.predict:进行模型推理。
  • np.argmax:获取最大值的索引。
  • Matplotlib:用于绘制图像。

整体代码

代码解析

GroupConv 类
class GroupConv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
        super(GroupConv, self).__init__()
        self.groups = groups
        self.convs = nn.CellList()
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))

    def construct(self, x):
        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
        outputs = ()
        for i in range(self.groups):
            outputs = outputs + (self.convs[i](features[i].astype("float32")),)
        out = ops.cat(outputs, axis=1)
        return out

解析:

  • GroupConv 类实现了分组卷积操作。
  • 在初始化方法中,根据输入和输出的通道数以及分组数,创建多个卷积层,每个卷积层处理一部分输入通道。
  • construct 方法中,首先将输入特征图按通道数分成多个组,然后对每个组应用相应的卷积层,最后将所有组的输出拼接起来。
ShuffleV1Block 类
class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

解析:

  • ShuffleV1Block 类实现了 ShuffleNet 的基本模块。
  • 初始化方法中定义了两个分支:branch_main_1branch_main_2,分别用于处理输入特征图的不同部分。
  • construct 方法中,根据步长决定如何处理输入特征图,并进行通道重排(Channel Shuffle)。
  • channel_shuffle 方法实现了通道重排操作,将特征图的通道重新排列以实现不同组之间的信息交流。
ShuffleNetV1 类
class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

解析:

  • ShuffleNetV1 类定义了整个 ShuffleNet 网络结构。
  • 初始化方法中定义了网络的各个阶段和通道数,并根据模型大小和组数设置相应的参数。
  • construct 方法中,依次应用初始卷积层、最大池化层、多个 ShuffleNet 模块、全局平均池化层和全连接层,最终输出分类结果。
数据集准备与加载
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms

def get_dataset(train_dataset_path, batch_size, usage):
    image_trans = []
    if usage == "train":
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    elif usage == "test":
        image_trans = [
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    label_trans = transforms.TypeCast(ms.int32)
    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
    dataset = dataset.map(image_trans, 'image')
    dataset = dataset.map(label_trans, 'label')
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()

解析:

  • get_dataset 函数用于加载和预处理 CIFAR-10 数据集。
  • 根据训练或测试的不同需求,应用不同的图像变换操作。
  • 最后将数据集按批次大小进行分批处理。
模型训练
import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="GPU")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    model.train(250, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()

解析:

  • train 函数用于训练 ShuffleNet 模型。
  • 设置训练模式和设备,定义网络、损失函数、学习率调度器和优化器。
  • 使用 Model 类封装模型、损失函数和优化器,并设置回调函数进行训练监控和模型保存。
  • 最后调用 model.train 进行模型训练,并打印训练时间和结果。
模型评估
from mindspore import load_checkpoint, load_param_into_net

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="GPU")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-250_391.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-250_391.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

if __name__ == '__main__':
    test()

解析:

  • test 函数用于评估训练好的 ShuffleNet 模型。
  • 设置评估模式和设备,加载预训练的模型参数。
  • 定义评估指标,包括损失和 Top-1、Top-5 准确率。
  • 使用 Model

模型评估(续)

解析(续):

  • 使用 Model 类封装模型、损失函数和评估指标。
  • 调用 model.eval 进行模型评估,并打印评估结果和时间。
模型预测
import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds

net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-250_391.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),
    vision.RandomHorizontalFlip(prob=0.5),
    vision.Resize((224, 224)),
    vision.Rescale(1.0 / 255.0, 0.0),
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    vision.HWC2CHW()
]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
    plt.subplot(2, 8, index+1)
    plt.title('{}'.format(class_dict[pred[index]]))
    index += 1
    plt.imshow(image)
    plt.axis("off")
plt.show()

解析:

  • 加载预训练模型参数并定义模型。
  • 使用 CIFAR-10 数据集加载并预处理预测数据。
  • 在预测数据上执行模型预测,并将预测结果可视化。
  • class_dict 定义了 CIFAR-10 数据集的类别标签。
  • 使用 Matplotlib 将预测结果和对应的图像以网格形式展示。

详细解析各个 API

  1. mindspore.nn.Cell
    • Cell 是 MindSpore 中所有网络层和模型的基类。通过继承 Cell,用户可以自定义网络模型。
  2. mindspore.nn.Conv2d
    • Conv2d 是一个二维卷积层。其主要参数包括输入和输出通道数、卷积核大小、步长、填充方式等。
  3. mindspore.ops.split
    • split 操作用于在指定轴上将张量分割为多个子张量。
  4. mindspore.ops.cat
    • cat 操作用于在指定轴上拼接多个张量。
  5. mindspore.nn.SequentialCell
    • SequentialCell 是一个容器,用于将多个网络层顺序组合在一起,便于构建复杂的网络结构。
  6. mindspore.dataset.Cifar10Dataset
    • Cifar10Dataset 是一个数据集类,用于加载 CIFAR-10 数据集。
  7. mindspore.dataset.transforms
    • transforms 模块提供了多种数据预处理和数据增强方法,例如 RandomCrop, RandomHorizontalFlip, Rescale, Normalize 等。
  8. mindspore.train.Model
    • Model 是一个高级 API,用于封装网络模型、损失函数、优化器和评估指标,便于训练、评估和推理。
  9. mindspore.train.callback
    • callback 提供了一些回调函数,例如 TimeMonitor, LossMonitor, ModelCheckpoint, 用于在训练过程中记录时间、监控损失和保存模型。
  10. mindspore.load_checkpoint
  • load_checkpoint 用于加载预训练的模型参数。
  1. mindspore.load_param_into_net
  • load_param_into_net 用于将加载的模型参数导入网络中。
  1. vision
  • vision 提供了一些图像处理方法,例如图像裁剪、翻转、缩放和归一化等。

通过以上代码和解析,相信你已经了解了 ShuffleNet 的实现以及如何在 MindSpore 中进行数据加载、模型训练、评估和预测。希望这些信息对你有所帮助!

相关推荐

  1. MindSpore 应用学习-ResNet50迁移学习-CSDN

    2024-07-21 01:20:02       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-21 01:20:02       145 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-21 01:20:02       159 阅读
  3. 在Django里面运行非项目文件

    2024-07-21 01:20:02       133 阅读
  4. Python语言-面向对象

    2024-07-21 01:20:02       145 阅读

热门阅读

  1. Piping(√)

    2024-07-21 01:20:02       27 阅读
  2. KTV点歌系统有什么作用?

    2024-07-21 01:20:02       28 阅读
  3. Flutter 状态管理调研总结

    2024-07-21 01:20:02       28 阅读