Skip to content

Commit 9d18174

Browse files
committed
dev: gemini function calling
1 parent d8c49d8 commit 9d18174

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

src/backend/src/modules/puterai/GeminiService.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const { GoogleGenerativeAI } = require('@google/generative-ai');
33
const GeminiSquareHole = require("./lib/GeminiSquareHole");
44
const { TypedValue } = require("../../services/drivers/meta/Runtime");
55
const putility = require("@heyputer/putility");
6+
const FunctionCalling = require("./lib/FunctionCalling");
67

78
class GeminiService extends BaseService {
89
async _init () {
@@ -31,17 +32,22 @@ class GeminiService extends BaseService {
3132
},
3233

3334
async complete ({ messages, stream, model, tools }) {
35+
tools = FunctionCalling.make_gemini_tools(tools);
36+
3437
const genAI = new GoogleGenerativeAI(this.config.apiKey);
3538
const genModel = genAI.getGenerativeModel({
3639
model: model ?? 'gemini-2.0-flash',
40+
tools,
3741
});
3842

3943
messages = await GeminiSquareHole.process_input_messages(messages);
4044

4145
// History is separate, so the last message gets special treatment.
4246
const last_message = messages.pop();
4347
const last_message_parts = last_message.parts.map(
44-
part => typeof part === 'string' ? part : part.text
48+
part => typeof part === 'string' ? part :
49+
typeof part.text === 'string' ? part.text :
50+
part
4551
);
4652

4753
const chat = genModel.startChat({

src/backend/src/modules/puterai/lib/FunctionCalling.js

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,16 @@ module.exports = class FunctionCalling {
9090
};
9191
});
9292
}
93+
94+
static make_gemini_tools (tools) {
95+
return [
96+
{
97+
function_declarations: tools.map(t => {
98+
const tool = t.function;
99+
delete tool.parameters.additionalProperties;
100+
return tool;
101+
})
102+
}
103+
];
104+
}
93105
}

src/backend/src/modules/puterai/lib/GeminiSquareHole.js

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
* but Google's AI API defies all the established conventions
44
* so it made sense to defy them here as well.
55
*/
6+
7+
const crypto = require('crypto');
68
module.exports = class GeminiSquareHole {
79
static process_input_messages = async (messages) => {
810
messages = messages.slice();
@@ -14,6 +16,35 @@ module.exports = class GeminiSquareHole {
1416
if ( msg.role === 'assistant' ) {
1517
msg.role = 'model';
1618
}
19+
20+
for ( let i=0 ; i < msg.parts.length ; i++ ) {
21+
const part = msg.parts[i];
22+
console.log('what the part is', part);
23+
if ( part.type === 'tool_use' ) {
24+
msg.parts[i] = {
25+
functionCall: {
26+
name: part.id,
27+
args: part.input,
28+
},
29+
};
30+
}
31+
if ( part.type === 'tool_result' ) {
32+
msg.parts[i] = {
33+
functionResponse: {
34+
name: part.tool_use_id,
35+
response: {
36+
name: part.tool_use_id,
37+
content: part.content,
38+
},
39+
},
40+
};
41+
}
42+
if ( part.type === 'text' ) {
43+
msg.parts[i] = {
44+
text: part.text,
45+
};
46+
}
47+
}
1748
}
1849

1950
return messages;
@@ -46,7 +77,12 @@ module.exports = class GeminiSquareHole {
4677
usage_promise,
4778
}) => async ({ chatStream }) => {
4879
const message = chatStream.message();
80+
4981
let textblock = message.contentBlock({ type: 'text' });
82+
let toolblock = null;
83+
let mode = 'text';
84+
85+
5086
let last_usage = null;
5187
for await ( const chunk of stream ) {
5288
// This is spread across several lines so that the stack trace
@@ -56,6 +92,31 @@ module.exports = class GeminiSquareHole {
5692
const content = candidate.content;
5793
const parts = content.parts;
5894
for ( const part of parts ) {
95+
if ( part.functionCall ) {
96+
if ( mode === 'text' ) {
97+
mode = 'tool';
98+
textblock.end();
99+
}
100+
101+
toolblock = message.contentBlock({
102+
type: 'tool_use',
103+
id: part.functionCall.name,
104+
name: part.functionCall.name,
105+
});
106+
toolblock.addPartialJSON(JSON.stringify(
107+
part.functionCall.args,
108+
));
109+
110+
continue;
111+
}
112+
113+
if ( mode === 'tool' ) {
114+
mode = 'text';
115+
toolblock.end();
116+
textblock = message.contentBlock({ type: 'text' });
117+
}
118+
119+
// assume text as default
59120
const text = part.text;
60121
textblock.addText(text);
61122
}
@@ -65,7 +126,8 @@ module.exports = class GeminiSquareHole {
65126

66127
usage_promise.resolve(last_usage);
67128

68-
textblock.end();
129+
if ( mode === 'text' ) textblock.end();
130+
if ( mode === 'tool' ) toolblock.end();
69131
message.end();
70132
chatStream.end();
71133
}

0 commit comments

Comments
 (0)