Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions packages/core/src/code_assist/oauth-credential-storage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,39 @@ describe('OAuthCredentialStorage', () => {
);
});

it('should merge existing refresh token when new payload lacks one', async () => {
const oldCredentials: OAuthCredentials = {
serverName: 'main-account',
token: {
accessToken: 'old-access-token',
refreshToken: 'persistent-refresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 3600000,
scope: 'email',
},
updatedAt: Date.now(),
};
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
oldCredentials,
);

const newTokens: Credentials = {
access_token: 'new-access-token',
expiry_date: Date.now() + 3600000,
};

await OAuthCredentialStorage.saveCredentials(newTokens);

expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
expect.objectContaining({
token: expect.objectContaining({
accessToken: 'new-access-token',
refreshToken: 'persistent-refresh-token', // correctly merged
}),
}),
);
});

it('should throw an error if access_token is missing', async () => {
const invalidCredentials: Credentials = {
...mockCredentials,
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/code_assist/oauth-credential-storage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ export class OAuthCredentialStorage {
throw new Error('Attempted to save credentials without an access token.');
}

const existing = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
const mergedRefreshToken =
credentials.refresh_token || existing?.token.refreshToken;

// Convert Google Credentials to OAuthCredentials format
const mcpCredentials: OAuthCredentials = {
serverName: MAIN_ACCOUNT_KEY,
token: {
accessToken: credentials.access_token,
refreshToken: credentials.refresh_token || undefined,
refreshToken: mergedRefreshToken || undefined,
tokenType: credentials.token_type || 'Bearer',
scope: credentials.scope || undefined,
expiresAt: credentials.expiry_date || undefined,
Expand Down
81 changes: 81 additions & 0 deletions packages/core/src/mcp/oauth-token-storage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,38 @@ describe('MCPOAuthTokenStorage', () => {
expect(savedData[0].serverName).toBe('existing-server');
});

it('should merge existing refresh token when new payload lacks one', async () => {
const existingCredentials: OAuthCredentials = {
...mockCredentials,
serverName: 'existing-server',
token: {
...mockToken,
refreshToken: 'old-refresh-token',
},
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([existingCredentials]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);

const newToken: OAuthToken = {
accessToken: 'new_access_token',
expiresAt: Date.now() + ONE_HR_MS,
tokenType: 'Bearer',
}; // missing refreshToken

await tokenStorage.saveToken('existing-server', newToken);

const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(
writeCall[1] as string,
) as OAuthCredentials[];

expect(savedData).toHaveLength(1);
expect(savedData[0].token.accessToken).toBe('new_access_token');
expect(savedData[0].token.refreshToken).toBe('old-refresh-token'); // successfully merged
});

it('should handle write errors gracefully', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
Expand Down Expand Up @@ -447,6 +479,55 @@ describe('MCPOAuthTokenStorage', () => {
expect(fs.mkdir).toHaveBeenCalled();
});

it('should merge existing refresh token when new payload lacks one in encrypted storage', async () => {
const serverName = 'server1';
const now = Date.now();
vi.spyOn(Date, 'now').mockReturnValue(now);

const existingCredentials: OAuthCredentials = {
serverName,
token: {
...mockToken,
refreshToken: 'old-refresh-token',
},
updatedAt: now,
};

mockHybridTokenStorage.getCredentials.mockResolvedValue(
existingCredentials,
);

const newToken: OAuthToken = {
accessToken: 'new_access_token',
expiresAt: Date.now() + ONE_HR_MS,
tokenType: 'Bearer',
};

await tokenStorage.saveToken(
serverName,
newToken,
'clientId',
'tokenUrl',
'mcpUrl',
);

const expectedCredential: OAuthCredentials = {
serverName,
token: {
...newToken,
refreshToken: 'old-refresh-token',
},
clientId: 'clientId',
tokenUrl: 'tokenUrl',
mcpServerUrl: 'mcpUrl',
updatedAt: now,
};

expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
expectedCredential,
);
});

it('should use HybridTokenStorage to get credentials', async () => {
mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials);
const result = await tokenStorage.getCredentials('server1');
Expand Down
11 changes: 10 additions & 1 deletion packages/core/src/mcp/oauth-token-storage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,18 @@ export class MCPOAuthTokenStorage implements TokenStorage {
): Promise<void> {
await this.ensureConfigDir();

const existing = await this.getCredentials(serverName);
const mergedRefreshToken =
token.refreshToken || existing?.token.refreshToken;

const mergedToken = {
...token,
refreshToken: mergedRefreshToken,
};

const credential: OAuthCredentials = {
serverName,
token,
token: mergedToken,
clientId,
tokenUrl,
mcpServerUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ describe('KeychainTokenStorage', () => {
expect(retrieved?.serverName).toBe('test-server');
});

it('should return null if no credentials are found or they are expired', async () => {
it('should return null if no credentials are found or they are expired and unrefreshable', async () => {
expect(await storage.getCredentials('missing')).toBeNull();

const expiredCreds = {
Expand All @@ -81,6 +81,20 @@ describe('KeychainTokenStorage', () => {
};
await storage.setCredentials(expiredCreds);
expect(await storage.getCredentials('test-server')).toBeNull();

// Ensure that if it has a refresh token, it is NOT returned as null
const expiredWithRefresh = {
...validCredentials,
token: {
...validCredentials.token,
expiresAt: Date.now() - 1000,
refreshToken: 'some-refresh-token',
},
};
await storage.setCredentials(expiredWithRefresh);
const retrieved = await storage.getCredentials('test-server');
expect(retrieved).not.toBeNull();
expect(retrieved?.token.refreshToken).toBe('some-refresh-token');
});

it('should throw if stored data is corrupted JSON', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export class KeychainTokenStorage
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const credentials = JSON.parse(data) as OAuthCredentials;

if (this.isTokenExpired(credentials)) {
if (this.isTokenExpired(credentials) && !credentials.token.refreshToken) {
return null;
}

Expand Down Expand Up @@ -104,7 +104,7 @@ export class KeychainTokenStorage
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const data = JSON.parse(cred.password) as OAuthCredentials;
if (!this.isTokenExpired(data)) {
if (!this.isTokenExpired(data) || data.token.refreshToken) {
result.set(cred.account, data);
}
} catch (error) {
Expand Down
Loading