-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathremoteInference.ts
161 lines (143 loc) · 6.67 KB
/
remoteInference.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import { ICompletions, IParams, AIRequestType, RemoteBackendOPModel, JsonStreamParser } from "../../types/types";
import { GenerationParams, CompletionParams, InsertionParams } from "../../types/models";
import { buildSolgptPrompt } from "../../prompts/promptBuilder";
import EventEmitter from "events";
import { ChatHistory } from "../../prompts/chat";
import axios from 'axios';
const defaultErrorMessage = `Unable to get a response from AI server`
export class RemoteInferencer implements ICompletions {
api_url: string
completion_url: string
max_history = 7
model_op = RemoteBackendOPModel.CODELLAMA // default model operation change this to llama if necessary
event: EventEmitter
test_env=true
test_url="http://solcodertest.org"
constructor(apiUrl?:string, completionUrl?:string) {
this.api_url = apiUrl!==undefined ? apiUrl: this.test_env? this.test_url : "https://solcoder.remixproject.org"
this.completion_url = completionUrl!==undefined ? completionUrl : this.test_env? this.test_url : "https://completion.remixproject.org"
this.event = new EventEmitter()
}
private async _makeRequest(payload, rType:AIRequestType){
this.event.emit("onInference")
const requestURL = rType === AIRequestType.COMPLETION ? this.completion_url : this.api_url
try {
const options = { headers: { 'Content-Type': 'application/json', } }
const result = await axios.post(requestURL, payload, options)
switch (rType) {
case AIRequestType.COMPLETION:
if (result.statusText === "OK")
return result.data.generatedText
else {
return defaultErrorMessage
}
case AIRequestType.GENERAL:
if (result.statusText === "OK") {
if (result.data?.error) return result.data?.error
const resultText = result.data.generatedText
ChatHistory.pushHistory(payload.prompt, resultText)
return resultText
} else {
return defaultErrorMessage
}
}
} catch (e) {
ChatHistory.clearHistory()
console.error('Error making request to Inference server:', e.message)
}
finally {
this.event.emit("onInferenceDone")
}
}
private async _streamInferenceRequest(payload, rType:AIRequestType){
let resultText = ""
try {
this.event.emit('onInference')
const requestURL = rType === AIRequestType.COMPLETION ? this.completion_url : this.api_url
const response = await fetch(requestURL, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload),
});
if (payload.return_stream_response) {
return response
}
const reader = response.body?.getReader();
const decoder = new TextDecoder();
const parser = new JsonStreamParser();
// eslint-disable-next-line no-constant-condition
while (true) {
const { done, value } = await reader.read();
if (done) break;
try {
console.log("value" + decoder.decode(value))
const chunk = parser.safeJsonParse<{ generatedText: string; isGenerating: boolean }>(decoder.decode(value, { stream: true }));
for (const parsedData of chunk) {
if (parsedData.isGenerating) {
this.event.emit('onStreamResult', parsedData.generatedText);
resultText = resultText + parsedData.generatedText
} else {
// stream generation is complete
resultText = resultText + parsedData.generatedText
ChatHistory.pushHistory(payload.prompt, resultText)
return parsedData.generatedText
}
}
} catch (error) {
console.error('Error parsing JSON:', error);
ChatHistory.clearHistory()
}
}
return resultText
} catch (error) {
ChatHistory.clearHistory()
console.error('Error making stream request to Inference server:', error.message);
}
finally {
this.event.emit('onInferenceDone')
}
}
async code_completion(prompt, promptAfter, ctxFiles, fileName, options:IParams=CompletionParams): Promise<any> {
const payload = { prompt, 'context':promptAfter, "endpoint":"code_completion",
'ctxFiles':ctxFiles, 'currentFileName':fileName, ...options }
return this._makeRequest(payload, AIRequestType.COMPLETION)
}
async code_insertion(msg_pfx, msg_sfx, ctxFiles, fileName, options:IParams=InsertionParams): Promise<any> {
const payload = { "endpoint":"code_insertion", msg_pfx, msg_sfx, 'ctxFiles':ctxFiles,
'currentFileName':fileName, ...options, prompt: '' }
return this._makeRequest(payload, AIRequestType.COMPLETION)
}
async code_generation(prompt, options:IParams=GenerationParams): Promise<any> {
const payload = { prompt, "endpoint":"code_completion", ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.COMPLETION)
else return this._makeRequest(payload, AIRequestType.COMPLETION)
}
async solidity_answer(prompt, options:IParams=GenerationParams): Promise<any> {
const main_prompt = buildSolgptPrompt(prompt, this.model_op)
const payload = { 'prompt': main_prompt, "endpoint":"solidity_answer", ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.GENERAL)
else return this._makeRequest(payload, AIRequestType.GENERAL)
}
async code_explaining(prompt, context:string="", options:IParams=GenerationParams): Promise<any> {
const payload = { prompt, "endpoint":"code_explaining", context, ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.GENERAL)
else return this._makeRequest(payload, AIRequestType.GENERAL)
}
async error_explaining(prompt, options:IParams=GenerationParams): Promise<any> {
const payload = { prompt, "endpoint":"error_explaining", ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.GENERAL)
else return this._makeRequest(payload, AIRequestType.GENERAL)
}
async vulnerability_check(prompt, options:IParams=GenerationParams): Promise<any> {
const payload = { prompt, "endpoint":"vulnerability_check", ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.GENERAL)
else return this._makeRequest(payload, AIRequestType.GENERAL)
}
async generate(userPrompt, options:IParams=GenerationParams): Promise<any> {
const payload = { prompt: userPrompt, "endpoint":"generate", ...options }
if (options.stream_result) return this._streamInferenceRequest(payload, AIRequestType.GENERAL)
else return this._makeRequest(payload, AIRequestType.GENERAL)
}
}