Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 49 additions & 6 deletions bin/spice/cmd/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
modelKeyFlag = "model"
httpEndpointKeyFlag = "http-endpoint"
userAgentKeyFlag = "user-agent"
temperatureFlag = "temperature"
)

type Message struct {
Expand All @@ -52,6 +53,38 @@ type ChatRequestBody struct {
Model string `json:"model"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
ChatRequestOptions
}

// ChatRequestOptions contains all optional fields for chat requests
type ChatRequestOptions struct {
Temperature *float32 `json:"temperature,omitempty"`
}

func NewChatRequestBody(messages []Message, model string, stream bool, streamOptions *StreamOptions) *ChatRequestBody {
return &ChatRequestBody{
Messages: messages,
Model: model,
Stream: stream,
StreamOptions: streamOptions,
}
}

func ApplyChatOptions(body *ChatRequestBody, cmd *cobra.Command) (*ChatRequestBody, error) {
if cmd.Flags().Changed("temperature") {
temperature, err := cmd.Flags().GetFloat32("temperature")
if err != nil {
slog.Error("could not get temperature flag", "error", err)
os.Exit(1)
}
if temperature < 0 {
slog.Error("temperature must be greater than or equal to 0")
os.Exit(1)
}
body.Temperature = &temperature
}

return body, nil
}

type StreamOptions struct {
Expand Down Expand Up @@ -120,6 +153,16 @@ spice chat --model <model> --cloud
rtcontext.SetApiKey(apiKey)
}

temperature, err := cmd.Flags().GetFloat32("temperature")
if err != nil {
slog.Error("could not get temperature flag", "error", err)
os.Exit(1)
}
if temperature < 0 {
slog.Error("temperature must be greater than or equal to 0")
os.Exit(1)
}

userAgent, _ := cmd.Flags().GetString(userAgentKeyFlag)
if userAgent != "" {
rtcontext.SetUserAgent(userAgent)
Expand Down Expand Up @@ -215,12 +258,11 @@ spice chat --model <model> --cloud
util.ShowSpinner(done)
}()

body := &ChatRequestBody{
Messages: messages,
Model: model,
Stream: true,
StreamOptions: &StreamOptions{IncludeUsage: true},
}
body := NewChatRequestBody(messages, model, true, &StreamOptions{
IncludeUsage: true,
})
body, _ = ApplyChatOptions(body, cmd)

var timeAtCompletion time.Time
var timeAtFirstToken time.Time
startTime := time.Now()
Expand Down Expand Up @@ -382,6 +424,7 @@ func init() {
chatCmd.Flags().String(modelKeyFlag, "", "Model to chat with")
chatCmd.Flags().String(httpEndpointKeyFlag, "", "HTTP endpoint for chat (default: http://localhost:8090)")
chatCmd.Flags().String(userAgentKeyFlag, "", "User agent to use in all requests")
chatCmd.Flags().Float32(temperatureFlag, 1, "Model temperature for chat request")

RootCmd.AddCommand(chatCmd)
}
2 changes: 2 additions & 0 deletions bin/spiced/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ pub async fn run(args: Args) -> Result<()> {
}

let result = match server_thread.await {
// Don't treat force terminated as an error
Ok(Err(runtime::Error::ForceTerminated { .. })) => Ok(()),
Ok(ok) => ok.context(UnableToStartServersSnafu),
Err(_) => Err(Error::GenericError {
reason: "Unable to start spiced".into(),
Expand Down
20 changes: 5 additions & 15 deletions crates/runtime/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct RuntimeBuilder {
datasets_health_monitor_enabled: bool,
metrics_endpoint: Option<SocketAddr>,
prometheus_registry: Option<prometheus::Registry>,
runtime_status: Option<Arc<status::RuntimeStatus>>,
runtime_status: Arc<status::RuntimeStatus>,
rate_limits: Option<Arc<RateLimits>>,
accelerator_engine_registry: Arc<AcceleratorEngineRegistry>,
datafusion_configuration_fn: Option<DatafusionConfigurationCallback>,
Expand All @@ -60,7 +60,7 @@ impl RuntimeBuilder {
metrics_endpoint: None,
prometheus_registry: None,
autoload_extensions: HashMap::new(),
runtime_status: None,
runtime_status: status::RuntimeStatus::new(),
rate_limits: None,
accelerator_engine_registry: Arc::new(AcceleratorEngineRegistry::new()),
datafusion_configuration_fn: None,
Expand Down Expand Up @@ -130,11 +130,6 @@ impl RuntimeBuilder {
self
}

pub fn with_runtime_status(mut self, runtime_status: Arc<status::RuntimeStatus>) -> Self {
self.runtime_status = Some(runtime_status);
self
}

pub fn with_rate_limits(mut self, rate_limits: RateLimits) -> Self {
self.rate_limits = Some(Arc::new(rate_limits));
self
Expand All @@ -147,13 +142,8 @@ impl RuntimeBuilder {
tools::factory::register_all_factories().await;
document_parse::register_all().await;

let status = match self.runtime_status {
Some(status) => status,
None => status::RuntimeStatus::new(),
};

let mut df = DataFusion::builder(
Arc::clone(&status),
Arc::clone(&self.runtime_status),
Arc::clone(&self.accelerator_engine_registry),
)
.build();
Expand Down Expand Up @@ -203,9 +193,9 @@ impl RuntimeBuilder {
metrics_endpoint: self.metrics_endpoint,
prometheus_registry: self.prometheus_registry,
rate_limits: self.rate_limits.unwrap_or_default(),
status,
status: self.runtime_status,
runtime_tasks: Arc::new(RwLock::new(HashMap::new())),
accelerator_engine_registry: Arc::clone(&self.accelerator_engine_registry),
accelerator_engine_registry: self.accelerator_engine_registry,
};

let mut extensions: HashMap<String, Arc<dyn Extension>> = HashMap::new();
Expand Down
22 changes: 18 additions & 4 deletions crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use util::force_shutdown_signal;

use crate::dataaccelerator::AcceleratorEngineRegistry;
use crate::{
Expand Down Expand Up @@ -314,6 +315,9 @@ pub enum Error {

#[snafu(display("Initialization has been cancelled"))]
ComponentsInitializationCancelled,

#[snafu(display("Force shutdown requested"))]
ForceTerminated,
}

const HTTP_SERVER: &str = "http_server";
Expand Down Expand Up @@ -546,10 +550,20 @@ impl Runtime {

// Shutdown signal
let shutdown_signal_future = async {
shutdown_signal().await;
tracing::debug!("Shutdown signal received.");
self.shutdown().await;
Ok(())
let graceful_shutdown = async {
shutdown_signal().await;
tracing::debug!("Shutdown signal received. Press Ctrl-C again to force exit.");
self.shutdown().await;
Ok(())
};
tokio::select! {
result = graceful_shutdown => result,
() = force_shutdown_signal() => {
tracing::info!("Force shutdown signal received. Terminating immediately.");
// return error to force stop waiting for other tasks and terminate immediately
Err(Error::ForceTerminated)
}
}
};

// wait for all servers to shut down or if any of the servers fail to start
Expand Down
1 change: 1 addition & 0 deletions crates/util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ humantime = "2.1.0"
rand = "0.9.0"
tokio = { workspace = true }
tracing = { workspace = true }
ctrlc = "3.4"
33 changes: 33 additions & 0 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ limitations under the License.

use std::{
cmp,
sync::Arc,
time::{Duration, SystemTime, SystemTimeError},
};

pub mod fibonacci_backoff;
pub use backoff::Error as RetryError;
pub use backoff::future::retry;
use tokio::{sync::oneshot, time::Instant};

#[allow(clippy::cast_precision_loss)]
#[allow(clippy::cast_sign_loss)]
Expand Down Expand Up @@ -59,6 +61,37 @@ pub async fn shutdown_signal() {
shutdown_signal_impl().await;
}

/// Waits for an additional Ctrl-C after the initial shutdown signal to trigger a forced shutdown.
pub async fn force_shutdown_signal() {
shutdown_signal().await;

// use 500ms as a debounce window to prevent the same Ctrl-C signal from being handled multiple times
let last_signal_time = Instant::now();

let (notify_ctrl_c, on_second_ctrl_c) = oneshot::channel::<()>();
let notify_ctrl_c = Arc::new(std::sync::Mutex::new(Some(notify_ctrl_c)));

if let Err(err) = ctrlc::set_handler({
move || {
if Instant::now().duration_since(last_signal_time) < Duration::from_millis(500) {
return;
}
if let Some(tx) = notify_ctrl_c
.lock()
.ok()
.and_then(|mut tx_opt| tx_opt.take())
{
tracing::debug!("Received Ctrl-C after the initial shutdown signal");
tx.send(()).ok();
}
}
}) {
tracing::error!("Failed to set listener for Ctrl-C: {err}");
// do not exit; otherwise, it will be interpreted as a force shutdown signal
};
on_second_ctrl_c.await.ok();
}

#[cfg(unix)]
async fn shutdown_signal_impl() {
use tokio::signal::unix::{SignalKind, signal};
Expand Down
Loading