0

0

高效生成BERT词嵌入:解决内存溢出挑战

聖光之護

聖光之護

发布时间:2025-10-18 14:43:00

|

229人浏览过

|

来源于php中文网

原创

高效生成BERT词嵌入:解决内存溢出挑战

本文探讨了在使用bert模型生成词嵌入时常见的内存溢出问题,尤其是在处理长文本或大规模数据集时。我们将介绍如何利用hugging face transformers库进行高效的文本分词和模型前向传播,并强调通过批处理策略进一步优化内存使用,从而稳定地获取高质量的词嵌入。

在使用BERT等大型预训练模型生成词嵌入时,开发者常遇到内存溢出(OutOfMemoryError)的问题,尤其是在处理包含大量长文本的数据集时。这通常发生在尝试一次性将所有数据加载到GPU内存中进行处理时。本教程将提供一种高效且内存友好的方法来生成BERT词嵌入,并讨论如何进一步优化以避免内存问题。

1. 理解内存溢出问题

当您拥有一个包含2000多行长文本的数据集,并尝试使用bert_tokenizer.batch_encode_plus对所有文本进行分词,然后一次性将所有input_ids和attention_mask传递给BERT模型进行前向传播时,即使设置了max_length=512,也极易导致GPU内存不足。错误信息如OutOfMemoryError: CUDA out of memory. Tried to allocate X GiB.明确指出是GPU内存不足。

2. 高效的BERT词嵌入生成方法

为了避免内存问题,推荐使用Hugging Face transformers库提供的AutoModel和AutoTokenizer接口,它们在设计上考虑了效率和易用性。

2.1 加载模型与分词器

首先,加载匹配的预训练模型和分词器。这里以indolem/indobert-base-uncased为例,您可以根据需要替换为其他BERT模型。

import torch
from transformers import AutoModel, AutoTokenizer

# 示例输入文本列表
texts = ['这是一个测试句子,它可能有点长,但我们希望它能被正确处理。', 
         '另一个示例文本,用于演示如何生成词嵌入。']

# 加载匹配的模型和分词器
# 替换为您的模型名称,例如 "bert-base-uncased"
model_name = "indolem/indobert-base-uncased" 
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 将模型移动到GPU(如果可用)
if torch.cuda.is_available():
    model.to('cuda')
    print("模型已移至GPU。")
else:
    print("未检测到GPU,模型将在CPU上运行。")

2.2 文本分词与编码

直接使用分词器对文本列表进行编码,它会处理批量分词、填充和截断,并返回PyTorch张量。

# 对批量句子进行分词,截断至512,并进行填充
tokenized_texts = tokenizer(texts, 
                            max_length=512,       # 最大序列长度
                            truncation=True,      # 启用截断,超出max_length的部分将被截断
                            padding=True,         # 启用填充,短于max_length的部分将被填充
                            return_tensors='pt')  # 返回PyTorch张量

# 将分词结果移动到GPU(如果模型在GPU上)
if torch.cuda.is_available():
    tokenized_texts = {k: v.to('cuda') for k, v in tokenized_texts.items()}

print(f"分词结果的input_ids形状: {tokenized_texts['input_ids'].shape}")

参数说明:

  • max_length: 指定最大序列长度。超出此长度的文本将被截断。
  • truncation=True: 确保所有序列都被截断到max_length。
  • padding=True: 确保所有序列都被填充到max_length(或批次中最长序列的长度,如果未指定max_length)。
  • return_tensors='pt': 返回PyTorch张量。

2.3 模型前向传播获取词嵌入

在分词完成后,将编码后的输入传递给模型进行前向传播。为了节省内存,我们通常在推理阶段使用torch.no_grad()上下文管理器。

Quillbot
Quillbot

一款AI写作润色工具,QuillBot的人工智能改写工具将提高你的写作能力。

下载
# 前向传播
with torch.no_grad():
    input_ids = tokenized_texts['input_ids']
    attention_mask = tokenized_texts['attention_mask']

    outputs = model(input_ids=input_ids, 
                    attention_mask=attention_mask)

    # 获取最后一层的隐藏状态作为词嵌入
    word_embeddings = outputs.last_hidden_state

# 打印词嵌入的形状
print(f"生成的词嵌入形状: {word_embeddings.shape}")
# 预期输出形状示例: torch.Size([batch_size, num_seq_tokens, embed_size])
# 例如: torch.Size([2, 512, 768])

word_embeddings的形状通常是 [batch_size, num_seq_tokens, embed_size]。其中:

  • batch_size:输入文本的数量。
  • num_seq_tokens:序列中的token数量(通常是max_length或实际序列长度)。
  • embed_size:模型的隐藏层大小(例如BERT-base是768)。

3. 处理大规模数据集的内存优化:批处理

尽管上述方法已经非常高效,但在处理极大规模的数据集或极长的文本时,仍可能出现内存不足。此时,最有效的策略是将数据分成更小的批次(mini-batches)进行处理。

from torch.utils.data import DataLoader, TensorDataset

# 假设您有一个非常大的文本列表
all_texts = ['长文本1', '长文本2', ..., '长文本N'] # N可能非常大

# 定义批次大小
batch_size = 16 # 根据您的GPU内存调整,尝试16, 8, 4等更小的值

# 分词所有文本 (注意:如果all_texts非常大,这一步本身可能耗内存,可以考虑分批次分词)
# 为了演示方便,我们假设分词结果可以一次性存储
tokenized_inputs = tokenizer(all_texts, 
                             max_length=512, 
                             truncation=True, 
                             padding='max_length', # 确保所有批次长度一致
                             return_tensors='pt')

input_ids_tensor = tokenized_inputs['input_ids']
attention_mask_tensor = tokenized_inputs['attention_mask']

# 创建一个TensorDataset
dataset = TensorDataset(input_ids_tensor, attention_mask_tensor)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

all_embeddings = []

# 迭代处理每个批次
print(f"\n开始分批处理,批次大小为: {batch_size}")
with torch.no_grad():
    for batch_idx, batch in enumerate(dataloader):
        batch_input_ids, batch_attention_mask = batch

        # 将批次数据移动到GPU
        if torch.cuda.is_available():
            batch_input_ids = batch_input_ids.to('cuda')
            batch_attention_mask = batch_attention_mask.to('cuda')

        # 模型前向传播
        outputs = model(input_ids=batch_input_ids, 
                        attention_mask=batch_attention_mask)

        # 获取词嵌入并移回CPU(可选,但推荐,以释放GPU内存)
        batch_word_embeddings = outputs.last_hidden_state.cpu()
        all_embeddings.append(batch_word_embeddings)
        print(f"  处理批次 {batch_idx+1}/{len(dataloader)},词嵌入形状: {batch_word_embeddings.shape}")

# 合并所有批次的词嵌入
final_embeddings = torch.cat(all_embeddings, dim=0)
print(f"\n所有文本的最终词嵌入形状: {final_embeddings.shape}")

注意事项:

  • 调整batch_size: 这是解决内存溢出最关键的参数。如果仍然出现OOM,请进一步减小batch_size。
  • padding='max_length': 在分批处理时,为了确保每个批次的张量形状一致,通常建议将padding设置为'max_length',而不是默认的True(它会填充到批次内最长序列的长度)。
  • 及时释放GPU内存: 在处理完一个批次后,如果不再需要该批次的数据,可以将其从GPU移回CPU (.cpu()),或者在循环结束后清理不再需要的张量,以帮助释放GPU内存。

总结

生成BERT词嵌入时避免内存溢出,关键在于:

  1. 使用Hugging Face AutoTokenizer直接处理文本列表:它能高效地完成分词、填充和截断,生成适合模型输入的张量。
  2. 利用torch.no_grad()进行推理:在模型前向传播时禁用梯度计算,显著减少内存消耗。
  3. 实施批处理(Batching)策略:将大型数据集划分为更小的批次,逐批次送入模型处理,这是解决大规模数据内存问题的根本方法。

通过以上策略,您可以有效地生成BERT词嵌入,即使面对大规模长文本数据,也能稳定运行并避免常见的内存溢出问题。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
登录token无效
登录token无效

登录token无效解决方法:1、检查token的有效期限,如果token已经过期,需要重新获取一个新的token;2、检查token的签名,如果签名不正确,需要重新获取一个新的token;3、检查密钥的正确性,如果密钥不正确,需要重新获取一个新的token;4、使用HTTPS协议传输token,建议使用HTTPS协议进行传输 ;5、使用双因素认证,双因素认证可以提高账户的安全性。

6194

2023.09.14

登录token无效怎么办
登录token无效怎么办

登录token无效的解决办法有检查Token是否过期、检查Token是否正确、检查Token是否被篡改、检查Token是否与用户匹配、清除缓存或Cookie、检查网络连接和服务器状态、重新登录或请求新的Token、联系技术支持或开发人员等。本专题为大家提供token相关的文章、下载、课程内容,供大家免费下载体验。

819

2023.09.14

token怎么获取
token怎么获取

获取token值的方法:1、小程序调用“wx.login()”获取 临时登录凭证code,并回传到开发者服务器;2、开发者服务器以code换取,用户唯一标识openid和会话密钥“session_key”。想了解更详细的内容,可以阅读本专题下面的文章。

1069

2023.12.21

token什么意思
token什么意思

token是一种用于表示用户权限、记录交易信息、支付虚拟货币的数字货币。可以用来在特定的网络上进行交易,用来购买或出售特定的虚拟货币,也可以用来支付特定的服务费用。想了解更多token什么意思的相关内容可以访问本专题下面的文章。

1358

2024.03.01

硬盘接口类型介绍
硬盘接口类型介绍

硬盘接口类型有IDE、SATA、SCSI、Fibre Channel、USB、eSATA、mSATA、PCIe等等。详细介绍:1、IDE接口是一种并行接口,主要用于连接硬盘和光驱等设备,它主要有两种类型:ATA和ATAPI,IDE接口已经逐渐被SATA接口;2、SATA接口是一种串行接口,相较于IDE接口,它具有更高的传输速度、更低的功耗和更小的体积;3、SCSI接口等等。

1133

2023.10.19

PHP接口编写教程
PHP接口编写教程

本专题整合了PHP接口编写教程,阅读专题下面的文章了解更多详细内容。

213

2025.10.17

php8.4实现接口限流的教程
php8.4实现接口限流的教程

PHP8.4本身不内置限流功能,需借助Redis(令牌桶)或Swoole(漏桶)实现;文件锁因I/O瓶颈、无跨机共享、秒级精度等缺陷不适用高并发场景。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

1828

2025.12.29

java接口相关教程
java接口相关教程

本专题整合了java接口相关内容,阅读专题下面的文章了解更多详细内容。

20

2026.01.19

java入门学习合集
java入门学习合集

本专题整合了java入门学习指南、初学者项目实战、入门到精通等等内容,阅读专题下面的文章了解更多详细学习方法。

1

2026.01.29

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.4万人学习

Rust 教程
Rust 教程

共28课时 | 5.1万人学习

Git 教程
Git 教程

共21课时 | 3.1万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号