-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
我最近也在写类似这个的插件,自己训练了识别角色的模型让机器人检测是否有两个角色同框的图片,被tensorflowjs的依赖阻拦了很久,在网上搜索了相关问题发现了贵项目,看到需要手动补充文件并且在windows上使用很费劲。
我发现tensorflow官方有提供wasm作为后端,可以用他来代替需繁琐补充和编译的libtensorflow.so.2(tensorflow.dll)以及tfjs-building.node,虽然性能可能略差于原生C语言库,但是兼容性会更好,下方是我写的一个demo,目前在linux和windows上可以直接运行,无需手动补充任何依赖
import * as tf from '@tensorflow/tfjs-node';
import { setWasmPaths } from '@tensorflow/tfjs-backend-wasm'; // 导入WASM 后端
import sharp from 'sharp';
import fs from 'fs';
import path from 'path';
import { fileURLToPath } from 'url';
// --- 0. 设置路径 ---
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
// 指定WASM文件路径
const wasmPath = path.join(__dirname, 'node_modules', '@tensorflow', 'tfjs-backend-wasm', 'dist', path.sep);
// --- 1. 定义文件路径 ---
const MODEL_PATH = path.join(__dirname, 'my_tfjs_model', 'model.json');
const IMAGE_PATH = path.join(__dirname, 'test.jpg');
const LABELS_PATH = path.join(__dirname, 'labels.txt');
const IMAGE_SIZE = 224; // 必须和训练时设置的 IMAGE_SIZE 一致
// --- 2. 加载类别标签 ---
let labels = [];
try {
const labelsContent = fs.readFileSync(LABELS_PATH, 'utf8');
labels = labelsContent.split('\n')
.filter(Boolean)
.map(line => line.split(':')[1].replace(/\r/g, ''));
if (labels.length === 0) {
throw new Error('labels.txt 为空或格式不正确。');
}
console.log('✅ 标签加载成功:', labels);
} catch (error) {
console.error(`❌ 无法加载 labels.txt: ${error.message}`);
process.exit(1);
}
// --- 3. 图像预处理函数 (【使用 Sharp 修复】) ---
async function preprocessImage(imagePath) {
// 使用 sharp 从路径读取、移除 alpha 通道、并获取原始像素数据
const { data, info } = await sharp(imagePath)
.removeAlpha() // 确保只有 3 个通道 (RGB)
.raw() // 获取原始像素数据
.toBuffer({ resolveWithObject: true }); // 获取 Buffer 和 info
// 2. 从 info 中获取高度和宽度
const { height, width } = info;
// 3. 将原始像素数据 (Buffer) 创建为 3D Tensor ([height, width, 3])
const imageTensor = tf.tensor3d(data, [height, width, 3], 'int32');
// 4. 缩放图片到模型输入的尺寸 [224, 224]
// (tf.image.resizeBilinear 将在 WASM 后端上运行)
const resizedTensor = tf.image.resizeBilinear(imageTensor, [IMAGE_SIZE, IMAGE_SIZE]);
// 5. 归一化
const normalizedTensor = resizedTensor.asType('float32').div(255.0);
// 6. 增加一个维度 (batch 维度),[1, 224, 224, 3]
const batchTensor = normalizedTensor.expandDims(0);
// 7. 内存回收 (清理中间过程的 tensor)
imageTensor.dispose();
resizedTensor.dispose();
normalizedTensor.dispose();
return batchTensor;
}
// --- 4. 主函数:加载并运行模型 ---
async function run() {
// 初始化 WASM 后端
console.log(`... 正在设置 WASM 后端路径: ${wasmPath} ...`);
setWasmPaths(wasmPath);
console.log('... 正在初始化 WASM 后端 ...');
// 确保覆盖 tfjs-node 的 C++ 后端,明确使用 wasm
await tf.setBackend('wasm');
await tf.ready(); // 等待后端完全准备好
console.log('✅ TFJS WASM 后端已激活。');
console.log('... 正在加载模型 ...');
// tf.io.fileSystem 现在在 tf 对象上可用
const handler = tf.io.fileSystem(MODEL_PATH);
// 1. 加载模型 (传入 IOHandler 对象)
const model = await tf.loadGraphModel(handler);
console.log('✅ 模型加载成功。');
// 2. 预处理图片
console.log(`... 正在处理图片: ${IMAGE_PATH} ...`);
const imageTensor = await preprocessImage(IMAGE_PATH);
// 3. 执行预测 (GraphModel 使用 .execute())
console.log('... 正在执行预测 ...');
const predictions = model.execute(imageTensor);
// 4. 获取预测结果 (概率数组)
const probabilities = await predictions.data();
// 5. 显示结果
console.log('--- 识别结果 ---');
let results = [];
probabilities.forEach((probability, index) => {
results.push({
label: labels[index] || `未知类别 ${index}`,
probability: probability
});
});
// 按概率从高到低排序
results.sort((a, b) => b.probability - a.probability);
// 打印格式化的结果
results.forEach(res => {
const percentage = (res.probability * 100).toFixed(2);
console.log(`[${percentage}%] - ${res.label}`);
});
// 6. 彻底清理内存
imageTensor.dispose();
predictions.dispose();
model.dispose();
}
// --- 5. 启动脚本 ---
run().catch(error => {
console.error('❌ 脚本运行出错:', error);
});
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels