以Llama-2为例,在生成模型中使用自定义LogitsProcessor
以Llama-2为例,在生成模型中使用自定义LogitsProcessor
- 1. 前言
- 2. 场景介绍
- 3. 解决方法
- 4. 结语
1. 前言
在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria
来控制生成过程的结束,本文将继续这一话题,结合具体的场景,介绍如何实现自定义的logits processor
,并以此来控制生成的过程。
2. 场景介绍
场景延续上篇介绍stopping criteria的文章,假如我们希望使用Llama-2模型,来生成一篇新闻的概要,希望它能够生成一句简短的话,来描述这篇新闻中主要发生了什么。
在上一篇文章中,我们成功的使用stopping criteria解决了模型废话太多的问题,然而,在某些情况下,模型输出的结果并不是我们想要的,它没有用一句话概括,反而是一条一条列举了其中的主要信息,类似:
1. ......
2. ......
3. ......
针对这种情况,我们可以强制要求生成的第一个token,不可以是数字,这样的话,就只能从字母中选择合适的单词生成,也就达到我们的目的了。为了实现这一策略,就需要用到logits processor。
3. 解决方法
logits processor
是在生成的过程中,每一个step的score计算完成之后,对score进行进一步的加工,改变模型输出的概率分布,从而影响后续生成的结果。
transformers模块中提供了若干内置的processor可以直接调用,具体的整理和简介可以参考之前的文章以beam search为例,详解transformers中generate方法(上)。
现在我们需要设计这样一个processor,判断如果是第一个生成的第一个token,则禁止它生成数字,也就是把所有数字对应的得分强制设置为负无穷。
首先,引入需要用到的类,与stopping criteria类似的,也是有要给基础类,和一个容器类:
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
然后继承基础类,实现我们所需的processor:
class SuppressSpecificBOSTokenLogitsProcessor(LogitsProcessor):"""防止生成的第一个token是某些特定的token---------------ver: 2023-08-02by: changhongyu"""def __init__(self, bad_bos_token_id_list: List[int] = None):""":param bad_bos_token_id_list: 不可以作为第一个token的token的id列表"""self.bad_bos_token_id_list = bad_bos_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:new_token_len = input_ids.shape[-1] - current_token_lenif new_token_len == 0:for id_ in self.bad_bos_token_id_list:scores[:, id_] = -float('inf')return scores
logits processor的使用方法与stopping criteria是一样的,我们设计好自己的processor类之后,实例化一个容器,再将实例化的processor放到这个容器中就好了:
NUMBER_ID_LIST = []
for i in range(10):NUMBER_ID_LIST.append(tokenizer.convert_tokens_to_ids(str(i)))
logits_processor = LogitsProcessorList()
logits_processor.append(SuppressSpecificBOSTokenLogitsProcessor(NUMBER_ID_LIST))
如果有多个processor的话,可能需要注意一下放入容器的顺序。
最后在生成的时候,将它作为参数传给generate方法就好了。
例如,原本生成的代码是:
outputs = model.generate(**inputs)
使用processor的话,可以写作:
outputs = model.generate(logits_processor=logits_processor, **inputs)
注意在实现的时候有一个小细节,由于是对话模型,输入的除了当前的query之外,还包括历史的对话记录,二者拼接在一起才是完整的prompt(prompt构建参考这一篇),所以我们并不能仅仅根据当前输入input_ids
的长度,来判断当前step是不是这一轮生成的第一个token,这就是为什么上面的代码中有一个为声明定义的变量current_token_len
。
对于这个current_token_len
,只需要在model.generate执行之前,对他global一下就可以了。
例如像这个样子,每次生成之前先计算一下截至生成之前的长度:
global current_token_len
current_token_len = inputs['input_ids'].shape[1]outputs = model.generate(logits_processor=logits_processor, **inputs)
4. 结语
作为用户控制生成过程的主要手段,如何巧妙地利用好logits processor对使用生成式模型来说非常重要。在实际情况中,需要针对场景,发现其中地规律,然后又针对性地去设计一个processor。它主要解决的问题,是一些有规律可循的场景,从一定意义上理解,可以认为是对生成模型的解空间进行了限制和变换。在解决问题的风格上给人的感觉,有点像抽取式模型所做的风格了,比如对于一个关键词生成任务,如果我们不希望模型生成文章中没有出现过的token,那完全可以利用本文中类似的方法,把生成结果限定为文中出现过的token。
以上就是本文的全部内容,如果对你有所帮助或启发,记得留下一个免费的赞,我们下期再见。
相关文章:
以Llama-2为例,在生成模型中使用自定义LogitsProcessor
以Llama-2为例,在生成模型中使用自定义LogitsProcessor 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言 在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria…...
python 计算图片hash 缓存图片为key
python,有时希望缓存图片作为key,怎么办?缓存整张突破占用内存太多,不妨缓存hash值: Fast way to Hash Numpy objects for Caching import hashlib import numpy a numpy.random.rand(10, 100) b a.view(numpy.uin…...

制造型企业如何实现车间设备生产数据的实时采集?需要5G网络吗?
引言 在制造业数字化转型的浪潮下,实时采集车间设备生产数据变得尤为重要。工业边缘网关HiWoo Box作为一款专为工业应用而设计的智能设备,具备工业级设计和多种联网方式,为制造型企业提供了高性能的车间设备数据实时采集解决方案。本文将重点…...
第2章 HTML中的JavaScript
引言 将JavaScript引入网页,首先要解决它与网页的主导语言HTML的关系问题 script元素 将JavaScript插入HTML的主要方法是使用script元素,script有8个可选属性 async:表示异步加载js文件内容,他们之间的顺序不一定按照html顺序ch…...

景联文科技高质量成品数据集上新啦!
景联文科技近期上新多个成品数据集,包含图像、视频等多种类型的数据,涵盖丰富的场景,可满足不同模型的多元化需求。 高质量成品数据集可用于训练和优化模型,使得模型能够更加全面和精准地理解和处理任务,更好地应对复…...
flask------请求拓展
flask中也有类似与django中的中间件,只不过是另一种写法,但是他们的作用是一样的,下面我们就一一介绍: 1.before_request 作用 : before_request 相当于 django 中的 process_request,每一个请求在被处理前都会经…...
大数据-玩转数据-FLINK-从kafka消费数据
一、基于前面kafka部署 大数据-玩转数据-Kafka安装 二、FLINK中编写代码 package com.lyh.flink04;import org.apache.flink.api.common.serialization.SimpleStringSchema; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apa…...

介绍Sping Boot的5个扩展点
1、初始化器ApplicationContextInitializer 我们在启动Spring Boot项目的时候,是执行这样一个方法来启动的 我们一层一层往下点,最终发现执行的是这个方法 所以我们在启动项目的时候也可以这样启动 new SpringApplication(SpringbootExtensionPointAp…...
Linux2.6内核配置说明
maturity level options代码成熟度选项 Prompt for development and/or incomplete code/drivers 显示尚在开发中或尚未完成的代码与驱动.除非你是测试人员或者开发者,否则请勿选择 setup常规设置 Local version - append to kernel release 在内核版本后面加上自定义的…...

Pytest简介及jenkins集成
一、pytest介绍 pytest介绍 - unittest\nose pytest:基于unittest之上的单元测试框架 自动发现测试模块和测试方法 断言使用assert表达式即可 可以设置测试会话级、模块级、类级、函数级的fixtures 数据准备 清理工作 unittest:setUp、teardown、…...

【LeetCode】105. 从前序与中序遍历序列构造二叉树 106. 从中序与后序遍历序列构造二叉树
105. 从前序与中序遍历序列构造二叉树 这道题也是经典的数据结构题了,有时候面试题也会遇到,已知前序与中序的遍历序列,由前序遍历我们可以知道第一个元素就是根节点,而中序遍历的特点就是根节点的左边全部为左子树,右…...

堆内存和一些检测工具
17 堆定义 通过new关键字创建,创建对象都会使用堆内存。 是线程共享的,需要考虑线程安全问题。 有垃圾回收机制。18 堆-内存溢出 当默认情况下,发现执行到26,出现内存溢出。 当我们将堆内存调为8m,继续执行ÿ…...
【JavaScript】元素获取指南
简介 在 JavaScript 中,我们经常需要通过获取元素来进行 DOM 操作和交互。本教程将介绍多种获取元素的方式,包括根据 ID、标签名、类名、选择器、属性和名称等。 通过ID获取元素 使用getElementById方法根据元素的ID属性获取单个元素。 var element = document.getElementB…...

uniapp 返回上一页并刷新
如要刷新的是mine页面 在/pages/mine/improveInfo页面修改信息,点击保存后跳转到个人中心(/pages/mine/index)页面并刷新更新数据 点击保存按钮时执行以下代码: wx.switchTab({url: /pages/mine/index }) // 页面重载 let pages …...

Java阶段五Day21
Java阶段五Day21 文章目录 Java阶段五Day21问题解析rocketmq清空数据 linux学习背景什么是linux系统虚拟机介绍启动 虚拟机linux虚拟机网络的问题 linux系统的基础命令命令提示符命令格式pwd指令ls指令cd指令mkdirtouch指令cp指令rm指令mv指令cat指令tail指令 文本编辑器vim操作…...

2023,谁在引领实时互动进入高清时代?
实践是检验真理的唯一标准,技术是行业进步的核心动能。在实时互动的新时代里,不断进化的声网已然完成自证。 作者|斗斗 出品|产业家 “一个医疗行业的客户,曾向我们提出一个需求,希望在120急救场景下,可以远程看清…...

STM32(HAL)串口中断接收
目录 1、简介 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 1、简介 本文对HAL串口中断函数进行介绍。 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 首先在main.c文件中进行…...

word转pdf怎么转?几种常用方法分享
word转pdf怎么转?在日常工作和学习中,将Word文档转换为PDF格式是一项必要的任务。不仅可以保证文档的格式不变,还可以防止文档被他人篡改。但是,Word文档并不是所有人都能够轻松打开和编辑的,而PDF文件则可以在各种设备…...

自学(黑客)技术,入门到入狱!
1.网络安全是什么 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高; 二、则是发展相对成熟入…...
js 函数、闭包及函数对象
js的函数是对象,可以通过程序来操控。比如,可以把函数赋值给变量,然后再传递给其他函数,也可以在函数上设置属性,甚至调用函数的方法。 js函数可以嵌套定义在其他函数里,内嵌函数可以访问定义在函数作用域…...
[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?
🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里…...

C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...

初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
Redis:现代应用开发的高效内存数据存储利器
一、Redis的起源与发展 Redis最初由意大利程序员Salvatore Sanfilippo在2009年开发,其初衷是为了满足他自己的一个项目需求,即需要一个高性能的键值存储系统来解决传统数据库在高并发场景下的性能瓶颈。随着项目的开源,Redis凭借其简单易用、…...
深度解析:etcd 在 Milvus 向量数据库中的关键作用
目录 🚀 深度解析:etcd 在 Milvus 向量数据库中的关键作用 💡 什么是 etcd? 🧠 Milvus 架构简介 📦 etcd 在 Milvus 中的核心作用 🔧 实际工作流程示意 ⚠️ 如果 etcd 出现问题会怎样&am…...

Axure零基础跟我学:展开与收回
亲爱的小伙伴,如有帮助请订阅专栏!跟着老师每课一练,系统学习Axure交互设计课程! Axure产品经理精品视频课https://edu.csdn.net/course/detail/40420 课程主题:Axure菜单展开与收回 课程视频:...