以Llama-2为例,在生成模型中使用自定义LogitsProcessor
以Llama-2为例,在生成模型中使用自定义LogitsProcessor
- 1. 前言
- 2. 场景介绍
- 3. 解决方法
- 4. 结语
1. 前言
在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria来控制生成过程的结束,本文将继续这一话题,结合具体的场景,介绍如何实现自定义的logits processor,并以此来控制生成的过程。
2. 场景介绍
场景延续上篇介绍stopping criteria的文章,假如我们希望使用Llama-2模型,来生成一篇新闻的概要,希望它能够生成一句简短的话,来描述这篇新闻中主要发生了什么。
在上一篇文章中,我们成功的使用stopping criteria解决了模型废话太多的问题,然而,在某些情况下,模型输出的结果并不是我们想要的,它没有用一句话概括,反而是一条一条列举了其中的主要信息,类似:
1. ......
2. ......
3. ......
针对这种情况,我们可以强制要求生成的第一个token,不可以是数字,这样的话,就只能从字母中选择合适的单词生成,也就达到我们的目的了。为了实现这一策略,就需要用到logits processor。
3. 解决方法
logits processor是在生成的过程中,每一个step的score计算完成之后,对score进行进一步的加工,改变模型输出的概率分布,从而影响后续生成的结果。
transformers模块中提供了若干内置的processor可以直接调用,具体的整理和简介可以参考之前的文章以beam search为例,详解transformers中generate方法(上)。
现在我们需要设计这样一个processor,判断如果是第一个生成的第一个token,则禁止它生成数字,也就是把所有数字对应的得分强制设置为负无穷。
首先,引入需要用到的类,与stopping criteria类似的,也是有要给基础类,和一个容器类:
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
然后继承基础类,实现我们所需的processor:
class SuppressSpecificBOSTokenLogitsProcessor(LogitsProcessor):"""防止生成的第一个token是某些特定的token---------------ver: 2023-08-02by: changhongyu"""def __init__(self, bad_bos_token_id_list: List[int] = None):""":param bad_bos_token_id_list: 不可以作为第一个token的token的id列表"""self.bad_bos_token_id_list = bad_bos_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:new_token_len = input_ids.shape[-1] - current_token_lenif new_token_len == 0:for id_ in self.bad_bos_token_id_list:scores[:, id_] = -float('inf')return scores
logits processor的使用方法与stopping criteria是一样的,我们设计好自己的processor类之后,实例化一个容器,再将实例化的processor放到这个容器中就好了:
NUMBER_ID_LIST = []
for i in range(10):NUMBER_ID_LIST.append(tokenizer.convert_tokens_to_ids(str(i)))
logits_processor = LogitsProcessorList()
logits_processor.append(SuppressSpecificBOSTokenLogitsProcessor(NUMBER_ID_LIST))
如果有多个processor的话,可能需要注意一下放入容器的顺序。
最后在生成的时候,将它作为参数传给generate方法就好了。
例如,原本生成的代码是:
outputs = model.generate(**inputs)
使用processor的话,可以写作:
outputs = model.generate(logits_processor=logits_processor, **inputs)
注意在实现的时候有一个小细节,由于是对话模型,输入的除了当前的query之外,还包括历史的对话记录,二者拼接在一起才是完整的prompt(prompt构建参考这一篇),所以我们并不能仅仅根据当前输入input_ids的长度,来判断当前step是不是这一轮生成的第一个token,这就是为什么上面的代码中有一个为声明定义的变量current_token_len。
对于这个current_token_len,只需要在model.generate执行之前,对他global一下就可以了。
例如像这个样子,每次生成之前先计算一下截至生成之前的长度:
global current_token_len
current_token_len = inputs['input_ids'].shape[1]outputs = model.generate(logits_processor=logits_processor, **inputs)
4. 结语
作为用户控制生成过程的主要手段,如何巧妙地利用好logits processor对使用生成式模型来说非常重要。在实际情况中,需要针对场景,发现其中地规律,然后又针对性地去设计一个processor。它主要解决的问题,是一些有规律可循的场景,从一定意义上理解,可以认为是对生成模型的解空间进行了限制和变换。在解决问题的风格上给人的感觉,有点像抽取式模型所做的风格了,比如对于一个关键词生成任务,如果我们不希望模型生成文章中没有出现过的token,那完全可以利用本文中类似的方法,把生成结果限定为文中出现过的token。
以上就是本文的全部内容,如果对你有所帮助或启发,记得留下一个免费的赞,我们下期再见。
相关文章:
以Llama-2为例,在生成模型中使用自定义LogitsProcessor
以Llama-2为例,在生成模型中使用自定义LogitsProcessor 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言 在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria…...
python 计算图片hash 缓存图片为key
python,有时希望缓存图片作为key,怎么办?缓存整张突破占用内存太多,不妨缓存hash值: Fast way to Hash Numpy objects for Caching import hashlib import numpy a numpy.random.rand(10, 100) b a.view(numpy.uin…...
制造型企业如何实现车间设备生产数据的实时采集?需要5G网络吗?
引言 在制造业数字化转型的浪潮下,实时采集车间设备生产数据变得尤为重要。工业边缘网关HiWoo Box作为一款专为工业应用而设计的智能设备,具备工业级设计和多种联网方式,为制造型企业提供了高性能的车间设备数据实时采集解决方案。本文将重点…...
第2章 HTML中的JavaScript
引言 将JavaScript引入网页,首先要解决它与网页的主导语言HTML的关系问题 script元素 将JavaScript插入HTML的主要方法是使用script元素,script有8个可选属性 async:表示异步加载js文件内容,他们之间的顺序不一定按照html顺序ch…...
景联文科技高质量成品数据集上新啦!
景联文科技近期上新多个成品数据集,包含图像、视频等多种类型的数据,涵盖丰富的场景,可满足不同模型的多元化需求。 高质量成品数据集可用于训练和优化模型,使得模型能够更加全面和精准地理解和处理任务,更好地应对复…...
flask------请求拓展
flask中也有类似与django中的中间件,只不过是另一种写法,但是他们的作用是一样的,下面我们就一一介绍: 1.before_request 作用 : before_request 相当于 django 中的 process_request,每一个请求在被处理前都会经…...
大数据-玩转数据-FLINK-从kafka消费数据
一、基于前面kafka部署 大数据-玩转数据-Kafka安装 二、FLINK中编写代码 package com.lyh.flink04;import org.apache.flink.api.common.serialization.SimpleStringSchema; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apa…...
介绍Sping Boot的5个扩展点
1、初始化器ApplicationContextInitializer 我们在启动Spring Boot项目的时候,是执行这样一个方法来启动的 我们一层一层往下点,最终发现执行的是这个方法 所以我们在启动项目的时候也可以这样启动 new SpringApplication(SpringbootExtensionPointAp…...
Linux2.6内核配置说明
maturity level options代码成熟度选项 Prompt for development and/or incomplete code/drivers 显示尚在开发中或尚未完成的代码与驱动.除非你是测试人员或者开发者,否则请勿选择 setup常规设置 Local version - append to kernel release 在内核版本后面加上自定义的…...
Pytest简介及jenkins集成
一、pytest介绍 pytest介绍 - unittest\nose pytest:基于unittest之上的单元测试框架 自动发现测试模块和测试方法 断言使用assert表达式即可 可以设置测试会话级、模块级、类级、函数级的fixtures 数据准备 清理工作 unittest:setUp、teardown、…...
【LeetCode】105. 从前序与中序遍历序列构造二叉树 106. 从中序与后序遍历序列构造二叉树
105. 从前序与中序遍历序列构造二叉树 这道题也是经典的数据结构题了,有时候面试题也会遇到,已知前序与中序的遍历序列,由前序遍历我们可以知道第一个元素就是根节点,而中序遍历的特点就是根节点的左边全部为左子树,右…...
堆内存和一些检测工具
17 堆定义 通过new关键字创建,创建对象都会使用堆内存。 是线程共享的,需要考虑线程安全问题。 有垃圾回收机制。18 堆-内存溢出 当默认情况下,发现执行到26,出现内存溢出。 当我们将堆内存调为8m,继续执行ÿ…...
【JavaScript】元素获取指南
简介 在 JavaScript 中,我们经常需要通过获取元素来进行 DOM 操作和交互。本教程将介绍多种获取元素的方式,包括根据 ID、标签名、类名、选择器、属性和名称等。 通过ID获取元素 使用getElementById方法根据元素的ID属性获取单个元素。 var element = document.getElementB…...
uniapp 返回上一页并刷新
如要刷新的是mine页面 在/pages/mine/improveInfo页面修改信息,点击保存后跳转到个人中心(/pages/mine/index)页面并刷新更新数据 点击保存按钮时执行以下代码: wx.switchTab({url: /pages/mine/index }) // 页面重载 let pages …...
Java阶段五Day21
Java阶段五Day21 文章目录 Java阶段五Day21问题解析rocketmq清空数据 linux学习背景什么是linux系统虚拟机介绍启动 虚拟机linux虚拟机网络的问题 linux系统的基础命令命令提示符命令格式pwd指令ls指令cd指令mkdirtouch指令cp指令rm指令mv指令cat指令tail指令 文本编辑器vim操作…...
2023,谁在引领实时互动进入高清时代?
实践是检验真理的唯一标准,技术是行业进步的核心动能。在实时互动的新时代里,不断进化的声网已然完成自证。 作者|斗斗 出品|产业家 “一个医疗行业的客户,曾向我们提出一个需求,希望在120急救场景下,可以远程看清…...
STM32(HAL)串口中断接收
目录 1、简介 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 1、简介 本文对HAL串口中断函数进行介绍。 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 首先在main.c文件中进行…...
word转pdf怎么转?几种常用方法分享
word转pdf怎么转?在日常工作和学习中,将Word文档转换为PDF格式是一项必要的任务。不仅可以保证文档的格式不变,还可以防止文档被他人篡改。但是,Word文档并不是所有人都能够轻松打开和编辑的,而PDF文件则可以在各种设备…...
自学(黑客)技术,入门到入狱!
1.网络安全是什么 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高; 二、则是发展相对成熟入…...
js 函数、闭包及函数对象
js的函数是对象,可以通过程序来操控。比如,可以把函数赋值给变量,然后再传递给其他函数,也可以在函数上设置属性,甚至调用函数的方法。 js函数可以嵌套定义在其他函数里,内嵌函数可以访问定义在函数作用域…...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...
云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地
借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...
Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...
算法:模拟
1.替换所有的问号 1576. 替换所有的问号 - 力扣(LeetCode) 遍历字符串:通过外层循环逐一检查每个字符。遇到 ? 时处理: 内层循环遍历小写字母(a 到 z)。对每个字母检查是否满足: 与…...
redis和redission的区别
Redis 和 Redisson 是两个密切相关但又本质不同的技术,它们扮演着完全不同的角色: Redis: 内存数据库/数据结构存储 本质: 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能: 提供丰…...
前端开发者常用网站
Can I use网站:一个查询网页技术兼容性的网站 一个查询网页技术兼容性的网站Can I use:Can I use... Support tables for HTML5, CSS3, etc (查询浏览器对HTML5的支持情况) 权威网站:MDN JavaScript权威网站:JavaScript | MDN...
数据库——redis
一、Redis 介绍 1. 概述 Redis(Remote Dictionary Server)是一个开源的、高性能的内存键值数据库系统,具有以下核心特点: 内存存储架构:数据主要存储在内存中,提供微秒级的读写响应 多数据结构支持&…...
数据分析六部曲?
引言 上一章我们说到了数据分析六部曲,何谓六部曲呢? 其实啊,数据分析没那么难,只要掌握了下面这六个步骤,也就是数据分析六部曲,就算你是个啥都不懂的小白,也能慢慢上手做数据分析啦。 第一…...
