0

0

使用 Transformers 解决 BERT 词嵌入中的内存问题

霞舞

霞舞

发布时间:2025-10-19 12:55:11

|

631人浏览过

|

来源于php中文网

原创

使用 transformers 解决 bert 词嵌入中的内存问题

本文旨在解决在使用 BERT 等 Transformer 模型进行词嵌入时遇到的内存不足问题。通过直接使用 tokenizer 处理文本输入,避免 `batch_encode_plus` 可能带来的问题。同时,提供了降低批次大小以进一步优化内存使用的建议,帮助用户高效地生成词嵌入。

在使用 BERT 或其他 Transformer 模型处理大量文本数据生成词嵌入时,OutOfMemoryError 是一个常见的问题。这通常是由于模型参数过多、输入序列过长或批次大小过大造成的。本文提供一种更高效的方法,通过优化文本预处理流程和调整批次大小来解决这个问题。

优化文本预处理

通常,我们会先使用 batch_encode_plus 对文本进行分词和编码,然后再将其输入到模型中。然而,对于长文本数据集,这种方法可能会导致内存占用过高。一种更优的方案是直接使用 tokenizer 处理文本输入,让 tokenizer 自身处理文本的截断、填充等操作。

以下是使用 AutoModel 和 AutoTokenizer 的示例代码:

import torch
from transformers import AutoModel, AutoTokenizer

# 输入文本列表 (可以是长句子)
texts = ['This is a test sentence.', 'Another test sentence.']

# 加载预训练模型和 tokenizer
model_name = "indolem/indobert-base-uncased" # 这里替换成你想要使用的模型
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 使用 tokenizer 对文本进行分词、截断和填充
tokenized_texts = tokenizer(texts, 
                            max_length=512,  # 根据实际情况调整
                            truncation=True, 
                            padding=True, 
                            return_tensors='pt')

这段代码首先加载了预训练的 BERT 模型和 tokenizer。然后,它使用 tokenizer 对文本进行分词、截断和填充,并将结果转换为 PyTorch 张量。通过这种方式,tokenizer 可以更好地管理内存,避免 batch_encode_plus 可能带来的问题。

模型前向传播

接下来,将编码后的文本输入到模型中进行前向传播,获取词嵌入:

# 前向传播
with torch.no_grad():
    input_ids, attention_mask = tokenized_texts['input_ids'], tokenized_texts['attention_mask']

    outputs = model(input_ids=input_ids, 
                    attention_mask=attention_mask)

    word_embeddings = outputs.last_hidden_state

这段代码使用 torch.no_grad() 上下文管理器禁用梯度计算,以减少内存占用。然后,它将 input_ids 和 attention_mask 输入到模型中,获取 last_hidden_state,即词嵌入。

腾讯AI 开放平台
腾讯AI 开放平台

腾讯AI开放平台

下载

结果分析

获取的词嵌入的形状为 [batch_size, num_seq_tokens, embed_size],其中 batch_size 是批次大小,num_seq_tokens 是序列长度,embed_size 是嵌入维度。

print(word_embeddings.shape)
# output: torch.Size([2, 4, 768])

进一步优化:降低批次大小

如果仍然遇到 OutOfMemoryError,可以尝试降低批次大小。这意味着将数据集分成更小的批次进行处理。

batch_size = 8  # 根据实际情况调整
for i in range(0, len(texts), batch_size):
    batch_texts = texts[i:i+batch_size]
    tokenized_texts = tokenizer(batch_texts, 
                                max_length=512,  # 根据实际情况调整
                                truncation=True, 
                                padding=True, 
                                return_tensors='pt')
    with torch.no_grad():
        input_ids, attention_mask = tokenized_texts['input_ids'], tokenized_texts['attention_mask']

        outputs = model(input_ids=input_ids, 
                        attention_mask=attention_mask)

        word_embeddings = outputs.last_hidden_state
        # 对 word_embeddings 进行后续处理

这段代码将数据集分成大小为 batch_size 的批次,并逐批处理。通过降低批次大小,可以显著减少内存占用。

总结与注意事项

通过直接使用 tokenizer 处理文本输入和降低批次大小,可以有效地解决在使用 BERT 等 Transformer 模型进行词嵌入时遇到的内存不足问题。

注意事项:

  • max_length 参数需要根据数据集的实际情况进行调整。过小的 max_length 可能会导致信息丢失,过大的 max_length 会增加内存占用。
  • 选择合适的预训练模型也很重要。较大的模型通常具有更好的性能,但也需要更多的内存。
  • 如果仍然遇到内存问题,可以考虑使用更小的模型或增加 GPU 内存。
  • 根据实际情况,可以尝试使用梯度累积等技术来进一步优化内存使用。
  • 在Colab上使用GPU时,确保已经选择了GPU运行时环境。
  • 可以尝试使用torch.cuda.empty_cache()释放不再使用的GPU内存。

通过以上方法,可以更有效地使用 Transformer 模型生成词嵌入,并避免 OutOfMemoryError。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

432

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

c++ 根号
c++ 根号

本专题整合了c++根号相关教程,阅读专题下面的文章了解更多详细内容。

17

2026.01.23

c++空格相关教程合集
c++空格相关教程合集

本专题整合了c++空格相关教程,阅读专题下面的文章了解更多详细内容。

22

2026.01.23

yy漫画官方登录入口地址合集
yy漫画官方登录入口地址合集

本专题整合了yy漫画入口相关合集,阅读专题下面的文章了解更多详细内容。

91

2026.01.23

漫蛙最新入口地址汇总2026
漫蛙最新入口地址汇总2026

本专题整合了漫蛙最新入口地址大全,阅读专题下面的文章了解更多详细内容。

124

2026.01.23

C++ 高级模板编程与元编程
C++ 高级模板编程与元编程

本专题深入讲解 C++ 中的高级模板编程与元编程技术,涵盖模板特化、SFINAE、模板递归、类型萃取、编译时常量与计算、C++17 的折叠表达式与变长模板参数等。通过多个实际示例,帮助开发者掌握 如何利用 C++ 模板机制编写高效、可扩展的通用代码,并提升代码的灵活性与性能。

14

2026.01.23

php远程文件教程合集
php远程文件教程合集

本专题整合了php远程文件相关教程,阅读专题下面的文章了解更多详细内容。

65

2026.01.22

PHP后端开发相关内容汇总
PHP后端开发相关内容汇总

本专题整合了PHP后端开发相关内容,阅读专题下面的文章了解更多详细内容。

59

2026.01.22

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 19.2万人学习

Rust 教程
Rust 教程

共28课时 | 4.8万人学习

Git 教程
Git 教程

共21课时 | 3万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号