Skip to content

Commit 4745e69

Browse files
authored
Pass literals by URI for LiteralsToLaunchFormJson (#7502)
Signed-off-by: Katrina Rogan <katroganGH@gmail.com>
1 parent a680e61 commit 4745e69

19 files changed

Lines changed: 1608 additions & 1144 deletions

dataproxy/service/dataproxy_service.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ func (s *Service) GetActionData(
500500
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read inputs from %s: %w", inputRef, err))
501501
}
502502
} else {
503+
resp.InputsUri = urisResp.Msg.GetInputsUri()
503504
logger.Debugf(groupCtx, "Read %d input literals and %d action contexts", len(resp.Inputs.Literals), len(resp.Inputs.Context))
504505
}
505506
return nil
@@ -523,6 +524,7 @@ func (s *Service) GetActionData(
523524
resp.Outputs = &task.Outputs{
524525
Literals: inputsOrOutputs.GetLiterals(),
525526
}
527+
resp.OutputsUri = urisResp.Msg.GetOutputsUri()
526528
logger.Debugf(groupCtx, "Read %d output literals", len(resp.Outputs.Literals))
527529
}
528530
return nil

dataproxy/service/dataproxy_service_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,15 @@ func TestGetActionData(t *testing.T) {
653653
assert.Len(t, resp.Msg.GetOutputs().GetLiterals(), tt.expectOutputsLen)
654654
if tt.expectInputsLen > 0 {
655655
assert.Equal(t, "x", resp.Msg.GetInputs().GetLiterals()[0].GetName())
656+
assert.Equal(t, tt.inputsURI, resp.Msg.GetInputsUri())
657+
} else {
658+
assert.Empty(t, resp.Msg.GetInputsUri())
656659
}
657660
if tt.expectOutputsLen > 0 {
658661
assert.Equal(t, "o", resp.Msg.GetOutputs().GetLiterals()[0].GetName())
662+
assert.Equal(t, tt.outputsURI, resp.Msg.GetOutputsUri())
663+
} else {
664+
assert.Empty(t, resp.Msg.GetOutputsUri())
659665
}
660666
})
661667
}

dataproxy/setup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
6767
sc.Mux.Handle(clusterPath, clusterHandler)
6868
logger.Infof(ctx, "Mounted ClusterService at %s", clusterPath)
6969

70-
translatorSvc := NewTranslatorService()
70+
translatorSvc := NewTranslatorService(sc.DataStore, runClient)
7171
translatorPath, translatorHandler := workflowconnect.NewTranslatorServiceHandler(translatorSvc, connect.WithInterceptors(otelInterceptor))
7272
sc.Mux.Handle(translatorPath, translatorHandler)
7373
logger.Infof(ctx, "Mounted TranslatorService at %s", translatorPath)

dataproxy/translator.go

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package dataproxy
22

33
import (
44
"context"
5+
"fmt"
56

67
"connectrpc.com/connect"
78

89
"github.com/flyteorg/flyte/v2/dataproxy/converter"
10+
"github.com/flyteorg/flyte/v2/flytestdlib/storage"
11+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task"
912
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
1013
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect"
1114
)
@@ -15,10 +18,16 @@ import (
1518
// binary so that translation requests do not transit the control plane.
1619
type TranslatorService struct {
1720
workflowconnect.UnimplementedTranslatorServiceHandler
21+
22+
dataStore *storage.DataStore
23+
runClient workflowconnect.RunServiceClient
1824
}
1925

20-
func NewTranslatorService() *TranslatorService {
21-
return &TranslatorService{}
26+
func NewTranslatorService(dataStore *storage.DataStore, runClient workflowconnect.RunServiceClient) *TranslatorService {
27+
return &TranslatorService{
28+
dataStore: dataStore,
29+
runClient: runClient,
30+
}
2231
}
2332

2433
var _ workflowconnect.TranslatorServiceHandler = (*TranslatorService)(nil)
@@ -27,13 +36,55 @@ func (s *TranslatorService) LiteralsToLaunchFormJson(
2736
ctx context.Context,
2837
req *connect.Request[workflow.LiteralsToLaunchFormJsonRequest],
2938
) (*connect.Response[workflow.LiteralsToLaunchFormJsonResponse], error) {
30-
schema, err := converter.LiteralsToLaunchFormJson(ctx, req.Msg.GetLiterals(), req.Msg.GetVariables())
39+
literals := req.Msg.GetLiterals()
40+
if req.Msg.GetLiteralsUri() != "" {
41+
var err error
42+
literals, err = s.readOffloadedLiterals(ctx, req.Msg)
43+
if err != nil {
44+
return nil, err
45+
}
46+
}
47+
schema, err := converter.LiteralsToLaunchFormJson(ctx, literals, req.Msg.GetVariables())
3148
if err != nil {
3249
return nil, err
3350
}
3451
return connect.NewResponse(&workflow.LiteralsToLaunchFormJsonResponse{Json: schema}), nil
3552
}
3653

54+
// readOffloadedLiterals reads action literals from the object store location named by
55+
// literals_uri. The URI must match one of the action's data URIs reported by RunService,
56+
// which both authorizes the read (RunService checks access to the action) and prevents
57+
// arbitrary storage paths from being supplied.
58+
func (s *TranslatorService) readOffloadedLiterals(
59+
ctx context.Context,
60+
req *workflow.LiteralsToLaunchFormJsonRequest,
61+
) ([]*task.NamedLiteral, error) {
62+
actionID := req.GetActionId()
63+
if actionID == nil {
64+
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("action_id is required when literals_uri is set"))
65+
}
66+
67+
urisResp, err := s.runClient.GetActionDataURIs(ctx, connect.NewRequest(&workflow.GetActionDataURIsRequest{
68+
ActionId: actionID,
69+
}))
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
uri := req.GetLiteralsUri()
75+
if uri != urisResp.Msg.GetInputsUri() && uri != urisResp.Msg.GetOutputsUri() {
76+
return nil, connect.NewError(connect.CodeInvalidArgument,
77+
fmt.Errorf("literals_uri does not match any data URI of action %s", actionID.GetName()))
78+
}
79+
80+
// Both inputs.pb and outputs.pb deserialize as task.Inputs (a NamedLiteral list).
81+
var inputsOrOutputs task.Inputs
82+
if err := s.dataStore.ReadProtobuf(ctx, storage.DataReference(uri), &inputsOrOutputs); err != nil {
83+
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read literals from %s: %w", uri, err))
84+
}
85+
return inputsOrOutputs.GetLiterals(), nil
86+
}
87+
3788
func (s *TranslatorService) LaunchFormJsonToLiterals(
3889
ctx context.Context,
3990
req *connect.Request[workflow.LaunchFormJsonToLiteralsRequest],

dataproxy/translator_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package dataproxy
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"connectrpc.com/connect"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/mock"
10+
"github.com/stretchr/testify/require"
11+
"google.golang.org/protobuf/proto"
12+
13+
"github.com/flyteorg/flyte/v2/flytestdlib/storage"
14+
storageMocks "github.com/flyteorg/flyte/v2/flytestdlib/storage/mocks"
15+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
16+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
17+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task"
18+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
19+
workflowMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks"
20+
)
21+
22+
func testActionID() *common.ActionIdentifier {
23+
return &common.ActionIdentifier{
24+
Run: &common.RunIdentifier{
25+
Org: "org",
26+
Project: "proj",
27+
Domain: "dev",
28+
Name: "run1",
29+
},
30+
Name: "a0",
31+
}
32+
}
33+
34+
func testNamedLiterals() []*task.NamedLiteral {
35+
return []*task.NamedLiteral{
36+
{
37+
Name: "test",
38+
Value: &core.Literal{
39+
Value: &core.Literal_Scalar{
40+
Scalar: &core.Scalar{
41+
Value: &core.Scalar_Primitive{
42+
Primitive: &core.Primitive{
43+
Value: &core.Primitive_StringValue{StringValue: "hello world"},
44+
},
45+
},
46+
},
47+
},
48+
},
49+
},
50+
}
51+
}
52+
53+
func testVariableMap() *core.VariableMap {
54+
return &core.VariableMap{
55+
Variables: []*core.VariableEntry{
56+
{
57+
Key: "test",
58+
Value: &core.Variable{
59+
Type: &core.LiteralType{
60+
Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING},
61+
},
62+
},
63+
},
64+
},
65+
}
66+
}
67+
68+
func assertHelloWorldSchema(t *testing.T, resp *connect.Response[workflow.LiteralsToLaunchFormJsonResponse]) {
69+
t.Helper()
70+
schema := resp.Msg.GetJson().AsMap()
71+
properties, ok := schema["properties"].(map[string]any)
72+
require.True(t, ok)
73+
testField, ok := properties["test"].(map[string]any)
74+
require.True(t, ok)
75+
assert.Equal(t, "hello world", testField["default"])
76+
}
77+
78+
func TestLiteralsToLaunchFormJson_Inline(t *testing.T) {
79+
svc := NewTranslatorService(nil, nil)
80+
81+
resp, err := svc.LiteralsToLaunchFormJson(context.Background(), connect.NewRequest(&workflow.LiteralsToLaunchFormJsonRequest{
82+
Literals: testNamedLiterals(),
83+
Variables: testVariableMap(),
84+
}))
85+
86+
require.NoError(t, err)
87+
assertHelloWorldSchema(t, resp)
88+
}
89+
90+
func TestLiteralsToLaunchFormJson_OffloadedURI(t *testing.T) {
91+
inputsURI := "s3://test-bucket/metadata/proj/dev/run1/a0/inputs.pb"
92+
storedInputs := &task.Inputs{Literals: testNamedLiterals()}
93+
94+
runClient := workflowMocks.NewRunServiceClient(t)
95+
runClient.EXPECT().GetActionDataURIs(mock.Anything, mock.Anything).Return(
96+
connect.NewResponse(&workflow.GetActionDataURIsResponse{
97+
InputsUri: inputsURI,
98+
OutputsUri: "s3://test-bucket/metadata/proj/dev/run1/a0/outputs.pb",
99+
}), nil)
100+
101+
mockComposedStore := storageMocks.NewComposedProtobufStore(t)
102+
mockComposedStore.On("ReadProtobuf", mock.Anything, storage.DataReference(inputsURI), mock.Anything).
103+
Run(func(args mock.Arguments) {
104+
msg := args.Get(2).(proto.Message)
105+
proto.Reset(msg)
106+
proto.Merge(msg, storedInputs)
107+
}).Return(nil)
108+
109+
svc := NewTranslatorService(&storage.DataStore{ComposedProtobufStore: mockComposedStore}, runClient)
110+
111+
resp, err := svc.LiteralsToLaunchFormJson(context.Background(), connect.NewRequest(&workflow.LiteralsToLaunchFormJsonRequest{
112+
Variables: testVariableMap(),
113+
LiteralsUri: inputsURI,
114+
ActionId: testActionID(),
115+
}))
116+
117+
require.NoError(t, err)
118+
assertHelloWorldSchema(t, resp)
119+
}
120+
121+
func TestLiteralsToLaunchFormJson_OffloadedURI_MissingActionId(t *testing.T) {
122+
svc := NewTranslatorService(nil, nil)
123+
124+
_, err := svc.LiteralsToLaunchFormJson(context.Background(), connect.NewRequest(&workflow.LiteralsToLaunchFormJsonRequest{
125+
Variables: testVariableMap(),
126+
LiteralsUri: "s3://test-bucket/metadata/proj/dev/run1/a0/inputs.pb",
127+
}))
128+
129+
require.Error(t, err)
130+
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
131+
assert.Contains(t, err.Error(), "action_id is required")
132+
}
133+
134+
func TestLiteralsToLaunchFormJson_OffloadedURI_Mismatch(t *testing.T) {
135+
runClient := workflowMocks.NewRunServiceClient(t)
136+
runClient.EXPECT().GetActionDataURIs(mock.Anything, mock.Anything).Return(
137+
connect.NewResponse(&workflow.GetActionDataURIsResponse{
138+
InputsUri: "s3://test-bucket/metadata/proj/dev/run1/a0/inputs.pb",
139+
OutputsUri: "s3://test-bucket/metadata/proj/dev/run1/a0/outputs.pb",
140+
}), nil)
141+
142+
svc := NewTranslatorService(nil, runClient)
143+
144+
_, err := svc.LiteralsToLaunchFormJson(context.Background(), connect.NewRequest(&workflow.LiteralsToLaunchFormJsonRequest{
145+
Variables: testVariableMap(),
146+
LiteralsUri: "s3://test-bucket/some/other/object.pb",
147+
ActionId: testActionID(),
148+
}))
149+
150+
require.Error(t, err)
151+
assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
152+
assert.Contains(t, err.Error(), "does not match any data URI")
153+
}

flyteidl2/dataproxy/dataproxy_service.proto

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ message GetActionDataResponse {
193193

194194
// Outputs for the action.
195195
task.Outputs outputs = 2;
196+
197+
// Raw object store URI (e.g. s3://bucket/...) for the action's offloaded inputs.
198+
// Empty if the inputs are only available inline (e.g. condition actions).
199+
string inputs_uri = 3;
200+
201+
// Raw object store URI for the action's offloaded outputs.
202+
// Empty if the action hasn't succeeded, has no outputs, or the outputs are only available inline.
203+
string outputs_uri = 4;
196204
}
197205

198206
// Request message for tailing logs.

flyteidl2/workflow/translator_service.proto

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ syntax = "proto3";
22

33
package flyteidl2.workflow;
44

5+
import "flyteidl2/common/identifier.proto";
56
import "flyteidl2/core/interface.proto";
67
import "flyteidl2/task/common.proto";
78
import "flyteidl2/task/task_definition.proto";
@@ -26,9 +27,19 @@ service TranslatorService {
2627
}
2728

2829
message LiteralsToLaunchFormJsonRequest {
29-
// The literals to convert to JSON.
30+
// The literals to convert to JSON. Ignored when literals_uri is set.
3031
repeated task.NamedLiteral literals = 1;
3132
flyteidl2.core.VariableMap variables = 2;
33+
34+
// Raw object store URI for offloaded action literals, as returned by
35+
// DataProxyService.GetActionData (inputs_uri / outputs_uri). When set, the
36+
// service reads the literals from storage instead of requiring them inline,
37+
// avoiding a round trip of potentially large payloads through the client.
38+
string literals_uri = 3;
39+
40+
// Identifies the action that owns literals_uri. Required when literals_uri is
41+
// set; the URI is validated against the action's data URIs before being read.
42+
common.ActionIdentifier action_id = 4;
3243
}
3344

3445
message LiteralsToLaunchFormJsonResponse {

0 commit comments

Comments
 (0)