-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhandle_node_configure.js
More file actions
72 lines (65 loc) · 2.62 KB
/
handle_node_configure.js
File metadata and controls
72 lines (65 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import { app } from '../../scripts/app.js'
import { applyBadgeToNode } from './handle_load_nodes.js'
import { hasModelInput } from './model_price.js'
// 链式callback实现,保留原始callback并添加新callback
function chainCallback(originalCallback, newCallback) {
return async function (...args) {
// 先执行原始callback
if (originalCallback) {
const result = originalCallback.apply(this, args)
// 如果原始callback返回Promise,等待它完成
if (result && typeof result.then === 'function') {
await result
}
}
// 然后执行新callback
await newCallback.apply(this, args)
}
}
function isImagesInputSlot(slot) {
return slot?.name === 'images' || slot?.localized_name === 'images'
}
app.registerExtension({
name: 'bizyair.handle.node.configure',
nodeCreated(node, app) {
// 做忽略的widget 这些widget不做获取价格的操作
const ignoreWidgets = ['prompt', 'negative_prompt', 'inputcount']
// 在这里可以拿到变化之后的值,并且也可以拿到node,这时候给node切换badge即可
if (node && node.widgets && Array.isArray(node.widgets)) {
if (!hasModelInput(node)) {
return
}
if (!node._bizyairPriceConnectionHooked) {
const originalOnConnectionsChange = node.onConnectionsChange
node.onConnectionsChange = chainCallback(
originalOnConnectionsChange,
async function (type, slotIndex, isConnected, linkInfo, ioSlot) {
const inputSlot = this.inputs?.[slotIndex]
if (!isImagesInputSlot(inputSlot)) {
return
}
// 只处理 images 输入槽位本身的连线变化,避免误判 output 变化。
if (ioSlot && ioSlot !== inputSlot) {
return
}
await applyBadgeToNode(node, true)
}
)
node._bizyairPriceConnectionHooked = true
}
// 不仅仅是切换model才会修改模型定价,比如切换输入参数也会修改模型定价
node.widgets.forEach(widget => {
// 对于prompt这种输入频繁的widget 不做获取价格操作
if (ignoreWidgets.includes(widget.name) || widget._bizyairManualPriceRefresh === true) {
return
}
// 保存原始callback,使用链式callback保留原始功能
const originalCallback = widget.callback
widget.callback = chainCallback(originalCallback, async function () {
// 用户手动修改widget时,强制刷新badge(不使用缓存的模型类型)
await applyBadgeToNode(node, true)
})
})
}
}
})