大语言模型控制生成的过程Trick:自定义LogitsProcessor实践
前言
在大模型的生成过程中,部分原生的大语言模型未经过特殊的对齐训练,往往会“胡说八道”的生成一些敏感词语等用户不想生成的词语,最简单粗暴的方式就是在大模型生成的文本之后,添加敏感词库等规则手段进行敏感词过滤,但是在生成过程中,生成敏感词仍然耗费了时间和算力成本。
本文以chatglm2-6B为例,通过自定义LogitsProcessor,实践大模型在生成过程中控制一些词语的生成。
LogitsProcessor
从下面代码可以看到,LogitsProcessor的作用就是在生成过程中修改score,改变模型输出的概率分布的工具。
class LogitsProcessor:"""Abstract base class for all logit processors that can be applied during generation."""@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:raise NotImplementedError(f"{self.__class__} is an abstract class. Only classes inheriting this class can be called.")class LogitsProcessorList(list):"""This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each[`LogitsProcessor`] or [`LogitsWarper`] to the inputs."""def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:r"""Args:input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):Prediction scores of a language modeling head. These can be logits for each vocabulary when not usingbeam search or log softmax for each vocabulary token when using beam searchkwargs (`Dict[str, Any]`, *optional*):Additional kwargs that are specific to a logits processor.Return:`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:The processed prediction scores."""for processor in self:function_args = inspect.signature(processor.__call__).parametersif len(function_args) > 2:if not all(arg in kwargs for arg in list(function_args.keys())[2:]):raise ValueError(f"Make sure that all the required parameters: {list(function_args.keys())} for "f"{processor.__class__} are passed to the logits processor.")scores = processor(input_ids, scores, **kwargs)else:scores = processor(input_ids, scores)return scores
自定义LogitsProcessor实践
回到正题,如何自定义LogitsProcessor控制大模型生成的过程呢?下面直接上实践代码:
class new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scores
forbid_token_id_list是不让模型生成词语的id映射列表,对于这些抑制生成的词语,在自定义logits_processor时将其概率推向负无穷大即可。
chatglm2-6B详细实践代码:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from typing import List
import torchclass new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scoresmodel_path = "THUDM/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True).to('mps')def add_forbid_words():'''添加需要抑制的词语,这里简单添加了数字和几个词语进行对比:return:list'''forbid_words = []for i in range(10):forbid_words.append(tokenizer.convert_tokens_to_ids(str(i)))forbid_words.append(tokenizer.convert_tokens_to_ids("首先"))forbid_words.append(tokenizer.convert_tokens_to_ids("积极"))forbid_words.append(tokenizer.convert_tokens_to_ids("回答"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇敢"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇气"))return forbid_wordslogits_processor = LogitsProcessorList()
logits_processor.append(new_logits_processor(add_forbid_words()))streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)input = "列举出10个积极的词语:"outputs = model.generate(tokenizer(input, return_tensors='pt').input_ids.to("mps"),max_new_tokens=1024,logits_processor=logits_processor, # 不开启注释即可streamer=streamer
)
decode_text = tokenizer.batch_decode(outputs, streamer=streamer)[0]
print(decode_text)
抑制前输出:
1. 勇敢
2. 快乐
3. 成功
4. 努力
5. 积极
6. 乐观
7. 自信
8. 开朗
9. 团结
10. 奋斗
抑制后输出:
- 积极主动
- 乐观向上
- 自信
- 自律
- 诚实守信
- 乐于助人
- 勇于尝试
- 坚韧不拔
- 乐观开朗
- 团结一心
小结
本文通过自定义LogitsProcessor,简单的实践了大语言模型在生成过程中屏蔽生成用户自定义词语的trick。在现实场景中,根据特定场景探索如何灵活的利用LogitsProcessor进行有针对性的控制生成模型的生成过程非常重要。
参考文献
【1】https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/generation/logits_process.py
相关文章:

大语言模型控制生成的过程Trick:自定义LogitsProcessor实践
前言 在大模型的生成过程中,部分原生的大语言模型未经过特殊的对齐训练,往往会“胡说八道”的生成一些敏感词语等用户不想生成的词语,最简单粗暴的方式就是在大模型生成的文本之后,添加敏感词库等规则手段进行敏感词过滤…...

Docker容器:docker的资源控制及docker数据管理
文章目录 一.docker的资源控制1.CPU 资源控制1.1 资源控制工具1.2 cgroups有四大功能1.3 设置CPU使用率上限1.4 进行CPU压力测试1.5 设置50%的比例分配CPU使用时间上限1.6 设置CPU资源占用比(设置多个容器时才有效)1.6.1 两个容器测试cpu1.6.2 设置容器绑…...

从零开始打造家装预约咨询小程序
在如今互联网高度发达的时代,家装行业也逐渐意识到了线上渠道的重要性。为了更好地服务客户,提高用户体验,越来越多的家装公司开始寻找合适的小程序制作平台。本文将向大家介绍如何使用第三方制作平台,如乔拓云网,打造…...

es线上处理命令记录
常用命令 搜索 GET _search {"query": {"match_all": {}} }获取全部模版 GET _index_template GET _index_template/yst_crawler_template获取全部索引 GET /_cat/indices?v 获取当前mapping GET /yst_crawler/_mapping创建一个mapping PUT /yst_c…...

mysql 在nodejs中的简单使用(增删改查)
一 、封装SQL查询请求链接 const mysql require(mysql) //创建开发工具数据库链接池 const pool mysql.createPool({host: 192.168.1.133,user: user_name, password: 123456,database: database_name,port: 3306,connectionLimit: 50 // 限制连接数 });// sql:查…...

1.MySQL数据库的基本操作
数据库操作过程: 1.用户在客户端输入 SQL 2.客户端会把 SQL 通过网络发送给服务器 3.服务器执行这个 SQL,把结果返回给客户端 4.客户端收到结果,显示到界面上 数据库的操作 这里的数据库不是代表一个软件,而是代表一个数据集合。 显示当前的数据库 …...

Zabbix-6.4.4 邮箱告警SMS告警配置
目录 ------------------------- # 邮箱告警 ---------------------------------- 1.安装mailx与postfix软件包 2.修改mailx配置文件 3. 创建文件夹 4. 编写mail-send.sh脚本 5. 将该脚本赋予执行权限 6. 进入web界面进行设置—> Alerts —> Media Types 7. 添…...

网络安全 Day30-运维安全项目-容器架构上
容器架构上 1. 什么是容器2. 容器 vs 虚拟机(化) :star::star:3. Docker极速上手指南1)使用rpm包安装docker2) docker下载镜像加速的配置3) 载入镜像大礼包(老师资料包中有) 4. Docker使用案例1) 案例01::star::star::…...

深入理解设计模式-创建型之单例模式
为什么要使用单例 1、表示全局唯一 如果有些数据在系统中应该且只能保存一份,那就应该设计为单例类。 如:配置类:在系统中,我们只有一个配置文件,当配置文件被加载到内存之后,应该被映射为一个唯一的【配…...

Vue中路由缓存问题及解决方法
一.问题 Vue Router 允许你在你的应用中创建多个视图,并根据路由来动态切换这些视图。默认情况下,当你从一个路由切换到另一个路由时,Vue Router 会销毁前一个路由的组件实例并创建新的组件实例。然而,有时候你可能希望保持一些页…...

Linux与bash(基础内容一)
一、常见的linux命令: 1、文件: (1)常见的文件命令: (2)文件属性: (3)修改文件属性: 查看文件的属性: ls -l 查看文件的属性 ls …...

NVIDIA Omniverse与GPT-4结合生成3D内容
全球各行业对 3D 世界和虚拟环境的需求呈指数级增长。3D 工作流程是工业数字化的核心,开发实时模拟来测试和验证自动驾驶车辆和机器人,操作数字孪生来优化工业制造,并为科学发现铺平新的道路。 如今,3D 设计和世界构建仍然是高度…...

Windows Server --- RDP远程桌面服务器激活和RD授权
RDP远程桌面服务器激活和RD授权 一、激活服务器二、设置RD授权 系统:Window server 2008 R2 服务:远程桌面服务 注:该方法适合该远程桌面服务器没网络状态下(离线),激活服务器。 一、激活服务器 1.打开远…...

关于游戏盾
游戏盾(Game Shield)是一种针对游戏行业特点的网络安全解决方案,主要针对游戏平台面临的各种网络攻击和安全威胁。以下是一些原因,说明为什么游戏平台需要加游戏盾: 1. DDoS攻击:游戏平台通常容易受到分布式…...

回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测
回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本…...

《cpolar内网穿透》外网SSH远程连接linux(CentOS)服务器
本次教程我们来实现如何在外公网环境下,SSH远程连接家里/公司的Linux CentOS服务器,无需公网IP,也不需要设置路由器。 视频教程 [video(video-jrpesBrv-1680147672481)(type-csdn)(url-CSDN直播https://live-file.csdnimg.cn/release/live/…...

IDEA启动报错【java.sql.SQLSyntaxErrorException: ORA-00904: “P“.“PRJ_NO“: 标识符无效】
IDEA报错如下: 2023-08-17 11:26:15.535 ERROR [egrant-biz,b48324d82fe23753,b48324d82fe23753,true] 24108 --- [ XNIO-1 task-1] c.i.c.l.c.RestExceptionController : 服务器异常org.springframework.jdbc.BadSqlGrammarException: ### Error queryin…...

Nginx详解
1、高并发时代 单台tomcat在理想情况下可支持的最大并发数量在200~500之间,如果大于这个数量可能会造成响应缓慢甚至宕机。 解决方案是通过多台服务器分摊并发压力,这不仅需要有多台tomcat服务器,还需要一台服务器专门用来分配请求。这既是…...

摸清一下mysql授权语句的实际执行关系
样例 ---------------------------------------------------------------------- grant all PRIVILEGES on db1.* to test% identified by test1; grant all PRIVILEGES on db2.* to test% identified by test2; grant all PRIVILEGES on db3.* to test127.0.0.1 identified …...

sCrypt于8月12日在上海亮相BSV数字未来论坛
2023年8月12日,由上海可一澈科技有限公司(以下简称“可一科技”)、 临港国际科创研究院发起,携手美国sCrypt公司、福州博泉网络科技有限公司、复旦大学区块链协会,举办的BSV数字未来论坛在中国上海成功落下帷幕。 本次…...

Hbase的列式存储到底是什么意思?一篇文章让你彻底明白
一、 HBase 定义 Apache HBase™ 是以 hdfs 为数据存储的,一种分布式、可扩展的 NoSQL 数据库。 二、 HBase 数据模型 HBase 的设计理念依据 Google 的 BigTable 论文,论文中对于数据模型的首句介绍。 Bigtable 是一个稀疏的、分布式的、持久的多维排…...

机器学习|Softmax 回归的数学理解及代码解析
机器学习|Softmax 回归的数学理解及代码解析 Softmax 回归是一种常用的多类别分类算法,适用于将输入向量映射到多个类别的概率分布。在本文中,我们将深入探讨 Softmax 回归的数学原理,并提供 Python 示例代码帮助读者更好地理解和…...

EmbedPress Pro 在WordPress网站中嵌入任何内容
EmbedPress Pro可让您通过高级自定义、自定义品牌、延迟加载和更多惊人功能嵌入源。为古腾堡块和Elementor编辑器提供支持的一体化 WordPress 嵌入解决方案。使用 EmbedPress 在古腾堡创建交互式内容。使用 EmbedPress 的古腾堡块立即将任何内容嵌入到您的网站。 网址: EmbedP…...

【C++学习手札】一文带你初识C++继承
食用指南:本文在有C基础的情况下食用更佳 🍀本文前置知识: C类 ♈️今日夜电波:napori—Vaundy 1:21 ━━━━━━️💟──────── 3:23 …...

【ubuntu18.04】01-network-manager-all.yaml和interfaces和resolv.conf各有什么区别和联系
文章目录 01-network-manager-all.yaml、interfaces 和 resolv.conf 是与网络配置相关的文件,它们在网络设置中有着不同的作用和使用方式。 01-network-manager-all.yaml: 这是一个配置文件,通常在 Ubuntu 系统上使用 NetworkManager 进行网络管理时使用…...

24近3年内蒙古大学自动化考研院校分析
今天给大家带来的是内蒙古大学控制考研分析 满满干货~还不快快点赞收藏 一、内蒙古大学 学校简介 内蒙古大学位于内蒙古自治区首府、历史文化名城呼和浩特市,距北京400余公里,是中华人民共和国成立后党和国家在民族地区创办的第一所综合大…...

大语言模型(LLM)与 Jupyter 连接起来了
现在,大语言模型(LLM)与 Jupyter 连接起来了! 这主要归功于一个名叫 Jupyter AI 的项目,它是官方支持的 Project Jupyter 子项目。目前该项目已经完全开源,其连接的模型主要来自 AI21、Anthropic、AWS、Co…...

ChatGLM2-6B在Windows下的微调
ChatGLM2-6B在Windows下的微调 零、重要参考资料 1、ChatGLM2-6B! 我跑通啦!本地部署微调(windows系统):这是最关键的一篇文章,提供了Windows下的脚本 2、LangChain ChatGLM2-6B 搭建个人专属知识库:提供…...

聊聊火车的发展
目录 1.火车的概念 2.火车的发展历史 3.火车对战争的影响 4.火车对人们出行造成的影响 1.火车的概念 火车是一种由机械动力驱动的陆上交通工具,通常用来运输人员和货物。它由一列或多列的连接在一起的车厢组成,有轨道作为其行驶的基础,并通…...

IDEA使用@Autowired为什么会警告?
在使用IDEA编写Spring相关的项目时,当在字段上使用Autowired注解时,总会出现一个波浪线提示:”Field injection is not recommended.” 这让我不禁疑惑:我每天都在使用这种方式,为何不被推荐呢?今天&#x…...