python实战(十五)——中文手写体数字图像CNN分类
一、任务背景
本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字。

二、python建模
1、数据读取
首先,读取jpg数据文件,可以看到总共有15000张图像数据。
import pandas as pd
import ospath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))

我们也可以打印一张图片出来看看。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 定义图片路径
image_path = path+files[3]# 加载图片
image = mpimg.imread(image_path)# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off') # 关闭坐标轴
plt.show()

2、数据集构建
加载必要的库以便后续使用,再定义一些超参数。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)
下面定义数据集类并完成数据的加载。
# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L') # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')
3、模型构建
我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。
# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x
4、模型实例化及训练
下面我们对模型进行实例化并定义criterion和optimizer。
# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

定义训练的代码并调用代码训练模型。
from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)

5、测试模型
定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。
# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

三、完整代码
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_scorepath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L') # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)
四、总结
本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。
相关文章:
python实战(十五)——中文手写体数字图像CNN分类
一、任务背景 本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿…...
[论文阅读] (37)CCS21 DeepAID:基于深度学习的异常检测(解释)
祝大家新春快乐,蛇年吉祥! 《娜璋带你读论文》系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢。由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正࿰…...
Linux - 进程间通信(2)
目录 2、进程池 1)理解进程池 2)进程池的实现 整体框架: a. 加载任务 b. 先描述,再组织 I. 先描述 II. 再组织 c. 创建信道和子进程 d. 通过channel控制子进程 e. 回收管道和子进程 问题1: 解答1ÿ…...
Kafka 消费端反复 Rebalance: `Attempt to heartbeat failed since group is rebalancing`
文章目录 Kafka 消费端反复 Rebalance: Attempt to heartbeat failed since group is rebalancing1. Rebalance 过程概述2. 错误原因分析2.1 消费者组频繁加入或退出2.1.1 消费者故障导致频繁重启2.1.2. 消费者加入和退出导致的 Rebalance2.1.3 消费者心跳超时导致的 Rebalance…...
SpringBoot+Electron教务管理系统 附带详细运行指导视频
文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.查询课程表代码2.保存学生信息代码3.用户登录代码 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SpringBootElectron框架开发的教务管理系统。首先ÿ…...
操作系统(Linux Kernel 0.11Linux Kernel 0.12)解读整理——内核初始化(main init)之控制台工作
前言 在 Linux 内核中,字符设备主要包括控制终端设备和串行终端设备,对这些设备的输入输出涉及控制台驱动程序,这包括键盘中断驱动程序 keyboard.S 和控制台显示驱动程序 console.c,还有终端驱动程序与上层程序之间的接口部分。 终端驱动程序…...
Autogen_core: Message and Communication
目录 完整代码代码解释1. 消息的数据类:2. 创建代理人(MyAgent):3. 创建和运行代理人的运行时环境:4. 根据发送者路由消息的代理(RoutedBySenderAgent):5. 创建和运行带路由的代理&a…...
ComfyUI工作流教程、软件使用、开发指导、模型下载
在人工智能和设计技术迅速发展的今天,AI赋能的工作流已成为创意设计与生产的重要工具。无论是图片处理、服装试穿,还是室内设计与3D建模,这些智能化的解决方案极大地提高了效率和创作质量。 为了帮助设计师、开发者以及AI技术爱好者更好地利用这些工具,我们整理了一份详尽…...
零基础Vue学习1——Vue学习前环境准备
目录 环境准备 创建Vue项目 项目目录说明 后续开发过程中常用命令 环境准备 安装开发工具:vscode、webstorm、idea都可以安装node:V22以上版本即可安装pnpm 不知道怎么安装的可以私信我教你方法 创建Vue项目 本地新建一个文件夹,之后在文件夹下打开…...
定西市建筑房屋轮廓数据shp格式gis无偏移坐标(字段有高度和楼层)内容测评
定西市建筑房屋轮廓数据是GIS(Geographic Information System,地理信息系统)领域的重要资源,用于城市规划、土地管理、环境保护等多个方面。这份2022年的数据集采用shp(Shapefile)格式,这是一种…...
汉语向编程指南
汉语向编程指南 一、引言王阳明代数与流形学习理论慢道缓行理性人类型指标系统为己之学与意气实体过程晏殊几何学半可分离相如矩阵与生成气质邻域镶嵌气度曲面细分生成气质邻域镶嵌气度曲面细分社会科学概论琴生生物机械科技工业研究所软凝聚态物理开发工具包琴生生物机械 报告…...
Writing an Efficient Vulkan Renderer
本文出自GPU Zen 2。 Vulkan 是一个新的显式跨平台图形 API。它引入了许多新概念,即使是经验丰富的图形程序员也可能不熟悉。Vulkan 的主要目标是性能——然而,获得良好的性能需要深入了解这些概念及其高效应用方法,以及特定驱动程序实现的实…...
AI常见的算法
人工智能(AI)中常见的算法分为多个领域,如机器学习、深度学习、强化学习、自然语言处理和计算机视觉等。以下是一些常见的算法及其用途: 1. 机器学习 (Machine Learning) 监督学习 (Supervised Learning) 线性回归 (Linear Regr…...
LibreChat
文章目录 一、关于 LibreChat✨特点 二、使用LibreChat🪶多合一AI对话 一、关于 LibreChat LibreChat 是增强的ChatGPT克隆:Features Agents, Anthropic, AWS, OpenAI, Assistants API, Azure, Groq, o1, GPT-4o, Mistral, OpenRouter, Vertex AI, Gemi…...
Spring Boot 日志:项目的“行车记录仪”
一、什么是Spring Boot日志 (一)日志引入 在正式介绍日志之前,我们先来看看上篇文章中(Spring Boot 配置文件)中的验证码功能的一个代码片段: 这是一段校验用户输入的验证码是否正确的后端代码,…...
Spring Boot 实现文件上传和下载
文章目录 Spring Boot 实现文件上传和下载一、引言二、文件上传1、配置Spring Boot项目2、创建文件上传控制器3、配置文件上传大小限制 三、文件下载1、创建文件下载控制器 四、使用示例1、文件上传2、文件下载 五、总结 Spring Boot 实现文件上传和下载 一、引言 在现代Web应…...
慕课:若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战,启动文档
代码: Javahhhh/miaosha191: 运行成功了慕课若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战https://github.com/Javahhhh/miaosha191 https://github.com/Javahhhh/miaosha191 miaosha项目启动文档 需安装的配置环境: VMwar…...
React第二十七章(Suspense)
Suspense Suspense 是一种异步渲染机制,其核心理念是在组件加载或数据获取过程中,先展示一个占位符(loading state),从而实现更自然流畅的用户界面更新体验。 应用场景 异步组件加载:通过代码分包实现组件…...
虚幻基础08:组件接口
能帮到你的话,就给个赞吧 😘 文章目录 作用 作用 组件接口:可以直接调用对方的组件接口,而无需转换为actor。 实现对象间的通知。 A 通知 B 做什么。...
iPhone SE(第三代) 设备详情图
目录 产品宣传图内部图——后设备详细信息 产品宣传图 内部图——后 设备详细信息 信息收集于HubWeb.cn...
AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
el-switch文字内置
el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
毫米波雷达基础理论(3D+4D)
3D、4D毫米波雷达基础知识及厂商选型 PreView : https://mp.weixin.qq.com/s/bQkju4r6med7I3TBGJI_bQ 1. FMCW毫米波雷达基础知识 主要参考博文: 一文入门汽车毫米波雷达基本原理 :https://mp.weixin.qq.com/s/_EN7A5lKcz2Eh8dLnjE19w 毫米波雷达基础…...
解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用
在工业制造领域,无损检测(NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统,以非接触式光学麦克风技术为核心,打破传统检测瓶颈,为半导体、航空航天、汽车制造等行业提供了高灵敏…...
OD 算法题 B卷【正整数到Excel编号之间的转换】
文章目录 正整数到Excel编号之间的转换 正整数到Excel编号之间的转换 excel的列编号是这样的:a b c … z aa ab ac… az ba bb bc…yz za zb zc …zz aaa aab aac…; 分别代表以下的编号1 2 3 … 26 27 28 29… 52 53 54 55… 676 677 678 679 … 702 703 704 705;…...
nnUNet V2修改网络——暴力替换网络为UNet++
更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...
Linux部署私有文件管理系统MinIO
最近需要用到一个文件管理服务,但是又不想花钱,所以就想着自己搭建一个,刚好我们用的一个开源框架已经集成了MinIO,所以就选了这个 我这边对文件服务性能要求不是太高,单机版就可以 安装非常简单,几个命令就…...
Leetcode33( 搜索旋转排序数组)
题目表述 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...
