根据deepseek模型微调训练自动驾驶模型及数据集的思路
以下是使用DeepSeek模型微调训练自动驾驶模型的详细步骤和代码示例。本流程假设你已有自动驾驶领域的数据集(如驾驶指令、传感器数据等),并基于PyTorch框架实现。
Step 1: 环境准备
# 安装依赖库
pip install torch transformers datasets numpy pandas
Step 2: 数据准备
假设数据集格式为JSON,包含输入文本(传感器/场景描述)和输出控制指令:
// data/train.json
[{"input": "前方10米有行人,当前车速30km/h,车道居中","output": "减速至20km/h,保持车道"},// 更多样本...
]
构建数据集加载器:
from datasets import load_dataset
from transformers import AutoTokenizer# 加载数据集
dataset = load_dataset('json', data_files={'train': 'data/train.json', 'val': 'data/val.json'})# 初始化分词器
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-base-1.3B")
tokenizer.pad_token = tokenizer.eos_token # 设置填充token# 数据预处理函数
def preprocess_function(examples):inputs = [f"自动驾驶指令生成: {text}" for text in examples["input"]]model_inputs = tokenizer(inputs,max_length=512,truncation=True,padding="max_length")# 处理标签labels = tokenizer(examples["output"],max_length=128,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs# 应用预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)
Step 3: 模型加载与适配
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer# 加载预训练模型
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-base-1.3B")# 修改模型头部(适配自动驾驶任务)
if model.config.vocab_size != len(tokenizer):model.resize_token_embeddings(len(tokenizer))
Step 4: 训练配置
training_args = TrainingArguments(output_dir="./autopilot_model",evaluation_strategy="epoch",learning_rate=2e-5,per_device_train_batch_size=4,per_device_eval_batch_size=4,num_train_epochs=3,weight_decay=0.01,logging_steps=50,fp16=True, # 启用混合精度训练save_strategy="epoch",report_to="tensorboard"
)# 初始化Trainer
trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["val"],tokenizer=tokenizer,
)
Step 5: 模型微调
# 开始训练
trainer.train()# 保存最终模型
model.save_pretrained("./autopilot_final")
tokenizer.save_pretrained("./autopilot_final")
Step 6: 推理测试
from transformers import pipeline# 创建推理管道
autopilot_pipe = pipeline("text-generation",model="./autopilot_final",tokenizer=tokenizer,device=0 if torch.cuda.is_available() else -1
)# 测试样例
input_text = "自动驾驶指令生成: 前方100米红灯,当前车速50km/h"
generated = autopilot_pipe(input_text,max_length=128,temperature=0.7,num_return_sequences=1
)
print(generated[0]['generated_text'])
# 输出示例: "减速至停车线前,等待绿灯"
Step 7: 部署优化(可选)
- 模型量化:
from transformers import BitsAndBytesConfig
quant_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.float16
)
quant_model = AutoModelForCausalLM.from_pretrained("./autopilot_final", quantization_config=quant_config)
- ONNX导出:
from transformers.convert_graph_to_onnx import convert
convert(framework="pt", model="./autopilot_final", output="autopilot.onnx", opset=12)
关键优化技巧
-
数据增强:
- 添加噪声:模拟传感器误差
- 场景扩展:生成雨天/雾天等特殊场景描述
def add_noise(text, noise_level=0.1):words = text.split()# 随机替换部分词汇return " ".join([w if random.random() > noise_level else "[UNK]" for w in words]) -
多模态融合(如结合视觉特征):
# 示例:融合图像特征
vision_encoder = AutoModel.from_pretrained("google/vit-base-patch16-224")
image_features = vision_encoder(images).last_hidden_state.mean(dim=1)# 将视觉特征注入语言模型
combined_features = torch.cat([text_embeddings, image_features.unsqueeze(1)], dim=1)
安全关键设计
- 冗余校验模块:
def safety_check(predicted_command):# 实现速度限制、碰撞检测等安全逻辑if "急加速" in predicted_command and current_speed > 60:return "建议维持当前车速"return predicted_command
- 实时性监控:
import time
start_time = time.time()
generated = autopilot_pipe(...)
if (time.time() - start_time) > 0.1: # 超过100ms触发警告print("WARNING: 推理延迟过高!")
效果评估指标
from rouge import Rouge
rouge = Rouge()def evaluate(predictions, references):# ROUGE指标scores = rouge.get_scores(predictions, references, avg=True)# 自定义安全评分safety_score = sum([1 if "危险" not in p else 0 for p in predictions])/len(predictions)return {"rouge": scores,"safety": safety_score}
通过以上流程,你可以基于DeepSeek模型构建一个针对自动驾驶场景的指令生成系统。实际应用中需注意:
- 数据质量:确保训练数据覆盖各类道路场景
- 实时性测试:在目标硬件上验证推理速度
- 安全机制:必须加入多层冗余安全检查
- 持续学习:定期用新数据更新模型
建议在实际部署前进行严格仿真测试,可使用CARLA等自动驾驶仿真平台验证模型行为。
以下是一个自动生成自动驾驶训练数据集的Python方案,包含多种典型驾驶场景的模拟数据生成逻辑。该数据集将包含文本指令、传感器数据(模拟)和对应的控制指令。
import json
import random
from faker import Faker
import numpy as np
from tqdm import tqdmfake = Faker('zh_CN')def generate_scenario(scene_type):"""生成基础场景描述"""base_scene = {"weather": random.choice(["晴天", "小雨", "雾天", "夜间"]),"road_type": random.choice(["城市道路", "高速公路", "乡村道路", "隧道"]),"speed": random.randint(20, 120)}# 根据不同场景类型添加特定要素if scene_type == "normal":return {**base_scene,"event": "保持车道行驶","obstacles": []}elif scene_type == "obstacle":return {**base_scene,"event": random.choice(["行人横穿", "车辆加塞", "动物闯入"]),"distance": random.randint(5, 100),"obstacle_speed": random.randint(0, 10) if base_scene["road_type"] != "高速公路" else 0}elif scene_type == "traffic_control":return {**base_scene,"event": random.choice(["红灯", "施工路段", "交警指挥"]),"distance": random.randint(10, 200)}elif scene_type == "emergency":return {**base_scene,"event": random.choice(["爆胎", "刹车失灵", "前方事故"]),"severity": random.choice(["轻度", "中度", "重度"])}def generate_sensor_data(scenario):"""生成模拟传感器数据"""sensor = {"camera": {"front": np.random.rand(256, 256, 3).tolist(), # 模拟图像数据"left": np.random.rand(128, 128, 3).tolist(),"right": np.random.rand(128, 128, 3).tolist()},"lidar": {"points": np.random.randn(1000, 3).tolist() # 1000个三维点云},"radar": {"frontal_objects": [{"distance": random.uniform(5.0, 150.0),"speed": random.uniform(-10.0, 30.0),"angle": random.uniform(-30.0, 30.0)} for _ in range(random.randint(0, 3))]}}# 根据场景调整传感器数据if scenario["event"] == "行人横穿":sensor["radar"]["frontal_objects"].append({"distance": scenario["distance"],"speed": 1.5, # 行人步行速度"angle": random.uniform(-15, 15)})return sensordef generate_control_command(scenario):"""生成控制指令"""base_speed = scenario["speed"]cmd_template = {"normal": "保持当前车速{}km/h,车道居中","obstacle": {"行人横穿": "减速至{}km/h,准备制动","车辆加塞": "减速至{}km/h,保持安全距离","动物闯入": "鸣笛警示,减速至{}km/h"},"traffic_control": {"红灯": "在距离{}米处开始减速,平稳停车","施工路段": "减速至{}km/h,向右变道","交警指挥": "按指挥手势行驶,保持车速{}km/h"},"emergency": {"爆胎": "紧握方向盘,缓踩刹车,车速降至{}km/h","刹车失灵": "启用电子手刹,车速降至{}km/h","前方事故": "紧急制动,车速降至{}km/h"}}if scenario["event"] == "保持车道行驶":return cmd_template["normal"].format(base_speed)for category in ["obstacle", "traffic_control", "emergency"]:if scenario["event"] in cmd_template[category]:target_speed = max(base_speed * 0.5, 20) if category == "emergency" else base_speed * 0.7return cmd_template[category][scenario["event"]].format(int(target_speed))return "维持当前操作"def generate_dataset(num_samples=1000):"""生成完整数据集"""dataset = []scene_types = ["normal", "obstacle", "traffic_control", "emergency"]for _ in tqdm(range(num_samples)):scene_type = random.choices(scene_types,weights=[0.4, 0.3, 0.2, 0.1], # 场景类型分布k=1)[0]scenario = generate_scenario(scene_type)sensor_data = generate_sensor_data(scenario)command = generate_control_command(scenario)# 构建输入描述input_desc = (f"当前环境:{scenario['weather']},{scenario['road_type']},"f"车速{scenario['speed']}km/h。")if "distance" in scenario:input_desc += f"检测到{scenario['distance']}米处{scenario['event']}"# 构建数据样本sample = {"input": input_desc,"output": command,"sensor_data": sensor_data,"metadata": {"scene_type": scene_type,"timestamp": fake.date_time_this_year().isoformat(),"location": fake.city() + "模拟道路"}}dataset.append(sample)return dataset# 生成并保存数据集
if __name__ == "__main__":dataset = generate_dataset(num_samples=5000)# 分割训练验证集random.shuffle(dataset)split_idx = int(len(dataset)*0.9)with open("data/train.json", "w") as f:json.dump(dataset[:split_idx], f, ensure_ascii=False, indent=2)with open("data/val.json", "w") as f:json.dump(dataset[split_idx:], f, ensure_ascii=False, indent=2)print(f"数据集生成完成,共{len(dataset)}条样本(训练集:{split_idx},验证集:{len(dataset)-split_idx})")
数据集结构说明
- 输入描述示例:
{"input": "当前环境:小雨,城市道路,车速45km/h。检测到28米处行人横穿","output": "减速至32km/h,准备制动","sensor_data": {"camera": {...},"lidar": {...},"radar": {"frontal_objects": [{"distance": 28.3, "speed": 1.5, "angle": 3.2}]}},"metadata": {"scene_type": "obstacle","timestamp": "2024-03-15T14:32:15","location": "上海模拟道路"}
}
-
场景覆盖范围:
- 天气条件:4种
- 道路类型:4种
- 障碍物类型:3种
- 交通管制:3种
- 紧急情况:3种
-
传感器数据模拟:
- 摄像头:模拟生成前视/左右摄像头图像(随机噪声)
- 激光雷达:生成1000个三维点云
- 毫米波雷达:生成前方物体距离/速度/角度
数据增强建议
- 真实传感器融合:
# 使用CARLA仿真平台获取真实传感器数据
from carla import World, Sensordef capture_real_sensor_data():world = connect_to_carla()camera = CameraSensor(world)lidar = LidarSensor(world)return {"camera": camera.capture(),"lidar": lidar.get_point_cloud()}
- 物理引擎增强:
def add_physics_noise(data):# 为传感器数据添加物理合理的噪声noise_levels = {"radar_distance": 0.1, # 10%距离噪声"camera_brightness": 0.05}if "radar" in data:for obj in data["radar"]["frontal_objects"]:obj["distance"] *= 1 + random.uniform(-noise_levels["radar_distance"], noise_levels["radar_distance"])if "camera" in data:for cam in data["camera"].values():cam = np.array(cam)cam += np.random.normal(0, noise_levels["camera_brightness"], cam.shape)return data
- 对抗样本生成:
def generate_adversarial_samples():# 生成极端情况样本return [{"input": "当前环境:暴雨,高速公路,车速120km/h。检测到5米处多车连环追尾","output": "紧急制动!开启双闪!车速降至30km/h"},{"input": "传感器故障:摄像头失效,雷达信号丢失","output": "启用冗余系统,维持最低安全车速40km/h"}]
数据集验证
- 统计分析脚本:
from collections import Counterdef analyze_dataset(dataset):scene_types = [s["metadata"]["scene_type"] for s in dataset]print("场景类型分布:", Counter(scene_types))cmd_lengths = [len(s["output"]) for s in dataset]print(f"指令平均长度: {np.mean(cmd_lengths):.1f}字符")speed_changes = [int("减速" in s["output"]) for s in dataset]print(f"需要减速的场景占比: {np.mean(speed_changes)*100:.1f}%")analyze_dataset(dataset)
- 可视化检查:
import matplotlib.pyplot as pltdef visualize_sample(sample):plt.figure(figsize=(12, 6))# 显示模拟摄像头图像plt.subplot(1, 2, 1)plt.imshow(np.array(sample["sensor_data"]["camera"]["front"]))plt.title("前视摄像头模拟")# 显示雷达数据plt.subplot(1, 2, 2)distances = [obj["distance"] for obj in sample["sensor_data"]["radar"]["frontal_objects"]]angles = [obj["angle"] for obj in sample["sensor_data"]["radar"]["frontal_objects"]]plt.scatter(angles, distances)plt.title("雷达探测示意图")plt.suptitle(f"控制指令: {sample['output']}")plt.show()visualize_sample(dataset[0])
该方案生成的合成数据集可用于:
- 自动驾驶决策模型的监督学习
- 强化学习的环境模拟
- 传感器融合算法的开发验证
- 异常情况处理能力的压力测试
实际应用时建议:
- 逐步替换合成数据为真实道路采集数据
- 添加车辆动力学参数(如转向角、加速度等)
- 结合高精度地图信息增强场景语义
- 增加驾驶员监控系统(DMS)的生理信号数据
相关文章:
根据deepseek模型微调训练自动驾驶模型及数据集的思路
以下是使用DeepSeek模型微调训练自动驾驶模型的详细步骤和代码示例。本流程假设你已有自动驾驶领域的数据集(如驾驶指令、传感器数据等),并基于PyTorch框架实现。 Step 1: 环境准备 # 安装依赖库 pip install torch transformers datasets n…...
蓝桥杯篇---IAP15F2K61S2定时器
文章目录 前言简介定时器的工作模式1.模式02.模式13.模式24.模式3 定时器的寄存器1.TMOD2.TCON3.THO/TL04.TH1/TL1 定时器的使用步骤1.配置TMOD2.设置初值3.启动定时器4.使能中断5.编写中断服务函数 示例代码:定时器的基本使用代码说明示例代码:定时器1用…...
Java发展史
JavaEE的由来 语言的诞生 Java的前身是Oak语言,其目的是搞嵌入式开发开发智能面包机 叮~~~🍞🍞🍞 产品以失败告终 巅峰 网景公司需要网景浏览器打开网页,Oak->Java,进行前端开发(相关技…...
Jenkins 新建配置 Freestyle project 任务 六
Jenkins 新建配置 Freestyle project 任务 六 一、新建任务 在 Jenkins 界面 点击 New Item 点击 Apply 点击 Save 回到任务主界面 二、General 点击左侧 Configure Description:任务描述 勾选 Discard old builds Discard old builds:控制何时…...
Electron视图进程和主进程通讯
快速创建基于vue的electron项目:quick-start/create-electron - npm 视图线程也就index.html是无法直接访问这个api的(如果没有开启视图层访问nodejs的功能,现在几乎没法直接开启,开启了一堆警告提示) 所以需要通过r…...
【湖南-益阳】《益阳市市本级政府投资信息化项目预算编制与财政评审工作指南》益财评〔2024〕346号-省市费用标准解读系列40
《益阳市市本级政府投资信息化项目预算编制与财政评审工作指南(试行)》(益财评〔2024〕346号)由益阳市财政局主编,2024年10月17日起正式执行,本指南主要规定了政府投资信息化项目费用的构成、测量过程和方法…...
springboot+mybatis按条件分页查询多张表
文章目录 背景方案推荐创建 DTO创建 Mapper创建对应 xmlService 代码 背景 假如同 mysql 数据源下有如下几张表: 用户基础信息表用户地址表用户学历信息表 我希望做分页查询用户数据,用户数据为各个表内信息的汇总,并且这个分页查询会根据…...
探索Java中的集合类_特性与使用场景
1. 引言 1.1 Java集合框架概述 Java集合框架(Java Collections Framework, JCF)是Java中用于存储和操作一组对象的类和接口的统称。它提供了多种数据结构来满足不同的需求,如列表、集合、映射等。JCF的核心接口包括Collection、List、Set、Queue和Map,以及它们的各种实现…...
具身智能在智能巡检机器人中的应用——以开关柜带电操作机器人为例
随着机器人技术和人工智能的迅速发展,具身智能在各行业的应用日益广泛,尤其是在电力行业中的智能巡检领域。传统的电力巡检和维护工作通常需要人工操作,存在着高温、高压、强电磁场等危险环境,且效率较低。开关柜带电操作机器人作…...
C#+SqlSugar实现主从库读写分离
在使用 **SqlSugar** 进行分库操作时,可以通过配置多个数据库连接,并根据业务逻辑动态切换数据库。以下是一个完整的分库示例,展示如何实现分库功能。 --- ### **1. 安装 NuGet 包** 安装 SqlSugarCore: bash dotnet add packag…...
Webpack 基础入门
一、Webpack 是什么 Webpack 是一款现代 JavaScript 应用程序的静态模块打包工具。在 Web 开发中,我们的项目会包含各种类型的文件,如 JavaScript、CSS、图片等。Webpack 可以将这些文件打包成一个或多个文件,以便在浏览器中高效加载。它就像…...
nuxt中引入element-ui组件控制台报错问题
在使用element-ui组件的外层加一层 <client-only placeholder"Loading..."><van-button type"primary">主要按钮</van-button> </client-only> 实际使用: <div class"tab"><client-only placehol…...
【机器学习】线性回归 多项式线性回归
【机器学习系列】 KNN算法 KNN算法原理简介及要点 特征归一化的重要性及方式线性回归算法 线性回归与一元线性回归 线性回归模型的损失函数 多元线性回归 多项式线性回归 多项式线性回归 V1.0多项式回归多项式回归的公式 特征代换超越函数作为特征向量维度 V1.0 多项式回归 …...
Java面试第二山!《计算机网络》!
在 Java 面试里,计算机网络知识是高频考点,今天就来盘点那些最容易被问到的计算机网络面试题,帮你轻松应对面试,也方便和朋友们一起探讨学习。 一、HTTP 和 HTTPS 的区别 1. 面试题呈现 HTTP 和 HTTPS 有什么区别?在…...
RocketMQ 5.0安装部署
0.前言 在微服务架构逐渐成为主流的今天,消息队列如同数字世界的快递员,承担着系统间高效通信的重要使命。 Apache RocketMQ 自诞生以来,因其架构简单、业务功能丰富、具备极强可扩展性等特点被众多企业开发者以及云厂商广泛采用。历经十余…...
go语言并发的最佳实践
Go 语言的并发模型是其最强大的特性之一,基于 CSP(Communicating Sequential Processes)理论,通过 goroutine 和 channel 实现轻量级并发. 一、并发核心概念 1. Goroutine 在 Go 语言中,Goroutine 是实现并发编程的…...
俄罗斯方块游戏完整代码示例
以下是一个基于Cocos Creator引擎开发的俄罗斯方块游戏的完整代码示例。该游戏实现了俄罗斯方块的基本功能,并且代码整合在单个文件中,无需任何外部依赖,可以直接在浏览器中运行。 1. 创建Cocos Creator项目 首先,确保你已经安装了…...
Ubuntu22.04配置cuda/cudnn/pytorch
Ubuntu22.04配置cuda/cudnn/pytorch 安装cuda官网下载.run文件并且安装/etc/profile中配置cuda环境变量 cudnn安装官网找cuda版本对应的cudnn版本下载复制相应文件到系统文件中 安装pytorch官网找cuda对应版本的pytorchpython代码测试pytorch-GPU版本安装情况 安装cuda 官网下…...
【九】Golang 数组
💢欢迎来到张胤尘的技术站 💥技术如江河,汇聚众志成。代码似星辰,照亮行征程。开源精神长,传承永不忘。携手共前行,未来更辉煌💥 文章目录 数组数组初始化默认初始化显式初始化省略长度初始化索…...
百达翡丽(Patek Philippe):瑞士制表的巅峰之作(中英双语)
百达翡丽(Patek Philippe):瑞士制表的巅峰之作 在钟表界,百达翡丽(Patek Philippe) 一直被誉为“世界三大名表”之一,并且常被认为是其中的至高存在。一句“没人能真正拥有一枚百达翡丽&#x…...
【学习】软件测试中的分类树法介绍
分类树法是一种软件测试设计技术,它通过构建一个树状结构来组织和展示输入数据的多种组合。这种方法有助于系统地识别和分析可能的测试情况,从而确保对软件进行全面而详尽的测试。分类树法特别适用于具有多个选择或条件的复杂系统,它可以有效…...
打造智能语料库:通过Coco AI Server 实现 Notion 笔记 RAG 检索功能
本文将详细介绍如何将 Notion 作为语料库,部署 Coco Server 的 RAG(Retrieval-Augmented Generation)功能。我们将使用 Easysearch 作为语料库存储 Notion 素材,并通过 ollama 进行 LLM 推理。 1. 环境准备 1.1 启动 Easysearch…...
SP字体UI放大代码
代码: echo off set QT_SCALE_FACTOR放大倍数 start "" "你的SP.exe启动路径"...
spring boot知识点2
1.spring boot 要开启一些特性,可通过什么方式开启 a.通过Enable注解,可启动定时服务 b.通过application.properties可设置端口号等地址信息 2.什么是热部署,以及spring boot通过什么方式进行热部署 热部署这个概念,我知道。就…...
动手学Agent——Day2
文章目录 一、用 Llama-index 创建 Agent1. 测试模型2. 自定义一个接口类3. 使用 ReActAgent & FunctionTool 构建 Agent 二、数据库对话 Agent1. SQLite 数据库1.1 创建数据库 & 连接1.2 创建、插入、查询、更新、删除数据1.3 关闭连接建立数据库 2. ollama3. 配置对话…...
qt实习总结
创建一个滑动条 QSlider *slider new QSlider(Qt::Vertical); //创建一个垂直方向的 进度条 带有上下箭头的输入框 QSpinBox 提供了一个带有上下箭头的输入框 垂直 水平怎么说 horizontal vetical 布局知识 BtnLayout->addWidget(AmendBtn); BtnLayout->addWidg…...
SpringBoot3.x整合WebSocket
SpringBoot3.x整合WebSocket 本文主要介绍最新springboot3.x下如何整合WebSocket. WebSocket简述 WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,它允许在浏览器和服务器之间进行实时的、双向的通信。相对于传统的基于请求和响应的 HTTP 协议ÿ…...
vLLM专题(二):安装-CPU
vLLM 是一个 Python 库,支持以下 CPU 变体。选择您的 CPU 类型以查看供应商特定的说明: Intel/AMD x86 vLLM 最初支持在 x86 CPU 平台上进行基本模型推理和服务,支持的数据类型包括 FP32、FP16 和 BF16。 注意 此设备没有预构建的 wheel 包或镜像,因此您必须从源代码构建 v…...
「软件设计模式」适配器模式(Adapter)
软件设计模式深度解析:适配器模式(Adapter)(C实现) 一、模式概述 适配器模式(Adapter Pattern)是结构型设计模式中的"接口转换器",它像现实世界中的电源适配器一样&#…...
Dify平台搭建面试机器人
无代码搭建面试机器人 什么是Dify 什么是Dify Dify 是一款开源的大语言模型(LLM) 应用开发平台。它融合了后端即服务(Backend as Service)和 LLMOps 的理念,使开发者可以快速搭建生产级的生成式 AI 应用。即使你是非技术人员,也能…...
