Skip to content

Commit 8ae51dc

Browse files
committed
#27によって発生していた画面リロード時に認証情報が表示されない問題、LLMの一覧取得ができない問題など種々の課題を修正
1 parent 3ce0b2a commit 8ae51dc

File tree

6 files changed

+36
-15
lines changed

6 files changed

+36
-15
lines changed

src/main/api/bedrock/services/imageService.ts

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { BedrockRuntimeClient, InvokeModelCommand } from '@aws-sdk/client-bedrock-runtime'
2-
import type { AWSCredentials, ServiceContext } from '../types'
2+
import type { ServiceContext } from '../types'
33
import type {
44
AspectRatio,
55
GenerateImageRequest,
@@ -111,14 +111,20 @@ export class ImageService {
111111
]
112112

113113
constructor(private context: ServiceContext) {
114-
const awsCredentials: AWSCredentials = this.context.store.get('aws')
115-
const { credentials, region } = awsCredentials
114+
const { accessKeyId, secretAccessKey, sessionToken, region } = this.context.store.get('aws')
116115

117-
if (!credentials || !credentials.accessKeyId || !credentials.secretAccessKey || !region) {
116+
if (!accessKeyId || !secretAccessKey || !region) {
118117
console.warn('AWS credentials not configured')
119118
}
120119

121-
this.runtimeClient = new BedrockRuntimeClient(awsCredentials)
120+
this.runtimeClient = new BedrockRuntimeClient({
121+
credentials: {
122+
accessKeyId,
123+
secretAccessKey,
124+
sessionToken
125+
},
126+
region
127+
})
122128
}
123129

124130
private getModelType(modelId: ImageGeneratorModel): ModelType {

src/main/api/bedrock/services/modelService.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { getDefaultPromptRouter, getModelsForRegion } from '../models'
22
import { getAccountId } from '../utils/awsUtils'
3-
import type { AWSCredentials, ServiceContext } from '../types'
3+
import type { ServiceContext } from '../types'
44
import { BedrockSupportRegion } from '../../../../types/llm'
55

66
export class ModelService {
@@ -10,14 +10,14 @@ export class ModelService {
1010
constructor(private context: ServiceContext) {}
1111

1212
async listModels() {
13-
const awsCredentials: AWSCredentials = this.context.store.get('aws')
14-
const { credentials, region } = awsCredentials
15-
if (!credentials || !credentials.accessKeyId || !credentials.secretAccessKey || !region) {
13+
const { accessKeyId, secretAccessKey, sessionToken, region } = this.context.store.get('aws')
14+
15+
if (!accessKeyId || !secretAccessKey || !region) {
1616
console.warn('AWS credentials not configured')
1717
return []
1818
}
1919

20-
const cacheKey = `${region}-${credentials.accessKeyId}`
20+
const cacheKey = `${region}-${accessKeyId}`
2121
const cachedData = this.modelCache[cacheKey]
2222

2323
if (
@@ -31,7 +31,12 @@ export class ModelService {
3131
try {
3232
const models = getModelsForRegion(region as BedrockSupportRegion)
3333

34-
const accountId = await getAccountId(awsCredentials)
34+
const accountId = await getAccountId({
35+
accessKeyId,
36+
secretAccessKey,
37+
sessionToken,
38+
region
39+
})
3540
const promptRouterModels = accountId ? getDefaultPromptRouter(accountId, region) : []
3641
const result = [...models, ...promptRouterModels]
3742
this.modelCache[cacheKey] = [...result, { _timestamp: Date.now() } as any]

src/main/api/bedrock/types.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { Message } from '@aws-sdk/client-bedrock-runtime'
2-
import type { AwsCredentialIdentity } from '@smithy/types'
32

43
export type CallConverseAPIProps = {
54
modelId: string
@@ -9,7 +8,9 @@ export type CallConverseAPIProps = {
98
}
109

1110
export type AWSCredentials = {
12-
credentials: AwsCredentialIdentity
11+
accessKeyId: string
12+
secretAccessKey: string
13+
sessionToken?: string
1314
region: string
1415
}
1516

src/main/api/bedrock/utils/awsUtils.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ export function getAlternateRegionOnThrottling(
5252

5353
export async function getAccountId(awsCredentials: AWSCredentials) {
5454
try {
55-
const sts = new STSClient(awsCredentials)
55+
const sts = new STSClient({
56+
credentials: {
57+
accessKeyId: awsCredentials.accessKeyId,
58+
secretAccessKey: awsCredentials.secretAccessKey,
59+
sessionToken: awsCredentials?.sessionToken
60+
},
61+
region: awsCredentials.region
62+
})
5663
const command = new GetCallerIdentityCommand({})
5764
const res = await sts.send(command)
5865
return res.Account

src/preload/store.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type StoreScheme = {
3737
region: string
3838
accessKeyId: string
3939
secretAccessKey: string
40+
sessionToken?: string
4041
}
4142
customAgents: CustomAgent[]
4243
selectedAgentId: string

src/renderer/src/contexts/SettingsContext.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ export const SettingsProvider: React.FC<{ children: React.ReactNode }> = ({ chil
294294
setStateAwsAccessKeyId(awsConfig.accessKeyId || '')
295295
setStateAwsSecretAccessKey(awsConfig.secretAccessKey || '')
296296
setStateAwsSessionToken(awsConfig.sessionToken || '')
297+
console.log({ awsConfig })
297298
}
298299

299300
// Load Custom Agents
@@ -516,7 +517,7 @@ export const SettingsProvider: React.FC<{ children: React.ReactNode }> = ({ chil
516517

517518
const saveAwsConfig = (credentials: AwsCredentialIdentity, region: string) => {
518519
window.store.set('aws', {
519-
accessKey: credentials.accessKeyId,
520+
accessKeyId: credentials.accessKeyId,
520521
secretAccessKey: credentials.secretAccessKey,
521522
sessionToken: credentials.sessionToken,
522523
region

0 commit comments

Comments
 (0)