0

0

使用 Transformers 解决 BERT 词嵌入中的内存溢出问题

心靈之曲

心靈之曲

发布时间:2025-10-18 15:35:18

|

204人浏览过

|

来源于php中文网

原创

使用 transformers 解决 bert 词嵌入中的内存溢出问题

本文旨在提供一种解决在使用 BERT 等 Transformers 模型进行词嵌入时遇到的内存溢出问题的有效方法。通过直接使用 tokenizer 处理文本输入,并适当调整 batch size,可以避免 `batch_encode_plus` 可能带来的内存压力,从而顺利生成词嵌入。

在使用 BERT 或其他 Transformers 模型生成文本数据集的词嵌入时,经常会遇到 OutOfMemoryError 错误,尤其是在处理长文本序列时。这主要是因为模型需要加载大量数据到 GPU 内存中进行计算。本文将介绍一种更高效的方法,通过优化文本处理流程和调整 batch size 来解决这个问题。

解决方案:优化文本处理和 Batch Size

传统的 batch_encode_plus 方法可能会导致内存占用过高。一种更有效的方法是直接使用 tokenizer 处理文本输入,并结合适当的 batch size。

1. 直接使用 Tokenizer 处理文本

不再使用 batch_encode_plus,而是直接使用 tokenizer 对象处理文本列表。这允许 tokenizer 内部更有效地管理内存。

import torch
from transformers import AutoModel, AutoTokenizer

# 输入文本列表 (可以是长句子)
texts = ['test1', 'test2']

# 加载预训练模型和 tokenizer
model_name = "indolem/indobert-base-uncased" # 这里替换为你想要使用的模型
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 对文本进行分词、截断和填充
tokenized_texts = tokenizer(texts,
                            max_length=512,
                            truncation=True,
                            padding=True,
                            return_tensors='pt')

代码解释:

Type
Type

生成草稿,转换文本,获得写作帮助-等等。

下载
  • AutoModel.from_pretrained(model_name): 加载指定名称的预训练模型。
  • AutoTokenizer.from_pretrained(model_name): 加载与模型对应的 tokenizer。
  • tokenizer(...): 使用 tokenizer 直接处理文本列表,设置最大长度、截断和填充策略,并返回 PyTorch 张量。

2. 前向传播

将 tokenizer 处理后的文本批次传递给模型进行前向传播。

# 前向传播
with torch.no_grad():
    input_ids, attention_mask = tokenized_texts['input_ids'], tokenized_texts['attention_mask']

    outputs = model(input_ids=input_ids,
                    attention_mask=attention_mask)

    word_embeddings = outputs.last_hidden_state

代码解释:

  • with torch.no_grad():: 禁用梯度计算,减少内存占用。
  • outputs = model(...): 将输入 ID 和 attention mask 传递给模型进行前向传播。
  • word_embeddings = outputs.last_hidden_state: 获取最后一层的隐藏状态,即词嵌入。

3. 检查输出形状

验证词嵌入的形状是否符合预期。

print(word_embeddings.shape)
# 输出: torch.Size([batch_size, num_seq_tokens, embed_size])
# 例如: torch.Size([2, 4, 768])

代码解释:

  • word_embeddings.shape: 打印词嵌入的形状,通常为 [batch_size, num_seq_tokens, embed_size],其中 batch_size 是批次大小,num_seq_tokens 是序列中的 token 数量,embed_size 是嵌入维度。

4. 调整 Batch Size (如果仍然出现 OOM)

如果即使使用上述方法仍然出现 OutOfMemoryError,则需要减小 batch size。可以循环处理数据,每次处理较小的批次。

import torch
from transformers import AutoModel, AutoTokenizer

# 输入文本列表
texts = ['test1', 'test2', 'test3', 'test4', 'test5']

# 加载预训练模型和 tokenizer
model_name = "indolem/indobert-base-uncased" # 这里替换为你想要使用的模型
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

batch_size = 2  # 设置较小的 batch size

for i in range(0, len(texts), batch_size):
    batch_texts = texts[i:i + batch_size]

    # 对文本进行分词、截断和填充
    tokenized_texts = tokenizer(batch_texts,
                                max_length=512,
                                truncation=True,
                                padding=True,
                                return_tensors='pt')

    # 前向传播
    with torch.no_grad():
        input_ids, attention_mask = tokenized_texts['input_ids'], tokenized_texts['attention_mask']

        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask)

        word_embeddings = outputs.last_hidden_state

    print(f"Batch {i//batch_size + 1} embeddings shape: {word_embeddings.shape}")

    # 在这里处理词嵌入,例如存储或进一步分析

代码解释:

  • batch_size = 2: 设置较小的 batch size,例如 2。
  • for i in range(0, len(texts), batch_size):: 循环处理数据,每次处理一个批次。
  • batch_texts = texts[i:i + batch_size]: 提取当前批次的文本。
  • 循环内部的代码与之前的示例相同,但现在每次处理的文本量较小,从而降低了内存占用。

注意事项

  • 选择合适的 Batch Size: Batch size 的选择取决于 GPU 的内存大小和模型的复杂度。可以尝试不同的 batch size,找到一个既能充分利用 GPU 资源又能避免内存溢出的值。
  • 使用 GPU: 确保代码在 GPU 上运行,这可以显著提高计算速度。
  • 优化模型: 如果可能,可以尝试使用更小的模型或对模型进行量化,以减少内存占用。
  • 梯度累积: 在某些情况下,可以使用梯度累积来模拟更大的 batch size,而无需增加内存占用。

总结

通过直接使用 tokenizer 处理文本输入并适当调整 batch size,可以有效地解决在使用 Transformers 模型进行词嵌入时遇到的内存溢出问题。这种方法简单易行,并且适用于各种 Transformers 模型。在实际应用中,可以根据具体情况调整 batch size 和其他参数,以达到最佳性能。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
登录token无效
登录token无效

登录token无效解决方法:1、检查token的有效期限,如果token已经过期,需要重新获取一个新的token;2、检查token的签名,如果签名不正确,需要重新获取一个新的token;3、检查密钥的正确性,如果密钥不正确,需要重新获取一个新的token;4、使用HTTPS协议传输token,建议使用HTTPS协议进行传输 ;5、使用双因素认证,双因素认证可以提高账户的安全性。

6172

2023.09.14

登录token无效怎么办
登录token无效怎么办

登录token无效的解决办法有检查Token是否过期、检查Token是否正确、检查Token是否被篡改、检查Token是否与用户匹配、清除缓存或Cookie、检查网络连接和服务器状态、重新登录或请求新的Token、联系技术支持或开发人员等。本专题为大家提供token相关的文章、下载、课程内容,供大家免费下载体验。

819

2023.09.14

token怎么获取
token怎么获取

获取token值的方法:1、小程序调用“wx.login()”获取 临时登录凭证code,并回传到开发者服务器;2、开发者服务器以code换取,用户唯一标识openid和会话密钥“session_key”。想了解更详细的内容,可以阅读本专题下面的文章。

1069

2023.12.21

token什么意思
token什么意思

token是一种用于表示用户权限、记录交易信息、支付虚拟货币的数字货币。可以用来在特定的网络上进行交易,用来购买或出售特定的虚拟货币,也可以用来支付特定的服务费用。想了解更多token什么意思的相关内容可以访问本专题下面的文章。

1358

2024.03.01

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

clawdbot ai使用教程 保姆级clawdbot部署安装手册
clawdbot ai使用教程 保姆级clawdbot部署安装手册

Clawdbot是一个“有灵魂”的AI助手,可以帮用户清空收件箱、发送电子邮件、管理日历、办理航班值机等等,并且可以接入用户常用的任何聊天APP,所有的操作均可通过WhatsApp、Telegram等平台完成,用户只需通过对话,就能操控设备自动执行各类任务。

18

2026.01.29

clawdbot龙虾机器人官网入口 clawdbot ai官方网站地址
clawdbot龙虾机器人官网入口 clawdbot ai官方网站地址

clawdbot龙虾机器人官网入口:https://clawd.bot/,clawdbot ai是一个“有灵魂”的AI助手,可以帮用户清空收件箱、发送电子邮件、管理日历、办理航班值机等等,并且可以接入用户常用的任何聊天APP,所有的操作均可通过WhatsApp、Telegram等平台完成,用户只需通过对话,就能操控设备自动执行各类任务。

12

2026.01.29

Golang 网络安全与加密实战
Golang 网络安全与加密实战

本专题系统讲解 Golang 在网络安全与加密技术中的应用,包括对称加密与非对称加密(AES、RSA)、哈希与数字签名、JWT身份认证、SSL/TLS 安全通信、常见网络攻击防范(如SQL注入、XSS、CSRF)及其防护措施。通过实战案例,帮助学习者掌握 如何使用 Go 语言保障网络通信的安全性,保护用户数据与隐私。

8

2026.01.29

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号