使用 CNN 训练自己的数据集
CNN(练习数据集)
- 1.导包:
- 2.导入数据集:
- 3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:
- 4. 查看数据集中的一部分图像,以及它们对应的标签:
- 5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:
- 6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:
- 7.将数据集缓存到内存中,加快速度:
- 8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:
- 9.打印网络结构:
- 10.设置优化器,定义了训练轮次和批量大小:
- 11.训练数据集:
- 12.画出图像:
- 13.评估您的模型在验证数据集的性能:
- 14.输出在验证集上的预测结果和真实值的对比:
- 15.输出可视化报表:
- 在网上寻找一个新的数据集,自己进行训练
1.导包:
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
输出结果:

2.导入数据集:
# 定义超参数
data_dir = "D:\JUANJI"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
batch_size = 30
img_height = 180
img_width = 180
输出结果:

3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:
# 使用image_dataset_from_directory()将数据加载到tf.data.Dataset中
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2, # 验证集0.2subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
输出结果:

4. 查看数据集中的一部分图像,以及它们对应的标签:
class_names = train_ds.class_names
print(class_names)
# 可视化
plt.figure(figsize=(16, 8))
for images, labels in train_ds.take(1):for i in range(16):ax = plt.subplot(4, 4, i + 1)# plt.imshow(images[i], cmap=plt.cm.binary)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.show()
输出结果:


5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
输出结果:

6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")
x = aug.flow(image_batch, labels_batch)
AUTOTUNE = tf.data.AUTOTUNE
输出结果:

7.将数据集缓存到内存中,加快速度:
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
输出结果:

8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:
# 为了增加模型的泛化能力,增加了Dropout层,并将最大池化层更新为平均池化层
num_classes = 3
model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width, 3)),layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(256, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(512, activation='relu'),layers.Dense(num_classes)
])
输出结果:

9.打印网络结构:
model.summary()
输出结果:

10.设置优化器,定义了训练轮次和批量大小:
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])EPOCHS = 100
BS = 5
输出结果:

11.训练数据集:
# 训练网络
# model.fit 可同时处理训练和即时扩充的增强数据。
# 我们必须将训练数据作为第一个参数传递给生成器。生成器将根据我们先前进行的设置生成批量的增强训练数据。
for images_train, labels_train in train_ds:continue
for images_test, labels_test in val_ds:continue
history = model.fit(x=aug.flow(images_train,labels_train, batch_size=BS),validation_data=(images_test,labels_test),
steps_per_epoch=1,epochs=EPOCHS)
输出结果:

12.画出图像:
# 画出训练精确度和损失图
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, history.history["loss"], label="train_loss")
plt.plot(N, history.history["val_loss"], label="val_loss")
plt.plot(N, history.history["accuracy"], label="train_acc")
plt.plot(N, history.history["val_accuracy"], label="val_acc")
plt.title("Aug Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc='upper right') # legend显示位置
plt.show()
输出结果:

13.评估您的模型在验证数据集的性能:
test_loss, test_acc = model.evaluate(val_ds, verbose=2)
print(test_loss, test_acc)
输出结果:

14.输出在验证集上的预测结果和真实值的对比:
# 优化2 输出在验证集上的预测结果和真实值的对比
pre = model.predict(val_ds)
for images, labels in val_ds.take(1):for i in range(4):ax = plt.subplot(1, 4, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.xticks([])plt.yticks([])# plt.xlabel('pre: ' + class_names[np.argmax(pre[i])] + ' real: ' + class_names[labels[i]])plt.xlabel('pre: ' + class_names[np.argmax(pre[i])])print('pre: ' + str(class_names[np.argmax(pre[i])]) + ' real: ' + class_names[labels[i]])
plt.show()
输出结果:

15.输出可视化报表:
print(labels_test)
print(labels)
print(pre)
print(class_names)
from sklearn.metrics import classification_report
# 优化1 输出可视化报表
print(classification_report(labels_test,pre.argmax(axis=1),
target_names=class_names))
输出结果:

相关文章:
使用 CNN 训练自己的数据集
CNN(练习数据集) 1.导包:2.导入数据集:3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:4. 查看数据集中的一部分图像,以及它们对应的标签:5.迭代数据集 train_ds࿰…...
自动控制: 最小二乘估计(LSE)、加权最小二乘估计(WLS)和线性最小方差估计
自动控制: 最小二乘估计(LSE)、加权最小二乘估计(WLS)和线性最小方差估计 在数据分析和机器学习中,参数估计是一个关键步骤。最小二乘估计(LSE)、加权最小二乘估计(WLS&…...
基于VMware安装Linux虚拟机
1.准备Linux环境 首先,我们要准备一个Linux的系统,成本最低的方式就是在本地安装一台虚拟机。为了统一学习环境,不管是使用MacOS还是Windows系统的同学,都建议安装一台虚拟机。 windows采用VMware,Mac则采用Fusion …...
6、phpjm混淆解密和php反序列化
题目:青少年雏形系统 1、打开链接也是一个登入面板 2、尝试了sqlmap没头绪 3、尝试御剑,发现一个www.zip 4、下载打开,有一个php文件打开有一段phpjm混淆加密 5、使用手工解混淆 具体解法链接:奇安信攻防社区-phpjm混淆解密浅谈…...
Codeforces Round 909 (Div. 3) E. Queue Sort(模拟 + 贪心之找到了一个边界点)
弗拉德找到了一个由 n 个整数组成的数组 a ,并决定按不递减的顺序排序。 为此,弗拉德可以多次执行下面的操作: 提取数组的第一个元素并将其插入末尾; 将个元素与前一个元素对调,直到它变成第一个元素或严格大于前一个…...
设计模式基础——设计原则介绍
1.概述 对于面向对象软件系统的设计而言,如何同时提高一个软件系统的可维护性、可复用性、可拓展性是面向对象设计需要解决的核心问题之一。面向对象设计原则应运而生,这些原则你会在设计模式中找到它们的影子,也是设计模式的基础。往往判…...
【校园网网络维修】当前用户使用的IP与设备重定向地址中IP不一致,请重新认证
出现的网络问题:当前用户使用的IP与设备重定向地址中IP不一致,请重新认证 可能的原因: 把之前登录的网页收藏到浏览器,然后直接通过这个链接进行登录认证。可能是收藏网址导致的ip地址请求参数不一致。 解决方法: 方法…...
如何找到docker的run(启动命令)
使用python三方库进行 需要安装python解释器 安装runlike安装包 pip3 install runlike 运行命令 runlike -p <container_name> # 后面可以是容器名和容器id,-p参数是显示自动换行实验 使用docker启动一个jenkins 启动命令为 docker run -d \ -p 9002:80…...
Spring如何管理Bean的生命周期呢?
我们都知道,在面试的过程中,关于 Spring 的面试题,那是各种各样,很多时候就会问到关于 Spring的相关问题,比如 AOP ,IOC 等等,还有就是关于 Spring 是如何管理 Bean 的生命周期的相关问题&#…...
Java网络编程:UDP通信篇
目录 UDP协议 Java中的UDP通信 DatagramSocket DatagramPacket UDP客户端-服务端代码实现 UDP协议 对于UDP协议,这里简单做一下介绍: 在TCP/IP协议簇中,用户数据报协议(UDP)是传输层的一个主要协议之一…...
HTML+CSS+JS简易计算器
HTMLCSSJS简易计算器 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>简易计算器</t…...
STM32使用ST-LINK下载程序中需要注意的几点
使用keil5的ST-link下载界面 前提是ST-LINK已经连接好,(下图中是没有连接ST-link设备),只是为了展示如何查看STlink设备是否连接的方式 下载前一定设置下载完成后自启动 这个虽然不是必须,但对立即看到新程序的现象…...
我和jetson-Nano的故事(12)——安装pytorch 以及 torchvision
在jetson nano中安装Anaconda、pytorch 以及 torchvision 1.Pytorch下载安装2.Torchvision安装 1.Pytorch下载安装 首先登录英伟达官网下载Pytorch安装包,这里以PyTorch v1.10.0为例 安装依赖库 sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev liba…...
「异步魔法:Python数据库交互的革命」(一)
Hi,我是阿佑,今天将和大家一块打开异步魔法的大门,进入Python异步编程的神秘领域,学习如何同时施展多个咒语而不需等待。了解asyncio的魔力,掌握Async SQLAlchemy和Tortoise-ORM的秘密,让你的数据库操作快如…...
探秘GPT-4o:从版本对比到技术能力的全面评价
随着人工智能技术的不断发展,自然语言处理领域的突破性技术——GPT(Generative Pre-trained Transformer)系列模型也在不断演进。最新一代的GPT-4o横空出世,引起了广泛的关注和讨论。在本文中,我们将对GPT-4o进行全面评…...
四川汇烁面试总结
自我介绍项目介绍、 目录 1.jdk和jre的区别? 2.一段代码的执行流程? 3.接口与抽象类的区别? 4.ArrayList与LinkList的区别? 5.对HashMap的理解? 6.常见的异常? 7.throw 和 throws 有什么区别? 8.…...
【小程序 按钮 表单 】
按钮 代码演示 xxx.wxml <view class"boss" hover-class"box"hover-start-time"2000"hover-stay-time"5000">测试文本<view hover-stop-propagation"true">子集</view><view>子集2</view>…...
高铁Wifi是如何接入的?
使用PC端的朋友,请将页面缩小到最小比例,阅读最佳! 在飞驰的高铁上,除了窗外一闪而过的风景,你是否好奇过,高铁Wifi信号如何连接的呢? 远动的火车可不能连接光纤吧,难道是连接的卫星…...
gitlab之docker-compose汉化离线安装
目录 概述离线资源docker-compose结束 概述 gitlab可以去 hub 上拉取最新版本,在此我选择汉化 gitlab ,版本 11.x 离线资源 想自制离线安装镜像,请稳步参考 docker镜像的导入导出 ,无兴趣的直接使用在此提供离线资源 百度网盘(链…...
【算法】dd爱转转
✨题目链接: dd爱旋转 ✨题目描述 读入一个n∗n的矩阵,对于一个矩阵有以下两种操作 1:顺时针旋180 2:关于行镜像 如 变成 给出q个操作,输出操作完的矩阵 ✨输入描述: 第一行一个数n(1≤n≤1000),表示矩阵大小 接下来n行ÿ…...
工业安全零事故的智能守护者:一体化AI智能安防平台
前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...
使用 SymPy 进行向量和矩阵的高级操作
在科学计算和工程领域,向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能,能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作,并通过具体…...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...
iview框架主题色的应用
1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题,无需引入,直接可…...
es6+和css3新增的特性有哪些
一:ECMAScript 新特性(ES6) ES6 (2015) - 革命性更新 1,记住的方法,从一个方法里面用到了哪些技术 1,let /const块级作用域声明2,**默认参数**:函数参数可以设置默认值。3&#x…...
解析两阶段提交与三阶段提交的核心差异及MySQL实现方案
引言 在分布式系统的事务处理中,如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议(2PC)通过准备阶段与提交阶段的协调机制,以同步决策模式确保事务原子性。其改进版本三阶段提交协议(3PC…...
DeepSeek越强,Kimi越慌?
被DeepSeek吊打的Kimi,还有多少人在用? 去年,月之暗面创始人杨植麟别提有多风光了。90后清华学霸,国产大模型六小虎之一,手握十几亿美金的融资。旗下的AI助手Kimi烧钱如流水,单月光是投流就花费2个亿。 疯…...
