当前位置: 首页 > news >正文

大语言模型控制生成的过程Trick:自定义LogitsProcessor实践

前言

在大模型的生成过程中,部分原生的大语言模型未经过特殊的对齐训练,往往会“胡说八道”的生成一些敏感词语等用户不想生成的词语,最简单粗暴的方式就是在大模型生成的文本之后,添加敏感词库等规则手段进行敏感词过滤,但是在生成过程中,生成敏感词仍然耗费了时间和算力成本。

本文以chatglm2-6B为例,通过自定义LogitsProcessor,实践大模型在生成过程中控制一些词语的生成。

LogitsProcessor

从下面代码可以看到,LogitsProcessor的作用就是在生成过程中修改score,改变模型输出的概率分布的工具。

class LogitsProcessor:"""Abstract base class for all logit processors that can be applied during generation."""@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:raise NotImplementedError(f"{self.__class__} is an abstract class. Only classes inheriting this class can be called.")class LogitsProcessorList(list):"""This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each[`LogitsProcessor`] or [`LogitsWarper`] to the inputs."""def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:r"""Args:input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):Prediction scores of a language modeling head. These can be logits for each vocabulary when not usingbeam search or log softmax for each vocabulary token when using beam searchkwargs (`Dict[str, Any]`, *optional*):Additional kwargs that are specific to a logits processor.Return:`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:The processed prediction scores."""for processor in self:function_args = inspect.signature(processor.__call__).parametersif len(function_args) > 2:if not all(arg in kwargs for arg in list(function_args.keys())[2:]):raise ValueError(f"Make sure that all the required parameters: {list(function_args.keys())} for "f"{processor.__class__} are passed to the logits processor.")scores = processor(input_ids, scores, **kwargs)else:scores = processor(input_ids, scores)return scores

自定义LogitsProcessor实践

回到正题,如何自定义LogitsProcessor控制大模型生成的过程呢?下面直接上实践代码:

class new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scores

forbid_token_id_list是不让模型生成词语的id映射列表,对于这些抑制生成的词语,在自定义logits_processor时将其概率推向负无穷大即可。

chatglm2-6B详细实践代码:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from typing import List
import torchclass new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scoresmodel_path = "THUDM/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True).to('mps')def add_forbid_words():'''添加需要抑制的词语,这里简单添加了数字和几个词语进行对比:return:list'''forbid_words = []for i in range(10):forbid_words.append(tokenizer.convert_tokens_to_ids(str(i)))forbid_words.append(tokenizer.convert_tokens_to_ids("首先"))forbid_words.append(tokenizer.convert_tokens_to_ids("积极"))forbid_words.append(tokenizer.convert_tokens_to_ids("回答"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇敢"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇气"))return forbid_wordslogits_processor = LogitsProcessorList()
logits_processor.append(new_logits_processor(add_forbid_words()))streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)input = "列举出10个积极的词语:"outputs = model.generate(tokenizer(input, return_tensors='pt').input_ids.to("mps"),max_new_tokens=1024,logits_processor=logits_processor,  # 不开启注释即可streamer=streamer
)
decode_text = tokenizer.batch_decode(outputs, streamer=streamer)[0]
print(decode_text)

抑制前输出:

1. 勇敢
2. 快乐
3. 成功
4. 努力
5. 积极
6. 乐观
7. 自信
8. 开朗
9. 团结
10. 奋斗

抑制后输出:

- 积极主动
- 乐观向上
- 自信
- 自律
- 诚实守信
- 乐于助人
- 勇于尝试
- 坚韧不拔
- 乐观开朗
- 团结一心

小结

本文通过自定义LogitsProcessor,简单的实践了大语言模型在生成过程中屏蔽生成用户自定义词语的trick。在现实场景中,根据特定场景探索如何灵活的利用LogitsProcessor进行有针对性的控制生成模型的生成过程非常重要。

参考文献

【1】https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/generation/logits_process.py

相关文章:

大语言模型控制生成的过程Trick:自定义LogitsProcessor实践

前言 在大模型的生成过程中,部分原生的大语言模型未经过特殊的对齐训练,往往会“胡说八道”的生成一些敏感词语等用户不想生成的词语,最简单粗暴的方式就是在大模型生成的文本之后,添加敏感词库等规则手段进行敏感词过滤&#xf…...

Docker容器:docker的资源控制及docker数据管理

文章目录 一.docker的资源控制1.CPU 资源控制1.1 资源控制工具1.2 cgroups有四大功能1.3 设置CPU使用率上限1.4 进行CPU压力测试1.5 设置50%的比例分配CPU使用时间上限1.6 设置CPU资源占用比(设置多个容器时才有效)1.6.1 两个容器测试cpu1.6.2 设置容器绑…...

从零开始打造家装预约咨询小程序

在如今互联网高度发达的时代,家装行业也逐渐意识到了线上渠道的重要性。为了更好地服务客户,提高用户体验,越来越多的家装公司开始寻找合适的小程序制作平台。本文将向大家介绍如何使用第三方制作平台,如乔拓云网,打造…...

es线上处理命令记录

常用命令 搜索 GET _search {"query": {"match_all": {}} }获取全部模版 GET _index_template GET _index_template/yst_crawler_template获取全部索引 GET /_cat/indices?v 获取当前mapping GET /yst_crawler/_mapping创建一个mapping PUT /yst_c…...

mysql 在nodejs中的简单使用(增删改查)

一 、封装SQL查询请求链接 const mysql require(mysql) //创建开发工具数据库链接池 const pool mysql.createPool({host: 192.168.1.133,user: user_name, password: 123456,database: database_name,port: 3306,connectionLimit: 50 // 限制连接数 });// sql:查…...

1.MySQL数据库的基本操作

数据库操作过程: 1.用户在客户端输入 SQL 2.客户端会把 SQL 通过网络发送给服务器 3.服务器执行这个 SQL,把结果返回给客户端 4.客户端收到结果,显示到界面上 数据库的操作 这里的数据库不是代表一个软件,而是代表一个数据集合。 显示当前的数据库 …...

Zabbix-6.4.4 邮箱告警SMS告警配置

目录 ​------------------------- # 邮箱告警 ---------------------------------- 1.安装mailx与postfix软件包 2.修改mailx配置文件 3. 创建文件夹 4. 编写mail-send.sh脚本 5. 将该脚本赋予执行权限 6. 进入web界面进行设置—> Alerts —> Media Types 7. 添…...

网络安全 Day30-运维安全项目-容器架构上

容器架构上 1. 什么是容器2. 容器 vs 虚拟机(化) :star::star:3. Docker极速上手指南1)使用rpm包安装docker2) docker下载镜像加速的配置3) 载入镜像大礼包(老师资料包中有) 4. Docker使用案例1) 案例01::star::star::…...

深入理解设计模式-创建型之单例模式

为什么要使用单例 1、表示全局唯一 如果有些数据在系统中应该且只能保存一份,那就应该设计为单例类。 如:配置类:在系统中,我们只有一个配置文件,当配置文件被加载到内存之后,应该被映射为一个唯一的【配…...

Vue中路由缓存问题及解决方法

一.问题 Vue Router 允许你在你的应用中创建多个视图,并根据路由来动态切换这些视图。默认情况下,当你从一个路由切换到另一个路由时,Vue Router 会销毁前一个路由的组件实例并创建新的组件实例。然而,有时候你可能希望保持一些页…...

Linux与bash(基础内容一)

一、常见的linux命令: 1、文件: (1)常见的文件命令: (2)文件属性: (3)修改文件属性: 查看文件的属性: ls -l 查看文件的属性 ls …...

NVIDIA Omniverse与GPT-4结合生成3D内容

全球各行业对 3D 世界和虚拟环境的需求呈指数级增长。3D 工作流程是工业数字化的核心,开发实时模拟来测试和验证自动驾驶车辆和机器人,操作数字孪生来优化工业制造,并为科学发现铺平新的道路。 如今,3D 设计和世界构建仍然是高度…...

Windows Server --- RDP远程桌面服务器激活和RD授权

RDP远程桌面服务器激活和RD授权 一、激活服务器二、设置RD授权 系统:Window server 2008 R2 服务:远程桌面服务 注:该方法适合该远程桌面服务器没网络状态下(离线),激活服务器。 一、激活服务器 1.打开远…...

关于游戏盾

游戏盾(Game Shield)是一种针对游戏行业特点的网络安全解决方案,主要针对游戏平台面临的各种网络攻击和安全威胁。以下是一些原因,说明为什么游戏平台需要加游戏盾: 1. DDoS攻击:游戏平台通常容易受到分布式…...

回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测

回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本…...

《cpolar内网穿透》外网SSH远程连接linux(CentOS)服务器

本次教程我们来实现如何在外公网环境下,SSH远程连接家里/公司的Linux CentOS服务器,无需公网IP,也不需要设置路由器。 视频教程 [video(video-jrpesBrv-1680147672481)(type-csdn)(url-CSDN直播https://live-file.csdnimg.cn/release/live/…...

IDEA启动报错【java.sql.SQLSyntaxErrorException: ORA-00904: “P“.“PRJ_NO“: 标识符无效】

IDEA报错如下: 2023-08-17 11:26:15.535 ERROR [egrant-biz,b48324d82fe23753,b48324d82fe23753,true] 24108 --- [ XNIO-1 task-1] c.i.c.l.c.RestExceptionController : 服务器异常org.springframework.jdbc.BadSqlGrammarException: ### Error queryin…...

Nginx详解

1、高并发时代 单台tomcat在理想情况下可支持的最大并发数量在200~500之间,如果大于这个数量可能会造成响应缓慢甚至宕机。 解决方案是通过多台服务器分摊并发压力,这不仅需要有多台tomcat服务器,还需要一台服务器专门用来分配请求。这既是…...

摸清一下mysql授权语句的实际执行关系

样例 ---------------------------------------------------------------------- grant all PRIVILEGES on db1.* to test% identified by test1; grant all PRIVILEGES on db2.* to test% identified by test2; grant all PRIVILEGES on db3.* to test127.0.0.1 identified …...

sCrypt于8月12日在上海亮相BSV数字未来论坛

2023年8月12日,由上海可一澈科技有限公司(以下简称“可一科技”)、 临港国际科创研究院发起,携手美国sCrypt公司、福州博泉网络科技有限公司、复旦大学区块链协会,举办的BSV数字未来论坛在中国上海成功落下帷幕。 本次…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频

使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

【JavaEE】-- HTTP

1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)

宇树机器人多姿态起立控制强化学习框架论文解析 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架&#xff08;一&#xff09; 论文解读&#xff1a;交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...

相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计&#xff0c;聪明的码友立马就知道了&#xff0c;该到数据访问模块了&#xff0c;要不就这俩玩个6啊&#xff0c;查库势在必行&#xff0c;至此&#xff0c;它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据&#xff08;数据库、No…...

ABAP设计模式之---“简单设计原则(Simple Design)”

“Simple Design”&#xff08;简单设计&#xff09;是软件开发中的一个重要理念&#xff0c;倡导以最简单的方式实现软件功能&#xff0c;以确保代码清晰易懂、易维护&#xff0c;并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计&#xff0c;遵循“让事情保…...

JavaScript基础-API 和 Web API

在学习JavaScript的过程中&#xff0c;理解API&#xff08;应用程序接口&#xff09;和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能&#xff0c;使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...