当前位置: 首页 > news >正文

python实战(十五)——中文手写体数字图像CNN分类

一、任务背景

        本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字

二、python建模

1、数据读取

        首先,读取jpg数据文件,可以看到总共有15000张图像数据。

import pandas as pd
import ospath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))

        我们也可以打印一张图片出来看看。

import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 定义图片路径
image_path = path+files[3]# 加载图片
image = mpimg.imread(image_path)# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off')  # 关闭坐标轴
plt.show()

2、数据集构建

        加载必要的库以便后续使用,再定义一些超参数。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])

        这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)

        下面定义数据集类并完成数据的加载。

# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')

3、模型构建

        我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x

4、模型实例化及训练

        下面我们对模型进行实例化并定义criterion和optimizer。

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

        定义训练的代码并调用代码训练模型。

from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)

5、测试模型

        定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。

# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

三、完整代码

import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_scorepath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

四、总结

        本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。

相关文章:

python实战(十五)——中文手写体数字图像CNN分类

一、任务背景 本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿…...

[论文阅读] (37)CCS21 DeepAID:基于深度学习的异常检测(解释)

祝大家新春快乐,蛇年吉祥! 《娜璋带你读论文》系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢。由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正&#xff0…...

Linux - 进程间通信(2)

目录 2、进程池 1)理解进程池 2)进程池的实现 整体框架: a. 加载任务 b. 先描述,再组织 I. 先描述 II. 再组织 c. 创建信道和子进程 d. 通过channel控制子进程 e. 回收管道和子进程 问题1: 解答1&#xff…...

Kafka 消费端反复 Rebalance: `Attempt to heartbeat failed since group is rebalancing`

文章目录 Kafka 消费端反复 Rebalance: Attempt to heartbeat failed since group is rebalancing1. Rebalance 过程概述2. 错误原因分析2.1 消费者组频繁加入或退出2.1.1 消费者故障导致频繁重启2.1.2. 消费者加入和退出导致的 Rebalance2.1.3 消费者心跳超时导致的 Rebalance…...

SpringBoot+Electron教务管理系统 附带详细运行指导视频

文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.查询课程表代码2.保存学生信息代码3.用户登录代码 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SpringBootElectron框架开发的教务管理系统。首先&#xff…...

操作系统(Linux Kernel 0.11Linux Kernel 0.12)解读整理——内核初始化(main init)之控制台工作

前言 在 Linux 内核中,字符设备主要包括控制终端设备和串行终端设备,对这些设备的输入输出涉及控制台驱动程序,这包括键盘中断驱动程序 keyboard.S 和控制台显示驱动程序 console.c,还有终端驱动程序与上层程序之间的接口部分。 终端驱动程序…...

Autogen_core: Message and Communication

目录 完整代码代码解释1. 消息的数据类:2. 创建代理人(MyAgent):3. 创建和运行代理人的运行时环境:4. 根据发送者路由消息的代理(RoutedBySenderAgent):5. 创建和运行带路由的代理&a…...

ComfyUI工作流教程、软件使用、开发指导、模型下载

在人工智能和设计技术迅速发展的今天,AI赋能的工作流已成为创意设计与生产的重要工具。无论是图片处理、服装试穿,还是室内设计与3D建模,这些智能化的解决方案极大地提高了效率和创作质量。 为了帮助设计师、开发者以及AI技术爱好者更好地利用这些工具,我们整理了一份详尽…...

零基础Vue学习1——Vue学习前环境准备

目录 环境准备 创建Vue项目 项目目录说明 后续开发过程中常用命令 环境准备 安装开发工具:vscode、webstorm、idea都可以安装node:V22以上版本即可安装pnpm 不知道怎么安装的可以私信我教你方法 创建Vue项目 本地新建一个文件夹,之后在文件夹下打开…...

定西市建筑房屋轮廓数据shp格式gis无偏移坐标(字段有高度和楼层)内容测评

定西市建筑房屋轮廓数据是GIS(Geographic Information System,地理信息系统)领域的重要资源,用于城市规划、土地管理、环境保护等多个方面。这份2022年的数据集采用shp(Shapefile)格式,这是一种…...

汉语向编程指南

汉语向编程指南 一、引言王阳明代数与流形学习理论慢道缓行理性人类型指标系统为己之学与意气实体过程晏殊几何学半可分离相如矩阵与生成气质邻域镶嵌气度曲面细分生成气质邻域镶嵌气度曲面细分社会科学概论琴生生物机械科技工业研究所软凝聚态物理开发工具包琴生生物机械 报告…...

Writing an Efficient Vulkan Renderer

本文出自GPU Zen 2。 Vulkan 是一个新的显式跨平台图形 API。它引入了许多新概念,即使是经验丰富的图形程序员也可能不熟悉。Vulkan 的主要目标是性能——然而,获得良好的性能需要深入了解这些概念及其高效应用方法,以及特定驱动程序实现的实…...

AI常见的算法

人工智能(AI)中常见的算法分为多个领域,如机器学习、深度学习、强化学习、自然语言处理和计算机视觉等。以下是一些常见的算法及其用途: 1. 机器学习 (Machine Learning) 监督学习 (Supervised Learning) 线性回归 (Linear Regr…...

LibreChat

文章目录 一、关于 LibreChat✨特点 二、使用LibreChat🪶多合一AI对话 一、关于 LibreChat LibreChat 是增强的ChatGPT克隆:Features Agents, Anthropic, AWS, OpenAI, Assistants API, Azure, Groq, o1, GPT-4o, Mistral, OpenRouter, Vertex AI, Gemi…...

Spring Boot 日志:项目的“行车记录仪”

一、什么是Spring Boot日志 (一)日志引入 在正式介绍日志之前,我们先来看看上篇文章中(Spring Boot 配置文件)中的验证码功能的一个代码片段: 这是一段校验用户输入的验证码是否正确的后端代码&#xff0c…...

Spring Boot 实现文件上传和下载

文章目录 Spring Boot 实现文件上传和下载一、引言二、文件上传1、配置Spring Boot项目2、创建文件上传控制器3、配置文件上传大小限制 三、文件下载1、创建文件下载控制器 四、使用示例1、文件上传2、文件下载 五、总结 Spring Boot 实现文件上传和下载 一、引言 在现代Web应…...

慕课:若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战,启动文档

代码: Javahhhh/miaosha191: 运行成功了慕课若鱼1919的视频课程:Java秒杀系统方案优化 高性能高并发实战https://github.com/Javahhhh/miaosha191 https://github.com/Javahhhh/miaosha191 miaosha项目启动文档 需安装的配置环境: VMwar…...

React第二十七章(Suspense)

Suspense Suspense 是一种异步渲染机制,其核心理念是在组件加载或数据获取过程中,先展示一个占位符(loading state),从而实现更自然流畅的用户界面更新体验。 应用场景 异步组件加载:通过代码分包实现组件…...

虚幻基础08:组件接口

能帮到你的话,就给个赞吧 😘 文章目录 作用 作用 组件接口:可以直接调用对方的组件接口,而无需转换为actor。 实现对象间的通知。 A 通知 B 做什么。...

iPhone SE(第三代) 设备详情图

目录 产品宣传图内部图——后设备详细信息 产品宣传图 内部图——后 设备详细信息 信息收集于HubWeb.cn...

2025苹果CMS v10短剧模板源码

文件不到70kb,加载非常快 无配置,没有详情页,上传就可以直接使用 使用教程:上传到网站template目录并解压、进入网站后台选择模板 注意:默认调用ID为1的数据和扩展分类,建议新建站使用 源码下载&#xf…...

2007-2020年各省国内专利申请授权量数据

2007-2020年各省国内专利申请授权量数据 1、时间:2007-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、国内专利申请授权量(项) 4、范围:31省 5、指标解释:专利是专利权的简称&…...

第一天-嵌入式应用开发介绍

首先,我们来介绍一下嵌入式的发展路线,虽然嵌入式的知识点众多,但是总体上来说,嵌入式分为以下两条主要路线: 单片机开发ArmLinux开发 当然,还有其他的一些例如FPGA这种的我们就不计算在内了,F…...

约瑟夫问题(信息学奥赛一本通-2037)

【题目描述】 N个人围成一圈,从第一个人开始报数,数到M的人出圈;再由下一个人开始报数,数到M 的人出圈;…输出依次出圈的人的编号。 【输入】 输入N和M。 【输出】 输出一行,依次出圈的人的编号。 【输入样…...

WPF5-x名称空间

1. x名称空间2. x名称空间内容3. x名称空间内容分类 3.1. x:Name3.2. x:Key3.3. x:Class3.4. x:TypeArguments 4. 总结 1. x名称空间 “x名称空间”的x是映射XAML名称空间时给它取的名字(取XAML的首字母),里面的成员(如x:Class、…...

一个python项目中的文件和目录的作用是什么?scripts,venv,predict的具体含义

今天学习SadTalker的项目,但目录和文件不知道都是干什么的,总结记录下,方便后续使用。 目录 1. docs: 作用: 这个文件夹通常包含项目的文档。文档可能包括用户指南、API 文档、开发文档等。 2. examples: 作用: 这里通常包含一些示例代码…...

python学opencv|读取图像(四十八)使用cv2.bitwise_xor()函数实现图像按位异或运算

【0】基础定义 按位与运算:两个等长度二进制数上下对齐,全1取1,其余取0。 按位或运算:两个等长度二进制数上下对齐,有1取1,其余取0。 按位取反运算:一个二进制数,0变1,1变0。 按…...

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-block.py

block.py ultralytics\nn\modules\block.py 目录 block.py 1.所需的库和模块 2.class DFL(nn.Module): 3.class Proto(nn.Module): 4.class HGStem(nn.Module): 5.class HGBlock(nn.Module): 6.class SPP(nn.Module): 7.class SPPF(nn.Module): 8.class C1(nn…...

c++多态

1.多态的概念 通俗来说,就是多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同 的状态。 2.多态的定义及实现 2.1多态的构成条件 多态是在不同继承关系的类对象,去调用同一函数,产生了不同的行为…...

ResNeSt: Split-Attention Networks 参考论文

参考文献 [1] Tensorflow Efficientnet. https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet. Accessed: 2020-03-04. 中文翻译:[1] TensorFlow EfficientNet. https://github.com/tensorflow/tpu/tree/master/models/official/efficien…...