百度飞浆ResNet50大模型微调实现十二种猫图像分类
12种猫分类比赛传送门
要求很简单,给train和test集,训练模型实现图像分类。
这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。
训练十几个迭代,每个批次60左右,准确率达到90%以上
一、导入库,解压文件
import os import zipfile import random import json import cv2 import numpy as np from PIL import Imageimport matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import paddle import paddle.nn as nn from paddle.io import Dataset,DataLoader from paddle.nn import \Layer, \Conv2D, Linear, \Embedding, MaxPool2D, \BatchNorm2D, ReLUimport paddle.vision.transforms as transforms from paddle.vision.models import resnet50 from paddle.metric import Accuracytrain_parameters = {"input_size": [3, 224, 224], # 输入图片的shape"class_dim": 12, # 分类数"src_path":"data/data10954/cat_12_train.zip", # 原始数据集路径"src_test_path":"data/data10954/cat_12_test.zip", # 原始数据集路径"target_path":"/home/aistudio/data/dataset", # 要解压的路径 "train_list_path": "./train.txt", # train_data.txt路径"eval_list_path": "./eval.txt", # eval_data.txt路径"label_dict":{}, # 标签字典"readme_path": "/home/aistudio/data/readme.json",# readme.json路径"num_epochs":6, # 训练轮数"train_batch_size": 16, # 批次的大小"learning_strategy": { # 优化函数相关的配置"lr": 0.0005 # 超参数学习率} }scr_path=train_parameters['src_path'] target_path=train_parameters['target_path'] src_test_path=train_parameters["src_test_path"] z = zipfile.ZipFile(scr_path, 'r') z.extractall(path=target_path) z = zipfile.ZipFile(src_test_path, 'r') z.extractall(path=target_path) z.close() for imgpath in os.listdir(target_path + '/cat_12_train'):src = os.path.join(target_path + '/cat_12_train/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)for imgpath in os.listdir(target_path + '/cat_12_test'):src = os.path.join(target_path + '/cat_12_test/', imgpath)img = Image.open(src)if img.mode != 'RGB':img = img.convert('RGB')img.save(src)
解压后将所有图像变为RGB图像
二、加载训练集,进行预处理、数据增强、格式变换
transform = transforms.Compose([transforms.Resize(size=224),transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签 contents=[] with open('data/data10954/train_list.txt')as f:contents=f.read().split('\n')for item in contents:if item=='':continuepath='data/dataset/'+item.split('\t')[0]data=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_train.append(data)y_train.append(int(item.split('\t')[-1]))contetns=os.listdir('data/dataset/cat_12_test') for item in contetns:path='data/dataset/cat_12_test/'+itemdata=np.array(Image.open(path).convert('RGB'))data=np.array(transform(data))x_eval.append(data)
重点是transforms变换的预处理
三、划分训练集和测试集
x_train=np.array(x_train)y_train=np.array(y_train)x_eval=np.array(x_eval)x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)x_train=paddle.to_tensor(x_train,dtype='float32') y_train=paddle.to_tensor(y_train,dtype='int64') x_test=paddle.to_tensor(x_test,dtype='float32') y_test=paddle.to_tensor(y_test,dtype='int64') x_eval=paddle.to_tensor(x_eval,dtype='float32')
这是必要的,可以随时利用测试集查看准确率
四、加载预训练模型,选择损失函数和优化器
learning_rate=0.001 epochs =5 # 迭代轮数 batch_size = 50 # 批次大小 weight_decay=1e-5 num_class=12cnn=resnet50(pretrained=True) checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False cnn.fc = nn.Linear(2048, num_class) cnn.set_dict(checkpoint['cnn_state_dict']) criterion=nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)
第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数
五、模型训练
if x_train.shape[3]==3:x_train=paddle.transpose(x_train,perm=(0,3,1,2))dataset = paddle.io.TensorDataset([x_train, y_train]) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) for epoch in range(epochs):for batch_data, batch_labels in data_loader:outputs = cnn(batch_data)loss = criterion(outputs, batch_labels)print(epoch)loss.backward()optimizer.step()optimizer.clear_grad()print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数 paddle.save({'cnn_state_dict': cnn.state_dict(),}, 'checkpoint.pdparams')
使用批处理,这个很重要,不然平台分分钟炸了
六、测试集准确率
num_class=12 batch_size=64 cnn=resnet50(pretrained=True) checkpoint=paddle.load('checkpoint.pdparams')for param in cnn.parameters():param.requires_grad=False cnn.fc = nn.Linear(2048, num_class) cnn.set_dict(checkpoint['cnn_state_dict'])cnn.eval()if x_test.shape[3]==3:x_test=paddle.transpose(x_test,perm=(0,3,1,2)) dataset = paddle.io.TensorDataset([x_test, y_test]) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)with paddle.no_grad():score=0for batch_data, batch_labels in data_loader:predictions = cnn(batch_data)predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)predicted_labels = paddle.argmax(predicted_probabilities, axis=1) print(predicted_labels)for i in range(len(predicted_labels)):if predicted_labels[i].numpy()==batch_labels[i]:score+=1print(score/len(y_test))
设置eval模式,使用批处理测试准确率
相关文章:

百度飞浆ResNet50大模型微调实现十二种猫图像分类
12种猫分类比赛传送门 要求很简单,给train和test集,训练模型实现图像分类。 这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。 训练十几个迭代,每个批次60左右,准确率达到90%以上…...

多服务器云探针源码(服务器云监控)/多服务器多节点_云监控程序python源码
源码简介: 多服务器云探针源码(服务器云监控),支持python多服务器多节点,云监控程序源码。它是一款很实用的云探针和服务器云监控程序源码。使用它可以帮助管理员能够快速监控和管理各种服务器和节点,实用性强。 源码链接: 网盘…...

ESP8266 WiFi物联网智能插座—下位机软件实现
目录 1、软件架构 2、开发环境 3、软件功能 4、程序设计 4.1、初始化 4.2、主循环状态机 4.3、初始化模式 4.4、配置模式 4.5、运行模式 4.6、重启模式 4.7、升级模式 5、程序功能特点 5.1、日志管理 5.2、数据缓存队列 本篇博文开始讲解下位机插座节点的MCU软件…...

微信小程序--下拉选择框组件封装,可CV直接使用
一、起因 接到的项目需求,查看ui设计图后,由于微信小程序官方设计的下拉选择框不符合需求,而且常用的第三方库也没有封装类似的,所以选择自己自定义组件。在此记录一下,方便日后复用。 ui设计图如下: 微信官方提供的选择框 对比发现并不能实现我们想要的功能。 二、自定义组件…...

代码随想录算法训练营第五十九天 |647. 回文子串、516.最长回文子序列、动态规划总结篇
一、647. 回文子串 题目链接/文章讲解:代码随想录 思考: 1.确定dp数组(dp table)以及下标的含义 如果本题定义dp[i] 为 下标i结尾的字符串有 dp[i]个回文串的话: 会发现很难找到递归关系,dp[i] 和 dp[i-1]…...

互联网性能和可用性优化CDN和DNS
当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…...

使用 ErrorStack 在出现报错 ORA-14402 时产生的日志量
0、测试结论: 测试结果:设置 ErrorStack 级别为 1 时产生 Trace 的日志量最小,大小为 308K,同时在 alert 日志中也存在记录。 1、准备测试数据: sqlplus / as sysdba show pdbs alter session set containerpdb; …...

详解Spring-ApplicationContext
加载器目前有两种选择:ContextLoaderListener和ContextLoaderServlet。 这两者在功能上完全等同,只是一个是基于Servlet2.3版本中新引入的Listener接口实现,而另一个基于Servlet接口实现。开发中可根据目标Web容器的实际情况进行选择。 配…...

关键字extern、static与const
关键字extern、static与const extern关键字与include的区别 extern:于声明某个函数或变量是外部的(其他源文件中)include:用于批量引入 项目中可以根据需要引入的函数或变量数量决定使用extern还是include static关键字 static关键字用于限制函数和全局变量的作用域仅在当…...

虹科方案|国庆出游季,古建筑振动监测让历史古迹不再受损
全文导读: 国庆长假即将到来,各位小伙伴是不是都做好了出游计划呢?今年中秋、国庆“双节”连休八天,多地预计游客接待量将创下新高,而各地的名胜古迹更是人流爆满。迎接游客的同时,如何保障历史古迹不因巨大…...

Python学习笔记-使用哈希算法Hash,Hashlib进行数据加密
文章目录 一、概述1.1 哈希算法1.2 常见算法分类1.2.1 SHA算法1.2.2 MD4算法1.2.3 MD5算法 1.3 Hash算法的特性1.4 Hash算法的应用场景1.4.1 数据校验1.4.2 安全加密1.4.3 数字签名 二、Hash算法使用2.1 使用hash函数直接获取hash值2.2 使用hashlib库进行hash计算2.2.1 基本使用…...

跨境电商能否成为黄河流域产业带的新引擎?
近年来,随着全球贸易格局的不断演变和中国经济的快速崛起,跨境电商已经成为中国外贸的一大亮点。而在中国国内,黄河流域产业带一直以其丰富的资源和悠久的历史而闻名,但也面临着转型升级的挑战。那么,跨境电商是否有潜…...

从数据到决策:企业投资信息查询API的关键作用
前言 在现代商业环境中,数据是一项无价的资产。企业不仅需要访问大量数据,还需要将这些数据转化为有用的见解,以支持战略决策。对于企业投资而言,准确的信息和实时的市场数据至关重要。在这个信息时代,企业投资信息查…...

NSIC2050JBT3G 车规级120V 50mA ±15% 用于LED照明的线性恒流调节器(CCR) 增强汽车安全
随着汽车行业的巨大变革,高品质的汽车氛围灯效、仪表盘等LED指示灯效已成为汽车内饰设计中不可或缺的元素。深力科安森美LED驱动芯片系列赋能智能座舱灯效充满艺术感和科技感——NSIC2050JBT3G LED驱动芯片,实现对每路LED亮度和颜色进行细腻控制…...

LuatOS-SOC接口文档(air780E)-- ftp - ftp 客户端
ftp.login(adapter,ip_addr,port,username,password)# FTP客户端 参数 传入值类型 解释 int 适配器序号, 只能是socket.ETH0, socket.STA, socket.AP,如果不填,会选择平台自带的方式,然后是最后一个注册的适配器 string ip_addr 地址 string port 端口,默认21 string…...

第二证券:市净率高好还是低好?
市净率是一个衡量公司股票投资价值的指标,通过比较公司股票价格和公司每股净资产的比值来评估公司股票的估值水平。市净率高好还是低好这个问题并没有一个简单的答案,取决于具体的市场环境和投资者的需求。本文将从多个角度分析市净率高好还是低好。 首…...

HTTP协议是什么
HTTP (全称为 “超文本传输协议”) 是一种应用非常广泛的 应用层协议,是一种网络通信协议。 超文本:所谓 “超文本” 的含义, 就是传输的内容不仅仅是文本(比如 html, css 这个就是文本), 还可以是一些其他的资源, 比如图片, 视频, 音频等二进制的数据。…...

微服务09-Sentinel的入门
文章目录 微服务中的雪崩现象解决办法:1. 超时处理2. 舱壁模式3. 熔断降级4.流量控制 Sentinel1.介绍2.使用操作3.限流规则4.实战:流量监控5.高级选项功能的使用1.关联模式2.链路模式3.总结 流控效果1.预热模式2.排队等待模式3.总结4.热点参数限流5.实战…...

2023-2024-1 高级语言程序设计实验一: 选择结构
7-1 古时年龄称谓知多少? 输入一个人的年龄(岁),判断出他属于哪个年龄段 ? 0-9 :垂髫之年; 10-19: 志学之年; 20-29 :弱冠之年; 30-39 &#…...

js事件循环详解
事件循环简介 JavaScript的事件循环是一种处理异步事件和回调函数的机制,它是在浏览器或Node.js环境中运行,用于管理任务队列和调用栈,以及在适当的时候执行回调函数。 事件循环的基本原理是,JavaScript引擎在空闲时等待事件的到…...

实战指南:使用 kube-prometheus-stack 监控 K3s 集群
作者简介 王海龙,Rancher 中国社区技术经理,Linux Foundation APAC Evangelist,负责 Rancher 中国技术社区的维护和运营。拥有 9 年的云计算领域经验,经历了 OpenStack 到 Kubernetes 的技术变革,无论底层操作系统 Lin…...

golang调用scws实现简易中文分词
1、安装 scws 官网以及文档 https://github.com/hightman/scws wget -q -O - http://www.xunsearch.com/scws/down/scws-1.2.3.tar.bz2 | tar xjf -cd scws-1.2.3 ./configure --prefix/usr/local/scws --enable-shared make && make installLibraries have been ins…...

Excel 中使用数据透视图进行数据可视化
使用数据透视表(PivotTable)是在Excel中进行数据可视化的强大工具。下面将提供详细的步骤来使用数据透视表进行数据可视化。 **步骤一:准备数据** 首先,确保你有一个包含所需数据的Excel表格。数据应该按照一定的结构和格式组织…...

在SIP 语音呼叫中出现单通时要怎么解决?
在VoIP的环境中,特别是基于SIP通信的环境中,我们经常会遇到一些非常常见的问题,例如,单通,注册问题,回声,单通等。这些问题事实上都有非常直接的排查方式和解决办法,用户可以按照一定…...

【师兄啊师兄2】公布,李长寿成功渡劫,敖乙叛变,又一美女登场
Hello,小伙伴们,我是小郑继续为大家深度解析国漫资讯。 由玄机制作的师兄啊师兄第一季这才完结没有多久,没想到现在第二季就公布了,连海报和预告都出来了,看样子已经做得差不多了。预告看下来,能够明显感觉到官方又进步…...

视频倒着播放,原来是这么实现的
视频倒放是当今最流行的内容类型之一,这些视频服务于不同的目的,例如营销、娱乐、教育等。它们以独特的内容脱颖而出,并给观众留下深刻印象,将视频编辑带到了一个全新的水平。 在本文中,您将了解有关视频倒着播放的内…...

# 02 初识Verilog HDL
02 初识Verilog HDL 对于Verilog的语言的学习,我认为没必要一开始就从头到尾认真的学习这个语言,把这个语言所有细节都搞清楚也不现实,我们能够看懂当前FPGA的代码的程度就可以了,随着学习FPGA深度的增加,再不断的…...

使用 Eziriz .NET Reactor 对c#程序加密
我目前测试过好几个c#加密软件。效果很多时候是加密后程序执行错误,或者字段找不到的现象 遇到这个加密软件用了一段时间都很正常,分享一下使用流程 破解版本自行百度。有钱的支持正版,我用的是 Eziriz .NET Reactor 6.8.0 第一步 安装 Ezi…...

Restclient-cpp库介绍和实际应用:爬取www.sohu.com
概述 Restclient-cpp是一个用C编写的简单而优雅的RESTful客户端库,它可以方便地发送HTTP请求和处理响应。它基于libcurl和jsoncpp,支持GET, POST, PUT, PATCH, DELETE, HEAD等方法,以及自定义HTTP头部,超时设置,代理服…...

提升市场调研和竞品分析效率:利用Appium实现App数据爬取
市场调研和竞品分析通常需要获取大量的数据,而手动收集这些数据往往耗时且容易出错。而利用Appium框架,我们可以轻松地实现自动化的App数据爬取,这种方法不仅可以节省时间和人力成本,还可以提高数据的准确性和一致性。 Appium是一…...