Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(amazonq): use Dispatchers.IO for @workspace requests #5374

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "bugfix",
"description" : "Amazon Q: Attempt to reduce thread pool contention locking IDE caused by `@workspace` making a large number of requests"
}
6 changes: 2 additions & 4 deletions detekt-rules/detekt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ coroutines:
active: true
GlobalCoroutineUsage:
active: true
RedundantSuspendModifier:
active: true
SleepInsteadOfDelay:
active: true
InjectDispatcher:
active: false
SuspendFunWithFlowReturnType:
active: true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.coroutines.IO
import software.aws.toolkits.jetbrains.utils.pluginAwareExecuteOnPooledThread
import java.util.concurrent.TimeoutException

Expand All @@ -28,7 +30,7 @@ class ProjectContextController(private val project: Project, private val cs: Cor
// TODO: Ideally we should inject dependencies via constructor for easier testing, refer to how [TelemetryService] inject publisher and batcher
private val encoderServer: EncoderServer = EncoderServer(project)
private val projectContextProvider: ProjectContextProvider = ProjectContextProvider(project, encoderServer, cs)
val initJob: Job = cs.launch {
val initJob: Job = cs.launch(IO) {
encoderServer.downloadArtifactsAndStartServer()
}

Expand All @@ -51,16 +53,16 @@ class ProjectContextController(private val project: Project, private val cs: Cor

fun getProjectContextIndexComplete() = projectContextProvider.isIndexComplete.get()

suspend fun queryChat(prompt: String, timeout: Long?): List<RelevantDocument> {
suspend fun queryChat(prompt: String, timeout: Long?): List<RelevantDocument> = withContext(IO) {
try {
return projectContextProvider.query(prompt, timeout)
projectContextProvider.query(prompt, timeout)
} catch (e: Exception) {
logger.warn { "error while querying for project context $e.message" }
return emptyList()
emptyList()
}
}

suspend fun queryInline(query: String, filePath: String): List<InlineBm25Chunk> =
suspend fun queryInline(query: String, filePath: String): List<InlineBm25Chunk> = withContext(IO) {
try {
projectContextProvider.queryInline(query, filePath, InlineContextTarget.CODEMAP)
} catch (e: Exception) {
Expand All @@ -71,6 +73,7 @@ class ProjectContextController(private val project: Project, private val cs: Cor
logger.warn { logStr }
emptyList()
}
}

@RequiresBackgroundThread
fun updateIndex(filePaths: List<String>, mode: IndexUpdateMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@ import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileVisitor
import com.intellij.openapi.vfs.isFile
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.async
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.info
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.coroutines.IO
import software.aws.toolkits.jetbrains.services.amazonq.CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT
import software.aws.toolkits.jetbrains.services.amazonq.FeatureDevSessionContext
import software.aws.toolkits.jetbrains.services.amazonq.SUPPLEMENTAL_CONTEXT_TIMEOUT
Expand All @@ -33,7 +36,7 @@ import software.aws.toolkits.jetbrains.settings.CodeWhispererSettings
import software.aws.toolkits.telemetry.AmazonqTelemetry
import java.io.OutputStreamWriter
import java.net.HttpURLConnection
import java.net.URL
import java.net.URI
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import kotlin.time.Duration.Companion.minutes
Expand All @@ -44,7 +47,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
private val mapper = jacksonObjectMapper()

init {
cs.launch {
cs.launch(IO) {
if (ApplicationManager.getApplication().isUnitTestMode) {
return@launch
}
Expand Down Expand Up @@ -127,13 +130,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
}
}

private fun initEncryption(): Boolean {
private suspend fun initEncryption(): Boolean {
val request = encoderServer.getEncryptionRequest()
val response = sendMsgToLsp(LspMessage.Initialize, request)
return response.responseCode == 200
}

fun index(): Boolean {
suspend fun index(): Boolean {
val projectRoot = project.basePath ?: return false

val indexStartTime = System.currentTimeMillis()
Expand Down Expand Up @@ -179,7 +182,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
}.await()
}

fun getUsage(): Usage? {
internal suspend fun getUsage(): Usage? {
val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null)
return try {
val parsedResponse = mapper.readValue<Usage>(response.responseBody)
Expand All @@ -190,9 +193,10 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
}
}

@RequiresBackgroundThread
fun updateIndex(filePaths: List<String>, mode: IndexUpdateMode) {
val encrypted = encryptRequest(UpdateIndexRequest(filePaths, mode.command))
sendMsgToLsp(LspMessage.UpdateIndex, encrypted)
runBlocking(IO) { sendMsgToLsp(LspMessage.UpdateIndex, encrypted) }
}

private fun recordIndexWorkspace(
Expand Down Expand Up @@ -306,12 +310,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
return encoderServer.encrypt(payloadJson)
}

private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse {
private suspend fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse = withContext(IO) {
logger.info { "sending message: ${msgType.endpoint} to lsp on port ${encoderServer.port}" }
val url = URL("http://localhost:${encoderServer.port}/${msgType.endpoint}")
val url = URI("http://127.0.0.1:${encoderServer.port}/${msgType.endpoint}").toURL()
// use 1h as timeout for index, 5 seconds for other APIs
val timeoutMs = if (msgType is LspMessage.Index) 60.minutes.inWholeMilliseconds.toInt() else 5000
return with(url.openConnection() as HttpURLConnection) {

with(url.openConnection() as HttpURLConnection) {
setConnectionProperties(this)
setConnectionTimeout(this, timeoutMs)
request?.let { r ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ fun getCoroutineUiContext(): CoroutineContext = EdtCoroutineDispatcher
fun getCoroutineBgContext(): CoroutineContext = AppExecutorUtil.getAppExecutorService().asCoroutineDispatcher()

val EDT = Dispatchers.EDT

val IO = Dispatchers.IO
Loading