0

0

CPM-Distill:经过知识蒸馏的小型文本生成模型

P粉084495128

P粉084495128

发布时间:2025-07-18 13:46:16

|

246人浏览过

|

来源于php中文网

原创

本文介绍知识蒸馏技术及基于PaddleNLP加载CPM-Distill模型实现文本生成。知识蒸馏是模型压缩方法,以“教师-学生网络”思想,让简单模型拟合复杂模型输出,效果优于从头训练。CPM-Distill由GPT-2 Large蒸馏得到,文中还给出安装依赖、加载模型、解码方法及文本生成示例。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

cpm-distill:经过知识蒸馏的小型文本生成模型 - php中文网

引入

  • 近些年来,随着 Bert 这样的大规模预训练模型的问世,NLP 领域的模型也逐渐变得越来越大了
  • 受限于算力水平,如此大规模的模型要应用在实际的部署场景都是不太实际的
  • 因此需要通过一些方式对大规模的模型进行压缩,使其能够在部署场景下达到一个相对可用的速度
  • 常见的模型压缩方法有:剪枝、量化、知识蒸馏等
  • 最近 CPM(Chinese Pre-Trained Models)项目又开源了一个使用知识蒸馏得到的小型文本生成模型 CPM-Distill
  • 本次项目就简单介绍一下知识蒸馏技术并且通过 PaddleNLP 套件加载 CPM-Distill 模型实现文本生成

相关项目

  • Paddle2.0:构建一个经典的文本生成模型GPT-2
  • 文本生成:使用GPT-2加载CPM-LM模型实现简单的问答机器人
  • 文本生成:让AI帮你写文章吧
  • 【AI创造营】PaddleHub 配合 PaddleNLP 实现简单的文本生成

相关资料

  • 论文:
    • CPM: A Large-scale Generative Chinese Pre-trained Language Model
    • Distilling the Knowledge in a Neural Network
  • 官方实现:TsinghuaAI/CPM-Distill

模型压缩技术

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

知识蒸馏(Knowledge Distillation)

  • 知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法。

  • 由 Hinton 在 2015 年 Distilling the Knowledge in a Neural Network 的论文首次提出了知识蒸馏的并尝试在 CV 领域中使用,旨在把大模型学到的知识灌输到小模型中,以达到缩小模型的目标,示意图如下:

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

maven使用方法 中文WORD版
maven使用方法 中文WORD版

本文档主要讲述的是maven使用方法;Maven是基于项目对象模型的(pom),可以通过一小段描述信息来管理项目的构建,报告和文档的软件项目管理工具。Maven将你的注意力从昨夜基层转移到项目管理层。Maven项目已经能够知道 如何构建和捆绑代码,运行测试,生成文档并宿主项目网页。希望本文档会给有需要的朋友带来帮助;感兴趣的朋友可以过来看看

下载
  • 说人话就是指用一个简单模型去拟合复杂模型的输出,这个输出也叫做“软标签”,当然也可以加入真实数据作为“硬标签”一同训练。
  • 使用知识蒸馏技术相比直接从头训练的效果一般会更好一些,因为教师模型能够指导学生模型收敛到一个更佳的位置。

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

  • 知识蒸馏技术除了可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;
  • 也可以将多个网络的学到的知识转移到一个网络中,使得单个网络的性能接近 emsemble 的结果。

蒸馏模型信息

  • 教师模型为 GPT-2 Large,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=2560,
    num_hidden_layers=32,
    num_attention_heads=32,
    intermediate_size=10240,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)
  • 学生模型为 GPT-2 Small,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)

蒸馏 loss

  • 将大模型和小模型每个位置上输出之间的 KL 散度作为蒸馏 loss,同时加上原来的 language model loss。总 loss 如下:

CPM-Distill:经过知识蒸馏的小型文本生成模型 - php中文网

其中 LlmLlm 为 GPT-2 原始的 language modeling loss。

安装依赖

In [ ]
!pip install paddlenlp==2.0.1 sentencepiece==0.1.92

加载模型

In [1]
import paddlefrom paddlenlp.transformers import GPTModel, GPTForPretraining, GPTChineseTokenizer# tokenizer 与 CPM-LM 模型一致tokenizer = GPTChineseTokenizer.from_pretrained('gpt-cpm-large-cn')# 实例化 GPT2-small 模型gpt = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)# 加载预训练模型参数params = paddle.load('data/data92160/gpt-cpm-small-cn-distill.pdparams')# 设置参数gpt.set_dict(params)# 使用 GPTForPretraining 向模型中添加输出层model = GPTForPretraining(gpt)# 将模型设置为评估模式model.eval()
[2021-05-28 19:38:04,469] [    INFO] - Found /home/aistudio/.paddlenlp/models/gpt-cpm-large-cn/gpt-cpm-cn-sentencepiece.model

模型解码

In [40]
import paddleimport numpy as np# Greedy Searchdef greedy_search(text, max_len=32, end_word=None):
    # # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    
    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token = int(np.argmax(output[0, -1].numpy()))
    ids.append(next_token)    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token = int(np.argmax(output[0, -1].numpy()))
        ids.append(next_token)        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break
    
    return tokenizer.convert_ids_to_string(ids)
In [39]
import paddleimport numpy as np# top_k and top_p filteringdef top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.shape[-1])  # Safety check
    logits_np = logits.numpy()    if top_k > 0:        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
        logits_np[indices_to_remove] = filter_value    if top_p < 1.0:
        sorted_logits = paddle.sort(logits, descending=True)
        sorted_indices = paddle.argsort(logits, descending=True).numpy()
        cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits_np[indices_to_remove] = filter_value    return paddle.to_tensor(logits_np)# Nucleus Sampledef nucleus_sample(text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
    # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token_logits = output[0, -1, :]    for id in set(ids):
        next_token_logits[id] /= repitition_penalty
    next_token_logits = next_token_logits / temperature
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
    ids += [int(next_token)]    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token_logits = output[0, -1, :]        for id in set(ids):
            next_token_logits[id] /= repitition_penalty
        next_token_logits = next_token_logits / temperature
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
        ids += [int(next_token)]        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break

    return tokenizer.convert_ids_to_string(ids)

文本生成

In [41]
# 输入文本inputs = input('请输入文本:')print(inputs)# 使用 Nucleus Sample 进行文本生成outputs = greedy_search(
    inputs, # 输入文本
    max_len=128, # 最大生成文本的长度
    end_word=None)# 打印输出print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名,然后点击“确定”,就可以开始游戏了。
游戏目标:在限定时间内,成功地把所有的牌都通通打完。
In [43]
# 输入文本inputs = input('请输入文本:')print(inputs)for x in range(5):    # 使用 Nucleus Sample 进行文本生成
    outputs = nucleus_sample(
        inputs, # 输入文本
        max_len=128, # 最大生成文本的长度
        end_word='。', # 终止符号
        repitition_penalty=1.0, # 重复度抑制
        temperature=1.0, # 温度
        top_k=3000, # 取前k个最大输出再进行采样
        top_p=0.9 # 抑制概率低于top_p的输出再进行采样
    )    # 打印输出
    print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名、学校、专业及学科,并在社交媒体上公布你的个人简介。
请在此处输入你的姓名或者电话,对方会及时通知你。
请在此处输入你的姓名、民族及籍贯信息,当您找到 CADULI 的联系方式后,我们会按您所选择的申请中心,以电子邮件的形式向您发送邮件。
请在此处输入你的姓名和电话号码,由资深会所接待员进行介绍,因为此处有不少中国的大老板,英文能看。
请在此处输入你的姓名、联系电话、银行卡号和手机号。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

28

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

8

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

31

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

3

2026.01.26

windows安全中心怎么关闭 windows安全中心怎么执行操作
windows安全中心怎么关闭 windows安全中心怎么执行操作

关闭Windows安全中心(Windows Defender)可通过系统设置暂时关闭,或使用组策略/注册表永久关闭。最简单的方法是:进入设置 > 隐私和安全性 > Windows安全中心 > 病毒和威胁防护 > 管理设置,将实时保护等选项关闭。

5

2026.01.26

2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】
2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】

铁路12306提供起售时间查询、起售提醒、购票预填、候补购票及误购限时免费退票五项服务,并强调官方渠道唯一性与信息安全。

35

2026.01.26

个人所得税税率表2026 个人所得税率最新税率表
个人所得税税率表2026 个人所得税率最新税率表

以工资薪金所得为例,应纳税额 = 应纳税所得额 × 税率 - 速算扣除数。应纳税所得额 = 月度收入 - 5000 元 - 专项扣除 - 专项附加扣除 - 依法确定的其他扣除。假设某员工月工资 10000 元,专项扣除 1000 元,专项附加扣除 2000 元,当月应纳税所得额为 10000 - 5000 - 1000 - 2000 = 2000 元,对应税率为 3%,速算扣除数为 0,则当月应纳税额为 2000×3% = 60 元。

12

2026.01.26

oppo云服务官网登录入口 oppo云服务登录手机版
oppo云服务官网登录入口 oppo云服务登录手机版

oppo云服务https://cloud.oppo.com/可以在云端安全存储您的照片、视频、联系人、便签等重要数据。当您的手机数据意外丢失或者需要更换手机时,可以随时将这些存储在云端的数据快速恢复到手机中。

40

2026.01.26

抖币充值官方网站 抖币性价比充值链接地址
抖币充值官方网站 抖币性价比充值链接地址

网页端充值步骤:打开浏览器,输入https://www.douyin.com,登录账号;点击右上角头像,选择“钱包”;进入“充值中心”,操作和APP端一致。注意:切勿通过第三方链接、二维码充值,谨防受骗

7

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 3万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号