当前位置: 首页 > news >正文

以Llama-2为例,在生成模型中使用自定义LogitsProcessor

以Llama-2为例,在生成模型中使用自定义LogitsProcessor

  • 1. 前言
  • 2. 场景介绍
  • 3. 解决方法
  • 4. 结语

1. 前言

在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria来控制生成过程的结束,本文将继续这一话题,结合具体的场景,介绍如何实现自定义的logits processor,并以此来控制生成的过程。

2. 场景介绍

场景延续上篇介绍stopping criteria的文章,假如我们希望使用Llama-2模型,来生成一篇新闻的概要,希望它能够生成一句简短的话,来描述这篇新闻中主要发生了什么。

在上一篇文章中,我们成功的使用stopping criteria解决了模型废话太多的问题,然而,在某些情况下,模型输出的结果并不是我们想要的,它没有用一句话概括,反而是一条一条列举了其中的主要信息,类似:

1. ......
2. ......
3. ......

针对这种情况,我们可以强制要求生成的第一个token,不可以是数字,这样的话,就只能从字母中选择合适的单词生成,也就达到我们的目的了。为了实现这一策略,就需要用到logits processor。

3. 解决方法

logits processor是在生成的过程中,每一个step的score计算完成之后,对score进行进一步的加工,改变模型输出的概率分布,从而影响后续生成的结果。

transformers模块中提供了若干内置的processor可以直接调用,具体的整理和简介可以参考之前的文章以beam search为例,详解transformers中generate方法(上)。

现在我们需要设计这样一个processor,判断如果是第一个生成的第一个token,则禁止它生成数字,也就是把所有数字对应的得分强制设置为负无穷。

首先,引入需要用到的类,与stopping criteria类似的,也是有要给基础类,和一个容器类:

from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList

然后继承基础类,实现我们所需的processor:

class SuppressSpecificBOSTokenLogitsProcessor(LogitsProcessor):"""防止生成的第一个token是某些特定的token---------------ver: 2023-08-02by: changhongyu"""def __init__(self, bad_bos_token_id_list: List[int] = None):""":param bad_bos_token_id_list: 不可以作为第一个token的token的id列表"""self.bad_bos_token_id_list = bad_bos_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:new_token_len = input_ids.shape[-1] - current_token_lenif new_token_len == 0:for id_ in self.bad_bos_token_id_list:scores[:, id_] = -float('inf')return scores

logits processor的使用方法与stopping criteria是一样的,我们设计好自己的processor类之后,实例化一个容器,再将实例化的processor放到这个容器中就好了:

NUMBER_ID_LIST = []
for i in range(10):NUMBER_ID_LIST.append(tokenizer.convert_tokens_to_ids(str(i)))
logits_processor = LogitsProcessorList()
logits_processor.append(SuppressSpecificBOSTokenLogitsProcessor(NUMBER_ID_LIST))

如果有多个processor的话,可能需要注意一下放入容器的顺序。

最后在生成的时候,将它作为参数传给generate方法就好了。

例如,原本生成的代码是:

outputs = model.generate(**inputs)

使用processor的话,可以写作:

outputs = model.generate(logits_processor=logits_processor, **inputs)

注意在实现的时候有一个小细节,由于是对话模型,输入的除了当前的query之外,还包括历史的对话记录,二者拼接在一起才是完整的prompt(prompt构建参考这一篇),所以我们并不能仅仅根据当前输入input_ids的长度,来判断当前step是不是这一轮生成的第一个token,这就是为什么上面的代码中有一个为声明定义的变量current_token_len

对于这个current_token_len,只需要在model.generate执行之前,对他global一下就可以了。

例如像这个样子,每次生成之前先计算一下截至生成之前的长度:

global current_token_len
current_token_len = inputs['input_ids'].shape[1]outputs = model.generate(logits_processor=logits_processor, **inputs)

4. 结语

作为用户控制生成过程的主要手段,如何巧妙地利用好logits processor对使用生成式模型来说非常重要。在实际情况中,需要针对场景,发现其中地规律,然后又针对性地去设计一个processor。它主要解决的问题,是一些有规律可循的场景,从一定意义上理解,可以认为是对生成模型的解空间进行了限制和变换。在解决问题的风格上给人的感觉,有点像抽取式模型所做的风格了,比如对于一个关键词生成任务,如果我们不希望模型生成文章中没有出现过的token,那完全可以利用本文中类似的方法,把生成结果限定为文中出现过的token。

以上就是本文的全部内容,如果对你有所帮助或启发,记得留下一个免费的赞,我们下期再见。

相关文章:

以Llama-2为例,在生成模型中使用自定义LogitsProcessor

以Llama-2为例,在生成模型中使用自定义LogitsProcessor 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言 在上一篇文章 以Llama-2为例,在生成模型中使用自定义StoppingCriteria中,介绍了怎样在生成的过程中,使用stopping criteria…...

python 计算图片hash 缓存图片为key

python,有时希望缓存图片作为key,怎么办?缓存整张突破占用内存太多,不妨缓存hash值: Fast way to Hash Numpy objects for Caching import hashlib import numpy a numpy.random.rand(10, 100) b a.view(numpy.uin…...

制造型企业如何实现车间设备生产数据的实时采集?需要5G网络吗?

引言 在制造业数字化转型的浪潮下,实时采集车间设备生产数据变得尤为重要。工业边缘网关HiWoo Box作为一款专为工业应用而设计的智能设备,具备工业级设计和多种联网方式,为制造型企业提供了高性能的车间设备数据实时采集解决方案。本文将重点…...

第2章 HTML中的JavaScript

引言 将JavaScript引入网页,首先要解决它与网页的主导语言HTML的关系问题 script元素 将JavaScript插入HTML的主要方法是使用script元素,script有8个可选属性 async:表示异步加载js文件内容,他们之间的顺序不一定按照html顺序ch…...

景联文科技高质量成品数据集上新啦!

景联文科技近期上新多个成品数据集,包含图像、视频等多种类型的数据,涵盖丰富的场景,可满足不同模型的多元化需求。 高质量成品数据集可用于训练和优化模型,使得模型能够更加全面和精准地理解和处理任务,更好地应对复…...

flask------请求拓展

flask中也有类似与django中的中间件,只不过是另一种写法,但是他们的作用是一样的,下面我们就一一介绍: 1.before_request 作用 : before_request 相当于 django 中的 process_request,每一个请求在被处理前都会经…...

大数据-玩转数据-FLINK-从kafka消费数据

一、基于前面kafka部署 大数据-玩转数据-Kafka安装 二、FLINK中编写代码 package com.lyh.flink04;import org.apache.flink.api.common.serialization.SimpleStringSchema; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apa…...

介绍Sping Boot的5个扩展点

1、初始化器ApplicationContextInitializer 我们在启动Spring Boot项目的时候,是执行这样一个方法来启动的 我们一层一层往下点,最终发现执行的是这个方法 所以我们在启动项目的时候也可以这样启动 new SpringApplication(SpringbootExtensionPointAp…...

Linux2.6内核配置说明

maturity level options代码成熟度选项 Prompt for development and/or incomplete code/drivers 显示尚在开发中或尚未完成的代码与驱动.除非你是测试人员或者开发者,否则请勿选择 setup常规设置 Local version - append to kernel release 在内核版本后面加上自定义的…...

Pytest简介及jenkins集成

一、pytest介绍 pytest介绍 - unittest\nose pytest:基于unittest之上的单元测试框架 自动发现测试模块和测试方法 断言使用assert表达式即可 可以设置测试会话级、模块级、类级、函数级的fixtures 数据准备 清理工作 unittest:setUp、teardown、…...

【LeetCode】105. 从前序与中序遍历序列构造二叉树 106. 从中序与后序遍历序列构造二叉树

105. 从前序与中序遍历序列构造二叉树 这道题也是经典的数据结构题了,有时候面试题也会遇到,已知前序与中序的遍历序列,由前序遍历我们可以知道第一个元素就是根节点,而中序遍历的特点就是根节点的左边全部为左子树,右…...

堆内存和一些检测工具

17 堆定义 通过new关键字创建,创建对象都会使用堆内存。 是线程共享的,需要考虑线程安全问题。 有垃圾回收机制。18 堆-内存溢出 当默认情况下,发现执行到26,出现内存溢出。 当我们将堆内存调为8m,继续执行&#xff…...

【JavaScript】元素获取指南

简介 在 JavaScript 中,我们经常需要通过获取元素来进行 DOM 操作和交互。本教程将介绍多种获取元素的方式,包括根据 ID、标签名、类名、选择器、属性和名称等。 通过ID获取元素 使用getElementById方法根据元素的ID属性获取单个元素。 var element = document.getElementB…...

uniapp 返回上一页并刷新

如要刷新的是mine页面 在/pages/mine/improveInfo页面修改信息,点击保存后跳转到个人中心(/pages/mine/index)页面并刷新更新数据 点击保存按钮时执行以下代码: wx.switchTab({url: /pages/mine/index }) // 页面重载 let pages …...

Java阶段五Day21

Java阶段五Day21 文章目录 Java阶段五Day21问题解析rocketmq清空数据 linux学习背景什么是linux系统虚拟机介绍启动 虚拟机linux虚拟机网络的问题 linux系统的基础命令命令提示符命令格式pwd指令ls指令cd指令mkdirtouch指令cp指令rm指令mv指令cat指令tail指令 文本编辑器vim操作…...

2023,谁在引领实时互动进入高清时代?

实践是检验真理的唯一标准,技术是行业进步的核心动能。在实时互动的新时代里,不断进化的声网已然完成自证。 作者|斗斗 出品|产业家 “一个医疗行业的客户,曾向我们提出一个需求,希望在120急救场景下,可以远程看清…...

STM32(HAL)串口中断接收

目录 1、简介 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 1、简介 本文对HAL串口中断函数进行介绍。 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 首先在main.c文件中进行…...

word转pdf怎么转?几种常用方法分享

word转pdf怎么转?在日常工作和学习中,将Word文档转换为PDF格式是一项必要的任务。不仅可以保证文档的格式不变,还可以防止文档被他人篡改。但是,Word文档并不是所有人都能够轻松打开和编辑的,而PDF文件则可以在各种设备…...

自学(黑客)技术,入门到入狱!

1.网络安全是什么 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高; 二、则是发展相对成熟入…...

js 函数、闭包及函数对象

js的函数是对象,可以通过程序来操控。比如,可以把函数赋值给变量,然后再传递给其他函数,也可以在函数上设置属性,甚至调用函数的方法。 js函数可以嵌套定义在其他函数里,内嵌函数可以访问定义在函数作用域…...

OpenLayers 可视化之热力图

注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

day52 ResNet18 CBAM

在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...

JDK 17 新特性

#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的&#xff…...

【Oracle】分区表

个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...