python实战(十五)——中文手写体数字图像CNN分类
一、任务背景
本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字。
二、python建模
1、数据读取
首先,读取jpg数据文件,可以看到总共有15000张图像数据。
import pandas as pd
import ospath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))
我们也可以打印一张图片出来看看。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 定义图片路径
image_path = path+files[3]# 加载图片
image = mpimg.imread(image_path)# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off') # 关闭坐标轴
plt.show()
2、数据集构建
加载必要的库以便后续使用,再定义一些超参数。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。
# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)
下面定义数据集类并完成数据的加载。
# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L') # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')
3、模型构建
我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。
# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x
4、模型实例化及训练
下面我们对模型进行实例化并定义criterion和optimizer。
# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
定义训练的代码并调用代码训练模型。
from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)
5、测试模型
定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。
# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)
三、完整代码
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_scorepath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L') # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)
四、总结
本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。
相关文章:

python实战(十五)——中文手写体数字图像CNN分类
一、任务背景 本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿…...

[论文阅读] (37)CCS21 DeepAID:基于深度学习的异常检测(解释)
祝大家新春快乐,蛇年吉祥! 《娜璋带你读论文》系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢。由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正࿰…...

Linux - 进程间通信(2)
目录 2、进程池 1)理解进程池 2)进程池的实现 整体框架: a. 加载任务 b. 先描述,再组织 I. 先描述 II. 再组织 c. 创建信道和子进程 d. 通过channel控制子进程 e. 回收管道和子进程 问题1: 解答1ÿ…...
Kafka 消费端反复 Rebalance: `Attempt to heartbeat failed since group is rebalancing`
文章目录 Kafka 消费端反复 Rebalance: Attempt to heartbeat failed since group is rebalancing1. Rebalance 过程概述2. 错误原因分析2.1 消费者组频繁加入或退出2.1.1 消费者故障导致频繁重启2.1.2. 消费者加入和退出导致的 Rebalance2.1.3 消费者心跳超时导致的 Rebalance…...

SpringBoot+Electron教务管理系统 附带详细运行指导视频
文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.查询课程表代码2.保存学生信息代码3.用户登录代码 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SpringBootElectron框架开发的教务管理系统。首先ÿ…...

操作系统(Linux Kernel 0.11Linux Kernel 0.12)解读整理——内核初始化(main init)之控制台工作
前言 在 Linux 内核中,字符设备主要包括控制终端设备和串行终端设备,对这些设备的输入输出涉及控制台驱动程序,这包括键盘中断驱动程序 keyboard.S 和控制台显示驱动程序 console.c,还有终端驱动程序与上层程序之间的接口部分。 终端驱动程序…...
Autogen_core: Message and Communication
目录 完整代码代码解释1. 消息的数据类:2. 创建代理人(MyAgent):3. 创建和运行代理人的运行时环境:4. 根据发送者路由消息的代理(RoutedBySenderAgent):5. 创建和运行带路由的代理&a…...
ComfyUI工作流教程、软件使用、开发指导、模型下载
在人工智能和设计技术迅速发展的今天,AI赋能的工作流已成为创意设计与生产的重要工具。无论是图片处理、服装试穿,还是室内设计与3D建模,这些智能化的解决方案极大地提高了效率和创作质量。 为了帮助设计师、开发者以及AI技术爱好者更好地利用这些工具,我们整理了一份详尽…...

零基础Vue学习1——Vue学习前环境准备
目录 环境准备 创建Vue项目 项目目录说明 后续开发过程中常用命令 环境准备 安装开发工具:vscode、webstorm、idea都可以安装node:V22以上版本即可安装pnpm 不知道怎么安装的可以私信我教你方法 创建Vue项目 本地新建一个文件夹,之后在文件夹下打开…...

定西市建筑房屋轮廓数据shp格式gis无偏移坐标(字段有高度和楼层)内容测评
定西市建筑房屋轮廓数据是GIS(Geographic Information System,地理信息系统)领域的重要资源,用于城市规划、土地管理、环境保护等多个方面。这份2022年的数据集采用shp(Shapefile)格式,这是一种…...

汉语向编程指南
汉语向编程指南 一、引言王阳明代数与流形学习理论慢道缓行理性人类型指标系统为己之学与意气实体过程晏殊几何学半可分离相如矩阵与生成气质邻域镶嵌气度曲面细分生成气质邻域镶嵌气度曲面细分社会科学概论琴生生物机械科技工业研究所软凝聚态物理开发工具包琴生生物机械 报告…...
Writing an Efficient Vulkan Renderer
本文出自GPU Zen 2。 Vulkan 是一个新的显式跨平台图形 API。它引入了许多新概念,即使是经验丰富的图形程序员也可能不熟悉。Vulkan 的主要目标是性能——然而,获得良好的性能需要深入了解这些概念及其高效应用方法,以及特定驱动程序实现的实…...
AI常见的算法
人工智能(AI)中常见的算法分为多个领域,如机器学习、深度学习、强化学习、自然语言处理和计算机视觉等。以下是一些常见的算法及其用途: 1. 机器学习 (Machine Learning) 监督学习 (Supervised Learning) 线性回归 (Linear Regr…...

LibreChat
文章目录 一、关于 LibreChat✨特点 二、使用LibreChat🪶多合一AI对话 一、关于 LibreChat LibreChat 是增强的ChatGPT克隆:Features Agents, Anthropic, AWS, OpenAI, Assistants API, Azure, Groq, o1, GPT-4o, Mistral, OpenRouter, Vertex AI, Gemi…...

Spring Boot 日志:项目的“行车记录仪”
一、什么是Spring Boot日志 (一)日志引入 在正式介绍日志之前,我们先来看看上篇文章中(Spring Boot 配置文件)中的验证码功能的一个代码片段: 这是一段校验用户输入的验证码是否正确的后端代码,…...

Spring Boot 实现文件上传和下载
文章目录 Spring Boot 实现文件上传和下载一、引言二、文件上传1、配置Spring Boot项目2、创建文件上传控制器3、配置文件上传大小限制 三、文件下载1、创建文件下载控制器 四、使用示例1、文件上传2、文件下载 五、总结 Spring Boot 实现文件上传和下载 一、引言 在现代Web应…...

慕课:若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战,启动文档
代码: Javahhhh/miaosha191: 运行成功了慕课若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战https://github.com/Javahhhh/miaosha191 https://github.com/Javahhhh/miaosha191 miaosha项目启动文档 需安装的配置环境: VMwar…...

React第二十七章(Suspense)
Suspense Suspense 是一种异步渲染机制,其核心理念是在组件加载或数据获取过程中,先展示一个占位符(loading state),从而实现更自然流畅的用户界面更新体验。 应用场景 异步组件加载:通过代码分包实现组件…...
虚幻基础08:组件接口
能帮到你的话,就给个赞吧 😘 文章目录 作用 作用 组件接口:可以直接调用对方的组件接口,而无需转换为actor。 实现对象间的通知。 A 通知 B 做什么。...

iPhone SE(第三代) 设备详情图
目录 产品宣传图内部图——后设备详细信息 产品宣传图 内部图——后 设备详细信息 信息收集于HubWeb.cn...

智慧医疗能源事业线深度画像分析(上)
引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)
CSI-2 协议详细解析 (一) 1. CSI-2层定义(CSI-2 Layer Definitions) 分层结构 :CSI-2协议分为6层: 物理层(PHY Layer) : 定义电气特性、时钟机制和传输介质(导线&#…...

dedecms 织梦自定义表单留言增加ajax验证码功能
增加ajax功能模块,用户不点击提交按钮,只要输入框失去焦点,就会提前提示验证码是否正确。 一,模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互
引擎版本: 3.8.1 语言: JavaScript/TypeScript、C、Java 环境:Window 参考:Java原生反射机制 您好,我是鹤九日! 回顾 在上篇文章中:CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...