|
3 | 3 |
|
4 | 4 | use std::ffi::OsStr; |
5 | 5 | use std::future::Future; |
6 | | -use std::time::Duration; |
7 | 6 |
|
8 | 7 | use clap_complete::engine::CompletionCandidate; |
| 8 | +use openshell_bootstrap::edge_token::load_edge_token; |
| 9 | +use openshell_bootstrap::oidc_token::{is_token_expired, load_oidc_token, store_oidc_token}; |
9 | 10 | use openshell_bootstrap::{list_gateways, load_active_gateway, load_gateway_metadata}; |
10 | 11 | use openshell_core::ObjectName; |
| 12 | +use openshell_core::auth::EdgeAuthInterceptor; |
11 | 13 | use openshell_core::proto::open_shell_client::OpenShellClient; |
12 | 14 | use openshell_core::proto::{ListProvidersRequest, ListSandboxesRequest}; |
13 | | -use tonic::transport::{Channel, Endpoint}; |
| 15 | +use tonic::service::interceptor::InterceptedService; |
| 16 | +use tonic::transport::Channel; |
14 | 17 |
|
15 | | -use crate::tls::{TlsOptions, build_tonic_tls_config, require_tls_materials}; |
| 18 | +use crate::oidc_auth::oidc_refresh_token; |
| 19 | +use crate::tls::{TlsOptions, build_channel}; |
16 | 20 |
|
17 | 21 | /// Complete gateway names from local metadata files (no network call). |
18 | 22 | pub fn complete_gateway_names(_prefix: &OsStr) -> Vec<CompletionCandidate> { |
@@ -84,17 +88,46 @@ fn resolve_active_gateway() -> Option<(String, String)> { |
84 | 88 | async fn completion_grpc_client( |
85 | 89 | server: &str, |
86 | 90 | gateway_name: &str, |
87 | | -) -> Option<OpenShellClient<Channel>> { |
88 | | - let tls_opts = TlsOptions::default().with_gateway_name(gateway_name); |
89 | | - let materials = require_tls_materials(server, &tls_opts).ok()?; |
90 | | - let tls_config = build_tonic_tls_config(&materials); |
91 | | - let endpoint = Endpoint::from_shared(server.to_string()) |
92 | | - .ok()? |
93 | | - .connect_timeout(Duration::from_secs(2)) |
94 | | - .tls_config(tls_config) |
95 | | - .ok()?; |
96 | | - let channel = endpoint.connect().await.ok()?; |
97 | | - Some(OpenShellClient::new(channel)) |
| 91 | +) -> Option<OpenShellClient<InterceptedService<Channel, EdgeAuthInterceptor>>> { |
| 92 | + let mut tls_opts = TlsOptions::default().with_gateway_name(gateway_name); |
| 93 | + tls_opts.gateway_insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE") |
| 94 | + .is_ok_and(|v| !v.is_empty() && v != "0" && v != "false"); |
| 95 | + |
| 96 | + if let Ok(meta) = load_gateway_metadata(gateway_name) { |
| 97 | + match meta.auth_mode.as_deref() { |
| 98 | + Some("oidc") => { |
| 99 | + if let Some(bundle) = load_oidc_token(gateway_name) { |
| 100 | + if is_token_expired(&bundle) { |
| 101 | + match oidc_refresh_token(&bundle).await { |
| 102 | + Ok(refreshed) => { |
| 103 | + let _ = store_oidc_token(gateway_name, &refreshed); |
| 104 | + tls_opts.oidc_token = Some(refreshed.access_token); |
| 105 | + } |
| 106 | + Err(_) => { |
| 107 | + tls_opts.oidc_token = Some(bundle.access_token); |
| 108 | + } |
| 109 | + } |
| 110 | + } else { |
| 111 | + tls_opts.oidc_token = Some(bundle.access_token); |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + Some("cloudflare_jwt") => { |
| 116 | + if let Some(token) = load_edge_token(gateway_name) { |
| 117 | + tls_opts.edge_token = Some(token); |
| 118 | + } |
| 119 | + } |
| 120 | + _ => {} |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + let channel = build_channel(server, &tls_opts).await.ok()?; |
| 125 | + let interceptor = EdgeAuthInterceptor::new( |
| 126 | + tls_opts.oidc_token.as_deref(), |
| 127 | + tls_opts.edge_token.as_deref(), |
| 128 | + ) |
| 129 | + .ok()?; |
| 130 | + Some(OpenShellClient::with_interceptor(channel, interceptor)) |
98 | 131 | } |
99 | 132 |
|
100 | 133 | /// Run an async future on a dedicated thread to avoid nested tokio runtime panics. |
|
0 commit comments