1 Star 1 Fork 0

嘉心糖EDA / 基于CNN的FashionAI数据集图像分类任务

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
LGPL-3.0

基于FashionAI数据集的时尚图像分类任务研究——卷积神经网络的不同架构对分类效果的影响

1. 引言

本实验旨在探究使用卷积神经网络(CNN)进行图像分类任务的效果,并使用阿里巴巴“图像和美”团队和香港理工大学纺织与制衣系联合推出的FashionAI数据集进行实验。FashionAI数据集是业界首个同时符合服饰专业性和机器学习要求的大规模高质量时尚数据集,数据集来源于天池“FashionAI全球挑战赛——服饰属性标签识别”竞赛。本实验将使用该数据集进行图像分类任务,通过对实验结果的分析,探究卷积神经网络在图像分类任务中的优势和局限性。通过本实验的探究,可以为进一步的图像分类任务提供参考和指导。同时,研究人员可基于本实验开展服饰图像检索、标签导航、服饰搭配等应用技术。

2. 实验设计

1 数据预处理

由于天池上下载的数据不是ImageNet的文件结构,而是把图片地址和label都存在csv文件里,而且csv表格里面的地址是相对地址,不能直接读取,label是nnnynnnn这种,不能直接喂给torchvision,故须自己写python代码来方便生成dataset。

1.1 数据集的特点

训练集
训练集包括如下几个文件:round1_fashionAI_attributes_train.tar(初赛训练集), round1_fashionAI_attributes_test_A.tar(初赛测试集A榜+答案), round1_fashionAI_attributes_test_B.tar(初赛测试集B榜+答案), round2_fashionAI_attributes_train.tar(复赛训练集)
训练数据文件结构
a) 我们将提供用于训练的图像数据和识别标签,文件夹结构:

  • Images
  • Annotations
  • README.md
    b) Images : 存放图像数据,jpeg编码图像文件。图像文件名如:0000001.jpg
    c) Annotations : 存放属性标签标注数据,csv格式文件。
    d) README.md:对数据的详细介绍。

训练数据示例
图3. 训练数据各属性示意图

图3. 训练数据各属性示意图

上图对应的csv标注文件示例:

ImageName AttrKey AttrValues
0000001.jpg sleeve_length_labels nnnnnnnym
0000001.jpg skirt_length_labels nynnnn
0000001.jpg neck_design_labels nnnyn
0000001.jpg coat_length_labels nnynnnnn

标注文件格式说明:
ImageName : 图像文件名,对应Images文件夹下面的图像文件
AttrKey :属性维度,比如袖长(sleeve_length_labels),裤长(pant_length_labels)等等
AttrValues :AttrKey属性维度对应的属性值。袖长属性维度(AttrKey)有9个属性值(AttrValues):不存在,无袖,杯袖,短袖,中袖,七分袖,九分袖,长袖,超长袖。分别对应上图示例标注数据中的:nnnnnnmyn。一共九位,每一位是下面三个值中的一个:y(yes, 一定是) m(maybe, 可能是) n(no, 一定不是),对某个图的某个属性维度的标注数据中,有且只有一个“y”标注,其余的可能是“m”或者“n”。

1.2 数据预处理的方法

# 数据预处理:整个项目只要跑一次
# 所有的 type 在天池数据集的 README.md 文件里
ALLTYPE = ['skirt_length_labels','coat_length_labels','collar_design_labels','lapel_design_labels','neck_design_labels','neckline_design_labels','pant_length_labels','sleeve_length_labels']

import pandas as pd
for type in ALLTYPE:

    # 读取数据
    df_train1 = pd.DataFrame(pd.read_csv("./autodl-tmp-cp/round1_fashionAI_attributes_train/Annotations/label.csv",header=None))
    df_train2 = pd.DataFrame(pd.read_csv("./autodl-tmp-cp/round2_fashionAI_attributes_train/Annotations/label.csv",header=None))
    df_test1a = pd.DataFrame(pd.read_csv("./autodl-tmp-cp/round1_fashionAI_attributes_test_a/Tests/round1_fashionAI_attributes_answer_a.csv",header=None))
    df_test1b = pd.DataFrame(pd.read_csv("./autodl-tmp-cp/round1_fashionAI_attributes_test_b/Tests/round1_fashionAI_attributes_answer_b.csv",header=None))
    print(df_train1)

    # 选择数据
    df_train1 = df_train1[df_train1[1] == type] # 判断等式是否成立
    df_train2 = df_train2[df_train2[1] == type] # 判断等式是否成立
    df_test1a = df_test1a[df_test1a[1] == type] # 判断等式是否成立
    df_test1b = df_test1b[df_test1b[1] == type] # 判断等式是否成立
    print(df_train1)

    # 数据集的地址和label规范化
    def df_label_std(df,pth):
        df_std = df.copy()
        for i in df.index :
            df_std[0][i] = str(pth+str(df[0][i]))
            if df[2][i][0] == 'y' :
                df_std[2][i] = 0
            elif df[2][i][1] == 'y' :
                df_std[2][i] = 1
            elif df[2][i][2] == 'y' :
                df_std[2][i] = 2
            elif df[2][i][3] == 'y' :
                df_std[2][i] = 3
            elif df[2][i][4] == 'y' :
                df_std[2][i] = 4
            elif df[2][i][5] == 'y' :
                df_std[2][i] = 5
            elif df[2][i][6] == 'y' :
                df_std[2][i] = 6
            elif df[2][i][7] == 'y' :
                df_std[2][i] = 7
            elif df[2][i][8] == 'y' :
                df_std[2][i] = 8
            elif df[2][i][9] == 'y' :
                df_std[2][i] = 9
        return df_std

    df_train1_std = df_label_std(df_train1,"./autodl-tmp-cp/round1_fashionAI_attributes_train/")
    df_train2_std = df_label_std(df_train2,"./autodl-tmp-cp/round2_fashionAI_attributes_train/")
    df_test1a_std = df_label_std(df_test1a,"./autodl-tmp-cp/round1_fashionAI_attributes_test_a/")
    df_test1b_std = df_label_std(df_test1b,"./autodl-tmp-cp/round1_fashionAI_attributes_test_b/")
    print(df_train1_std)

    # Pandas数据合并-重置索引
    df_train = pd.concat([df_train1_std,df_train2_std], ignore_index=True)
    df_eval = pd.concat([df_test1a_std], ignore_index=True)
    df_test = pd.concat([df_test1b_std], ignore_index=True)
    print(df_train)

    # pandas 之 DataFrame 保存为文件
    df_train.to_csv("./df-path/"+type+"/train.csv",header=False,index=False)
    df_eval.to_csv("./df-path/"+type+"/eval.csv",header=False,index=False)
    df_test.to_csv("./df-path/"+type+"/test.csv",header=False,index=False)

以上代码的主要作用是将 FashionAI 数据集中的标注数据按照不同的属性维度进行筛选和标准化,并将处理后的数据集分别保存为 train.csv、eval.csv 和 test.csv 三个文件。
数据集预处理的具体步骤如下:

  1. 读取 FashionAI 数据集中的标注数据,然后根据不同的属性维度(即 ALLATTRKEY 列表中的元素)对数据进行筛选,得到对应属性维度的标注数据。
  2. 对标注数据进行标准化,将图像文件路径和属性值分别保存为标注数据的第一列和第三列,将图片对应的属性值AttrValues从nnnyn形式映射到数字0-9,y在字符串中的第几位就映射到数字几。
  3. 将数据集划分为训练集,验证集和测试集。其中round1_fashionAI_attributes_train.tar(初赛训练集), round2_fashionAI_attributes_train.tar(复赛训练集)合并为本实验的训练集, round1_fashionAI_attributes_test_A.tar(初赛测试集A榜+答案), round1_fashionAI_attributes_test_B.tar(初赛测试集B榜+答案)分别作为本实验的验证集和测试集。
  4. 将处理后的数据集分别保存为 train.csv、eval.csv 和 test.csv 三个文件,并按照属性维度的不同将文件保存在不同的文件夹中。

基于数据集的特征,用以上代码对数据集进行预处理,我们就可以方便地在 PyTorch 中使用这些数据集进行模型训练和评估。

2 数据增强

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值:

# 计算归一化参数mean和std(整个实验只要跑一次)
import pandas as pd
from torchvision import transforms
from torch.utils import data
from PIL import Image
import torch
for attrkey in ALLATTRKEY:
    df_train = pd.DataFrame(pd.read_csv("./df-path/"+attrkey+"/train.csv",header=None))

    transform_aug = transforms.Compose([
        transforms.Resize((224,224)), #图片统一缩放到224*224
        transforms.ToTensor(),
    ])
    
    class fashion_dataset_train(data.Dataset):
        def __init__(self):
            self.df = df_train
            
        def __getitem__(self,index):
            label = self.df[2][self.df.index[index-1]]
            img = transform_aug(Image.open(self.df[0][self.df.index[index-1]]))
            return img,label

        def __len__(self):
            return len(self.df)
        
    dataset_train = fashion_dataset_train()
    
    def get_mean_and_std(train_data):
        train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=1, shuffle=False, num_workers=0,
            pin_memory=True)
        mean = torch.zeros(3)
        std = torch.zeros(3)
        for X, _ in train_loader:
            for d in range(3):
                mean[d] += X[:, d, :, :].mean()
                std[d] += X[:, d, :, :].std()
        mean.div_(len(train_data))
        std.div_(len(train_data))
        return list(mean.numpy()), list(std.numpy())
    
    mean_and_std = get_mean_and_std(dataset_train)
    print(mean_and_std)

    # 将训练集计算得到的归一化参数mean和std保存到文件,并按照属性维度AttrKey的不同将文件保存在不同的文件夹中,以便后续读取
    with open("./df-path/"+attrkey+"mean_and_std.txt", 'w') as f:
        f.write(str(mean_and_std[0][0])+'\n')
        f.write(str(mean_and_std[0][1])+'\n')
        f.write(str(mean_and_std[0][2])+'\n')
        f.write(str(mean_and_std[1][0])+'\n')
        f.write(str(mean_and_std[1][1])+'\n')
        f.write(str(mean_and_std[1][2])+'\n')

数据增强/归一化

from torchvision import transforms
from cutout import Cutout # 自定义的 Cutout 操作 https://blog.csdn.net/u013685264/article/details/122562509
from random_erasing import RandomErasing # 自定义的 RandomErasing 操作 https://blog.csdn.net/u013685264/article/details/122564323
cut = Cutout() # 自定义的 Cutout 操作
re = RandomErasing() # 自定义的 RandomErasing 操作
norm_mean = [0.65300995, 0.61700195, 0.6039642]#[0.485, 0.456, 0.406]
norm_std = [0.24680348, 0.25832596, 0.25814256]#[0.229, 0.224, 0.225]
transform_aug = transforms.Compose([ # 定义数据增强的操作
    transforms.RandomRotation(10), # 随机旋转
    # transforms.CenterCrop(448), # 中心裁剪
    # transforms.GaussianBlur(kernel_size=(21,21)), # 高斯模糊
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), # 随机调整亮度、对比度和饱和度
    transforms.Resize([224, 224]), # 图片统一缩放到224*224
    # transforms.RandomCrop(224, padding=16), # 随机裁剪
    transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(), # 将图片转换为张量
    transforms.Normalize(norm_mean, norm_std) # 归一化
])
transform_norm = transforms.Compose([ # 定义归一化的操作
    transforms.Resize([224, 224]), # 图片统一缩放到224*224
    transforms.ToTensor(), # 将图片转换为张量
    transforms.Normalize(norm_mean, norm_std) # 归一化
])

这段代码定义了两个数据预处理的操作,分别为 transform_augtransform_norm。其中,transform_aug 包含了多个数据增强的操作,如旋转、裁剪、翻转、高斯模糊、颜色调整等,可以增强模型的鲁棒性和泛化能力。而 transform_norm 则只包含了归一化的操作,用于在模型训练和测试时对数据进行标准化处理,使得模型更容易学习和泛化。

显示处理后的照片

# 显示照片 https://tianchi.aliyun.com/course/337/4003
from PIL import Image
cat = Image.open(df_train[0][df_train.index[0]])  # 从网上下载图片到本地后,再上传至DSW
print(cat.size)
cat_t = cut(transform_aug(cat))  # 传入transforms中的数据是PIL数据,lena_t为tensor
cat_t.shape  # 3*224*224 ; 当T.CenterCrop()的参数大于T.Resize()的参数时,周围用0填充
transforms.ToPILImage()(cat_t)

自定义Dataset

# 自定义Dataset
from torch.utils import data
from PIL import Image
import torch as torch

class fashion_dataset_train(data.Dataset):
    def __init__(self):
        self.df = df_train
        
    def __getitem__(self,index):
        label = self.df[2][self.df.index[index-1]]
        img = cutout(transform_aug(Image.open(self.df[0][self.df.index[index-1]]))) # tarmsform再cutout
        return img,label

    def __len__(self):
        return len(self.df)

class fashion_dataset_eval(data.Dataset):
    def __init__(self):
        self.df = df_eval

    def __getitem__(self,index):
        label = self.df[2][self.df.index[index-1]]
        img = transform_norm(Image.open(self.df[0][self.df.index[index-1]]))
        return img,label

    def __len__(self):
        return len(self.df)

class fashion_dataset_test(data.Dataset):
    def __init__(self):
        self.df = df_test

    def __getitem__(self,index):
        label = self.df[2][self.df.index[index-1]]
        img = transform_norm(Image.open(self.df[0][self.df.index[index-1]]))
        return img,label

    def __len__(self):
        return len(self.df)

# 创建Dateset(可以自定义)
dataset_train = fashion_dataset_train() # Dataset部分自定义过的fashion_dataset_train
print("-训练集大小={}".format(len(dataset_train)))
dataset_eval = fashion_dataset_eval() # Dataset部分自定义过的fashion_dataset_eva
print("-验证集大小={}".format(len(dataset_eval)))
dataset_test = fashion_dataset_test() # Dataset部分自定义过的fashion_dataset_test
print("-测试集大小={}".format(len(dataset_test)))

3 模型选择

torchvision.models上的模型

模型 类型 论文时间 论文地址
ResNet CNN CVPR 2016 https://arxiv.org/abs/1512.0338
Resnext CNN CVPR 2017 https://arxiv.org/abs/1611.05431
MobileNetV2 轻量级CNN CVPR 2018 https://arxiv.org/abs/1801.04381
MobileNetV3 轻量级CNN ICCV 2019 https://arxiv.org/abs/1905.02244
ShuffleNetV2 轻量级CNN ECCV 2018 https://arxiv.org/abs/1807.11164v1
RegNet 轻量级CNN CVPR 2020 https://arxiv.org/pdf/2003.13678.pdf
ConvNeXt CNN CVPR 2022 https://arxiv.org/abs/2201.03545
MaxVit Transformer ECCV 2022 https://arxiv.org/pdf/2204.01697.pdf

综合算力和准确度的考量,最后选择的是轻量级卷积神经网络RegNet作为本实验所用网络。

from torchvision import models
# 创建 ResNet18 模型
# model = models.resnet18(weights = None).to(device)
# model.load_state_dict(torch.load('./models/resnet18-f37072fd.pth'))

# model = models.resnet50(weights = None).to(device)
# model.load_state_dict(torch.load('./models/resnet50-0676ba61.pth'))

# model = models.resnext50_32x4d(weights = None).to(device)
# model.load_state_dict(torch.load('./models/resnext50_32x4d-7cdf4587.pth'))

# model = models.mobilenet_v2(weights = None).to(device)
# model.load_state_dict(torch.load('./models/mobilenet_v2-b0353104.pth'))

# model = models.mobilenet_v3_large(weights = None).to(device)
# model.load_state_dict(torch.load('./models/mobilenet_v3_large-8738ca79.pth'))

# model = models.shufflenet_v2_x1_0(weights = None).to(device)
# model.load_state_dict(torch.load('./models/shufflenetv2_x1-5666bf0f80.pth'))

# model = models.shufflenet_v2_x2_0(weights = None).to(device)
# model.load_state_dict(torch.load('./models/shufflenetv2_x2_0-8be3c8ee.pth'))

# model = models.regnet_y_800mf(weights = None).to(device)
# model.load_state_dict(torch.load('./models/regnet_y_800mf-58fc7688.pth'))

model = models.regnet_y_1_6gf(weights = None).to(device)
model.load_state_dict(torch.load('./models/regnet_y_1_6gf-0d7bc02a.pth'))

# model = models.regnet_y_3_2gf(weights = None).to(device)
# model.load_state_dict(torch.load('./models/regnet_y_3_2gf-9180c971.pth'))

# model = models.convnext_tiny(weights = None).to(device)
# model.load_state_dict(torch.load('./models/convnext_tiny-983f1562.pth'))

# model = models.convnext_small(weights = None).to(device)
# model.load_state_dict(torch.load('./models/convnext_small-0c510722.pth'))

# model = models.convnext_base(weights = None).to(device)
# model.load_state_dict(torch.load('./models/convnext_base-6075fbad.pth'))

# model = models.maxvit_t(weights = None).to(device)
# model.load_state_dict(torch.load('./models/maxvit_t-bc5ab103.pth'))

torch.backends.cudnn.benchmark = True

4 训练和验证

模型 batch size learning rate optimizer scheduler Epoch 最好验证集acc 最好验证集acc出现的Step 说明
resnet18 64 1e-2 RMSprop None 30 0.7192 27 正常
resnet50 64 1e-2 RMSprop None 30 0.6332 25 正常
resnext50_32x4d 64 1e-2 RMSprop None 30 0.7323 25 正常
shufflenet_v2_x1_0 64 1e-2 RMSprop None 30 0.7887 26 正常
shufflenet_v2_x2_0 64 5e-3 RMSprop None 30 0.7481 22 正常
mobilenet_v2 64 5e-3 RMSprop None 30 0.3599 5 过拟合,训练acc递增,验证acc振荡
mobilenet_v3_large 64 5e-3 RMSprop None 30 0.501 11 过拟合,训练acc递增,验证acc振荡
maxvit 64 5e-3 RMSprop None 15 0.1968 2 欠拟合,训练acc递增,验证acc不动,训练开销太大
convnext_tiny 64 1e-3 RMSprop None 5 0.1204 1 训练验证acc和loss都严重振荡,降低学习率也不行,训练开销太大
regnet_y_800mf 64 1e-3 RMSprop Cosine,T_max=30 90 0.8438 32 余弦退火后最好acc才0.8266,提升不大,且训练时间太长,故最后没有选择退火
regnet_y_1_6gf 128 1e-3 AdamW Cosine,T_max=30 30 0.8431 30 最好,且训练时间可以接受

从上述试验结果来看,RegNet在验证集的准确度最高,且训练时间可以接受,故最后选择RegNet预训练网络进行进一步的实验。

代码如下

# 定义一个训练名字,方便保存日志和模型
name = "20_regnet_y_1_6gf_lr=1e-3+AdamW+Cosine"

# python:判断文件夹是否存在,不存在则创建
import os
if not os.path.isdir("./df-path/"+attrkey):
    # 创建文件夹 
    os.makedirs("./df-path/"+attrkey)

# 导入环境
import torch.utils.data
from torch import nn
from torch.utils.tensorboard import SummaryWriter # TensorBoard 日志
from torch.cuda.amp import autocast # 混合精度

dataset_train_size = len(dataset_train)
dataset_eval_size = len(dataset_eval)
print("------训练集大小{}------".format(dataset_train_size))
print("------验证集大小{}------".format(dataset_eval_size))

# 加载数据 
# 从数据集 data_train 和 data_test 中读取数据,并将数据分成 batch_size 个数据组成的 batch
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128, num_workers=16, pin_memory=True, prefetch_factor=8, persistent_workers=True)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=128, num_workers=16, pin_memory=True, prefetch_factor=8, persistent_workers=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128, num_workers=16)

# 训练轮数
EPOCHS = 30

# 定义损失函数为交叉熵损失函数
loss_fun = nn.CrossEntropyLoss().to(device)

# 定义优化器optimizer
learning_rate = 1e-3
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

# 定义学习率调整器scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=1) # 按照 step_size=1 和 gamma=0.85 的方式对优化器的学习率进行调整
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=EPOCHS, eta_min=1e-6) # T_max 参数指的是 cosine 函数经过多少次更新完成二分之一个周期

# 记录训练次数
train_step = 0
eval_step = 0
eval_accuracy_max = 0

# 创建一个 TensorBoard 的日志记录器
writer = SummaryWriter("tf-logs/"+name)

# 开始训练和验证
for epoch in range(EPOCHS):
    print("------第{}轮训练开始------".format(epoch + 1))
 
    # 训练
    # 将模型设置为训练模式
    model.train()
    # 记录训练准确率
    train_accuracy = 0
    for data in dataloader_train: # 这里的每一次循环都是一个 minibatch,一次 for 循环里面有 64 个数据。
        # 数据分开 一个是图片数据,一个是图片标签
        # 将数据和标签分别存储到 img 和 label 中,并将数据和标签转移到 GPU 上
        img, label = data
        img = img.to(device)
        label = label.to(device)
        with autocast(): # 混合精度
            # 拿到模型的预测值
            output = model(img)
            # 计算损失值
            loss = loss_fun(output, label)
 
        # 优化器优化模型
        # 优化开始~ ~ 先将梯度清零
        optimizer.zero_grad()
        # 反向传播+更新
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 20) # 梯度裁剪
        optimizer.step()
 
        train_step += 1
 
        # 记录训练损失
        # 每 100 次训练记录一次训练损失
        if train_step % 10 == 0:
            writer.add_scalar("train_loss", loss.item(), train_step)
            print("第{}次训练,LOSS值为:{}".format(train_step, loss.item()))
 
        # 记录训练准确率
        accuracy = (output.argmax(1) == label).sum()
        train_accuracy += accuracy
 
    # 验证
    # 将模型设置为评估模式
    model.eval()
    # 记录验证损失和验证准确率
    loss_eval = 0
    eval_accuracy = 0
    with torch.no_grad():
        for data in dataloader_eval:
            # 将数据和标签分别存储到 img 和 label 中,并将数据和标签转移到 GPU 上
            img, label = data
            img = img.to(device)
            label = label.to(device)
            with autocast(): # 混合精度
                # 拿到模型的预测值
                output = model(img)
                # 计算损失值
                loss = loss_fun(output, label)
            loss_eval += loss.item()
            # 记录验证准确率
            accuracy = (output.argmax(1) == label).sum()
            eval_accuracy += accuracy
 
    eval_step += 1
    # 记录验证损失
    print("第{}轮验证,LOSS值为:{}".format(epoch + 1, loss_eval))
    writer.add_scalar("test_loss", loss_eval, eval_step)
    # 记录训练准确率和验证准确率
    print("第{}轮训练,准确率为:{}".format(epoch + 1, train_accuracy / dataset_train_size))
    print("第{}轮验证,准确率为:{}".format(epoch + 1, eval_accuracy / dataset_eval_size))
    writer.add_scalar("train_accuracy", train_accuracy / dataset_train_size, eval_step)
    writer.add_scalar("test_accuracy", eval_accuracy / dataset_eval_size, eval_step)
 
    # 模型保存
    # 每个 epoch 保存一次模型
    if eval_accuracy > eval_accuracy_max:
        eval_accuracy_max = eval_accuracy
        torch.save(model, "save/"+name+"/model_epoch={}_acc={}%.pth".format(epoch, (eval_accuracy / dataset_eval_size *100)//100))
        print("模型已保存")
    
    # 更新优化器的学习率
    scheduler.step()
    
# 关闭 TensorBoard 日志记录器
writer.close()

4 测试

# 测试
model = torch.load('save/16_regnet_y_800mf_lr=1e-3+RMSprop+Cosine.01+Cosine/model_epoch=28_acc=0.0%.pth').to(device)
# 测试准确率
test_accuracy = 0
with torch.no_grad():
    for data in dataloader_test:
        # 将数据和标签分别存储到 img 和 label 中,并将数据和标签转移到 GPU 上
        img, label = data
        img = img.to(device)
        label = label.to(device)
        # 记录测试准确率
        accuracy = (output.argmax(1) == label).sum()
        test_accuracy += accuracy

# 记录测试准确率
print("测试准确率为:{}".format(test_accuracy / dataset_eval_size))
测试准确率为:0.8291

3. 实验结果

通过本实验,我们能发现轻量级CNN,例如RegNet,在FashionAI数据集的分类上表现已经很好,验证集acc和测试集acc都达到了80%以上。

4. 实验讨论

4.1 实验的局限性

  1. 准确度的局限性
    本实验所提出的RegNet模型,在30轮训练后,验证集准确度最高为0.8431,测试集准确度为0.8291,还有提升空间。
  2. 对数据集中AttrValues属性值m(maybe, 可能是)项的忽略的影响
    忽略了训练集中的m(maybe, 可能是)一项,可能对本实验的准确度有一定影响。

4.2 改进和优化建议

  1. 使用一个模型识别所有属性维度 本实验训练一个模型只能识别所有属性维度,要识别所有属性维度则需要多个模型。可以考虑训练一个模型来识别所有属性维度。
  2. 提高准确度 使用更复杂的网络,或者使用Transformer模型,可能能提高识别的准确度。

5. 结论

根据实验结果和分析,我们得出以下结论:

  1. 在FashionAI数据集上,卷积神经网络(CNN)在图像分类任务中表现出色。通过对数据集的训练和测试,我们得到了较高的分类准确率,证明了CNN在服饰图像分类中的有效性。

  2. 在实验中,我们使用了经典的卷积神经网络结构,如ResNet、ShuffleNet和MobileNet等。通过对比实验,我们发现在FashionAI数据集上,轻量化CNN(如RegNet)在分类任务中表现很好,具有很高的准确率和可以接受的训练开销。

  3. 实验结果还表明,数据预处理对图像分类任务的性能有重要影响。通过对图像进行预处理,如图像归一化、数据增强等,可以提高分类准确率和模型的鲁棒性。

综上所述,本实验通过对FashionAI数据集的图像分类任务,验证了卷积神经网络在服饰图像分类中的有效性和优越性。同时,我们也发现了一些局限性和改进空间,为进一步的研究和应用提供了指导和参考。

GNU LESSER GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/> Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. 0. Additional Definitions. As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. "The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. 1. Exception to Section 3 of the GNU GPL. You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. 2. Conveying Modified Versions. If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. 3. Object Code Incorporating Material from Library Header Files. The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the object code with a copy of the GNU GPL and this license document. 4. Combined Works. You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the Combined Work with a copy of the GNU GPL and this license document. c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. d) Do one of the following: 0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. 1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) 5. Combined Libraries. You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 6. Revised Versions of the GNU Lesser General Public License. The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library.

简介

基于Pytorch框架和预训练网络,准确度达到83% 探索了不同的CNN架构,optimizer及scheduler对结果的影响 展开 收起
Python 等 2 种语言
LGPL-3.0
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
Python
1
https://gitee.com/lceda/FashionAI.git
git@gitee.com:lceda/FashionAI.git
lceda
FashionAI
基于CNN的FashionAI数据集图像分类任务
master

搜索帮助