Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/vs/platform/extensions/common/extensionsApiProposals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ const _allApiProposals = {
tokenInformation: {
proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.tokenInformation.d.ts',
},
toolProgress: {
proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.toolProgress.d.ts',
},
treeViewActiveItem: {
proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.treeViewActiveItem.d.ts',
},
Expand Down
20 changes: 14 additions & 6 deletions src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import { CancellationToken } from '../../../base/common/cancellation.js';
import { Disposable, DisposableMap } from '../../../base/common/lifecycle.js';
import { revive } from '../../../base/common/marshalling.js';
import { IProgress, IProgressStep } from '../../../platform/progress/common/progress.js';
import { CountTokensCallback, ILanguageModelToolsService, IToolData, IToolInvocation, IToolResult } from '../../contrib/chat/common/languageModelToolsService.js';
import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';
import { Dto } from '../../services/extensions/common/proxyIdentifier.js';
Expand All @@ -16,7 +17,10 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre

private readonly _proxy: ExtHostLanguageModelToolsShape;
private readonly _tools = this._register(new DisposableMap<string>());
private readonly _countTokenCallbacks = new Map</* call ID */string, CountTokensCallback>();
private readonly _runningToolCalls = new Map</* call ID */string, {
countTokens: CountTokensCallback;
progress: IProgress<IProgressStep>;
}>();

constructor(
extHostContext: IExtHostContext,
Expand Down Expand Up @@ -45,26 +49,30 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre
};
}

$acceptToolProgress(callId: string, progress: IProgressStep): void {
this._runningToolCalls.get(callId)?.progress.report(progress);
}

$countTokensForInvocation(callId: string, input: string, token: CancellationToken): Promise<number> {
const fn = this._countTokenCallbacks.get(callId);
const fn = this._runningToolCalls.get(callId);
if (!fn) {
throw new Error(`Tool invocation call ${callId} not found`);
}

return fn(input, token);
return fn.countTokens(input, token);
}

$registerTool(id: string): void {
const disposable = this._languageModelToolsService.registerToolImplementation(
id,
{
invoke: async (dto, countTokens, token) => {
invoke: async (dto, countTokens, progress, token) => {
try {
this._countTokenCallbacks.set(dto.callId, countTokens);
this._runningToolCalls.set(dto.callId, { countTokens, progress });
const resultDto = await this._proxy.$invokeTool(dto, token);
return revive(resultDto) as IToolResult;
} finally {
this._countTokenCallbacks.delete(dto.callId);
this._runningToolCalls.delete(dto.callId);
}
},
prepareToolInvocation: (parameters, token) => this._proxy.$prepareToolInvocation(id, parameters, token),
Expand Down
1 change: 1 addition & 0 deletions src/vs/workbench/api/common/extHost.protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ export type IToolDataDto = Omit<IToolData, 'when'>;

export interface MainThreadLanguageModelToolsShape extends IDisposable {
$getTools(): Promise<Dto<IToolDataDto>[]>;
$acceptToolProgress(callId: string, progress: IProgressStep): void;
$invokeTool(dto: IToolInvocation, token?: CancellationToken): Promise<Dto<IToolResult>>;
$countTokensForInvocation(callId: string, input: string, token: CancellationToken): Promise<number>;
$registerTool(id: string): void;
Expand Down
15 changes: 14 additions & 1 deletion src/vs/workbench/api/common/extHostLanguageModelTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
throw new Error(`Unknown tool ${dto.toolId}`);
}

const options: vscode.LanguageModelToolInvocationOptions<Object> = {
const options: vscode.LanguageModelToolInvocation<Object> = {
input: dto.parameters,
toolInvocationToken: dto.context as vscode.ChatParticipantToolToken | undefined,
progress: undefined!,
};
if (isProposedApiEnabled(item.extension, 'chatParticipantPrivate')) {
options.chatRequestId = dto.chatRequestId;
Expand All @@ -138,6 +139,18 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
};
}

if (isProposedApiEnabled(item.extension, 'toolProgress')) {
options.progress = {
report: value => {
this._proxy.$acceptToolProgress(dto.callId, {
message: value.message,
increment: value.increment,
total: 100,
});
}
};
}

const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token)), token);
if (!extensionResult) {
throw new CancellationError();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { Codicon } from '../../../../../base/common/codicons.js';
import { Emitter } from '../../../../../base/common/event.js';
import { IMarkdownString, MarkdownString } from '../../../../../base/common/htmlContent.js';
import { Disposable, DisposableStore, IDisposable, thenIfNotDisposed, toDisposable } from '../../../../../base/common/lifecycle.js';
import { autorunWithStore } from '../../../../../base/common/observable.js';
import { ThemeIcon } from '../../../../../base/common/themables.js';
import { URI } from '../../../../../base/common/uri.js';
import { generateUuid } from '../../../../../base/common/uuid.js';
Expand Down Expand Up @@ -454,27 +455,37 @@ class ChatToolInvocationSubPart extends Disposable {
}

private createProgressPart(): HTMLElement {
let content: IMarkdownString;
if (this.toolInvocation.isComplete && this.toolInvocation.isConfirmed !== false && this.toolInvocation.pastTenseMessage) {
content = typeof this.toolInvocation.pastTenseMessage === 'string' ?
new MarkdownString().appendText(this.toolInvocation.pastTenseMessage) :
this.toolInvocation.pastTenseMessage;
const part = this.renderProgressContent(this.toolInvocation.pastTenseMessage);
this._register(part);
return part.domNode;
} else {
content = typeof this.toolInvocation.invocationMessage === 'string' ?
new MarkdownString().appendText(this.toolInvocation.invocationMessage + '…') :
MarkdownString.lift(this.toolInvocation.invocationMessage).appendText('…');
const container = document.createElement('div');
const progressObservable = this.toolInvocation.kind === 'toolInvocation' ? this.toolInvocation.progress : undefined;
this._register(autorunWithStore((reader, store) => {
const progress = progressObservable?.read(reader);
const part = store.add(this.renderProgressContent(progress?.message || this.toolInvocation.invocationMessage));
dom.reset(container, part.domNode);
}));
return container;
}
}

private renderProgressContent(content: IMarkdownString | string) {
if (typeof content === 'string') {
content = new MarkdownString().appendText(content);
}

const progressMessage: IChatProgressMessage = {
kind: 'progressMessage',
content
};

const iconOverride = !this.toolInvocation.isConfirmed ?
Codicon.error :
this.toolInvocation.isComplete ?
Codicon.check : undefined;
const progressPart = this._register(this.instantiationService.createInstance(ChatProgressContentPart, progressMessage, this.renderer, this.context, undefined, true, iconOverride));
return progressPart.domNode;
return this.instantiationService.createInstance(ChatProgressContentPart, progressMessage, this.renderer, this.context, undefined, true, iconOverride);
}

private createTerminalMarkdownProgressPart(toolInvocation: IChatToolInvocation | IChatToolInvocationSerialized, terminalData: IChatTerminalToolInvocationData): HTMLElement {
Expand Down
4 changes: 2 additions & 2 deletions src/vs/workbench/contrib/chat/browser/chatSetup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import { ILogService } from '../../../../platform/log/common/log.js';
import { IOpenerService } from '../../../../platform/opener/common/opener.js';
import product from '../../../../platform/product/common/product.js';
import { IProductService } from '../../../../platform/product/common/productService.js';
import { IProgressService, ProgressLocation } from '../../../../platform/progress/common/progress.js';
import { IProgress, IProgressService, IProgressStep, ProgressLocation } from '../../../../platform/progress/common/progress.js';
import { IQuickInputService } from '../../../../platform/quickinput/common/quickInput.js';
import { Registry } from '../../../../platform/registry/common/platform.js';
import { ITelemetryService, TelemetryLevel } from '../../../../platform/telemetry/common/telemetry.js';
Expand Down Expand Up @@ -518,7 +518,7 @@ class SetupTool extends Disposable implements IToolImpl {
super();
}

invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult> {
invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, progress: IProgress<IProgressStep>, token: CancellationToken): Promise<IToolResult> {
const result: IToolResult = {
content: [
{
Expand Down
25 changes: 18 additions & 7 deletions src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { Emitter } from '../../../../base/common/event.js';
import { MarkdownString } from '../../../../base/common/htmlContent.js';
import { Iterable } from '../../../../base/common/iterator.js';
import { Lazy } from '../../../../base/common/lazy.js';
import { Disposable, DisposableStore, dispose, IDisposable, toDisposable } from '../../../../base/common/lifecycle.js';
import { Disposable, DisposableStore, IDisposable, toDisposable } from '../../../../base/common/lifecycle.js';
import { LRUCache } from '../../../../base/common/map.js';
import { localize } from '../../../../nls.js';
import { IAccessibilityService } from '../../../../platform/accessibility/common/accessibility.js';
Expand Down Expand Up @@ -40,6 +40,11 @@ interface IToolEntry {
impl?: IToolImpl;
}

interface ITrackedCall {
invocation?: ChatToolInvocation;
store: IDisposable;
}

export class LanguageModelToolsService extends Disposable implements ILanguageModelToolsService {
_serviceBrand: undefined;

Expand All @@ -53,7 +58,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
private _toolContextKeys = new Set<string>();
private readonly _ctxToolsCount: IContextKey<number>;

private _callsByRequestId = new Map<string, IDisposable[]>();
private _callsByRequestId = new Map<string, ITrackedCall[]>();

private _workspaceToolConfirmStore: Lazy<ToolConfirmStore>;
private _profileToolConfirmStore: Lazy<ToolConfirmStore>;
Expand Down Expand Up @@ -235,7 +240,8 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
if (!this._callsByRequestId.has(requestId)) {
this._callsByRequestId.set(requestId, []);
}
this._callsByRequestId.get(requestId)!.push(store);
const trackedCall: ITrackedCall = { store };
this._callsByRequestId.get(requestId)!.push(trackedCall);

const source = new CancellationTokenSource();
store.add(toDisposable(() => {
Expand All @@ -252,6 +258,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo

const prepared = await this.prepareToolInvocation(tool, dto, token);
toolInvocation = new ChatToolInvocation(prepared, tool.data, dto.callId);
trackedCall.invocation = toolInvocation;
if (this.shouldAutoConfirm(tool.data.id, tool.data.runsInWorkspace)) {
toolInvocation.confirmed.complete(true);
}
Expand Down Expand Up @@ -285,7 +292,11 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
throw new CancellationError();
}

toolResult = await tool.impl.invoke(dto, countTokens, token);
toolResult = await tool.impl.invoke(dto, countTokens, {
report: step => {
toolInvocation?.acceptProgress(step);
}
}, token);
this.ensureToolDetails(dto, toolResult, tool.data);

this._telemetryService.publicLog2<LanguageModelToolInvokedEvent, LanguageModelToolInvokedClassification>(
Expand Down Expand Up @@ -403,7 +414,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
private cleanupCallDisposables(requestId: string, store: DisposableStore): void {
const disposables = this._callsByRequestId.get(requestId);
if (disposables) {
const index = disposables.indexOf(store);
const index = disposables.findIndex(d => d.store === store);
if (index > -1) {
disposables.splice(index, 1);
}
Expand All @@ -417,15 +428,15 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
cancelToolCallsForRequest(requestId: string): void {
const calls = this._callsByRequestId.get(requestId);
if (calls) {
calls.forEach(call => call.dispose());
calls.forEach(call => call.store.dispose());
this._callsByRequestId.delete(requestId);
}
}

public override dispose(): void {
super.dispose();

this._callsByRequestId.forEach(calls => dispose(calls));
this._callsByRequestId.forEach(calls => calls.forEach(call => call.store.dispose()));
this._ctxToolsCount.reset();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import { DeferredPromise } from '../../../../../base/common/async.js';
import { IMarkdownString } from '../../../../../base/common/htmlContent.js';
import { observableValue } from '../../../../../base/common/observable.js';
import { localize } from '../../../../../nls.js';
import { IProgressStep } from '../../../../../platform/progress/common/progress.js';
import { IChatTerminalToolInvocationData, IChatToolInputInvocationData, IChatToolInvocation, IChatToolInvocationSerialized } from '../chatService.js';
import { IPreparedToolInvocation, IToolConfirmationMessages, IToolData, IToolResult } from '../languageModelToolsService.js';

Expand Down Expand Up @@ -45,6 +47,8 @@ export class ChatToolInvocation implements IChatToolInvocation {

public readonly toolSpecificData?: IChatTerminalToolInvocationData | IChatToolInputInvocationData;

public readonly progress = observableValue<{ message?: string; progress: number }>(this, { progress: 0 });

constructor(preparedInvocation: IPreparedToolInvocation | undefined, toolData: IToolData, public readonly toolCallId: string) {
const defaultMessage = localize('toolInvocationMessage', "Using {0}", `"${toolData.displayName}"`);
const invocationMessage = preparedInvocation?.invocationMessage ?? defaultMessage;
Expand Down Expand Up @@ -84,6 +88,14 @@ export class ChatToolInvocation implements IChatToolInvocation {
return this._confirmationMessages;
}

public acceptProgress(step: IProgressStep) {
const prev = this.progress.get();
this.progress.set({
progress: step.increment ? (prev.progress + step.increment) : prev.progress,
message: step.message,
}, undefined);
}

public toJSON(): IChatToolInvocationSerialized {
return {
kind: 'toolInvocationSerialized',
Expand Down
2 changes: 2 additions & 0 deletions src/vs/workbench/contrib/chat/common/chatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { DeferredPromise } from '../../../../base/common/async.js';
import { CancellationToken } from '../../../../base/common/cancellation.js';
import { Event } from '../../../../base/common/event.js';
import { IMarkdownString } from '../../../../base/common/htmlContent.js';
import { IObservable } from '../../../../base/common/observable.js';
import { ThemeIcon } from '../../../../base/common/themables.js';
import { URI } from '../../../../base/common/uri.js';
import { IRange, Range } from '../../../../editor/common/core/range.js';
Expand Down Expand Up @@ -234,6 +235,7 @@ export interface IChatToolInvocation {
invocationMessage: string | IMarkdownString;
pastTenseMessage: string | IMarkdownString | undefined;
resultDetails: IToolResult['toolResultDetails'];
progress: IObservable<{ message?: string; progress: number }>;
readonly toolId: string;
readonly toolCallId: string;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import { Event } from '../../../../base/common/event.js';
import { IMarkdownString } from '../../../../base/common/htmlContent.js';
import { IJSONSchema } from '../../../../base/common/jsonSchema.js';
import { IDisposable } from '../../../../base/common/lifecycle.js';
import { Schemas } from '../../../../base/common/network.js';
import { ThemeIcon } from '../../../../base/common/themables.js';
import { URI } from '../../../../base/common/uri.js';
import { Location } from '../../../../editor/common/languages.js';
import { ContextKeyExpression } from '../../../../platform/contextkey/common/contextkey.js';
import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';
import { createDecorator } from '../../../../platform/instantiation/common/instantiation.js';
import { Location } from '../../../../editor/common/languages.js';
import { IProgress, IProgressStep } from '../../../../platform/progress/common/progress.js';
import { IChatTerminalToolInvocationData, IChatToolInputInvocationData } from './chatService.js';
import { Schemas } from '../../../../base/common/network.js';
import { PromptElementJSON, stringifyPromptElementJSON } from './tools/promptTsxTypes.js';

export interface IToolData {
Expand Down Expand Up @@ -134,7 +135,7 @@ export interface IPreparedToolInvocation {
}

export interface IToolImpl {
invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult>;
invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, progress: IProgress<IProgressStep>, token: CancellationToken): Promise<IToolResult>;
prepareToolInvocation?(parameters: any, token: CancellationToken): Promise<IPreparedToolInvocation | undefined>;
}

Expand Down
3 changes: 2 additions & 1 deletion src/vs/workbench/contrib/chat/common/tools/editFileTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { IDisposable } from '../../../../../base/common/lifecycle.js';
import { autorun } from '../../../../../base/common/observable.js';
import { URI, UriComponents } from '../../../../../base/common/uri.js';
import { generateUuid } from '../../../../../base/common/uuid.js';
import { IProgress, IProgressStep } from '../../../../../platform/progress/common/progress.js';
import { SaveReason } from '../../../../common/editor.js';
import { ITextFileService } from '../../../../services/textfile/common/textfiles.js';
import { CellUri } from '../../../notebook/common/notebookCommon.js';
Expand Down Expand Up @@ -42,7 +43,7 @@ export class EditTool implements IToolImpl {
@INotebookService private readonly notebookService: INotebookService,
) { }

async invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult> {
async invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, _progress: IProgress<IProgressStep>, token: CancellationToken): Promise<IToolResult> {
if (!invocation.context) {
throw new Error('toolInvocationToken is required for this tool');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { ITrustedDomainService } from '../../../url/browser/trustedDomainService
import { CountTokensCallback, IPreparedToolInvocation, IToolData, IToolImpl, IToolInvocation, IToolResult, IToolResultTextPart } from '../../common/languageModelToolsService.js';
import { MarkdownString } from '../../../../../base/common/htmlContent.js';
import { InternalFetchWebPageToolId } from '../../common/tools/tools.js';
import { IProgress, IProgressStep } from '../../../../../platform/progress/common/progress.js';

export const FetchWebPageToolData: IToolData = {
id: InternalFetchWebPageToolId,
Expand Down Expand Up @@ -41,7 +42,7 @@ export class FetchWebPageTool implements IToolImpl {
@ITrustedDomainService private readonly _trustedDomainService: ITrustedDomainService,
) { }

async invoke(invocation: IToolInvocation, _countTokens: CountTokensCallback, _token: CancellationToken): Promise<IToolResult> {
async invoke(invocation: IToolInvocation, _countTokens: CountTokensCallback, _progress: IProgress<IProgressStep>, _token: CancellationToken): Promise<IToolResult> {
const parsedUriResults = this._parseUris((invocation.parameters as { urls?: string[] }).urls);
const validUris = Array.from(parsedUriResults.values()).filter((uri): uri is URI => !!uri);
if (!validUris.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ suite('LanguageModelToolsService', () => {

const toolBarrier = new Barrier();
const toolImpl: IToolImpl = {
invoke: async (invocation, countTokens, cancelToken) => {
invoke: async (invocation, countTokens, progress, cancelToken) => {
assert.strictEqual(invocation.callId, '1');
assert.strictEqual(invocation.toolId, 'testTool');
assert.deepStrictEqual(invocation.parameters, { a: 1 });
Expand Down
Loading
Loading