当前位置: 首页 > article >正文

作业2 CNN实现手写数字识别

# 导入必要库
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns  # 用于高级可视化
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import time  # 用于计时# ======================
# 1. 数据加载与预处理
# ======================# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 数据预处理
# 归一化并添加通道维度(CNN需要通道信息)
x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255
x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255# 将标签转换为one-hot编码
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)# ======================
# 2. 构建CNN模型
# ======================
model = keras.Sequential([# 第一卷积层:32个3x3滤波器,ReLU激活,输入28x28x1layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),  # 下采样# 第二卷积层:64个3x3滤波器,ReLU激活layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),# 全连接层前处理layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dropout(0.5),  # 防止过拟合# 输出层:10个类别,softmax激活layers.Dense(10, activation='softmax')
])# ======================
# 3. 模型编译与训练
# ======================
model.compile(optimizer='adam',loss='categorical_crossentropy', # 分类交叉熵metrics=['accuracy']   # 准确率
)# 训练配置
epochs = 15
batch_size = 128
validation_split = 0.1  # 使用10%训练数据作为验证集# 训练模型并记录历史数据
start_time = time.time()
history = model.fit(x_train, y_train,epochs=epochs,batch_size=batch_size,validation_split=validation_split,verbose=1  # 显示训练进度
)
training_time = time.time() - start_time# ======================
# 4. 模型评估与可视化
# ======================# 打印训练信息
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Test accuracy: {model.evaluate(x_test, y_test, verbose=0)[1]:.4f}")# ======================
# 可视化1:训练过程曲线
# ======================
plt.figure(figsize=(12, 4))# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()plt.tight_layout()
plt.show()# ======================
# 可视化2:混淆矩阵(修正版)
# ======================
# 获取预测结果
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_test, axis=1)# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred_classes)# 手动设置类别标签(MNIST 是 0-9)
class_names = [str(i) for i in range(10)]# 可视化混淆矩阵
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True,      # 在单元格中显示数值fmt='d',         # 数值格式为整数(适用于混淆矩阵的计数)cmap='Blues',    # 颜色映射(蓝色渐变)xticklabels=class_names,  # X轴标签(类别名称)yticklabels=class_names)  # Y轴标签(类别名称)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()# ======================
# 可视化3:错误预测样本
# ======================
# 找出预测错误的样本
errors = (y_pred_classes != y_true)
error_samples = x_test[errors]
true_labels = y_true[errors]
pred_labels = y_pred_classes[errors]# 显示前15个错误样本
plt.figure(figsize=(15, 6))
for i in range(min(15, len(error_samples))):plt.subplot(3, 5, i + 1)plt.imshow(error_samples[i].reshape(28, 28), cmap='gray')plt.title(f"True: {true_labels[i]}, Pred: {pred_labels[i]}")plt.axis('off')
plt.tight_layout()
plt.show()# ======================
# 可视化4:特征图可视化
# ======================
# 获取第一个卷积层的输出
layer_outputs = [layer.output for layer in model.layers[:2]]
activation_model = keras.models.Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(x_test[0:1])# 显示第一卷积层的特征图
plt.figure(figsize=(12, 6))
first_layer_activation = activations[0]
for i in range(32):  # 显示前32个滤波器plt.subplot(4, 8, i + 1)plt.imshow(first_layer_activation[0, :, :, i], cmap='viridis')plt.axis('off')
plt.suptitle('First Convolutional Layer Activations', fontsize=16)
plt.show()

运行结果 

Epoch 15/15
422/422 [==============================] - 16s 38ms/step - loss: 0.0184 - accuracy: 0.9941 - val_loss: 0.0295 - val_accuracy: 0.9938Training completed in 343.88 seconds
Test accuracy: 0.9931
313/313 [==============================] - 1s 2ms/step

 ======================
# 3. 新增功能:随机展示20张测试集样本(调整到模型训练之后)
# ======================
def show_random_samples(model, x_test, y_test, num_samples=20):"""显示随机测试样本及其预测结果"""# 确保模型已训练if not hasattr(model, 'layers'):raise ValueError("Model must be trained first")# 生成预测结果y_pred = model.predict(x_test)y_pred_classes = np.argmax(y_pred, axis=1)# 获取真实标签y_true = np.argmax(y_test, axis=1)# 随机选择样本sample_indices = random.sample(range(len(x_test)), num_samples)# 创建可视化plt.figure(figsize=(16, 18))plt.suptitle("Random Handwritten Digit Samples with Predictions\n(Green=Correct, Red=Wrong)",fontsize=16, y=1.03)rows, cols = 4, 5plt.subplots_adjust(hspace=0.5, wspace=0.3)# 使用新版Matplotlib APIcmap = plt.colormaps.get_cmap('RdYlGn')  # 修复弃用警告for i, idx in enumerate(sample_indices):ax = plt.subplot(rows, cols, i + 1)img = x_test[idx].squeeze()# 显示图像plt.imshow(img, cmap='gray')# 获取标签信息true_label = y_true[idx]pred_label = y_pred_classes[idx]# 设置标题和颜色color = 'green' if true_label == pred_label else 'red'title = f'True: {true_label}\nPred: {pred_label}'plt.title(title, color=color, fontsize=10, pad=8)plt.axis('off')plt.tight_layout()plt.show()

 

混淆矩阵基础结构

1. 矩阵布局(以二分类为例)

2. 关键指标计算

 

TensorFlow Keras 核心组件

1. 常用层类型

2. 构建模型的三种方式

方式1:顺序模型(Sequential API)

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Densemodel = Sequential([Dense(128, activation='relu', input_shape=(784,)),  # 输入层Dense(64, activation='relu'),                       # 隐藏层Dense(10, activation='softmax')                     # 输出层
])

方式2:函数式API(Functional API)

from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Denseinput_layer = Input(shape=(784,))
hidden = Dense(128, activation='relu')(input_layer)
output = Dense(10, activation='softmax')(hidden)
model = Model(inputs=input_layer, outputs=output)

方式3:子类化模型(Subclassing)

from tensorflow.keras import Model
from tensorflow.keras.layers import Denseclass MyModel(Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = Dense(128, activation='relu')self.dense2 = Dense(10, activation='softmax')def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()

3. 模型编译与训练

# 编译模型
model.compile(optimizer='adam',                 # 优化器(自动调参)loss='sparse_categorical_crossentropy',  # 损失函数(分类任务)metrics=['accuracy']              # 评估指标
)# 训练模型
history = model.fit(x_train, y_train,batch_size=32,                    # 每批样本数epochs=10,                        # 训练轮次validation_split=0.2              # 验证集比例
)

相关文章:

作业2 CNN实现手写数字识别

# 导入必要库 import numpy as np import matplotlib.pyplot as plt import seaborn as sns # 用于高级可视化 from tensorflow import keras from tensorflow.keras import layers from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay import time # 用于…...

整流二极管详解:原理、作用、应用与选型要点

一、整流二极管的基本定义 整流二极管是一种利用PN结单向导电性将交流电(AC)转换为直流电(DC)的半导体器件。其核心特性是正向导通、反向截止,允许电流仅沿单一方向流动。 典型结构:硅材料(正向…...

WordPress自定义页面与文章:打造独特网站风格的进阶指南

文章目录 引言一、理解WordPress页面与文章的区别二、主题与模板层级:自定义的基础三、自定义页面模板:打造专属页面风格四、自定义文章模板:打造个性化文章呈现五、使用自定义字段和元数据:增强内容灵活性六、利用WordPress钩子&…...

PHP最新好看UI个人引导页网页源码

PHP最新好看UI个人引导页网页源码 采用PHP、HTML、CSS及JavaScript等前端技术,构建了一个既美观又实用的个人主页解决方案。 源码设计初衷在于提供一个高度可定制、跨平台兼容的模板,让用户无需深厚的编程基础,即可快速搭建出专业且富有创意的…...

jQuery — 动画和事件

介绍 jQuery动画与事件是提升网页交互的核心工具。动画方面,jQuery通过简洁API实现平滑过渡效果,提供预设方法如slideUp(),支持.animate()自定义CSS属性动画,并内置队列系统实现动画链式执行。开发者可精准控制动画速度、回调时机…...

arkTs:使用回调函数的方法实现子组件向父组件传值

使用回调函数的方法实现子组件向父组件传值 1 主要内容说明2 实现步骤2.1 父组件中定义回调函数2.2 子组件声明并调用回调函数2.3 注意事项 3 源码3.1 父组件3.2 子组件3.3 源码效果显示截图 4 结语5 定位日期 1 主要内容说明 本文源码是一套 父组件与子组件之间双向数据传递的…...

VBA 调用 dll 优化执行效率

问题描述 之前excel 用vba写过一个应用,请求的是aws lambda 后端, 但是受限于是云端服务,用起来响应特别慢,最近抽了点时间准备优化下,先加了点日志看看是哪里慢了 主方法代码如下,函数的主要目的是将 Excel 工作簿的…...

【机器学习-周总结】-第4周

以下是本周学习内容的整理总结,从技术学习、实战应用到科研辅助技能三个方面归纳: 文章目录 📘 一、技术学习模块:TCN 基础知识与结构理解🔹 博客1:【时序预测05】– TCN(Temporal Convolutiona…...

Django-Friendship 项目常见问题解决方案

Django-Friendship 项目常见问题解决方案 django-friendship Django app to manage following and bi-directional friendships 项目地址: https://gitcode.com/gh_mirrors/dj/django-friendship Django-Friendship 是一个基于 Django 的应用,它允许创建和管…...

C语言用if else求三个数最小值的一题多解

一、问题引入 假设x,y,z为整数,使用if else语句求x,y,z三个数中的最小值? 二、三种解法 第一种解法: #include<stdio.h> int main(){int x,y,z,min;printf("请输入三个整数&#xff1a;");scanf_s("%d %d %d", &x, &y, &z);//初始值…...

AI时代下 你需要和想要了解的英文缩写含义

在AI智能时代下&#xff0c;越来愈多的企业都开始重视并应用以及开发AI相关产品&#xff0c;这个时候都会或多或少的涉及到英文&#xff0c;英文还好&#xff0c;但是如果是缩写&#xff0c;如果我们没有提前了解过&#xff0c;我们往往很难以快速Get到对方的意思。在这里&…...

uniApp小程序保存定制二维码到本地(V3)

这里的二维码组件用的 uv-ui 的二维码 可以按需引入 QRCode 二维码 | 我的资料管理-uv-ui 是全面兼容vue32、nvue、app、h5、小程序等多端的uni-app生态框架 <uv-qrcode ref"qrcode" :size"280" :value"payCodeUrl"></uv-qrcode>&l…...

2025年对讲机选购指南:聚焦核心参数与场景适配

在无线通信领域&#xff0c;对讲机始终占据着专业通讯工具的独特地位。随着5G时代到来和物联网技术深化&#xff0c;2025年的对讲机市场正呈现智能化、专业化、场景化的升级趋势。面对琳琅满目的产品&#xff0c;选购者需从通信性能、环境适应性、智能集成度三个维度进行综合考…...

C/C++ 动态链接详细解读

1. 为什么要动态链接&#xff1f; 1.1 静态链接浪费内存和磁盘空间 静态链接的方式对于计算机内存和磁盘空间浪费非常严重&#xff0c;特别是多进程操作系统的情况下&#xff0c;静态链接极大的浪费了内存空间。在现在的Linux系统中&#xff0c;一个普通的程序会使用的C 语言静…...

力扣-hot100(无重复字符的最长子串)

3. 无重复字符的最长子串 中等 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长 子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc"&#xff0c;所以其长度为 3。暴力直观解法一&#xff1…...

python flask 项目部署

文章目录 概述 windows 部署准备工作使用 Waitress 部署 Flask 应用 linux 部署**2. 使用 WSGI 服务器**示例&#xff1a;使用 Gunicorn nginx反向代理**5. 使用进程管理工具**示例&#xff1a;使用 Systemd 概述 在 Windows 上使用 Waitress 部署 Flask 应用是一个不错的选择…...

Java课程内容大纲(附重点与考试方向)

本文是在传统 Java 教程框架基础上&#xff0c;加入了重点提示与考试思路&#xff0c;适合用于课程备考、知识查漏与面试准备。 第1章&#xff1a;Java语言基础 ⭐ 重点知识&#xff1a; Java平台特点&#xff08;跨平台性、JVM&#xff09; JDK、JRE、JVM 区别 Java 程序的…...

实现AWS Lambda函数安全地请求企业内部API返回数据

需要编写一个Lambda函数在AWS云上运行&#xff0c;它需要访问企业内部的API获取JSON格式的数据&#xff0c;企业有网关和防火墙&#xff0c;API有公司的okta身份认证&#xff0c;通过公司的域账号来授权访问&#xff0c;现在需要创建一个专用的域账号&#xff0c;让Lambda函数访…...

面试题--随机(一)

MySQL事务中的ACID特性&#xff1f; A 原子性 事务是一组SQL语句&#xff0c;不可分割 C 一致性 事务中的SQL语句要么同时执行&#xff0c;即全部执行成功&#xff0c;要么全部不执行&#xff0c;即执行失败 I 隔离性 MySQL中的各个事务通过不同的事务隔离等级&#xff0c;产生…...

200+短剧出海平台:谁能成为“海外红果”?

2025年&#xff0c;短剧的国际市场表现令人瞩目。仅在两年前&#xff0c;业界关注的焦点仍是美国市场&#xff0c;如今国产短剧应用已成功打入包括印尼、巴西、美国、墨西哥、印度、菲律宾、泰国、日本、哥伦比亚及韩国在内的多个国家&#xff0c;轻松获得超过500万次下载。 市…...

Visio导出清晰图片步骤

在Visio里画完图之后如何导出清晰的图片&#xff1f;&#x1f447; ①左上角单击【文件】 ②导出—更改文件类型—PNG/JPG ③分辨率选择【打印机】&#xff0c;大小选择【源】&#xff0c;即可。 ④选择保存位置并命名 也可以根据自己需要选择是否需要【透明底】哈。 选PNG 然…...

Linux系统:详解进程等待wait与waitpid解决僵尸进程

本节重点 理解进程等待的相关概念掌握系统调用wait与waitpid的使用方法输出型status参数的存储结构阻塞等待与非阻塞等待 一、概念 进程等待是操作系统中父进程与子进程协作的核心机制&#xff0c;指父进程通过特定方式等待子进程终止并回收其资源的过程。这一机制的主要目的…...

6.7 ChatGPT自动生成定时任务脚本:Python与Cron双方案实战指南

ChatGPT自动生成定时任务脚本:Python与Cron双方案实战指南 关键词:定时任务调度, ChatGPT 代码生成, Cron 脚本开发, Python 调度器, 自动化更新系统 6.3 使用 ChatGPT 生成 Cron 调度脚本 在 GitHub Sentinel 的定期更新功能中,定时任务调度是核心模块。本节演示如何通过…...

K8S运维实战之集群证书升级与容器运行时更换全记录

第一部分&#xff1a;Kubernetes集群证书升级实战 tips:此博文只演示一个节点作为示范&#xff0c;所有的集群节点步骤都可以参考。 项目背景 某金融业务系统Kubernetes集群即将面临生产证书集中过期风险&#xff08;核心组件证书剩余有效期不足90天&#xff09;&#xff0c…...

IntelliJ IDEA clean git password

IntelliJ IDEA clean git password 清除git密码 方法一&#xff1a;&#xff08;这个要特别注意啊&#xff0c;恢复默认设置&#xff0c;你的插件什么要重新下载了&#xff09; File->Manage IDE Settings->Restore Default Settings以恢复IDEA的默认设置(可选); 清空…...

【已更新完毕】2025泰迪杯数据挖掘竞赛C题数学建模思路代码文章教学:竞赛智能客服机器人构建

完整内容请看文末最后的推广群 基于大模型的竞赛智能客服机器人构建 摘要 随着国内学科和技能竞赛的增多&#xff0c;参赛者对竞赛相关信息的需求不断上升&#xff0c;但传统人工客服存在效率低、成本高、服务不稳定和用户体验差的问题。因此&#xff0c;设计一款智能客服机器…...

2025年4月19日 记录大模型出现的计算问题

2025年4月19日 记录大模型出现的计算问题&#xff0c;用了四个大模型计算json的数值&#xff0c;3个错误&#xff0c;1个正确 问题 Class Train Val answer 2574 853 screen 5025 1959 blackBoard 7847 3445 teacher 8490 3228 stand…...

ACI EP Learning Whitepaper 3. Disabling IP Data-plane Learning 功能

目录 1. 使用场景 1.1 未disable IP data-plane learning时 1.2 disable IP data-plane learning后 2. 一代Leaf注意事项 3. L2 未知单播注意事项 1. 使用场景 Windows网卡的动态负载均衡绑定模式等。或多个设备共享相同VIP并通过ARP/GARP/ND来宣告VIP切换时,这些外部设…...

C++入门七式——模板初阶

目录 函数模板 函数模板概念 函数模板格式 函数模板的原理 函数模板的实例化 模板参数的匹配原则 类模板 类模板的定义格式 类模板的显式实例化 当面对下面的代码时&#xff0c;大家会不会有一种无力的感觉&#xff1f;明明这些代码差不多&#xff0c;只是因为类型不…...

计算机网络 - 在浏览器中输入 URL 地址到显示主页的过程?

第一步&#xff0c;浏览器通过 DNS 来解析 URL&#xff0c;得到相应的 ip 地址&#xff08;到哪里找) 和 方法&#xff08;做什么&#xff09; 第二步&#xff0c;浏览器于服务器建立 TCP 三次握手连接 第三步&#xff0c;建立好连接后&#xff0c;浏览器会组装 HTTP 请求报文…...