Java中用DJL加载和运行预训练模型只需三步:添加依赖(如djl-api、pytorch-engine等)、选择模型(URL/本地路径/模型ID)、构建Predictor执行推理;DJL自动适配PyTorch等引擎,无需编写底层计算逻辑。

Java中用DJL加载和运行预训练模型,核心是三步:添加依赖、选择模型(本地或远程)、构建Predictor执行推理。不需要写底层计算逻辑,DJL自动处理引擎适配(如PyTorch、TensorFlow、ONNX Runtime)。
1. 添加DJL依赖(Maven)
DJL支持多引擎,推荐从PyTorch开始(生态成熟、模型丰富)。在pom.xml中引入:
-
核心API:
djl-api -
PyTorch引擎:
model-zoo+pytorch-engine -
预编译本地库(免编译):
pytorch-native-auto(自动匹配系统架构)
示例依赖片段:
ai.djl api 0.27.0 ai.djl.pytorch pytorch-engine 0.27.0 ai.djl.pytorch pytorch-native-auto 2.1.2
2. 加载预训练模型(支持URL/本地路径/模型ID)
DJL内置ModelZoo,可直接用HuggingFace ID或DJL Model Zoo地址加载。例如加载bert-base-uncased文本分类模型:
立即学习“Java免费学习笔记(深入)”;
- 用
Criteria声明输入输出类型、模型来源、设备(CPU/GPU) - 调用
ModelLoader.loadModel()获得Model实例 - 注意:首次加载会自动下载模型权重到本地缓存(
~/.djl.ai/cache)
代码示例:
Criteriacriteria = Criteria.builder() .setTypes(String.class, Classifications.class) .optModelUrls("https://resources.djl.ai/test-models/pytorch/transformers/bert-base-uncased.zip") .optEngine("PyTorch") .optTranslator(new BertTranslator()) .build(); Model model = Model.newInstance("bert"); model = ModelLoader.loadModel(criteria);
3. 构建Predictor并运行推理
Predictor是执行推理的入口,封装了预处理、前向传播、后处理。创建后调用predict()即可:
-
Translator负责输入转NDArray、输出转业务对象(如Classifications) - 支持批量输入(List),也支持单条字符串
- 用完记得
close()释放资源(推荐try-with-resources)
完整推理示例:
try (Predictorpredictor = model.newPredictor(new BertTranslator())) { Classifications result = predictor.predict("I love DJL!"); System.out.println(result); // 输出类似:positive: 0.982, negative: 0.018 }
4. 常见问题与建议
实际使用时容易卡在环境或格式上,注意以下几点:
- GPU支持需安装CUDA驱动+cuDNN,并用
pytorch-native-cu118等对应版本依赖 - 模型输入必须匹配
Translator定义(如Bert要tokenize,CNN图像要resize+normalize) - 自定义模型:把
model.pt和synset.txt等放在同目录,用optModelPath(Paths.get("models/my-model")) - 性能优化:启用
setLimit(1)限制线程数,或用Model.setBlock()手动指定计算图










