神经网络-MNIST数据集训练

news/2024/10/4 19:07:17/文章来源:https://blog.csdn.net/2301_77698138/article/details/142316164

文章目录

  • 一、MNIST数据集
    • 1.数据集概述
    • 2.数据集组成
    • 3.文件结构
    • 4.数据特点
  • 二、代码实现
    • 1.数据加载与预处理
    • 2. 模型定义
    • 3. 训练和测试函数
    • 4.训练和测试结果
  • 三、总结

一、MNIST数据集

MNIST数据集是深度学习和计算机视觉领域非常经典且基础的数据集,它包含了大量的手写数字图片,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。

1.数据集概述

  • 来源:MNIST数据集由Yann LeCun等人于1994年创建,它是NIST(美国国家标准与技术研究所)数据集的一个子集。
  • 内容:数据集主要包含手写数字(0~9)的图片及其对应的标签。
  • 用途:作为深度学习和计算机视觉领域的入门级数据集,它适合初学者练习建立模型、训练和预测。

2.数据集组成

MNIST数据集总共包含两个子数据集:训练数据集和测试数据集。

训练数据集:

  • 包含了60,000张28x28像素的灰度图像。
  • 对应的标签文件包含了60,000个标签,每个标签对应一张图像中的手写数字。

测试数据集:

  • 包含了10,000张28x28像素的灰度图像。
  • 对应的标签文件包含了10,000个标签。

3.文件结构

MNIST数据集包含四个文件,分别是训练集图像、训练集标签、测试集图像和测试集标签。这些文件以gzip格式压缩,并且不是标准的图像格式,需要通过专门的编程方式读取。

  • 训练集图像:train-images-idx3-ubyte.gz
  • 训练集标签:train-labels-idx1-ubyte.gz)
  • 测试集图像:t10k-images-idx3-ubyte.gz
  • 测试集标签:t10k-labels-idx1-ubyte.gz

4.数据特点

  • 图像大小:每张图像的大小为28x28像素,是一个灰度图像,位深度为8(灰度值范围为0~255)。
  • 数据来源:手写数字来自250个不同的人。
  • 数据格式:图像数据以字节的形式存储在二进制文件中,标签文件则存储了每张图像对应的数字标签。

二、代码实现

1.数据加载与预处理

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor"""下载训练集数据(包含训练图片和标签)"""
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),  # 张量,图片是不能直接传入神经网络模型
)"""下载测试集数据(包括训练图片和标签)"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
  • 下载数据集:使用torchvision.datasets.MNIST下载并加载MNIST数据集。数据集分为训练集和测试集,train=True为训练集数据,train=False为测试集数据。
  • 数据转换:数据通过transform=ToTensor()进行预处理,将图片转换为PyTorch张量(Tensor),并自动将像素值归一化到[0,1]区间。
  • 数据封装:使用DataLoader将数据集封装成批次(batch)形式,便于后续的训练和测试过程。

2. 模型定义

class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型,nn.moduledef __init__(self):  # python基础关于类,self类自己本身super().__init__()  # 继承的父类初始化self.flatten = nn.Flatten()  # 展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28 * 28, 128)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去前一层神经元的个数,当前本层神经元个数self.hidden2 = nn.Linear(128, 256)self.hidden3 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)def forward(self, x):  # 前向传播,告诉它,数据的流向。x = self.flatten(x)  # 图像进行展开x = self.hidden1(x)x = torch.sigmoid(x) x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.out(x)return xmodel = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到gpu
print(model)

定义类:定义了一个名为NeuralNetwork的类,该类继承自nn.Module,用于构建神经网络模型。
模型结构:模型包含输入层,输出层,隐藏层,其中隐藏层使用了Sigmoid激活函数,最后输出10个类别的得分(对应0-9的数字)
打印模型结构:打印了模型的结构,有助于理解模型的架构。
在这里插入图片描述

3. 训练和测试函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:  # 其中batch为每一个数据的编号X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPUpred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值loss# Backpropaqation 进来-个bqtch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 100 == 0:print(f"loss:{loss_value:>7f}  [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)  # 把预测值Ture、False 转换为01test_loss /= num_batches  # 评判模型的好坏correct /= size  # 平均的准确率print(f"Test result:\n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}")
  • train函数负责训练模型。它遍历训练数据集的每个批次,计算模型的预测、损失,并执行反向传播和参数更新。
  • test函数用于评估模型在测试集上的性能。它遍历测试数据集的每个批次,计算模型的预测和损失,但不进行反向传播或参数更新。
  • 在训练和测试过程中,都使用了torch.no_grad()上下文管理器来关闭梯度计算,这可以节省内存和计算资源。

4.训练和测试结果

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 创建一个优化器,S6D为随机梯度下降算法epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
  • 使用torch.optim.Adam优化器来优化模型的参数,这里的学习率设置为0.01。
  • 定义了训练轮次(epochs),并在每个epoch中调用train函数来训练模型。
  • 最后,使用test函数来评估模型在测试集上的性能,并打印出准确率和平均损失。
    在这里插入图片描述

三、总结

本文为大家介绍了MNIST数据集的组成、文件结构与数据集特点,然后为大家提供了MNIST数据集训练的相关代码,通过对数据集进行处理,训练来得出准确率与损失率,为大家更好的展示。总之,MNIST数据集是深度学习和计算机视觉领域不可或缺的基础数据集之一,对于初学者来说是一个非常好的练手项目,同时也为相关领域的研究和实验提供了宝贵的数据资源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ldbm.cn/p/443512.html

如若内容造成侵权/违法违规/事实不符,请联系编程新知网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动态规划:07.路径问题_珠宝的最大价值_C++

题目链接:LCR 166. 珠宝的最高价值 - 力扣(LeetCode)https://leetcode.cn/problems/li-wu-de-zui-da-jie-zhi-lcof/description/ 一、题目解析 题目: 解析: 有过做前几道题的经验,我们会发现这道题其实就…

MySQL篇(窗口函数/公用表达式(CTE))

目录 讲解一:窗口函数 一、简介 二、常见操作 1. sumgroup by常规的聚合函数操作 2. sum窗口函数的聚合操作 三、基本语法 1. Function(arg1,..., argn) 1.1. 聚合函数 sum函数:求和 min函数 :最小值 1.2. 排序函数 1.3. 跨行函数…

markdown 使用技巧

文章目录 markdown使用技巧1.标题快捷键设置2.文档可读性设置 markdown使用技巧 1.标题快捷键设置 ctl 1:一级标题 ctl 2:二级标题 ctl 3:三级标题 ctl 4:四级标题 ...2.文档可读性设置 输入~~~pro 可选择代码框,并且可以选择不同的字体 ctrl shift ] : 可…

基于MicroPython的ESP32控制LED灯闪烁设计方案的Wokwi仿真

以下是一个基于MicroPython的ESP32控制LED灯闪烁设计方案的Wokwi仿真: 一、硬件准备: 在Wokwi仿真平台(https://wokwi.com/)选择ESP32开发板,添加一个LED灯,和一个220欧姆限流电阻。 二、硬件连接: 1. 将LED灯的阳极…

gin配置swagger文档

一、基本准备工作 1、安装依赖包 go get -u github.com/swaggo/swag/cmd/swag go get -u github.com/swaggo/gin-swagger go get -u github.com/swaggo/files2、在根目录上配置swagger的路由文件 //2.初始化路由router : initialize.Routers()// 配置swaggerdocs.SwaggerInfo…

Linux进程等待 | 程序替换

进程终止 一个进程退出了,无非只有三种情况: 代码跑完了,结果正确代码跑完了,结果不正确代码没跑完,程序异常退出了 代码跑完了,我们可以通过退出码获取其结果是否正确,(这个退出…

初始爬虫6

数据提取 数据提取总结 响应分类 结构化 json数据(高频出现) json模块 jsonpath模块 xml数据(低频出现) re模块 …

【OJ刷题】双指针问题6

这里是阿川的博客,祝您变得更强 ✨ 个人主页:在线OJ的阿川 💖文章专栏:OJ刷题入门到进阶 🌏代码仓库: 写在开头 现在您看到的是我的结论或想法,但在这背后凝结了大量的思考、经验和讨论 目录 1…

Rust使用Actix-web和SeaORM库开发WebAPI通过Swagger UI查看接口文档

本文将介绍Rust语言使用Actix-web和SeaORM库,数据库使用PostgreSQL,开发增删改查项目,同时可以通过Swagger UI查看接口文档和查看标准Rust文档 开始项目 首先创建新项目,名称为rusty_crab_api cargo new rusty_crab_apiCargo.t…

中标喜讯!湖北产教融合教育研究院携手湖北医药学院,共筑同等学力申硕新篇章

在深化教育改革、推动产教融合的大潮中,湖北产教融合教育研究院再传捷报!其控股子公司——武汉产教融汇教育科技有限公司,凭借卓越的技术研发实力、丰富的教育资源储备及高效的运营管理能力,成功中标湖北医药学院同等学力申硕工作…

Windows下SDL2创建最简单的一个窗口

先看运行效果 再上代码&#xff1a; #include <stdio.h> #include "SDL.h"int main(int argc, char* argv[]) {// 初始化SDL视频子系统if (SDL_Init(SDL_INIT_VIDEO) -1){printf("Error: %s\n", SDL_GetError());return -1;} // 创建一个窗口SDL_…

通过防火墙分段增强网络安全

什么是网络分段‌ 随着组织规模的扩大&#xff0c;管理一个不断扩大的网络成为一件棘手的事情&#xff0c;同时确保安全性、合规性、性能和不间断的运行可能是一项艰巨的任务。为了克服这一挑战&#xff0c;网络管理员部署了网络分段&#xff0c;这是一种将网络划分为更小且易…

nvm无法下载npm的问题

1、问题 执行 nvm install 14.21.3 命令&#xff0c;node可以正常下载成功&#xff0c;npm下载失败 2、nvm配置信息 …/nvm/settings.txt root: D:\soft\nvm path: D:\soft\nodejs node_mirror: npmmirror.com/mirrors/node/ npm_mirror: registry.npmmirror.com/mirrors/…

【设计模式-外观】

这里写自定义目录标题 定义UML图角色作用代码使用场景 定义 为子系统中一组相关接口提供一致界面&#xff0c;定义一个高级接口&#xff0c;使得子系统更加容易使用。 UML图 角色作用 外观&#xff08;Facade&#xff09;角色&#xff1a;这是外观模式的核心&#xff0c;它知…

论文阅读: SigLit | SigLip |Sigmoid Loss for Language Image Pre-Training

论文地址&#xff1a;https://arxiv.org/pdf/2303.15343 项目地址&#xff1a;https://github.com/google-research/big_vision 发表时间&#xff1a;2023年3月27日 我们提出了一种用于语言图像预训练&#xff08;SigLIP&#xff09;的简单成对 Sigmoid 损失。与使用 softmax …

避免服务器安装多个mysql引起冲突的安装方法

最近工作中涉及到了数据迁移的工作. 需要升级mysql版本到8.4.2为了避免升级后服务出现异常, 因此需要保留原来的mysql,所以会出现一台服务器上运行两个mysql的情况 mysql并不陌生, 但是安装不当很容易引起服务配置文件的冲突,导致服务不可用, 今天就来介绍一种可以完美避免冲突…

linux网络编程1

24.9.16学习目录 一.TCP/IP协议简介1.TCP/IP的分层结构2.协议的简介 二、MAC地址和IP地址1.网卡2.MAC地址3.IP地址&#xff08;1&#xff09;IP地址的分类&#xff08;2&#xff09;IP地址的特点&#xff08;3&#xff09;回环IP地址 3.子网掩码4.端口&#xff08;1&#xff09…

7天速成前端 ------学习日志 (继苍穹外卖之后)

前端速成计划总结&#xff1a; 全26h课程&#xff0c;包含html&#xff0c;css&#xff0c;js&#xff0c;vue3&#xff0c;预计7天内学完。 起始日期&#xff1a;9.16 预计截止&#xff1a;9.22 每日更新&#xff0c;学完为止。 学前计划 课…

Linux操作系统入门(五)

————————————————————————————————————————— 至此&#xff0c;大部分Linux操作系统的文件操作指令已经总结完成&#xff0c;最后还需进行vim编辑器的使用 使用方法&#xff1a;在FinalShell终端中输入"vim [文件]",以下图…

【CPP】模板(后篇)

目录 13.1 非类型模板参数13.2 函数模板的特化13.3 类模板的特化13.4 模板的分离编译 这里是oldking呐呐,感谢阅读口牙!先赞后看,养成习惯! 个人主页:oldking呐呐 专栏主页:深入CPP语法口牙 13.1 非类型模板参数 顾名思义,非类型模板参数就是一个模板的参数,只不过不是类型,而…