diff --git a/backend/src/apiserver/main.go b/backend/src/apiserver/main.go index fc1331ee9c4..f5389f64f57 100644 --- a/backend/src/apiserver/main.go +++ b/backend/src/apiserver/main.go @@ -356,6 +356,12 @@ func startRPCServer(resourceManager *resource.ResourceManager, tlsCfg *tls.Confi ReportServerV1 := server.NewReportServerV1(resourceManager) ReportServer := server.NewReportServer(resourceManager) + VisualizationServerV1 := server.NewVisualizationServerV1(resourceManager) + VisualizationServer := server.NewVisualizationServer(resourceManager) + + AuthServerV1 := server.NewAuthServerV1(resourceManager) + AuthServer := server.NewAuthServer(resourceManager) + apiv1beta1.RegisterExperimentServiceServer(s, ExperimentServerV1) apiv1beta1.RegisterPipelineServiceServer(s, PipelineServerV1) apiv1beta1.RegisterJobServiceServer(s, JobServerV1) @@ -363,20 +369,16 @@ func startRPCServer(resourceManager *resource.ResourceManager, tlsCfg *tls.Confi apiv1beta1.RegisterTaskServiceServer(s, server.NewTaskServer(resourceManager)) apiv1beta1.RegisterReportServiceServer(s, ReportServerV1) - apiv1beta1.RegisterVisualizationServiceServer( - s, - server.NewVisualizationServer( - resourceManager, - common.GetStringConfig(cm.VisualizationServiceHost), - common.GetStringConfig(cm.VisualizationServicePort), - )) - apiv1beta1.RegisterAuthServiceServer(s, server.NewAuthServer(resourceManager)) + apiv1beta1.RegisterVisualizationServiceServer(s, VisualizationServerV1) + apiv1beta1.RegisterAuthServiceServer(s, AuthServerV1) apiv2beta1.RegisterExperimentServiceServer(s, ExperimentServer) apiv2beta1.RegisterPipelineServiceServer(s, PipelineServer) apiv2beta1.RegisterRecurringRunServiceServer(s, JobServer) apiv2beta1.RegisterRunServiceServer(s, RunServer) apiv2beta1.RegisterReportServiceServer(s, ReportServer) + apiv2beta1.RegisterAuthServiceServer(s, AuthServer) + apiv2beta1.RegisterVisualizationServiceServer(s, VisualizationServer) // Register reflection service on gRPC server. reflection.Register(s) @@ -424,6 +426,8 @@ func startHTTPProxy(resourceManager *resource.ResourceManager, usePipelinesKuber register(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService") register(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService") register(apiv2beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService") + register(apiv2beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization") + register(apiv2beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService") sharedPipelineUploadServer := server.NewPipelineUploadServer(resourceManager, &server.PipelineUploadServerOptions{CollectMetrics: *collectMetricsFlag}) runLogServer := server.NewRunLogServer(resourceManager) diff --git a/backend/src/apiserver/server/auth_server.go b/backend/src/apiserver/server/auth_server.go index cc124053009..eaab6a377b1 100644 --- a/backend/src/apiserver/server/auth_server.go +++ b/backend/src/apiserver/server/auth_server.go @@ -19,8 +19,8 @@ import ( "strings" apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -37,12 +37,45 @@ var rbacResourceTypeToGroup = map[string]string{ common.RbacResourceTypeVisualizations: common.RbacPipelinesGroup, } -type AuthServer struct { +type AuthServerV1 struct { resourceManager *resource.ResourceManager apiv1beta1.UnimplementedAuthServiceServer } -func (s *AuthServer) AuthorizeV1(ctx context.Context, request *api.AuthorizeRequest) ( +type AuthServer struct { + resourceManager *resource.ResourceManager + apiv2beta1.UnimplementedAuthServiceServer +} + +func (s *AuthServerV1) AuthorizeV1(ctx context.Context, request *apiv1beta1.AuthorizeRequest) ( + *emptypb.Empty, error, +) { + err := ValidateAuthorizeRequestV1(request) + if err != nil { + return nil, util.Wrap(err, "Authorize request is not valid") + } + + namespace := strings.ToLower(request.GetNamespace()) + verb := strings.ToLower(request.GetVerb().String()) + resource := strings.ToLower(request.GetResources().String()) + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: namespace, + Verb: verb, + Group: rbacResourceTypeToGroup[resource], + Version: common.RbacPipelinesVersion, + Resource: resource, + Subresource: "", + Name: "", + } + err = s.resourceManager.IsAuthorized(ctx, resourceAttributes) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request") + } + + return &emptypb.Empty{}, nil +} + +func (s *AuthServer) Authorize(ctx context.Context, request *apiv2beta1.AuthorizeRequest) ( *emptypb.Empty, error, ) { err := ValidateAuthorizeRequest(request) @@ -70,22 +103,40 @@ func (s *AuthServer) AuthorizeV1(ctx context.Context, request *api.AuthorizeRequ return &emptypb.Empty{}, nil } -func ValidateAuthorizeRequest(request *api.AuthorizeRequest) error { +func ValidateAuthorizeRequestV1(request *apiv1beta1.AuthorizeRequest) error { + if request == nil { + return util.NewInvalidInputError("request object is empty") + } + if len(request.Namespace) == 0 { + return util.NewInvalidInputError("Namespace is empty. Please specify a valid namespace") + } + if request.Resources == apiv1beta1.AuthorizeRequest_UNASSIGNED_RESOURCES { + return util.NewInvalidInputError("Resources not specified. Please specify a valid resources") + } + if request.Verb == apiv1beta1.AuthorizeRequest_UNASSIGNED_VERB { + return util.NewInvalidInputError("Verb not specified. Please specify a valid verb") + } + return nil +} +func ValidateAuthorizeRequest(request *apiv2beta1.AuthorizeRequest) error { if request == nil { return util.NewInvalidInputError("request object is empty") } if len(request.Namespace) == 0 { return util.NewInvalidInputError("Namespace is empty. Please specify a valid namespace") } - if request.Resources == api.AuthorizeRequest_UNASSIGNED_RESOURCES { + if request.Resources == apiv2beta1.AuthorizeRequest_UNASSIGNED_RESOURCES { return util.NewInvalidInputError("Resources not specified. Please specify a valid resources") } - if request.Verb == api.AuthorizeRequest_UNASSIGNED_VERB { + if request.Verb == apiv2beta1.AuthorizeRequest_UNASSIGNED_VERB { return util.NewInvalidInputError("Verb not specified. Please specify a valid verb") } return nil } +func NewAuthServerV1(resourceManager *resource.ResourceManager) *AuthServerV1 { + return &AuthServerV1{resourceManager: resourceManager} +} func NewAuthServer(resourceManager *resource.ResourceManager) *AuthServer { return &AuthServer{resourceManager: resourceManager} } diff --git a/backend/src/apiserver/server/auth_server_test.go b/backend/src/apiserver/server/auth_server_test.go index a6e0a2d5adb..a8cecab1935 100644 --- a/backend/src/apiserver/server/auth_server_test.go +++ b/backend/src/apiserver/server/auth_server_test.go @@ -19,6 +19,7 @@ import ( "testing" api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/spf13/viper" @@ -30,7 +31,7 @@ import ( func TestAuthorizeRequest_SingleUserMode(t *testing.T) { clients, manager, _ := initWithExperiment(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} clients.SubjectAccessReviewClientFake = client.NewFakeSubjectAccessReviewClientUnauthorized() md := metadata.New(map[string]string{}) @@ -53,7 +54,7 @@ func TestAuthorizeRequest_InvalidRequest(t *testing.T) { clients, manager, _ := initWithExperiment(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} md := metadata.New(map[string]string{}) ctx := metadata.NewIncomingContext(context.Background(), md) @@ -75,7 +76,7 @@ func TestAuthorizeRequest_Authorized(t *testing.T) { clients, manager, _ := initWithExperiment(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) ctx := metadata.NewIncomingContext(context.Background(), md) @@ -96,7 +97,7 @@ func TestAuthorizeRequest_Unauthorized(t *testing.T) { clients, manager, _ := initWithExperiment_SubjectAccessReview_Unauthorized(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} userIdentity := "user@google.com" md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + userIdentity}) @@ -129,7 +130,7 @@ func TestAuthorizeRequest_EmptyUserIdPrefix(t *testing.T) { clients, manager, _ := initWithExperiment(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "user@google.com"}) ctx := metadata.NewIncomingContext(context.Background(), md) @@ -150,7 +151,7 @@ func TestAuthorizeRequest_Unauthenticated(t *testing.T) { clients, manager, _ := initWithExperiment(t) defer clients.Close() - authServer := AuthServer{resourceManager: manager} + authServer := AuthServerV1{resourceManager: manager} md := metadata.New(map[string]string{"no-identity-header": "user"}) ctx := metadata.NewIncomingContext(context.Background(), md) @@ -169,3 +170,146 @@ func TestAuthorizeRequest_Unauthenticated(t *testing.T) { "there is no user identity header", ) } + +func TestAuthorizeV2Request_SingleUserMode(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + clients.SubjectAccessReviewClientFake = client.NewFakeSubjectAccessReviewClientUnauthorized() + + md := metadata.New(map[string]string{}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "ns1", + Resources: apiv2beta1.AuthorizeRequest_VIEWERS, + Verb: apiv2beta1.AuthorizeRequest_GET, + } + + _, err := authServer.Authorize(ctx, request) + // Authz is completely skipped without checking anything. + assert.Nil(t, err) +} + +func TestAuthorizeV2Request_InvalidRequest(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + + md := metadata.New(map[string]string{}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "", + Resources: apiv2beta1.AuthorizeRequest_UNASSIGNED_RESOURCES, + Verb: apiv2beta1.AuthorizeRequest_UNASSIGNED_VERB, + } + + _, err := authServer.Authorize(ctx, request) + assert.Error(t, err) + assert.EqualError(t, err, "Authorize request is not valid: Invalid input error: Namespace is empty. Please specify a valid namespace") +} + +func TestAuthorizeV2Request_Authorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "ns1", + Resources: apiv2beta1.AuthorizeRequest_VIEWERS, + Verb: apiv2beta1.AuthorizeRequest_GET, + } + + _, err := authServer.Authorize(ctx, request) + assert.Nil(t, err) +} + +func TestAuthorizeV2Request_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + clients, manager, _ := initWithExperiment_SubjectAccessReview_Unauthorized(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + + userIdentity := "user@google.com" + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + userIdentity}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "ns1", + Resources: apiv2beta1.AuthorizeRequest_VIEWERS, + Verb: apiv2beta1.AuthorizeRequest_GET, + } + + _, err := authServer.Authorize(ctx, request) + assert.Error(t, err) + + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: "ns1", + Verb: common.RbacResourceVerbGet, + Group: common.RbacKubeflowGroup, + Version: common.RbacPipelinesVersion, + Resource: common.RbacResourceTypeViewers, + } + assert.EqualError(t, err, wrapFailedAuthzRequestError(getPermissionDeniedError(userIdentity, resourceAttributes)).Error()) +} + +func TestAuthorizeV2Request_EmptyUserIdPrefix(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + viper.Set(common.KubeflowUserIDPrefix, "") + defer viper.Set(common.KubeflowUserIDPrefix, common.GoogleIAPUserIdentityPrefix) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "ns1", + Resources: apiv2beta1.AuthorizeRequest_VIEWERS, + Verb: apiv2beta1.AuthorizeRequest_GET, + } + + _, err := authServer.Authorize(ctx, request) + assert.Nil(t, err) +} + +func TestAuthorizeV2Request_Unauthenticated(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + authServer := AuthServer{resourceManager: manager} + + md := metadata.New(map[string]string{"no-identity-header": "user"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + request := &apiv2beta1.AuthorizeRequest{ + Namespace: "ns1", + Resources: apiv2beta1.AuthorizeRequest_VIEWERS, + Verb: apiv2beta1.AuthorizeRequest_GET, + } + + _, err := authServer.Authorize(ctx, request) + assert.NotNil(t, err) + assert.Contains( + t, + err.Error(), + "there is no user identity header", + ) +} diff --git a/backend/src/apiserver/server/visualization_server.go b/backend/src/apiserver/server/visualization_server.go index ed6fbf3964f..eee5192b3f7 100644 --- a/backend/src/apiserver/server/visualization_server.go +++ b/backend/src/apiserver/server/visualization_server.go @@ -19,14 +19,15 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/golang/glog" - "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -39,19 +40,50 @@ const ( visualizationServicePort = "VisualizationService.Port" ) -type VisualizationServer struct { +func buildVisualizationServiceURL(namespace string) string { + host := common.GetStringConfig(visualizationServiceName) + if common.IsMultiUserMode() && len(namespace) > 0 { + host = fmt.Sprintf("%s.%s", host, namespace) + } + u := &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(host, common.GetStringConfig(visualizationServicePort)), + } + return u.String() +} + +func isVisualizationServiceAlive(serviceURL string) error { + resp, err := http.Get(serviceURL) + if err != nil { + wrappedErr := util.Wrapf(err, "Unable to verify visualization service aliveness by sending request to %s", serviceURL) + glog.Error(wrappedErr) + return wrappedErr + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + wrappedErr := errors.New(fmt.Sprintf("Unable to verify visualization service aliveness by sending request to %s and got response code: %s", serviceURL, resp.Status)) + glog.Error(wrappedErr) + return wrappedErr + } + return nil +} + +type VisualizationServerV1 struct { resourceManager *resource.ResourceManager - serviceURL string apiv1beta1.UnimplementedVisualizationServiceServer } -func (s *VisualizationServer) CreateVisualizationV1(ctx context.Context, request *go_client.CreateVisualizationRequest) (*go_client.Visualization, error) { +func NewVisualizationServerV1(resourceManager *resource.ResourceManager) *VisualizationServerV1 { + return &VisualizationServerV1{resourceManager: resourceManager} +} + +func (s *VisualizationServerV1) CreateVisualizationV1(ctx context.Context, request *apiv1beta1.CreateVisualizationRequest) (*apiv1beta1.Visualization, error) { if err := s.validateCreateVisualizationRequest(request); err != nil { return nil, err } - // In multi-user mode, we allow empty namespace in which case we fall back to use the visualization service in system namespace. - // See getVisualizationServiceURL() for details. + // In multi-user mode, allow empty namespace falls back to the + // visualization service running in the system namespace. if common.IsMultiUserMode() && len(request.Namespace) > 0 { resourceAttributes := &authorizationv1.ResourceAttributes{ Namespace: request.Namespace, @@ -62,13 +94,12 @@ func (s *VisualizationServer) CreateVisualizationV1(ctx context.Context, request Subresource: "", Name: "", } - err := s.resourceManager.IsAuthorized(ctx, resourceAttributes) - if err != nil { + if err := s.resourceManager.IsAuthorized(ctx, resourceAttributes); err != nil { return nil, util.Wrap(err, "Failed to authorize on namespace") } } - body, err := s.generateVisualizationFromRequest(request) + body, err := s.generateVisualization(request) if err != nil { return nil, err } @@ -76,22 +107,12 @@ func (s *VisualizationServer) CreateVisualizationV1(ctx context.Context, request return request.Visualization, nil } -// validateCreateVisualizationRequest ensures that a go_client.Visualization -// object has valid values. -// It returns an error if a go_client.Visualization object does not have valid -// values. -func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error { - // Only validate that a source is provided for non-custom visualizations. - if request.Visualization.Type != go_client.Visualization_CUSTOM { +func (s *VisualizationServerV1) validateCreateVisualizationRequest(request *apiv1beta1.CreateVisualizationRequest) error { + if request.Visualization.Type != apiv1beta1.Visualization_CUSTOM { if len(request.Visualization.Source) == 0 { return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source) } } - // Manually set Arguments to empty JSON if nothing is provided. This is done - // because visualizations such as TFDV and TFMA only require a Source to - // be provided for a visualization to be generated. If no JSON is provided - // json.Valid will fail without this check as an empty string is provided for - // those visualizations. if len(request.Visualization.Arguments) == 0 { request.Visualization.Arguments = "{}" } @@ -101,15 +122,12 @@ func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_cli return nil } -// generateVisualizationFromRequest communicates with the python visualization -// service to generate HTML visualizations from a request. -// It returns the generated HTML as a string and any error that is encountered. -func (s *VisualizationServer) generateVisualizationFromRequest(request *go_client.CreateVisualizationRequest) ([]byte, error) { - serviceURL := s.getVisualizationServiceURL(request) +func (s *VisualizationServerV1) generateVisualization(request *apiv1beta1.CreateVisualizationRequest) ([]byte, error) { + serviceURL := buildVisualizationServiceURL(request.Namespace) if err := isVisualizationServiceAlive(serviceURL); err != nil { return nil, util.Wrap(err, "Cannot generate visualization") } - visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)]) + visualizationType := strings.ToLower(apiv1beta1.Visualization_Type_name[int32(request.Visualization.Type)]) urlValues := url.Values{ "arguments": {request.Visualization.Arguments}, "source": {request.Visualization.Source}, @@ -119,10 +137,10 @@ func (s *VisualizationServer) generateVisualizationFromRequest(request *go_clien if err != nil { return nil, util.Wrap(err, "Unable to initialize visualization request") } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%s", resp.Status) + return nil, fmt.Errorf("visualization service returned non-OK status: %s", resp.Status) } - defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, util.Wrap(err, "Unable to parse visualization response") @@ -130,35 +148,80 @@ func (s *VisualizationServer) generateVisualizationFromRequest(request *go_clien return body, nil } -func (s *VisualizationServer) getVisualizationServiceURL(request *go_client.CreateVisualizationRequest) string { +type VisualizationServer struct { + resourceManager *resource.ResourceManager + apiv2beta1.UnimplementedVisualizationServiceServer +} + +func NewVisualizationServer(resourceManager *resource.ResourceManager) *VisualizationServer { + return &VisualizationServer{resourceManager: resourceManager} +} + +func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *apiv2beta1.CreateVisualizationRequest) (*apiv2beta1.Visualization, error) { + if err := s.validateCreateVisualizationRequest(request); err != nil { + return nil, err + } + if common.IsMultiUserMode() && len(request.Namespace) > 0 { - return fmt.Sprintf("http://%s.%s:%s", - common.GetStringConfig(visualizationServiceName), - request.Namespace, - common.GetStringConfig(visualizationServicePort)) + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: request.Namespace, + Verb: common.RbacResourceVerbCreate, + Group: common.RbacPipelinesGroup, + Version: common.RbacPipelinesVersion, + Resource: common.RbacResourceTypeVisualizations, + Subresource: "", + Name: "", + } + if err := s.resourceManager.IsAuthorized(ctx, resourceAttributes); err != nil { + return nil, util.Wrap(err, "Failed to authorize on namespace") + } } - return s.serviceURL -} -func isVisualizationServiceAlive(serviceURL string) error { - resp, err := http.Get(serviceURL) + body, err := s.generateVisualization(request) if err != nil { - wrappedErr := util.Wrapf(err, "Unable to verify visualization service aliveness by sending request to %s", serviceURL) - glog.Error(wrappedErr) - return wrappedErr - } else if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - wrappedErr := errors.New(fmt.Sprintf("Unable to verify visualization service aliveness by sending request to %s and get response code: %s !", serviceURL, resp.Status)) - glog.Error(wrappedErr) - return wrappedErr + return nil, err + } + request.Visualization.Html = string(body) + return request.Visualization, nil +} + +func (s *VisualizationServer) validateCreateVisualizationRequest(request *apiv2beta1.CreateVisualizationRequest) error { + if request.Visualization.Type != apiv2beta1.Visualization_CUSTOM { + if len(request.Visualization.Source) == 0 { + return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source) + } + } + if len(request.Visualization.Arguments) == 0 { + request.Visualization.Arguments = "{}" + } + if !json.Valid([]byte(request.Visualization.Arguments)) { + return util.NewInvalidInputError("A visualization requires valid JSON to be provided as Arguments. Received %s", request.Visualization.Arguments) } return nil } -func NewVisualizationServer(resourceManager *resource.ResourceManager, serviceHost string, servicePort string) *VisualizationServer { - serviceURL := fmt.Sprintf("http://%s:%s", serviceHost, servicePort) - return &VisualizationServer{ - resourceManager: resourceManager, - serviceURL: serviceURL, +func (s *VisualizationServer) generateVisualization(request *apiv2beta1.CreateVisualizationRequest) ([]byte, error) { + serviceURL := buildVisualizationServiceURL(request.Namespace) + if err := isVisualizationServiceAlive(serviceURL); err != nil { + return nil, util.Wrap(err, "Cannot generate visualization") } + visualizationType := strings.ToLower(apiv2beta1.Visualization_Type_name[int32(request.Visualization.Type)]) + urlValues := url.Values{ + "arguments": {request.Visualization.Arguments}, + "source": {request.Visualization.Source}, + "type": {visualizationType}, + } + resp, err := http.PostForm(serviceURL, urlValues) + if err != nil { + return nil, util.Wrap(err, "Unable to initialize visualization request") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("visualization service returned non-OK status: %s", resp.Status) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, util.Wrap(err, "Unable to parse visualization response") + } + return body, nil } diff --git a/backend/src/apiserver/server/visualization_server_test.go b/backend/src/apiserver/server/visualization_server_test.go index da6ecfee1fc..1acd977b19c 100644 --- a/backend/src/apiserver/server/visualization_server_test.go +++ b/backend/src/apiserver/server/visualization_server_test.go @@ -16,12 +16,13 @@ package server import ( "context" - "fmt" + "net" "net/http" "net/http/httptest" "testing" apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" @@ -32,241 +33,448 @@ import ( authorizationv1 "k8s.io/api/authorization/v1" ) -func TestValidateCreateVisualizationRequest(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - server := &VisualizationServer{ - resourceManager: manager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } - request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, +// startFakeVisualizationService spins up an httptest server that returns 200 +// for GET (liveness probe) and writes responseBody for POST (generate). +// It sets viper so that buildVisualizationServiceURL resolves to it. +// Callers must defer the returned closer. +func startFakeVisualizationService(t *testing.T, responseBody string, postStatusCode int) (close func()) { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/", req.URL.String()) + if req.Method == http.MethodGet { + rw.WriteHeader(http.StatusOK) + return + } + rw.WriteHeader(postStatusCode) + if responseBody != "" { + rw.Write([]byte(responseBody)) + } + })) + + addr := srv.Listener.Addr().String() + host, port, _ := splitHostPort(addr) + viper.Set(visualizationServiceName, host) + viper.Set(visualizationServicePort, port) + + return func() { + srv.Close() + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") } - err := server.validateCreateVisualizationRequest(request) - assert.Nil(t, err) } -func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) { +func splitHostPort(hostport string) (host, port string, err error) { + return net.SplitHostPort(hostport) +} + +func newV1Server(t *testing.T) (*VisualizationServerV1, func() error) { + t.Helper() clients, manager, _ := initWithExperiment(t) - defer clients.Close() - server := &VisualizationServer{ - resourceManager: manager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "", - } - request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, - } - err := server.validateCreateVisualizationRequest(request) - assert.Nil(t, err) + return NewVisualizationServerV1(manager), clients.Close } -func TestValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) { +func newV2Server(t *testing.T) (*VisualizationServer, func() error) { + t.Helper() clients, manager, _ := initWithExperiment(t) - defer clients.Close() - server := &VisualizationServer{ - resourceManager: manager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "", - Arguments: "{}", + return NewVisualizationServer(manager), clients.Close +} + +func TestBuildVisualizationServiceURL_SingleUser(t *testing.T) { + viper.Set(visualizationServiceName, "ml-pipeline-visualizationserver") + viper.Set(visualizationServicePort, "8888") + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + + url := buildVisualizationServiceURL("") + assert.Equal(t, "http://ml-pipeline-visualizationserver:8888", url) +} + +func TestBuildVisualizationServiceURL_MultiuserWithNamespace(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + viper.Set(visualizationServiceName, "ml-pipeline-visualizationserver") + viper.Set(visualizationServicePort, "8888") + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + + url := buildVisualizationServiceURL("ns1") + assert.Equal(t, "http://ml-pipeline-visualizationserver.ns1:8888", url) +} + +func TestBuildVisualizationServiceURL_MultiuserEmptyNamespaceFallsBackToServiceName(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + viper.Set(visualizationServiceName, "ml-pipeline-visualizationserver") + viper.Set(visualizationServicePort, "8888") + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + + url := buildVisualizationServiceURL("") + assert.Equal(t, "http://ml-pipeline-visualizationserver:8888", url) +} + +func TestV1_ValidateCreateVisualizationRequest(t *testing.T) { + server, close := newV1Server(t) + defer close() + + request := &apiv1beta1.CreateVisualizationRequest{ + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, } + assert.Nil(t, server.validateCreateVisualizationRequest(request)) +} + +func TestV1_ValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) { + server, close := newV1Server(t) + defer close() + + request := &apiv1beta1.CreateVisualizationRequest{ + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "", + }, + } + assert.Nil(t, server.validateCreateVisualizationRequest(request)) + assert.Equal(t, "{}", request.Visualization.Arguments) +} + +func TestV1_ValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) { + server, close := newV1Server(t) + defer close() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "", + Arguments: "{}", + }, } err := server.validateCreateVisualizationRequest(request) assert.Contains(t, err.Error(), "A visualization requires a Source to be provided. Received") } -func TestValidateCreateVisualizationRequest_SourceIsEmptyAndTypeIsCustom(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - server := &VisualizationServer{ - resourceManager: manager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_CUSTOM, - Arguments: "{}", - } +func TestV1_ValidateCreateVisualizationRequest_SourceIsEmptyAndTypeIsCustom(t *testing.T) { + server, close := newV1Server(t) + defer close() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_CUSTOM, + Arguments: "{}", + }, } - err := server.validateCreateVisualizationRequest(request) - assert.Nil(t, err) + assert.Nil(t, server.validateCreateVisualizationRequest(request)) } -func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - server := &VisualizationServer{ - resourceManager: manager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{", - } +func TestV1_ValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) { + server, close := newV1Server(t) + defer close() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{", + }, } err := server.validateCreateVisualizationRequest(request) assert.Contains(t, err.Error(), "A visualization requires valid JSON to be provided as Arguments. Received {") } -func TestGenerateVisualization(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - assert.Equal(t, "/", req.URL.String()) - rw.Write([]byte("roc_curve")) - })) - defer httpServer.Close() - server := &VisualizationServer{ - resourceManager: manager, - serviceURL: httpServer.URL, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } +func TestV1_GenerateVisualization(t *testing.T) { + server, close := newV1Server(t) + defer close() + stopService := startFakeVisualizationService(t, "roc_curve", http.StatusOK) + defer stopService() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, } - body, err := server.generateVisualizationFromRequest(request) + body, err := server.generateVisualization(request) assert.Nil(t, err) assert.Equal(t, []byte("roc_curve"), body) } -func TestGenerateVisualization_ServiceNotAvailableError(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - assert.Equal(t, "/", req.URL.String()) - if req.Method == http.MethodGet { - rw.WriteHeader(500) - } else { - rw.WriteHeader(200) - } +func TestV1_GenerateVisualization_ServiceNotAvailableError(t *testing.T) { + server, close := newV1Server(t) + defer close() + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) })) - server := &VisualizationServer{ - resourceManager: manager, - serviceURL: httpServer.URL, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } + defer srv.Close() + host, port, _ := splitHostPort(srv.Listener.Addr().String()) + viper.Set(visualizationServiceName, host) + viper.Set(visualizationServicePort, port) + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, } - body, err := server.generateVisualizationFromRequest(request) + body, err := server.generateVisualization(request) assert.Nil(t, body) assert.Contains(t, err.Error(), "500 Internal Server Error") } -func TestGenerateVisualization_ServiceHostNotExistError(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - nonExistingServerURL := "http://127.0.0.2:53484" - server := &VisualizationServer{ - resourceManager: manager, - serviceURL: nonExistingServerURL, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } +func TestV1_GenerateVisualization_ServiceHostNotExistError(t *testing.T) { + server, close := newV1Server(t) + defer close() + + viper.Set(visualizationServiceName, "127.0.0.2") + viper.Set(visualizationServicePort, "53484") + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, } - body, err := server.generateVisualizationFromRequest(request) + body, err := server.generateVisualization(request) assert.Nil(t, body) - errMsg := err.Error() - assert.Contains(t, errMsg, "Unable to verify visualization service aliveness") - assert.Contains(t, err.Error(), fmt.Sprintf("dial tcp %s", nonExistingServerURL[7:])) + assert.Contains(t, err.Error(), "Unable to verify visualization service aliveness") + assert.Contains(t, err.Error(), "dial tcp 127.0.0.2:53484") } -func TestGenerateVisualization_ServerError(t *testing.T) { - clients, manager, _ := initWithExperiment(t) - defer clients.Close() - httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - assert.Equal(t, "/", req.URL.String()) - // The get requests 200s to indicate the service is alive, but the - // visualization request fails with a 500. +func TestV1_GenerateVisualization_ServerError(t *testing.T) { + server, close := newV1Server(t) + defer close() + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if req.Method == http.MethodGet { - rw.WriteHeader(200) + rw.WriteHeader(http.StatusOK) } else { - rw.WriteHeader(500) + rw.WriteHeader(http.StatusInternalServerError) } })) - defer httpServer.Close() - server := &VisualizationServer{ - resourceManager: manager, - serviceURL: httpServer.URL, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } + defer srv.Close() + host, port, _ := splitHostPort(srv.Listener.Addr().String()) + viper.Set(visualizationServiceName, host) + viper.Set(visualizationServicePort, port) + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, } - body, err := server.generateVisualizationFromRequest(request) + body, err := server.generateVisualization(request) assert.Nil(t, body) - assert.Equal(t, "500 Internal Server Error", err.Error()) + assert.Equal(t, "visualization service returned non-OK status: 500 Internal Server Error", err.Error()) } -func TestGetVisualizationServiceURL(t *testing.T) { - server := &VisualizationServer{ - resourceManager: nil, - serviceURL: "http://host:port", - } +func TestV1_CreateVisualization_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + userIdentity := "user@google.com" + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: common.GoogleIAPUserIdentityPrefix + userIdentity}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clientManager := resource.NewFakeClientManagerOrFatal(util.NewFakeTimeForEpoch()) + clientManager.SubjectAccessReviewClientFake = client.NewFakeSubjectAccessReviewClientUnauthorized() + resourceManager := resource.NewResourceManager(clientManager, &resource.ResourceManagerOptions{CollectMetrics: false}) + defer clientManager.Close() + + server := NewVisualizationServerV1(resourceManager) request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: nil, + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + Namespace: "ns1", } - url := server.getVisualizationServiceURL(request) - assert.Equal(t, "http://host:port", url) + _, err := server.CreateVisualizationV1(ctx, request) + assert.NotNil(t, err) + resourceAttributes := &authorizationv1.ResourceAttributes{ + Namespace: "ns1", + Verb: common.RbacResourceVerbCreate, + Group: common.RbacPipelinesGroup, + Version: common.RbacPipelinesVersion, + Resource: common.RbacResourceTypeVisualizations, + } + assert.EqualError(t, err, util.Wrap(getPermissionDeniedError(userIdentity, resourceAttributes), "Failed to authorize on namespace").Error()) } -func TestGetVisualizationServiceURL_Multiuser(t *testing.T) { +func TestV1_CreateVisualization_Unauthenticated(t *testing.T) { viper.Set(common.MultiUserMode, "true") defer viper.Set(common.MultiUserMode, "false") - viper.Set("VisualizationService.Name", "ml-pipeline-visualizationserver") - viper.Set("VisualizationService.Port", "8888") - server := &VisualizationServer{ - resourceManager: nil, - serviceURL: "http://host:port", - } + md := metadata.New(map[string]string{"no-identity-header": "user"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clientManager := resource.NewFakeClientManagerOrFatal(util.NewFakeTimeForEpoch()) + resourceManager := resource.NewResourceManager(clientManager, &resource.ResourceManagerOptions{CollectMetrics: false}) + defer clientManager.Close() + server := NewVisualizationServerV1(resourceManager) request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: nil, - Namespace: "ns1", + Visualization: &apiv1beta1.Visualization{ + Type: apiv1beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + Namespace: "ns1", } - url := server.getVisualizationServiceURL(request) - assert.Equal(t, "http://ml-pipeline-visualizationserver.ns1:8888", url) + _, err := server.CreateVisualizationV1(ctx, request) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "there is no user identity header") +} + +func TestV2_ValidateCreateVisualizationRequest(t *testing.T) { + server, close := newV2Server(t) + defer close() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + } + assert.Nil(t, server.validateCreateVisualizationRequest(request)) +} + +func TestV2_ValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) { + server, close := newV2Server(t) + defer close() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "", + }, + } + assert.Nil(t, server.validateCreateVisualizationRequest(request)) + assert.Equal(t, "{}", request.Visualization.Arguments) +} + +func TestV2_ValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) { + server, close := newV2Server(t) + defer close() - // when namespace is not provided, we fall back to the default visuliaztion service - request = &apiv1beta1.CreateVisualizationRequest{ - Visualization: nil, + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "", + Arguments: "{}", + }, } - url = server.getVisualizationServiceURL(request) - assert.Equal(t, "http://host:port", url) + err := server.validateCreateVisualizationRequest(request) + assert.Contains(t, err.Error(), "A visualization requires a Source to be provided. Received") } -func TestCreateVisualization_Unauthorized(t *testing.T) { +func TestV2_ValidateCreateVisualizationRequest_SourceIsEmptyAndTypeIsCustom(t *testing.T) { + server, close := newV2Server(t) + defer close() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_CUSTOM, + Arguments: "{}", + }, + } + assert.Nil(t, server.validateCreateVisualizationRequest(request)) +} + +func TestV2_ValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) { + server, close := newV2Server(t) + defer close() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{", + }, + } + err := server.validateCreateVisualizationRequest(request) + assert.Contains(t, err.Error(), "A visualization requires valid JSON to be provided as Arguments. Received {") +} + +func TestV2_GenerateVisualization(t *testing.T) { + server, close := newV2Server(t) + defer close() + stopService := startFakeVisualizationService(t, "roc_curve", http.StatusOK) + defer stopService() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + } + body, err := server.generateVisualization(request) + assert.Nil(t, err) + assert.Equal(t, []byte("roc_curve"), body) +} + +func TestV2_GenerateVisualization_ServerError(t *testing.T) { + server, close := newV2Server(t) + defer close() + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method == http.MethodGet { + rw.WriteHeader(http.StatusOK) + } else { + rw.WriteHeader(http.StatusInternalServerError) + } + })) + defer srv.Close() + host, port, _ := splitHostPort(srv.Listener.Addr().String()) + viper.Set(visualizationServiceName, host) + viper.Set(visualizationServicePort, port) + defer func() { + viper.Set(visualizationServiceName, "") + viper.Set(visualizationServicePort, "") + }() + + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + } + body, err := server.generateVisualization(request) + assert.Nil(t, body) + assert.Equal(t, "visualization service returned non-OK status: 500 Internal Server Error", err.Error()) +} + +func TestV2_CreateVisualization_Unauthorized(t *testing.T) { viper.Set(common.MultiUserMode, "true") defer viper.Set(common.MultiUserMode, "false") @@ -279,20 +487,16 @@ func TestCreateVisualization_Unauthorized(t *testing.T) { resourceManager := resource.NewResourceManager(clientManager, &resource.ResourceManagerOptions{CollectMetrics: false}) defer clientManager.Close() - server := &VisualizationServer{ - resourceManager: resourceManager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } - - request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, - Namespace: "ns1", + server := NewVisualizationServer(resourceManager) + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + Namespace: "ns1", } - _, err := server.CreateVisualizationV1(ctx, request) + _, err := server.CreateVisualization(ctx, request) assert.NotNil(t, err) resourceAttributes := &authorizationv1.ResourceAttributes{ Namespace: "ns1", @@ -301,14 +505,10 @@ func TestCreateVisualization_Unauthorized(t *testing.T) { Version: common.RbacPipelinesVersion, Resource: common.RbacResourceTypeVisualizations, } - assert.EqualError( - t, - err, - util.Wrap(getPermissionDeniedError(userIdentity, resourceAttributes), "Failed to authorize on namespace").Error(), - ) + assert.EqualError(t, err, util.Wrap(getPermissionDeniedError(userIdentity, resourceAttributes), "Failed to authorize on namespace").Error()) } -func TestCreateVisualization_Unauthenticated(t *testing.T) { +func TestV2_CreateVisualization_Unauthenticated(t *testing.T) { viper.Set(common.MultiUserMode, "true") defer viper.Set(common.MultiUserMode, "false") @@ -319,24 +519,17 @@ func TestCreateVisualization_Unauthenticated(t *testing.T) { resourceManager := resource.NewResourceManager(clientManager, &resource.ResourceManagerOptions{CollectMetrics: false}) defer clientManager.Close() - server := &VisualizationServer{ - resourceManager: resourceManager, - } - visualization := &apiv1beta1.Visualization{ - Type: apiv1beta1.Visualization_ROC_CURVE, - Source: "gs://ml-pipeline/roc/data.csv", - Arguments: "{}", - } - - request := &apiv1beta1.CreateVisualizationRequest{ - Visualization: visualization, - Namespace: "ns1", + server := NewVisualizationServer(resourceManager) + request := &apiv2beta1.CreateVisualizationRequest{ + Visualization: &apiv2beta1.Visualization{ + Type: apiv2beta1.Visualization_ROC_CURVE, + Source: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + }, + Namespace: "ns1", } - _, err := server.CreateVisualizationV1(ctx, request) + _, err := server.CreateVisualization(ctx, request) assert.NotNil(t, err) - assert.Contains( - t, - err.Error(), - "there is no user identity header", - ) + assert.Contains(t, err.Error(), "there is no user identity header") } +