当前位置: 首页 > news >正文

以Llama-2为例,在生成模型中使用自定义StoppingCriteria

以Llama-2为例,在生成模型中使用自定义StoppingCriteria

  • 1. 前言
  • 2. 场景介绍
  • 3. 解决方法
  • 4. 结语

1. 前言

在之前的文章中,介绍了使用transformers模块创建的模型,其generate方法的详细原理和使用方法,文章链接:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

其中提到了用户参与生成过程的两个关键组件,logits_processorstopping_criteria,使用这两个类,是用户控制生成过程的主要手段。其中,logits_processor用来在生成过程中,根据用户设置的指定规则,强行修改当前step在词表空间上的概率分布,而stopping_criteria,根据用户所规定的规则来中止生成。

这两个组件在transformers模块中都有一些预设的类可以直接使用,预设类的基本信息介绍可参考以beam search为例,详解transformers中generate方法(上)。

本文将结合实际应用场景,介绍用户如何根据自己的需求来设计并实现一个自定义的stopping_criteria,来控制生成过程提前结束。

2. 场景介绍

这次介绍的场景是,使用Llama-2的生成能力对一段新闻进行概括,希望能够生成一句简短的话,来概括新闻中发生的最核心的事情。

通过给定对话背景,结合历史样例的方式,希望Llama-2能够输出期望的结果。

对话的prompt构造方法可以参考之前的内容:NLP实践——Llama-2 多轮对话prompt构建。

然而,即便是采用了in-context learning的方式,Llama-2生成的结果仍然过于冗长。

例如对于这样一篇新闻:

text = """, Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ..."""  
# 后边忽略若干内容

模型生成的结果为:

Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This year's exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容

可以看出,并不是模型生成的结果不好,但是它太啰嗦了,而对于我的需求而言,模型只需要输出其中的第一句话就足够了。

这时候可能有人就会觉得:“那我分句然后把第一句话保留下来不就好了?”

——这样做虽然也可以达成效果,但是这个生成过程,时间和算力已经被消耗了。

所以需要采取方法,让模型在生成到第一个句号的时候,就停止生成,返回结果。于是就需要用到今天的主角——Stopping Criteria。

3. 解决方法

transformers模块中内置了几个默认的stopping criteria,然而,在很多情况下,它们并不能满足需求,这时,就需要创建自定义的stopping criteria。

首先需要引用基类:

from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

其中,

  • StoppingCriteriaList是一个容器,需要将所有的criteria都添加到其中,generate时传入的是这个容器;
  • StoppingCriteria是基础类,自定义的criteria需要继承这个基础类。

接下来就实现一个criteria,效果是,遇到指定的token时,就停止生成:

class StopAtSpecificTokenCriteria(StoppingCriteria):"""当生成出第一个指定token时,立即停止生成---------------ver: 2023-08-02by: changhongyu"""def __init__(self, token_id_list: List[int] = None):""":param token_id_list: 停止生成的指定token的id的列表"""self.token_id_list = token_id_list@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:# return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list# 储存scores会额外占用资源,所以直接用input_ids进行判断return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list

那么,如果希望遇到句号就停止生成,那就用句号对应的token_id去实例化一个这样的stopping criteria,并将它添加到容器中:

# Llama-2的词表中,英文句号的id是29889
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[29889]))

然后,在生成的时候,假如原本的生成指令是:

model.generate(**inputs)

那么再把stopping criteria作为参数传入进去,就可以发挥效果了:

model.generate(stopping_criteria=stopping_criteria, **inputs)

4. 结语

Stopping Criteria用于在每一个step的生成结束时,判断生成过程是否要结束,是用户控制生成过程的有效手段,其发挥作用的方式也比较直接,实现自定义criteria也并不复杂,只需要确保该类的调用方法返回值是bool值,并覆盖全部情况即可。

Logits Processor是用户控制生成的另一个有效工具,在接下来的博客中,还将介绍自定义logits processor是如何使用的,欢迎感兴趣的同学继续关注。

相关文章:

以Llama-2为例,在生成模型中使用自定义StoppingCriteria

以Llama-2为例,在生成模型中使用自定义StoppingCriteria 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言 在之前的文章中,介绍了使用transformers模块创建的模型,其generate方法的详细原理和使用方法,文章链接: 以be…...

servlet接受参数和乱码问题

servlet接受参数和乱码问题 1、乱码问题 1)get请求 传输参数出现中文乱码问题: 如果还存在问题: 2)post请求 传输参数出现中文乱码问题: 2、接受参数: 3、登录注册案例...

2023-08-05力扣今日三题

链接&#xff1a; 剑指 Offer 22. 链表中倒数第k个节点 题意&#xff1a; 如题 解&#xff1a; 快慢指针 实际代码&#xff1a; #include<iostream> using namespace std; struct ListNode {int val;ListNode *next;ListNode(int x) : val(x), next(NULL) {} }; L…...

webpack图片压缩

减少代码体积 | 尚硅谷 Web 前端之 Webpack5 教程 (yk2012.github.io) npm install image-mininizer webpack plugin imagemin -D 无损压缩 npm install imagemin-gifsicle imagemin-jpegtran imagemin-optipng imagemin-svgo -D 有损压缩 npm install imagemin-gifsicle image…...

JPA使用nativeQuery自定义SQL怎么插入一个对象参数呢?

0、我们在前后端传递数据时候&#xff0c;参数多的情况下&#xff0c;常常将这些参数封装成对象&#xff1b;当有些场景你需要使用JPA nativeQuery自定义SQL&#xff0c;要将这个对象insert时候&#xff0c;初学者似乎有点犯难&#xff0c;jpa不是spring-data项目的内容吗&…...

用C语言构建一个数字识别卷积神经网络

卷积神经网络的具体原理和对应的python例子参见末尾的参考资料2.3. 这里仅叙述卷积神经网络的配置, 其余部分不做赘述&#xff0c;构建和训练神经网络的具体步骤请参见上一篇: 用C语言构建一个手写数字识别神经网路 卷积网络同样采用简单的三层结构&#xff0c;包括输入层con…...

【CSS】圆形放大的hover效果

效果 index.html <!DOCTYPE html> <html><head><title> Document </title><link type"text/css" rel"styleSheet" href"index.css" /></head><body><div class"avatar"></…...

work weekly

每周汇报&#xff1a;围绕着项目范围及需求内容完成情况多少、人力资源情况、整体进度情况、成本情况、【范围】多少工作、【资源】投入多少人、【时间】花费多少时间、【成本】花了多少钱 【质量】一般没有特别要求的默认软件开发过程规范要求响应时间 【沟通】这里不说了 …...

Mac端口扫描工具

端口扫描工具 Mac内置了一个网络工具 网络使用工具 按住 Command 空格 然后搜索 “网络实用工具” 或 “Network Utility” 即可 域名/ip转换Lookup ping功能 端口扫描 https://zhhll.icu/2022/Mac/端口扫描工具/ 本文由 mdnice 多平台发布...

如何隐藏开源流媒体EasyPlayer.js视频H.265播放器的实时录像按钮?

目前我们TSINGSEE青犀视频所有的视频监控平台&#xff0c;集成的都是EasyPlayer.js版播放器&#xff0c;它属于一款高效、精炼、稳定且免费的流媒体播放器&#xff0c;可支持多种流媒体协议播放&#xff0c;包括WebSocket-FLV、HTTP-FLV&#xff0c;HLS&#xff08;m3u8&#x…...

Spring Cloud Eureka 和 zookeeper 的区别

CAP理论 在了解eureka和zookeeper区别之前&#xff0c;我们先来了解一下这个知识&#xff0c;cap理论。 1998年的加州大学的计算机科学家 Eric Brewer 提出&#xff0c;分布式有三个指标。Consistency&#xff0c;Availability&#xff0c;Partition tolerance。简称即为CAP。…...

Golang之路---04 并发编程——信道/通道

信道/通道 如果说 goroutine 是 Go语言程序的并发体的话&#xff0c;那么 channel&#xff08;信道&#xff09; 就是 它们之间的通信机制。channel&#xff0c;是一个可以让一个 goroutine 与另一个 goroutine 传输信息的通道&#xff0c;我把他叫做信道&#xff0c;也有人将…...

【Rust 基础篇】Rust派生宏:自动实现trait的魔法

导言 Rust是一门现代的、安全的系统级编程语言&#xff0c;它提供了丰富的元编程特性&#xff0c;其中派生宏&#xff08;Derive Macros&#xff09;是其中之一。派生宏允许开发者自定义类型上的trait实现&#xff0c;从而在编译期间自动实现trait。在本篇博客中&#xff0c;我…...

PHP8的程序结构-PHP8知识详解

在做任何事情之前&#xff0c;都需要遵循一定的规则。在PHP8中&#xff0c;程序能够安照人们的意愿执行程序&#xff0c;主要依靠程序的流程控制语句。 不管多复杂的程序&#xff0c;都是由这些基本的语句组成的。语句是构造程序的基本单位。程序执行的过程就是执行程序语句的…...

Spring Cloud +UniApp 智慧工地云平台源码,智能监控和AI分析系统,危大工程管理、视频监控管理、项目人员管理、绿色施工管理

一套智慧工地云平台源码&#xff0c;PC管理端APP端平板端可视化数据大屏端源码 智慧工地可视化系统利用物联网、人工智能、云计算、大数据、移动互联网等新一代信息技术&#xff0c;通过工地中台、三维建模服务、视频AI分析服务等技术支撑&#xff0c;实现智慧工地高精度动态仿…...

“科创中国”青百会轮值主席吴甜:以大语言模型为代表的AI将引发产业变革

8月1日&#xff0c;“科创中国”青年百人会&#xff08;后文简称青百会&#xff09;联合百度举办“青创汇”高端对话&#xff0c;围绕人工智能技术创新与产业发展交流研讨&#xff0c;同时正式成立“科创中国”青年百人会女性工作委员会。该委员会将鼓励更多女性投身科技创新事…...

【Git /Github】知识学习

1.新手入门视频Github 新手够用指南 | 全程演示&个人找项目技巧放送_哔哩哔哩_bilibili 找开源项目的一些途径 • https://github.com/trending/ 指定一些语言显示出star数较高的项目 • https://github.com/521xueweihan/HelloGitHub 定期分析各种项目 • https://g…...

【雕爷学编程】Arduino动手做(181)---Maixduino AI开发板2

37款传感器与执行器的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&am…...

PHP 编译问题PEAR package PHP_Archive not installed的解决

php 的编译时需要依赖pear package &#xff0c;目前的问题错误"PEAR package PHP_Archive not installed"&#xff0c;已经明显报出这个问题。 因此编译使用参数 --without-pear 将pear 屏蔽掉编译安装后&#xff0c;再进行安装&#xff1b;同时因为phar 属于pear…...

【探索Linux】—— 步步学习强大的命令行工具 P.1(Linux简介)

目录 前言 一、Linux简介 二、linux的不同发行版本 三、Linux的开源性质 四、Linux的特点 五、Linux代码演示&#xff08;仅供参考&#xff09; 总结 前言 前面我们讲了C语言的基础知识&#xff0c;也了解了一些数据结构&#xff0c;并且讲了有关C的一些知识&#xff…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中&#xff0c;高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术&#xff0c;实现年省电费15%-60%&#xff0c;且不改动原有装备、安装快捷、…...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年&#xff0c;作为行业领先的3D工业相机及视觉系统供应商&#xff0c;累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成&#xff0c;通过稳定、易用、高回报的AI3D视觉系统&#xff0c;为汽车、新能源、金属制造等行…...

安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖

在Vuzix M400 AR智能眼镜的助力下&#xff0c;卢森堡罗伯特舒曼医院&#xff08;the Robert Schuman Hospitals, HRS&#xff09;凭借在无菌制剂生产流程中引入增强现实技术&#xff08;AR&#xff09;创新项目&#xff0c;荣获了2024年6月7日由卢森堡医院药剂师协会&#xff0…...

安全突围:重塑内生安全体系:齐向东在2025年BCS大会的演讲

文章目录 前言第一部分&#xff1a;体系力量是突围之钥第一重困境是体系思想落地不畅。第二重困境是大小体系融合瓶颈。第三重困境是“小体系”运营梗阻。 第二部分&#xff1a;体系矛盾是突围之障一是数据孤岛的障碍。二是投入不足的障碍。三是新旧兼容难的障碍。 第三部分&am…...

STM32HAL库USART源代码解析及应用

STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...

CSS | transition 和 transform的用处和区别

省流总结&#xff1a; transform用于变换/变形&#xff0c;transition是动画控制器 transform 用来对元素进行变形&#xff0c;常见的操作如下&#xff0c;它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...

wpf在image控件上快速显示内存图像

wpf在image控件上快速显示内存图像https://www.cnblogs.com/haodafeng/p/10431387.html 如果你在寻找能够快速在image控件刷新大图像&#xff08;比如分辨率3000*3000的图像&#xff09;的办法&#xff0c;尤其是想把内存中的裸数据&#xff08;只有图像的数据&#xff0c;不包…...