Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions decentralized-api/apiconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,32 @@ type SeedInfo struct {
}

type ApiConfig struct {
Port int `koanf:"port" json:"port"`
PoCCallbackUrl string `koanf:"poc_callback_url" json:"poc_callback_url"`
MlGrpcCallbackAddress string `koanf:"ml_grpc_callback_address" json:"ml_grpc_callback_address"`
PublicUrl string `koanf:"public_url" json:"public_url"`
PublicServerPort int `koanf:"public_server_port" json:"public_server_port"`
MLServerPort int `koanf:"ml_server_port" json:"ml_server_port"`
AdminServerPort int `koanf:"admin_server_port" json:"admin_server_port"`
MlGrpcServerPort int `koanf:"ml_grpc_server_port" json:"ml_grpc_server_port"`
TestMode bool `koanf:"test_mode" json:"test_mode"`
NodeManagerGrpcPort int `koanf:"node_manager_grpc_port" json:"node_manager_grpc_port"`
NodeManagerLockTTLSeconds int `koanf:"node_manager_lock_ttl_seconds" json:"node_manager_lock_ttl_seconds"`
Port int `koanf:"port" json:"port"`
PoCCallbackUrl string `koanf:"poc_callback_url" json:"poc_callback_url"`
MlGrpcCallbackAddress string `koanf:"ml_grpc_callback_address" json:"ml_grpc_callback_address"`
PublicUrl string `koanf:"public_url" json:"public_url"`
PublicServerPort int `koanf:"public_server_port" json:"public_server_port"`
MLServerPort int `koanf:"ml_server_port" json:"ml_server_port"`
AdminServerPort int `koanf:"admin_server_port" json:"admin_server_port"`
MlGrpcServerPort int `koanf:"ml_grpc_server_port" json:"ml_grpc_server_port"`
TestMode bool `koanf:"test_mode" json:"test_mode"`
NodeManagerGrpcPort int `koanf:"node_manager_grpc_port" json:"node_manager_grpc_port"`
NodeManagerLockTTLSeconds int `koanf:"node_manager_lock_ttl_seconds" json:"node_manager_lock_ttl_seconds"`
MLNodeTLS MLNodeTLSConfig `koanf:"mlnode_tls" json:"mlnode_tls"`
}

type MLNodeTLSConfig struct {
Enabled bool `koanf:"enabled" json:"enabled"`
CertFile string `koanf:"cert_file" json:"cert_file"`
KeyFile string `koanf:"key_file" json:"key_file"`
PeerCertFile string `koanf:"peer_cert_file" json:"peer_cert_file"`
}

func (c MLNodeTLSConfig) Scheme() string {
if c.Enabled {
return "https"
}
return "http"
}

type ChainNodeConfig struct {
Expand Down
42 changes: 35 additions & 7 deletions decentralized-api/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"slices"
"sort"
Expand Down Expand Up @@ -171,25 +172,33 @@ type Node struct {
}

func (n *Node) InferenceUrl() string {
return fmt.Sprintf("http://%s:%d%s", n.Host, n.InferencePort, n.InferenceSegment)
return n.inferenceUrlWithScheme("http", "")
}

func (n *Node) InferenceUrlWithVersion(version string) string {
return n.inferenceUrlWithScheme("http", version)
}

func (n *Node) inferenceUrlWithScheme(scheme, version string) string {
if version == "" {
return n.InferenceUrl()
return fmt.Sprintf("%s://%s:%d%s", scheme, n.Host, n.InferencePort, n.InferenceSegment)
}
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.InferencePort, version, n.InferenceSegment)
return fmt.Sprintf("%s://%s:%d/%s%s", scheme, n.Host, n.InferencePort, version, n.InferenceSegment)
}

func (n *Node) PoCUrl() string {
return fmt.Sprintf("http://%s:%d%s", n.Host, n.PoCPort, n.PoCSegment)
return n.pocUrlWithScheme("http", "")
}

func (n *Node) PoCUrlWithVersion(version string) string {
return n.pocUrlWithScheme("http", version)
}

func (n *Node) pocUrlWithScheme(scheme, version string) string {
if version == "" {
return n.PoCUrl()
return fmt.Sprintf("%s://%s:%d%s", scheme, n.Host, n.PoCPort, n.PoCSegment)
}
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.PoCPort, version, n.PoCSegment)
return fmt.Sprintf("%s://%s:%d/%s%s", scheme, n.Host, n.PoCPort, version, n.PoCSegment)
}

type NodeWithState struct {
Expand Down Expand Up @@ -450,7 +459,26 @@ func (b *Broker) QueueMessage(command Command) error {

func (b *Broker) NewNodeClient(node *Node) mlnodeclient.MLNodeClient {
version := b.configManager.GetCurrentNodeVersion()
return b.mlNodeClientFactory.CreateClient(node.PoCUrlWithVersion(version), node.InferenceUrlWithVersion(version))
return b.mlNodeClientFactory.CreateClient(b.pocUrlForNode(node, version), b.InferenceUrlForNode(node, version))
}

func (b *Broker) NewMLNodeHTTPClient(timeout time.Duration) *http.Client {
return b.mlNodeClientFactory.NewHTTPClient(timeout)
}

func (b *Broker) mlNodeScheme() string {
if b.configManager != nil {
return b.configManager.GetApiConfig().MLNodeTLS.Scheme()
}
return "http"
}

func (b *Broker) pocUrlForNode(node *Node, version string) string {
return node.pocUrlWithScheme(b.mlNodeScheme(), version)
}

func (b *Broker) InferenceUrlForNode(node *Node, version string) string {
return node.inferenceUrlWithScheme(b.mlNodeScheme(), version)
}

func (b *Broker) lockAvailableNode(command LockAvailableNode) {
Expand Down
2 changes: 1 addition & 1 deletion decentralized-api/broker/node_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (b *Broker) AcquireMLNode(ctx context.Context, model string, skipNodeIDs []
b.lockMap[lockID] = lockEntry{nodeID: node.Id, createdAt: time.Now()}
b.lockMapMu.Unlock()
version := b.configManager.GetCurrentNodeVersion()
return lockID, node.InferenceUrlWithVersion(version), node.Id, nil
return lockID, b.InferenceUrlForNode(node, version), node.Id, nil
}
}

Expand Down
4 changes: 2 additions & 2 deletions decentralized-api/broker/node_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ func (w *NodeWorker) CheckClientVersionAlive(version string, factory mlnodeclien
}

node := w.node.Node
pocUrl := node.PoCUrlWithVersion(version)
inferenceUrl := node.InferenceUrlWithVersion(version)
pocUrl := w.broker.pocUrlForNode(&node, version)
inferenceUrl := w.broker.InferenceUrlForNode(&node, version)

versionClient := factory.CreateClient(pocUrl, inferenceUrl)
_, err := versionClient.NodeState(context.Background())
Expand Down
27 changes: 26 additions & 1 deletion decentralized-api/cmd/devshardd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (

"decentralized-api/apiconfig"
internaldevshard "decentralized-api/internal/devshard"
"decentralized-api/internal/mtls"
pserver "decentralized-api/internal/server/public"
"decentralized-api/payloadstorage"

Expand Down Expand Up @@ -135,7 +136,10 @@ func main() {
3*time.Minute,
)

httpClient := pserver.NewNoRedirectClient(internaldevshard.MLNodeHTTPTimeout)
httpClient, err := newMLNodeHTTPClient()
if err != nil {
log.Fatalf("mlnode http client: %v", err)
}

availabilityTracker := devshardpkg.NewAvailabilityTracker(true, 0, 0)
chainParams := newChainParamsProvider(ctx, recorder, availabilityTracker)
Expand Down Expand Up @@ -413,6 +417,27 @@ func envOr(key, fallback string) string {
return fallback
}

func newMLNodeHTTPClient() (*http.Client, error) {
client := pserver.NewNoRedirectClient(internaldevshard.MLNodeHTTPTimeout)

certFile := os.Getenv("MLNODE_TLS_CERT_FILE")
keyFile := os.Getenv("MLNODE_TLS_KEY_FILE")
peerCertFile := os.Getenv("MLNODE_TLS_PEER_CERT_FILE")
if certFile == "" && keyFile == "" && peerCertFile == "" {
return client, nil
}

tlsConfig, err := mtls.ClientConfig(certFile, keyFile, peerCertFile)
if err != nil {
return nil, err
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = tlsConfig
client.Transport = transport
slog.Info("mlnode mTLS client enabled", "cert", certFile, "pinned_peer_cert", peerCertFile)
return client, nil
}

func expandHome(path string) (string, error) {
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
Expand Down
2 changes: 1 addition & 1 deletion decentralized-api/internal/devshard/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (e *EngineAdapter) Execute(ctx context.Context, req devshard.ExecuteRequest
func (e *EngineAdapter) executeMLRequest(ctx context.Context, model string, body []byte) (*http.Response, error) {
resp, err := broker.DoWithLockedNodeHTTPRetry(e.broker, model, nil, 3,
func(node *broker.Node) (*http.Response, *broker.ActionError) {
url := node.InferenceUrlWithVersion(e.nodeVersion) + "/v1/chat/completions"
url := e.broker.InferenceUrlForNode(node, e.nodeVersion) + "/v1/chat/completions"
httpReq, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if reqErr != nil {
return nil, broker.NewApplicationActionError(reqErr)
Expand Down
2 changes: 1 addition & 1 deletion decentralized-api/internal/devshard/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (v *ValidationAdapter) Validate(ctx context.Context, req devshard.ValidateR
func (v *ValidationAdapter) executeMLRequest(ctx context.Context, model string, body []byte) (*http.Response, error) {
resp, err := broker.DoWithLockedNodeHTTPRetry(v.broker, model, nil, 3,
func(node *broker.Node) (*http.Response, *broker.ActionError) {
url := node.InferenceUrlWithVersion(v.nodeVersion) + "/v1/chat/completions"
url := v.broker.InferenceUrlForNode(node, v.nodeVersion) + "/v1/chat/completions"
httpReq, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if reqErr != nil {
return nil, broker.NewApplicationActionError(reqErr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type NodesConfigManagerInterface interface {
GetNodes() []apiconfig.InferenceNodeConfig
GetCurrentNodeVersion() string
SetNodes(nodes []apiconfig.InferenceNodeConfig) error
GetApiConfig() apiconfig.ApiConfig
}

// PhaseTrackerInterface defines the minimal interface needed from PhaseTracker
Expand Down Expand Up @@ -125,8 +126,8 @@ func (m *MLNodeBackgroundManager) isInDownloadWindow(epochState *chainphase.Epoc
// checkNodeModels checks and downloads models for a specific node
func (m *MLNodeBackgroundManager) checkNodeModels(node apiconfig.InferenceNodeConfig) {
version := m.configManager.GetCurrentNodeVersion()
pocUrl := getPoCUrlWithVersion(node, version)
inferenceUrl := getInferenceUrlWithVersion(node, version)
pocUrl := m.pocUrlForNode(node, version)
inferenceUrl := m.inferenceUrlForNode(node, version)
client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down Expand Up @@ -192,42 +193,20 @@ func (m *MLNodeBackgroundManager) checkNodeModels(node apiconfig.InferenceNodeCo
}
}

func getPoCUrlWithVersion(node apiconfig.InferenceNodeConfig, version string) string {
func (m *MLNodeBackgroundManager) pocUrlForNode(node apiconfig.InferenceNodeConfig, version string) string {
scheme := m.configManager.GetApiConfig().MLNodeTLS.Scheme()
if version == "" {
return getPoCUrl(node)
return fmt.Sprintf("%s://%s:%d%s", scheme, node.Host, node.PoCPort, node.PoCSegment)
}
return getPoCUrlVersioned(node, version)
return fmt.Sprintf("%s://%s:%d/%s%s", scheme, node.Host, node.PoCPort, version, node.PoCSegment)
}

func getInferenceUrlWithVersion(node apiconfig.InferenceNodeConfig, version string) string {
func (m *MLNodeBackgroundManager) inferenceUrlForNode(node apiconfig.InferenceNodeConfig, version string) string {
scheme := m.configManager.GetApiConfig().MLNodeTLS.Scheme()
if version == "" {
return getInferenceUrl(node)
return fmt.Sprintf("%s://%s:%d%s", scheme, node.Host, node.InferencePort, node.InferenceSegment)
}
return getInferenceUrlVersioned(node, version)
}

func getPoCUrl(node apiconfig.InferenceNodeConfig) string {
return formatURL(node.Host, node.PoCPort, node.PoCSegment)
}

func getPoCUrlVersioned(node apiconfig.InferenceNodeConfig, version string) string {
return formatURLWithVersion(node.Host, node.PoCPort, version, node.PoCSegment)
}

func getInferenceUrl(node apiconfig.InferenceNodeConfig) string {
return formatURL(node.Host, node.InferencePort, node.InferenceSegment)
}

func getInferenceUrlVersioned(node apiconfig.InferenceNodeConfig, version string) string {
return formatURLWithVersion(node.Host, node.InferencePort, version, node.InferenceSegment)
}

func formatURL(host string, port int, segment string) string {
return fmt.Sprintf("http://%s:%d%s", host, port, segment)
}

func formatURLWithVersion(host string, port int, version string, segment string) string {
return fmt.Sprintf("http://%s:%d/%s%s", host, port, version, segment)
return fmt.Sprintf("%s://%s:%d/%s%s", scheme, node.Host, node.InferencePort, version, node.InferenceSegment)
}

// checkAndUpdateGPUs fetches GPU info from all nodes and updates hardware
Expand Down Expand Up @@ -289,8 +268,8 @@ func (m *MLNodeBackgroundManager) checkAndUpdateGPUs(ctx context.Context) {
// fetchNodeGPUHardware fetches GPU devices and transforms to Hardware entries
func (m *MLNodeBackgroundManager) fetchNodeGPUHardware(ctx context.Context, node *apiconfig.InferenceNodeConfig) ([]apiconfig.Hardware, error) {
version := m.configManager.GetCurrentNodeVersion()
pocUrl := getPoCUrlWithVersion(*node, version)
inferenceUrl := getInferenceUrlWithVersion(*node, version)
pocUrl := m.pocUrlForNode(*node, version)
inferenceUrl := m.inferenceUrlForNode(*node, version)
client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl)

timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"decentralized-api/chainphase"
"decentralized-api/mlnodeclient"
"errors"
"net/http"
"testing"
"time"

Expand All @@ -18,12 +19,19 @@ type mockConfigManager struct {
nodes []apiconfig.InferenceNodeConfig
currentNodeVersion string
setNodesError error
mlNodeTLSEnabled bool
}

func (m *mockConfigManager) GetNodes() []apiconfig.InferenceNodeConfig {
return m.nodes
}

func (m *mockConfigManager) GetApiConfig() apiconfig.ApiConfig {
return apiconfig.ApiConfig{
MLNodeTLS: apiconfig.MLNodeTLSConfig{Enabled: m.mlNodeTLSEnabled},
}
}

func (m *mockConfigManager) GetCurrentNodeVersion() string {
return m.currentNodeVersion
}
Expand Down Expand Up @@ -78,6 +86,10 @@ func (m *mockClientFactory) CreateClient(pocUrl, inferenceUrl string) mlnodeclie
return m.client
}

func (m *mockClientFactory) NewHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout}
}

// Custom mock client for testing error handling
type customMockClient struct {
*mlnodeclient.MockClient
Expand Down Expand Up @@ -561,49 +573,28 @@ func TestURLFormatting(t *testing.T) {
InferenceSegment: "/inference",
}

t.Run("PoC URL without version", func(t *testing.T) {
url := getPoCUrl(node)
expected := "http://localhost:8080/api/v1"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
})

t.Run("PoC URL with version", func(t *testing.T) {
url := getPoCUrlVersioned(node, "v2")
expected := "http://localhost:8080/v2/api/v1"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
})

t.Run("Inference URL without version", func(t *testing.T) {
url := getInferenceUrl(node)
m := &MLNodeBackgroundManager{configManager: &mockConfigManager{}}
url := m.inferenceUrlForNode(node, "")
expected := "http://localhost:8081/inference"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
})

t.Run("Inference URL with version", func(t *testing.T) {
url := getInferenceUrlVersioned(node, "v2")
m := &MLNodeBackgroundManager{configManager: &mockConfigManager{}}
url := m.inferenceUrlForNode(node, "v2")
expected := "http://localhost:8081/v2/inference"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
})

t.Run("URL with version helper", func(t *testing.T) {
url := getPoCUrlWithVersion(node, "v2")
expected := "http://localhost:8080/v2/api/v1"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
})

t.Run("URL without version helper (empty string)", func(t *testing.T) {
url := getPoCUrlWithVersion(node, "")
expected := "http://localhost:8080/api/v1"
t.Run("Inference URL with mTLS enabled", func(t *testing.T) {
m := &MLNodeBackgroundManager{configManager: &mockConfigManager{mlNodeTLSEnabled: true}}
url := m.inferenceUrlForNode(node, "v2")
expected := "https://localhost:8081/v2/inference"
if url != expected {
t.Errorf("expected %s, got %s", expected, url)
}
Expand Down
Loading
Loading