Skip to content

Commit aaf6fec

Browse files
authored
Agent auto discovery conflicts with explicit task type mapping (#6464)
Signed-off-by: Alex Wu <c.alexwu@gmail.com>
1 parent 8bf1de6 commit aaf6fec

4 files changed

Lines changed: 130 additions & 10 deletions

File tree

flyteplugins/go/tasks/plugins/webapi/agent/client.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,11 @@ func getAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
150150
strings.Join(maps.Keys(agentSupportedTaskCategories), ", "))
151151
}
152152

153-
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
153+
// Always replace the registry with the settings defined in the configuration
154154
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
155155
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
156-
if _, ok := newAgentRegistry[taskType]; !ok {
157-
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
158-
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
159-
}
156+
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
157+
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
160158
}
161159
}
162160

flyteplugins/go/tasks/plugins/webapi/agent/client_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/mock"
9+
10+
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
11+
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
812
)
913

1014
func TestInitializeClients(t *testing.T) {
@@ -26,3 +30,61 @@ func TestInitializeClients(t *testing.T) {
2630
_, ok = cs.asyncAgentClients["x"]
2731
assert.True(t, ok)
2832
}
33+
34+
func TestAgentForTaskTypesAlwaysOverwrite(t *testing.T) {
35+
deploymentX := Deployment{Endpoint: "x"}
36+
deploymentY := Deployment{Endpoint: "y"}
37+
deploymentZ := Deployment{Endpoint: "z"}
38+
cfg := defaultConfig
39+
cfg.AgentDeployments = map[string]*Deployment{
40+
"x": &deploymentX,
41+
"y": &deploymentY,
42+
"z": &deploymentZ,
43+
}
44+
cfg.AgentForTaskTypes = map[string]string{
45+
"task1": "x", // we expect the "task1" task type should always route to deploymentX
46+
}
47+
ctx := context.Background()
48+
err := SetConfig(&cfg)
49+
assert.NoError(t, err)
50+
cs := getAgentClientSets(ctx)
51+
52+
// let's mock the "ListAgent" behaviour for 3 deployments
53+
// they both have SupportedTaskTypes "task1"
54+
mockClientForDeploymentX := mocks.NewAgentMetadataServiceClient(t)
55+
mockClientForDeploymentY := mocks.NewAgentMetadataServiceClient(t)
56+
mockClientForDeploymentZ := mocks.NewAgentMetadataServiceClient(t)
57+
mockClientForDeploymentX.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
58+
Agents: []*admin.Agent{
59+
{
60+
Name: "agent1",
61+
SupportedTaskTypes: []string{"task1"},
62+
},
63+
},
64+
}, nil)
65+
mockClientForDeploymentY.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
66+
Agents: []*admin.Agent{
67+
{
68+
Name: "agent2",
69+
SupportedTaskTypes: []string{"task1"},
70+
},
71+
},
72+
}, nil)
73+
mockClientForDeploymentZ.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
74+
Agents: []*admin.Agent{
75+
{
76+
Name: "agent3",
77+
SupportedTaskTypes: []string{"task1"},
78+
},
79+
},
80+
}, nil)
81+
cs.agentMetadataClients[deploymentX.Endpoint] = mockClientForDeploymentX
82+
cs.agentMetadataClients[deploymentY.Endpoint] = mockClientForDeploymentY
83+
cs.agentMetadataClients[deploymentZ.Endpoint] = mockClientForDeploymentZ
84+
// while auto-discovery execute in getAgentRegistry function, the deployment of task1 will be amended to deploymentZ
85+
// but the always-overwrite policy will overwrite deployment of task1 back to deploymentX according to cfg.AgentForTaskTypes
86+
registry := getAgentRegistry(ctx, cs)
87+
finalDeployment := registry["task1"][defaultTaskTypeVersion].AgentDeployment
88+
expectedDeployment := &deploymentX
89+
assert.Equal(t, finalDeployment, expectedDeployment)
90+
}

flyteplugins/go/tasks/plugins/webapi/connector/client.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,11 @@ func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry {
149149
strings.Join(maps.Keys(connectorSupportedTaskCategories), ", "))
150150
}
151151

152-
// If the connector doesn't implement the metadata service, we construct the registry based on the configuration
152+
// Always replace the connector registry with the settings defined in the configuration
153153
for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes {
154154
if connectorDeployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok {
155-
if _, ok := newConnectorRegistry[taskType]; !ok {
156-
connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: false}
157-
newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector}
158-
}
155+
connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: false}
156+
newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector}
159157
}
160158
}
161159

flyteplugins/go/tasks/plugins/webapi/connector/client_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/mock"
9+
10+
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
11+
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
812
)
913

1014
func TestInitializeClients(t *testing.T) {
@@ -26,3 +30,61 @@ func TestInitializeClients(t *testing.T) {
2630
_, ok = cs.asyncConnectorClients["x"]
2731
assert.True(t, ok)
2832
}
33+
34+
func TestAgentForTaskTypesAlwaysOverwrite(t *testing.T) {
35+
deploymentX := Deployment{Endpoint: "x"}
36+
deploymentY := Deployment{Endpoint: "y"}
37+
deploymentZ := Deployment{Endpoint: "z"}
38+
cfg := defaultConfig
39+
cfg.ConnectorDeployments = map[string]*Deployment{
40+
"x": &deploymentX,
41+
"y": &deploymentY,
42+
"z": &deploymentZ,
43+
}
44+
cfg.ConnectorForTaskTypes = map[string]string{
45+
"task1": "x", // we expect the "task1" task type should always route to deploymentX
46+
}
47+
ctx := context.Background()
48+
err := SetConfig(&cfg)
49+
assert.NoError(t, err)
50+
cs := getConnectorClientSets(ctx)
51+
52+
// let's mock the "ListAgent" behaviour for 3 deployments
53+
// they both have SupportedTaskTypes "task1"
54+
mockClientForDeploymentX := mocks.NewAgentMetadataServiceClient(t)
55+
mockClientForDeploymentY := mocks.NewAgentMetadataServiceClient(t)
56+
mockClientForDeploymentZ := mocks.NewAgentMetadataServiceClient(t)
57+
mockClientForDeploymentX.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
58+
Agents: []*admin.Agent{
59+
{
60+
Name: "connector1",
61+
SupportedTaskTypes: []string{"task1"},
62+
},
63+
},
64+
}, nil)
65+
mockClientForDeploymentY.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
66+
Agents: []*admin.Agent{
67+
{
68+
Name: "connector2",
69+
SupportedTaskTypes: []string{"task1"},
70+
},
71+
},
72+
}, nil)
73+
mockClientForDeploymentZ.On("ListAgents", mock.Anything, mock.Anything).Return(&admin.ListAgentsResponse{
74+
Agents: []*admin.Agent{
75+
{
76+
Name: "connector3",
77+
SupportedTaskTypes: []string{"task1"},
78+
},
79+
},
80+
}, nil)
81+
cs.connectorMetadataClients[deploymentX.Endpoint] = mockClientForDeploymentX
82+
cs.connectorMetadataClients[deploymentY.Endpoint] = mockClientForDeploymentY
83+
cs.connectorMetadataClients[deploymentZ.Endpoint] = mockClientForDeploymentZ
84+
// while auto-discovery execute in getAgentRegistry function, the deployment of task1 will be amended to deploymentZ
85+
// but the always-overwrite policy will overwrite deployment of task1 back to deploymentX according to cfg.AgentForTaskTypes
86+
registry := getConnectorRegistry(ctx, cs)
87+
finalDeployment := registry["task1"][defaultTaskTypeVersion].ConnectorDeployment
88+
expectedDeployment := &deploymentX
89+
assert.Equal(t, finalDeployment, expectedDeployment)
90+
}

0 commit comments

Comments
 (0)