Skip to content

WIP: Tool calling UI #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Imports:
jsonlite,
promises (>= 1.3.2),
rlang,
S7,
shiny (>= 1.10.0)
Suggests:
later,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export(chat_clear)
export(chat_mod_server)
export(chat_mod_ui)
export(chat_ui)
export(contents_shinychat)
export(markdown_stream)
export(output_markdown_stream)
import(rlang)
Expand Down
4 changes: 4 additions & 0 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,10 @@ rlang::on_load(

res$add(msg)

if (S7::S7_inherits(msg, ellmer::Content)) {
msg <- contents_shinychat(msg)
}

chat_append_message(
id,
list(role = role, content = msg),
Expand Down
54 changes: 37 additions & 17 deletions R/chat_app.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,7 @@ chat_mod_ui <- function(id, ..., client = NULL, messages = NULL) {
if (!is.null(client)) {
check_ellmer_chat(client)

client_msgs <- map(client$get_turns(), function(turn) {
content <- ellmer::contents_markdown(turn)
if (is.null(content) || identical(content, "")) {
return(NULL)
}
list(role = turn@role, content = content)
})
client_msgs <- compact(client_msgs)
client_msgs <- contents_shinychat(client)

if (length(client_msgs)) {
if (!is.null(messages)) {
Expand All @@ -125,10 +118,22 @@ chat_mod_ui <- function(id, ..., client = NULL, messages = NULL) {
}
}

shinychat::chat_ui(
shiny::NS(id, "chat"),
messages = messages,
...
shiny::tagList(
shinychat::chat_ui(
shiny::NS(id, "chat"),
messages = messages,
...
),
shiny::includeCSS(system.file(
"tools",
"tool-request.css",
package = "shinychat"
)),
shiny::includeScript(system.file(
"tools",
"tool-request.js",
package = "shinychat"
))
)
}

Expand All @@ -139,12 +144,27 @@ chat_mod_server <- function(id, client) {

append_stream_task <- shiny::ExtendedTask$new(
function(client, ui_id, user_input) {
promises::then(
promises::promise_resolve(client$stream_async(user_input)),
function(stream) {
chat_append(ui_id, stream)
}
clear_on_tool_result <- client$on_tool_result(function(result) {
session <- shiny::getDefaultReactiveDomain()
if (is.null(session)) return()
session$sendCustomMessage(
"shinychat-hide-tool-request",
result@request@id
)
})

stream <- client$stream_async(
user_input,
stream = "content"
)

p <- promises::promise_resolve(stream)
p <- promises::then(p, function(stream) {
chat_append(ui_id, stream)
})
promises::finally(p, function() {
clear_on_tool_result()
})
}
)

Expand Down
229 changes: 229 additions & 0 deletions R/contents_shinychat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#' Format ellmer content for shinychat
#'
#' @param content An [`ellmer::Content`] object.
#' @param ... Additional arguments passed to underlying methods.
#'
#' @return Returns text or HTML formatted for use in `chat_ui()`.
#'
#' @export
contents_shinychat <- S7::new_generic("contents_shinychat", "content")

S7::method(contents_shinychat, ellmer::Content) <- function(content, ...) {
# Fall back to html or markdown
html <- ellmer::contents_html(content)
if (!is.null(html)) shiny::HTML(html) else ellmer::contents_markdown(content)
}

S7::method(contents_shinychat, ellmer::ContentText) <- function(content) {
content@text
}

S7::method(contents_shinychat, ellmer::ContentToolRequest) <- function(
content,
...
) {
call <- format(content, show = "call")
if (length(call) > 1) {
call <- sprintf("%s()", content@name)
}
shiny::HTML(sprintf(
'\n\n<p class="shiny-tool-request" data-tool-call-id="%s">Running <code>%s</code></p>\n\n',
content@id,
call
))
}

S7::method(contents_shinychat, ellmer::ContentToolResult) <- function(
content,
...
) {
pre_code <- function(x) {
x <- gsub("`", "&#96;", x, fixed = TRUE)
x <- gsub("<", "&lt;", x, fixed = TRUE)
x <- gsub(">", "&gt;", x, fixed = TRUE)
sprintf("<pre><code>%s</code></pre>", paste(x, collapse = "\n"))
}

deps <- NULL

tool_result_display <- function(content) {
if (is.null(content@extra$display)) {
return(pre_code(content@value))
}

display <- content@extra$display

html <- NULL
md <- NULL
text <- NULL

if (
is.list(content@extra$display) &&
!inherits(content@extra$display, c("shiny.tag.list", "shiny.tag"))
) {
if (
!some(
c("text", "markdown", "html"),
\(x) x %in% names(content@extra$display)
)
) {
stop(
"ContentToolResult@extra$display must be a list with at least one of the following elements: text, markdown, html."
)
}
html <- content@extra$display$html
md <- content@extra$display$markdown
text <- content@extra$display$text
} else {
if (inherits(content@extra$display, "html")) {
html <- content@extra$display
} else {
md <- content@extra$display
}
}

if (!is.null(html)) {
deps <<- htmltools::findDependencies(html)
return(format(html))
}

if (!is.null(markdown)) {
md <- paste(md, collapse = "\n")
md <- paste0("\n\n", md, "\n\n")
return(md)
}

return(text %||% pre_code(contents$value))
}

if (isFALSE(content@extra$display_tool_request)) {
res <- tool_result_display(content)
if (!is.null(deps)) {
res <- htmltools::attachDependencies(res, deps)
}
return(res)
}

if (!is.null(content@error)) {
class <- "shiny-tool-result failed"
summary_text <- "Failed to call"
tool_result <- sprintf(
"<strong>Error</strong>%s",
pre_code(strip_ansi(content@error))
)
} else {
class <- "shiny-tool-result"
summary_text <- "Result from"
tool_result <- sprintf(
'<strong>Tool Result</strong>%s',
tool_result_display(content)
)
}

if (!is.null(content@request@tool)) {
if (!is.null(content@request@tool@annotations$title)) {
# Use the tool title if available
tool_name <- content@request@tool@annotations$title
summary_text <- ""
} else {
# Fallback to tool name
tool_name <- content@request@tool@name
}
} else {
tool_name <- "unknown tool"
}

intent <- ""
if (!is.null(content@request@arguments$intent)) {
intent <- sprintf(
' | <span class="intent">%s</span>',
content@request@arguments$intent
)
}

tool_call <-
details_open <- sprintf(
'<details class="%s" id="%s">',
class,
content@request@id
)

summary <- sprintf(
'<summary>%s <span class="function-name">%s</span>%s</summary>',
summary_text,
tool_name,
intent
)

tool_call <- sprintf(
'<strong>Tool Call</strong>%s',
pre_code(format(content@request, show = "call"))
)

body <- sprintf(
'<p>%s</p><p>%s</p></details>\n\n',
tool_call,
tool_result
)

res <- shiny::HTML(paste0(details_open, summary, body))
if (!is.null(deps)) {
res <- htmltools::attachDependencies(res, deps)
}
return(res)
}

S7::method(contents_shinychat, ellmer::Turn) <- function(content) {
lapply(content@contents, contents_shinychat)
}

S7::method(contents_shinychat, S7::new_S3_class(c("Chat", "R6"))) <- function(
content,
...
) {
# Consolidate tool calls into assistant turns. This currently assumes that
# tool calls are always returned in user turns that have at least one
# proceeding assistant turn.
turns <- map(content$get_turns(), function(turn) {
if (
all(map_lgl(turn@contents, S7::S7_inherits, ellmer::ContentToolResult))
) {
turn@role <- "assistant"
}
is_tool_request <- map_lgl(
turn@contents,
S7::S7_inherits,
ellmer::ContentToolRequest
)
turn@contents <- turn@contents[!is_tool_request]
turn
})
turns <- reduce(turns, .init = list(), function(turns, turn) {
if (length(turns) == 0) {
return(list(turn))
}

# consolidate turns with adjacent roles
last_turn <- turns[[length(turns)]]
if (identical(last_turn@role, turn@role)) {
turns[[length(turns)]]@contents <- c(last_turn@contents, turn@contents)
return(turns)
}

c(turns, list(turn))
})

messages <- map(turns, function(turn) {
content <- compact(contents_shinychat(turn))
if (is.null(content) || identical(content, "")) {
return(NULL)
}
if (every(content, is.character)) {
# TODO: Fix chat_ui() to handle lists of strings
content <- paste(unlist(content), collapse = "\n\n")
}
list(role = turn@role, content = content)
})

compact(messages)
}
1 change: 1 addition & 0 deletions R/shinychat-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ NULL
ignore_unused_imports <- function() {
jsonlite::fromJSON
fastmap::fastqueue
ellmer::contents_html
}

release_bullets <- function() {
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.onLoad <- function(libname, pkgname) {
rlang::run_on_load()
S7::methods_register()
}

as_generator <- function(x) {
Expand Down
Loading