当前位置: 首页 > news >正文

【Pytorch和Keras】使用transformer库进行图像分类

目录

    • 一、环境准备
    • 二、基于Pytorch的预训练模型
      • 1、准备数据集
      • 2、加载预训练模型
      • 3、 使用pytorch进行模型构建
    • 三、基于keras的预训练模型
    • 四、模型测试
    • 五、参考

 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型,并且提供简洁的transformer模型调用,这大大提高了开发人员的开发效率。本博客主要利用transformer库实现一个简单的模型微调,以进行图像分类的任务。


一、环境准备

 使用终端命令行安装对应的第三方包,具体安装命令输入如下:

pip install transformers datasets evaluate

二、基于Pytorch的预训练模型

  由于下面这些内容需要在huggface上申请账号权限,才能进行模型和数据集加载,如果之前有从huggface上拉取模型和数据集的经验,可以略过,如果没有配置过,可以参考笔者之前的文章https://blog.csdn.net/qq_40734883/article/details/143922095,然后直接申请Write权限就可以。

在这里插入图片描述
  后续所有涉及到的数据集food101和transformer模型都需要参考上述文章进行直接下载,才能运行整个程序,或者在google的colab直接运行。

  如果在google的colab上运行,请提前设置好电脑的GPU资源,同时加入huggface登录代码,具体如下:

from huggingface_hub import notebook_login
notebook_login()

  运行之后会提示进行token输入,按之前获取到的token输入即可。

1、准备数据集

  这里以food101数据集作为微调数据集,在imagenet-21k上训练完成的transformer模型(vit-base-patch16-224)进行优化

from datasets import load_datasetfood = load_dataset("food101", split="train[:5000]")# 划分数据集,训练集:测试集=8:2,food有两个键:一个train,一个test
food = food.train_test_split(test_size=0.2)  # 标签转换
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = label

id2label为通过id访问标签的字典,后续会使用到。

2、加载预训练模型

from transformers import AutoImageProcessorcheckpoint = "google/vit-base-patch16-224-in21k"   # ImageNet-21k上的预训练模型
image_processor = AutoImageProcessor.from_pretrained(checkpoint)  # 从huggface拉取并加载模型

3、 使用pytorch进行模型构建

from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor# 数据预处理操作定义
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["shortest_edge"]if "shortest_edge" in image_processor.sizeelse (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])# 对原始数据进行RGB及字典化
def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesfood = food.with_transform(transforms)
# 验证
import evaluate
# 指定验证过程中的评价指标-准确率
accuracy = evaluate.load("accuracy")import numpy as np
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels)

  训练设置和运行,具体输入代码如下:

# 整合训练中的数据,以便在模型训练或评估过程中使用
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()from transformers import AutoModelForImageClassification, TrainingArguments, Trainer# 初始化模型
model = AutoModelForImageClassification.from_pretrained(checkpoint,num_labels=len(labels),id2label=id2label,label2id=label2id,
)# 设置模型优化参数
training_args = TrainingArguments(output_dir="my_awesome_food_model",remove_unused_columns=False,evaluation_strategy="epoch",save_strategy="epoch",learning_rate=5e-5,per_device_train_batch_size=16,gradient_accumulation_steps=4,per_device_eval_batch_size=16,num_train_epochs=3,warmup_ratio=0.1,logging_steps=10,load_best_model_at_end=True,metric_for_best_model="accuracy",push_to_hub=True,
)# 初始化训练实例
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=food["train"],eval_dataset=food["test"],tokenizer=image_processor,compute_metrics=compute_metrics,
)trainer.train()  # 开始训练trainer.push_to_hub()  # 推送到huggfacehub

  经过上述设置训练完成之后,会将模型微调结果推送到huggface平台,如果不想推送,可以不运行相关的命令行,并且training_args中的push_to_hub=False

  训练结果如下图所示:

在这里插入图片描述
  默认需要选择是否关联wandb,如果不想选择,直接根据设置提示跳过即可。

  如果选择了推送到huggfacehub(trainer.push_to_hub() )的话,在个人的huggface上会有一个名为my_awesome_food_model的模型,里面包含了模型训练的各个参数设置和测试结果。

在这里插入图片描述


三、基于keras的预训练模型

  使用transflow的keras API 进行模型的搭建,具体代码如下:

from transformers import create_optimizer# 超参数设置
batch_size = 16
num_epochs = 5
num_train_steps = len(food["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01# 定义优化方式和策略
optimizer, lr_schedule = create_optimizer(init_lr=learning_rate, num_train_steps=num_train_steps, weight_decay_rate=weight_decay_rate, num_warmup_steps=0)# 定义分类器
from transformers import TFAutoModelForImageClassification
model = TFAutoModelForImageClassification.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)# converting our train dataset to tf.data.Dataset
tf_train_dataset = food["train"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator)# converting our test dataset to tf.data.Dataset
tf_eval_dataset = food["test"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=False, batch_size=batch_size, collate_fn=data_collator)# 定义损失函数
from tensorflow.keras.losses import SparseCategoricalCrossentropy
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)model.compile(optimizer=optimizer, loss=loss)from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
# 定义验证指标
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
# 推送到huggface回调函数
push_to_hub_callback = PushToHubCallback(output_dir="food_classifier", tokenizer=image_processor, save_strategy="no")
callbacks = [metric_callback, push_to_hub_callback]# 开始训练
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)

四、模型测试

 这里使用微调好的模型在food101上找一张验证图像进行简单的验证测试,具体代码如下:

# 验证food中验证集的某一张图像
ds = load_dataset("food101", split="validation[-5:-1]")
image = ds["image"][-1]# visualize image
import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')  
plt.show()

  测试图像如下所示:
在这里插入图片描述


from transformers import pipeline
# initialize classifier instance
classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)from transformers import AutoImageProcessor
import torch
# load pre-trained image processor
image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")from transformers import AutoModelForImageClassification
# laod pre-trained model
model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():logits = model(**inputs).logits# 输出测试结果
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

  输出结果如下所示:

Device set to use cuda:0
[{'label': 'ramen', 'score': 0.9517934918403625},{'label': 'bruschetta', 'score': 0.7566707730293274},{'label': 'hamburger', 'score': 0.7004948854446411},{'label': 'chicken_wings', 'score': 0.6275856494903564},{'label': 'prime_rib', 'score': 0.5991673469543457}]

  预测结果为:ramen

  释义:“ramen”一词源于日语“ラーメン”,是“拉面”的意思。它进一步追溯至汉语“拉面”,是一种起源于中国、流行于日本及其他东亚地区的面条食品。在日本,拉面通常由小麦面粉制成的面条,搭配肉汤和各种配料,如叉烧、鸡蛋、蔬菜等。

五、参考

[1] https://huggingface.co/docs/transformers/main/tasks/image_classification

[2] https://github.com/huggingface/transformers/blob/main/docs/source/en/installation.md

相关文章:

【Pytorch和Keras】使用transformer库进行图像分类

目录 一、环境准备二、基于Pytorch的预训练模型1、准备数据集2、加载预训练模型3、 使用pytorch进行模型构建 三、基于keras的预训练模型四、模型测试五、参考 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型&am…...

快速了解 c++ 异常处理 基础知识

相关代码概览&#xff1a; #include<stdexcept>std::runtime_errorcatch (const std::runtime_error& e) e.what() 相信大家一定见过这些代码&#xff0c;那么这些代码具体什么意思呢&#xff1f;我们一起来看一下 知识精讲&#xff1a; 异常处理是C中非常重要…...

deepseek API 调用-python

【1】创建 API keys 【2】安装openai SDK pip3 install openai 【3】代码&#xff1a; https://download.csdn.net/download/notfindjob/90343352...

玩转Gin框架:Golang使用Gin完成登录流程

文章目录 背景基于Token认证机制简介常见的Token类型Token的生成和验证在项目工程里创建jwt.go文件根目录新建.env文件 创建登录接口 /loginToken认证机制的优点 背景 登录流程&#xff0c;相信大家都很熟悉的。传统网站采用session后端验证登录状态&#xff0c;大致流程如下&…...

Linux学习笔记16---高精度延时实验

延时函数是很常用的 API 函数&#xff0c;在前面的实验中我们使用循环来实现延时函数&#xff0c;但是使用循环来实现的延时函数不准确&#xff0c;误差会很大。虽然使用到延时函数的地方精度要求都不会很严格( 要求严格的话就使用硬件定时器了 ) &#xff0c;但是延时函数肯定…...

vue2:如何动态控制el-form-item之间的行间距

需求 某页面有查看和编辑两种状态: 编辑: 查看: 可以看到,查看时,行间距太大导致页面不紧凑,所以希望缩小查看是的行间距。 行间距设置 行间距通常是通过 CSS 的 margin 或 padding 属性来控制的。在 Element UI 的样式表中,.el-form-item 的下边距(margin-bottom)…...

deepseek从网络拓扑图生成说明文字实例

deepseek对话页面中输入问题指令&#xff1a; 我是安全测评工程师&#xff0c;正在撰写系统测评报告&#xff0c;现在需要对系统网络架构进行详细说明&#xff0c;请根据附件网络拓扑图输出详细说明文字。用总分的段落结构&#xff0c;先介绍各网络区域&#xff0c;再介绍网络…...

两种文件类型(pdf/图片)打印A4半张纸方法

环境:windows10、Adobe Reader XI v11.0.23 Pdf: 1.把内容由横排变为纵排&#xff1a; 2.点击打印按钮&#xff1a; 3.选择打印页范围和多页&#xff1a; 4.内容打印在纸张上部 图片&#xff1a; 1.右键图片点击打印&#xff1a; 2.选择打印类型&#xff1a; 3.打印配置&am…...

HTB:UnderPass[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机UDP开放端口进行脚本、服务扫描 …...

【deepseek实战】绿色好用,不断网

前言 最佳deepseek火热网络&#xff0c;我也开发一款windows的电脑端&#xff0c;接入了deepseek&#xff0c;基本是复刻了网页端&#xff0c;还加入一些特色功能。 助力国内AI&#xff0c;发出自己的热量 说一下开发过程和内容的使用吧。 目录 一、介绍 二、具体工作 1.1、引…...

MySQL 进阶专题:索引(索引原理/操作/优缺点/B+树)

在数据库的秋招面试中&#xff0c;索引&#xff08;Index&#xff09;是一个经典且高频的题目。索引的作用类似于书中的目录&#x1f4d6;&#xff0c;它能够显著加快数据库查询的速度。本文将深入探讨索引的概念、作用、优缺点以及背后的数据结构&#xff0c;帮助你从原理到应…...

用NeuralProphet预测股价:AI金融新利器(附源码)

作者&#xff1a;老余捞鱼 原创不易&#xff0c;转载请标明出处及原作者。 写在前面的话&#xff1a;我用NeuralProphet模型预测了股票价格&#xff0c;发现其通过结合时间序列分析和神经网络算法&#xff0c;确实能提供比传统Last Value方法更精准的预测。经过一系列超参数调优…...

【Elasticsearch】parent aggregation

在Elasticsearch中&#xff0c;Parent Aggregation是一种特殊的单桶聚合&#xff0c;用于选择具有指定类型的父文档&#xff0c;这些类型是通过一个join字段定义的。以下是关于Parent Aggregation的详细介绍&#xff1a; 1.基本概念 Parent Aggregation是一种聚合操作&#x…...

IDEA使用Auto-dev+DeepSeek 10分钟快速集成,让java开发起飞

在当今的软件开发领域,AI 工具的辅助作用愈发凸显,DeepSeek AI 便是其中的佼佼者。它凭借强大的自然语言处理能力和高效的代码生成能力,成为众多开发者的得力助手。而 IntelliJ IDEA 作为一款广受欢迎的集成开发环境(IDE),若能与 DeepSeek AI 无缝集成,无疑将为开发者带…...

ASP.NET Core中间件Markdown转换器

目录 需求 文本编码检测 Markdown→HTML 注意 实现 需求 Markdown是一种文本格式&#xff1b;不被浏览器支持&#xff1b;编写一个在服务器端把Markdown转换为HTML的中间件。我们开发的中间件是构建在ASP.NET Core内置的StaticFiles中间件之上&#xff0c;并且在它之前运…...

使用page assist浏览器插件结合deepseek-r1 7b本地模型

为本地部署的DeepSeek R1 7b模型安装Page Assist&#xff0c;可以按照以下步骤进行&#xff1a; 一、下载并安装Ollama‌ 首先&#xff0c;你需要下载并安装Ollama&#xff0c;这是部署DeepSeek所必需的工具。你可以访问Ollama的官方网站&#xff08;ollama.com&#xff09;下…...

【华为OD-E卷 - 108 最大矩阵和 100分(python、java、c++、js、c)】

【华为OD-E卷 - 最大矩阵和 100分&#xff08;python、java、c、js、c&#xff09;】 题目 给定一个二维整数矩阵&#xff0c;要在这个矩阵中选出一个子矩阵&#xff0c;使得这个子矩阵内所有的数字和尽量大&#xff0c;我们把这个子矩阵称为和最大子矩阵&#xff0c;子矩阵的…...

【Reading Notes】Favorite Articles from 2025

文章目录 1、January2、February3、March4、April5、May6、June7、July8、August9、September10、October11、November12、December 1、January 极越之后&#xff0c;中国车市只会倒下更多人&#xff08;2025年01月01日&#xff09; 在这波枪林弹雨中&#xff0c;合资品牌中最…...

云计算行业分析

云计算作为数字经济的核心基础设施&#xff0c;未来十年将持续重塑全球科技格局&#xff0c;并渗透到几乎所有行业的数字化转型中。 一、云计算的发展潜力 1. 技术融合驱动爆发式创新 AI与云计算的深度耦合 - **智能云服务**&#xff1a;云厂商将提供预训练模型、自动化ML工…...

【Linux系统】线程:线程的优点 / 缺点 / 超线程技术 / 异常 / 用途

1、线程的优点 创建和删除线程代价较小 创建一个新线程的代价要比创建一个新进程小得多&#xff0c;删除代价也小。这种说法主要基于以下几个方面&#xff1a; &#xff08;1&#xff09;资源共享 内存空间&#xff1a;每个进程都有自己独立的内存空间&#xff0c;包括代码段…...

Kook Zimage 真实幻想 Typora文档集成方案

Kook Zimage 真实幻想 Typora文档集成方案 1. 引言 技术文档写作最头疼的是什么&#xff1f;文字描述得再生动&#xff0c;也不如一张直观的图片来得有说服力。传统的文档创作流程中&#xff0c;我们需要先在专门的AI绘图工具中生成图片&#xff0c;然后下载保存&#xff0c;…...

SmallThinker-3B-Preview赋能Java后端:智能客服系统数据库设计

SmallThinker-3B-Preview赋能Java后端&#xff1a;智能客服系统数据库设计 最近在做一个Java后端的智能客服项目&#xff0c;核心是要接入一个轻量级的AI模型——SmallThinker-3B-Preview。模型选好了&#xff0c;代码逻辑也搭得差不多了&#xff0c;但一到数据库设计这块&…...

告别“差不多就行”:用Cascade R-CNN解决目标检测中那些“似对非对”的边界框

从边界框“模糊地带”到工业级精度&#xff1a;Cascade R-CNN实战全解析 当你在自动驾驶系统中看到车辆识别框与真实车身存在5个像素的偏移&#xff0c;或在工业质检场景中某个关键缺陷的检测框刚好漏掉了1毫米的裂纹区域&#xff0c;这些“看似正确实则不准”的预测结果&#…...

【技术解析】SimpleNet:用极简网络架构革新工业图像异常检测

1. 工业图像异常检测的现状与挑战 工业生产线上的质检环节一直是个让人头疼的问题。想象一下&#xff0c;你站在一条每分钟生产上百件产品的流水线旁&#xff0c;需要肉眼检查每个产品表面是否有划痕、凹陷或污渍——这几乎是不可能完成的任务。传统计算机视觉方法在这个领域已…...

黑客为什么不攻击微信钱包?

黑客为什么不攻击微信钱包&#xff1f; 现在人人手机里都装着微信和支付宝&#xff0c;里面都或多或少存了些钱。怎么从来没听说谁的钱被技术牛逼黑客惦记走&#xff1f; 是黑客没攻击过&#xff1f;还是黑客不敢攻击&#xff1f;其实都不是。阿里巴巴首席风险官郑俊芳就说过&…...

保姆级教程:在Ubuntu 22.04上从Anaconda到PyTorch,一步步搞定GPU环境(含CUDA 11.7避坑指南)

保姆级教程&#xff1a;在Ubuntu 22.04上从Anaconda到PyTorch&#xff0c;一步步搞定GPU环境&#xff08;含CUDA 11.7避坑指南&#xff09; 刚接触深度学习的开发者们&#xff0c;最头疼的往往不是模型设计本身&#xff0c;而是环境搭建这个"拦路虎"。本文将手把手带…...

片上网络NOC:可生成RTL源代码与UVM验证环境的实用学习资料

片上网络NOC&#xff0c;可生成RTL源代码&#xff0c;生成uvm验证环境&#xff0c;内含有丰富的文档&#xff0c;带有readme文档&#xff0c;有例子工程&#xff0c;操作简单&#xff0c;是学习工作的好资料最近折腾NoC项目的时候挖到一个宝藏工具包&#xff0c;名字先不透露&a…...

我已战胜一切!感谢哥白尼,感谢爱因斯坦,感谢豆包,,,曾经我都经历过什么,我自己非常清楚,既有爱因斯坦的压缩版,又有哥白尼的压缩版,,,

不是时代不好&#xff0c;是人心中的成见就像一座大山般&#xff0c;无法被逾越&#xff0c;只有暴雨降下&#xff0c;洗刷这个世界&#xff0c;重塑这个宇宙&#xff0c;各位其位&#xff0c;大道至简。历史的车轮早已不可阻挡&#xff0c;&#xff0c;&#xff0c;暴风雨会来…...

新手避坑指南:用Altium Designer打开嘉立创PCB文件,这3个设置不改布线全乱

Altium Designer导入嘉立创PCB文件的三大核心设置解析 刚接触硬件设计的新手工程师们&#xff0c;当你们第一次尝试用Altium Designer打开从嘉立创EDA导出的PCB文件时&#xff0c;是否遇到过这样的场景&#xff1a;板框莫名其妙错位、网络连接全部丢失、设计规则一片混乱&#…...

新手福音:基于预置镜像,在快马平台零配置开启Python Web开发之旅

作为一个刚接触Python Web开发的新手&#xff0c;我最近在InsCode(快马)平台上体验了一把零配置搭建个人博客的过程。不得不说&#xff0c;这种基于预置镜像的开发方式&#xff0c;简直是为我们这些初学者量身定制的福音。下面我就来分享一下这次的学习心得。 为什么选择预置镜…...