一起学Hugging Face Transformers(13)- 模型微调之自定义训练循环

文章目录

  • 前言
  • 一、什么是训练循环
    • 1. 训练循环的关键步骤
    • 2. 示例
    • 3. 训练循环的重要性
  • 二、使用 Hugging Face Transformers 库实现自定义训练循环
    • 1. 前期准备
      • 1)安装依赖
      • 2)导入必要的库
    • 2. 加载数据和模型
      • 1) 加载数据集
      • 2) 加载预训练模型和分词器
      • 3) 预处理数据
      • 4) 创建数据加载器
    • 3. 自定义训练循环
      • 1) 定义优化器和学习率调度器
      • 2) 定义训练和评估函数
      • 3) 运行训练和评估
  • 总结


前言

Hugging Face Transformers 库为 NLP 模型的预训练和微调提供了丰富的工具和简便的方法。虽然 Trainer API 简化了许多常见任务,但有时我们需要更多的控制权和灵活性,这时可以实现自定义训练循环。本文将介绍什么是训练循环以及如何使用 Hugging Face Transformers 库实现自定义训练循环。


一、什么是训练循环

在模型微调过程中,训练循环是指模型训练的核心过程,通过多次迭代数据集来调整模型的参数,使其在特定任务上表现更好。训练循环包含以下几个关键步骤:

1. 训练循环的关键步骤

1) 前向传播(Forward Pass)

  • 模型接收输入数据并通过网络进行计算,生成预测输出。这一步是将输入数据通过模型的各层逐步传递,计算出最终的预测结果。

2) 计算损失(Compute Loss)

  • 将模型的预测输出与真实标签进行比较,计算损失函数的值。损失函数是一个衡量预测结果与真实值之间差距的指标,常用的损失函数有交叉熵损失(用于分类任务)和均方误差(用于回归任务)。

3) 反向传播(Backward Pass)

  • 根据损失函数的值,计算每个参数对损失的贡献,得到梯度。反向传播使用链式法则,将损失对每个参数的梯度计算出来。

4) 参数更新(Parameter Update)

  • 使用优化算法(如梯度下降、Adam 等)根据计算出的梯度调整模型的参数。优化算法会更新每个参数,使损失函数的值逐步减小,模型的预测性能逐步提高。

5) 重复以上步骤

  • 以上过程在整个数据集上进行多次(多个epoch),每次遍历数据集被称为一个epoch。随着训练的进行,模型的性能会不断提升。

2. 示例

假设你在微调一个BERT模型用于情感分析任务,训练循环的步骤如下:

1) 前向传播

  • 输入一条文本评论,模型通过各层网络计算,生成预测的情感标签(如正面或负面)。

2) 计算损失

  • 将模型的预测标签与实际标签进行比较,计算交叉熵损失。

3) 反向传播

  • 计算损失对每个模型参数的梯度,确定每个参数需要调整的方向和幅度。

4) 参数更新

  • 使用Adam优化器,根据计算出的梯度调整模型的参数。

5) 重复以上步骤

  • 在整个训练数据集上进行多次迭代,不断调整参数,使模型的预测精度逐步提高。

3. 训练循环的重要性

训练循环是模型微调的核心,通过多次迭代和参数更新,使模型能够从数据中学习,逐步提高在特定任务上的性能。理解训练循环的各个步骤和原理,有助于更好地调试和优化模型,获得更好的结果。

在实际应用中,训练循环可能会包含一些额外的步骤和技术,例如:

  • 批量训练(Mini-Batch Training):将数据集分成小批量,每次训练一个批量,降低计算资源的需求。
  • 学习率调度(Learning Rate Scheduling):动态调整学习率,以提高训练效率和模型性能。
  • 正则化技术(Regularization Techniques):如Dropout、权重衰减等,防止模型过拟合。

这些技术和方法结合使用,可以进一步提升模型微调的效果和性能。

二、使用 Hugging Face Transformers 库实现自定义训练循环

1. 前期准备

1)安装依赖

首先,确保已经安装了必要的库:

pip install transformers datasets torch

2)导入必要的库

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm

2. 加载数据和模型

1) 加载数据集

这里我们以 IMDb 电影评论数据集为例:

dataset = load_dataset("imdb")

2) 加载预训练模型和分词器

我们将使用 distilbert-base-uncased 作为基础模型:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

3) 预处理数据

定义一个预处理函数,并将其应用到数据集:

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

4) 创建数据加载器

train_dataloader = DataLoader(encoded_dataset["train"], batch_size=8, shuffle=True)
eval_dataloader = DataLoader(encoded_dataset["test"], batch_size=8)

3. 自定义训练循环

1) 定义优化器和学习率调度器

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

2) 定义训练和评估函数

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def train_loop():
    model.train()
    for batch in tqdm(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

def eval_loop():
    model.eval()
    total_loss = 0
    correct_predictions = 0

    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            logits = outputs.logits
            total_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == batch["labels"]).sum().item()

    avg_loss = total_loss / len(eval_dataloader)
    accuracy = correct_predictions / len(eval_dataloader.dataset)
    return avg_loss, accuracy

3) 运行训练和评估

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_loop()
    avg_loss, accuracy = eval_loop()
    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

总结

通过上述步骤,我们实现了使用 Hugging Face Transformers 库的自定义训练循环。这种方法提供了更大的灵活性,可以根据具体需求调整训练过程。无论是优化器、学习率调度器,还是其他训练策略,都可以根据需要进行定制。希望这篇文章能帮助你更好地理解和实现自定义训练循环,为你的 NLP 项目提供更强大的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/782281.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文略读:Can Long-Context Language Models Subsume Retrieval, RAG, SQL, and More?

202406 arxiv 1 intro 传统上,复杂的AI任务需要多个专门系统协作完成。 这类系统通常需要独立的模块来进行信息检索、问答和数据库查询等任务大模型时代,尤其是上下文语言模型(LCLM)时代,上述问题可以“一体化”完成…

Qt/C++音视频开发78-获取本地摄像头支持的分辨率/帧率/格式等信息/mjpeg/yuyv/h264

一、前言 上一篇文章讲到用ffmpeg命令方式执行打印到日志输出,可以拿到本地摄像头设备信息,顺藤摸瓜,发现可以通过执行 ffmpeg -f dshow -list_options true -i video“Webcam” 命令获取指定摄像头设备的分辨率帧率格式等信息,会…

Python 全栈系列258 线程并发与协程并发

说明 最近在大模型调用上,为了尽快的进行大量的数据处理,需要采用并发进行处理。 Before: 以前主要是自己利用CPU和GPU来搭建数据处理程序或者服务,资源受限于所用的硬件,并不那么考虑并发问题。在处理程序中,并发主要…

互联网十万个为什么之什么是数据备份?

数据备份是按照一定的备份频率创建数据副本的过程,将重要的数据复制到其它位置或者存储介质,并对生成的副本保留一定的时长。备份通常储存在不同的物理介质或云端,以确保数据的连续性和完整性。有效的备份策略至关重要,以防止数据…

ESP32-C3-Arduino-uart

引脚图 2实现串口发送接收 1默认值初始化串口(默认是uart0) Serial.begin(UART_BAUD); 参数是波特率 2自定义其他串口 2-1创建实例 HardwareSerial SerialUART(0); //数值指的是uart0 1为uart1.。。。。 2-2初始化 SerialUART.begin(UART_BAU…

LabVIEW的Actor Framework (AF) 结构介绍

LabVIEW的Actor Framework (AF) 是一种高级架构,用于开发并发、可扩展和模块化的应用程序。通过面向对象编程(OOP)和消息传递机制,AF结构实现了高效的任务管理和数据处理。其主要特点包括并发执行、动态可扩展性和强大的错误处理能…

不是哥们?你怎么抖成这样了?求你进来学学防抖吧!全方位深入剖析防抖的奥秘

前言 古有猴哥三打白骨精,白骨精 > 噶 今有用户疯狂点请求,服务器 > 噶 所以这防抖咱必须得学会!!! 本文就来讲解一下Web前端中防抖的奥秘吧!!!! 为什么要做防…

适用于 Windows 11/10/8/7/Vista/XP 的最佳免费分区软件

无论您使用的是 SSD、机械磁盘还是任何类型的 RAID 阵列,硬盘驱动器都是 Windows 计算机中不可或缺的组件。在将文件保存到全新磁盘之前,您应该初始化它,创建分区并使用文件系统格式化。在运行计算机一段时间后,您需要收缩、扩展、…

14-25 剑和侠客 – 预训练模型三部曲2 – 视觉

概述 在第 1 部分中,我们讨论了适用于文本的预训练模型的重要性及其在当今世界的相关性。大型语言模型 (LLM),尤其是 GPT-3 和随后的 GPT-3.5,已经获得了极大的欢迎,从而在 AI 讨论中引起了越来越多的关注。我们已经看到了用于构…

everything高级搜索-cnblog

everything高级搜索用法 基础4选项验证 总结搜索方式 高级搜索搜指定路径文件名: 文件名 路径不含文件名: !文件名包含单词 路径包含指定内容: 路径 content:内容 大小写 区分大小写搜索搜指定路径文件名: case:文件名 路径全字匹配 全字搜指定路径文件名: wholewo…

网络安全基础-2

知识点 1.网站搭建前置知识 域名,子域名,DNS,HTTP/HTTPS,证书等 注册购买域名:阿里云企航_万网域名_商标注册_资质备案_软件著作权_网站建设-阿里云 2.web应用环境架构类 理解不同WEB应用组成角色功能架构: 开发语…

四、(1)网络爬虫入门及准备工作(爬虫及数据可视化)

四、(1)网络爬虫入门及准备工作(爬虫及数据可视化) 1,网络爬虫入门1.1 百度指数1.2 天眼查1.3 爬虫原理1.4 搜索引擎原理 2,准备工作2.1 分析爬取页面2.2 爬虫拿到的不仅是网页还是网页的源代码2.3 爬虫就是…

Golang | Leetcode Golang题解之第213题打家劫舍II

题目: 题解: func _rob(nums []int) int {first, second : nums[0], max(nums[0], nums[1])for _, v : range nums[2:] {first, second second, max(firstv, second)}return second }func rob(nums []int) int {n : len(nums)if n 1 {return nums[0]}…

7.pwn 工具安装和使用

关闭保护的方法 pie: -no-pie Canary:-fno-stack-protector aslr:查看:cat /proc/sys/kernel/randomize_va_space 2表示打开 关闭:echo 0>/proc/sys/kernel/randomize_va_space NX:-z execstack gdb使用以及插件安装 是GNU软件系统中的标准调试工具,此外GD…

【计组OS】I/O方式笔记总结

苏泽 “弃工从研”的路上很孤独,于是我记下了些许笔记相伴,希望能够帮助到大家 目录 IO方式:程序查询方式 工作原理 程序查询方式的详细流程: 1. 初始化阶段 2. 发送I/O命令 3. 循环检查状态 4. 数据传输 5. 继续查询 6…

reactor和proactor模型

Reactor模型是非阻塞的同步IO模型。在主线程中也就是IO处理单元中,只负责监听文件描述符上是否有事件发生,有的话就立即将事件通知工作线程,将socket可读可写事件放入请求队列,交给工作线程处理。 总而言之就是主线程监听有事件发…

期末考试结束,老师该如何私发成绩?

随着期末考试的落幕,校园里又恢复了往日的宁静。然而,对于老师们来说,这并不意味着工作的结束,相反,一系列繁琐的任务才刚刚开始。 成绩单的发放,就是其中一项让人头疼的工作。家长们焦急地等待着孩子的考试…

可视化作品集(08):能源电力领域

能源电力领域的可视化大屏,有着巨大的用武之地,不要小看它。 监控能源生产和消耗情况: 通过可视化大屏,可以实时监控能源生产和消耗情况,包括发电量、能源供应情况、能源消耗情况等,帮助管理者及时了解能…

14-39 剑和诗人13 - 顶级大模型测试分析和建议

​​​​​ 随着对高级语言功能的需求不断飙升,市场上涌现出大量语言模型,每种模型都拥有独特的优势和功能。然而,驾驭这个错综复杂的生态系统可能是一项艰巨的任务,开发人员和研究人员经常面临选择最适合其特定需求的模型的挑战。…

React中的useMemo和memo

引言 React是一个声明式的JavaScript库,用于构建用户界面。在开发过程中,性能优化是一个重要的方面。useMemo和memo是React提供的工具,用于帮助开发者避免不必要的渲染和计算,从而提升应用性能。 问题背景 在React应用中&#…