Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions core/scheduler/handler/v1beta1/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
pb "github.com/goto/optimus/protos/gotocompany/optimus/core/v1beta1"
)

type ReplayValidator interface {
ValidateDateRange(ctx context.Context, replayRequest *scheduler.Replay) error
}
type ReplayService interface {
CreateReplay(ctx context.Context, tenant tenant.Tenant, jobName scheduler.JobName, config *scheduler.ReplayConfig) (replayID uuid.UUID, err error)
GetReplayList(ctx context.Context, projectName tenant.ProjectName) (replays []*scheduler.Replay, err error)
Expand All @@ -43,8 +46,9 @@ type replayRequest interface {
}

type ReplayHandler struct {
l log.Logger
service ReplayService
l log.Logger
service ReplayService
validator ReplayValidator

pb.UnimplementedReplayServiceServer
}
Expand All @@ -54,6 +58,11 @@ func (h ReplayHandler) ReplayDryRun(ctx context.Context, req *pb.ReplayDryRunReq
if err != nil {
return nil, err
}
err = h.validator.ValidateDateRange(ctx, replayReq)
if err != nil {
return nil, errors.GRPCErr(err, "invalid date range for replay")
}

// TODO: should convert from logical time
runs, err := h.service.GetRunsStatus(ctx, replayReq.Tenant(), replayReq.JobName(), replayReq.Config())
if err != nil {
Expand Down Expand Up @@ -302,6 +311,6 @@ func parseJobConfig(jobConfig string) (map[string]string, error) {
return configs, nil
}

func NewReplayHandler(l log.Logger, service ReplayService) *ReplayHandler {
return &ReplayHandler{l: l, service: service}
func NewReplayHandler(l log.Logger, service ReplayService, validator ReplayValidator) *ReplayHandler {
return &ReplayHandler{l: l, service: service, validator: validator}
}
101 changes: 73 additions & 28 deletions core/scheduler/handler/v1beta1/replay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func TestReplayHandler(t *testing.T) {
t.Run("ReplayDryRun", func(t *testing.T) {
t.Run("returns error when unable to create tenant", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayDryRunRequest{
JobName: jobName.String(),
Expand All @@ -62,7 +63,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when job name is invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayDryRunRequest{
ProjectName: projectName,
Expand All @@ -83,7 +85,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when start time is invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayDryRunRequest{
ProjectName: projectName,
Expand All @@ -105,7 +108,8 @@ func TestReplayHandler(t *testing.T) {

t.Run("returns error when end time is present but invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayDryRunRequest{
ProjectName: projectName,
Expand All @@ -127,9 +131,6 @@ func TestReplayHandler(t *testing.T) {
})

t.Run("returns error when unable to get runs status", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)

req := &pb.ReplayDryRunRequest{
ProjectName: projectName,
JobName: jobName.String(),
Expand All @@ -145,6 +146,12 @@ func TestReplayHandler(t *testing.T) {
}
replayConfig := scheduler.NewReplayConfig(req.StartTime.AsTime(), req.EndTime.AsTime(), false, jobConfig, description, category, approvalID, userID)

service := new(mockReplayService)
validator := new(mockReplayValidator)

validator.On("ValidateDateRange", ctx, mock.Anything).Return(nil).Once()
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

service.On("GetRunsStatus", ctx, jobTenant, jobName, replayConfig).Return(nil, errors.New("internal error"))

result, err := replayHandler.ReplayDryRun(ctx, req)
Expand All @@ -154,7 +161,8 @@ func TestReplayHandler(t *testing.T) {

t.Run("returns list of replay runs status when success", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayDryRunRequest{
ProjectName: projectName,
Expand All @@ -177,6 +185,7 @@ func TestReplayHandler(t *testing.T) {
},
}

validator.On("ValidateDateRange", ctx, mock.Anything).Return(nil).Once()
service.On("GetRunsStatus", ctx, jobTenant, jobName, replayConfig).Return(runs, nil)

result, err := replayHandler.ReplayDryRun(ctx, req)
Expand All @@ -188,7 +197,8 @@ func TestReplayHandler(t *testing.T) {
t.Run("Replay", func(t *testing.T) {
t.Run("returns replay ID when able to create replay successfully", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -213,7 +223,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns replay ID when able to create replay successfully without overriding job config", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -237,7 +248,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when unable to create tenant", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
JobName: jobName.String(),
Expand All @@ -258,7 +270,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when job name is invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -279,7 +292,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when start time is invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -300,7 +314,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns no error when end time is empty", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -324,7 +339,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when end time is present but invalid", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand All @@ -349,7 +365,8 @@ func TestReplayHandler(t *testing.T) {
})
t.Run("returns error when unable to create replay", func(t *testing.T) {
service := new(mockReplayService)
replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ReplayRequest{
ProjectName: projectName,
Expand Down Expand Up @@ -379,7 +396,8 @@ func TestReplayHandler(t *testing.T) {
service := new(mockReplayService)
defer service.AssertExpectations(t)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ListReplayRequest{
ProjectName: "",
Expand All @@ -394,7 +412,8 @@ func TestReplayHandler(t *testing.T) {
service.On("GetReplayList", ctx, tenant.ProjectName("project-test")).Return(nil, errors.New("some error"))
defer service.AssertExpectations(t)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ListReplayRequest{
ProjectName: "project-test",
Expand All @@ -409,7 +428,8 @@ func TestReplayHandler(t *testing.T) {
service.On("GetReplayList", ctx, tenant.ProjectName("project-test")).Return([]*scheduler.Replay{}, nil)
defer service.AssertExpectations(t)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ListReplayRequest{
ProjectName: "project-test",
Expand All @@ -433,7 +453,8 @@ func TestReplayHandler(t *testing.T) {
service.On("GetReplayList", ctx, tenant.ProjectName("project-test")).Return([]*scheduler.Replay{replay1, replay2, replay3}, nil)
defer service.AssertExpectations(t)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.ListReplayRequest{
ProjectName: "project-test",
Expand All @@ -447,7 +468,7 @@ func TestReplayHandler(t *testing.T) {

t.Run("GetReplay", func(t *testing.T) {
t.Run("returns error when uuid is not valid", func(t *testing.T) {
replayHandler := v1beta1.NewReplayHandler(logger, nil)
replayHandler := v1beta1.NewReplayHandler(logger, nil, nil)

req := &pb.GetReplayRequest{
ProjectName: projectName,
Expand All @@ -464,7 +485,8 @@ func TestReplayHandler(t *testing.T) {
replayID := uuid.New()
service.On("GetReplayByID", ctx, replayID).Return(nil, errors.New("internal error"))

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.GetReplayRequest{
ProjectName: projectName,
Expand All @@ -481,7 +503,8 @@ func TestReplayHandler(t *testing.T) {
replayID := uuid.New()
service.On("GetReplayByID", ctx, replayID).Return(nil, errs.NotFound("entity", "not found"))

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.GetReplayRequest{
ProjectName: projectName,
Expand Down Expand Up @@ -512,7 +535,8 @@ func TestReplayHandler(t *testing.T) {
},
}, nil)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.GetReplayRequest{
ProjectName: projectName,
Expand All @@ -526,7 +550,7 @@ func TestReplayHandler(t *testing.T) {

t.Run("CancelReplay", func(t *testing.T) {
t.Run("returns error when uuid is not valid", func(t *testing.T) {
replayHandler := v1beta1.NewReplayHandler(logger, nil)
replayHandler := v1beta1.NewReplayHandler(logger, nil, nil)

req := &pb.CancelReplayRequest{
ProjectName: projectName,
Expand All @@ -543,7 +567,8 @@ func TestReplayHandler(t *testing.T) {
replayID := uuid.New()
service.On("GetReplayByID", ctx, replayID).Return(nil, errors.New("internal error"))

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.CancelReplayRequest{
ProjectName: projectName,
Expand All @@ -560,7 +585,8 @@ func TestReplayHandler(t *testing.T) {
replayID := uuid.New()
service.On("GetReplayByID", ctx, replayID).Return(nil, errs.NotFound("entity", "not found"))

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.CancelReplayRequest{
ProjectName: projectName,
Expand Down Expand Up @@ -594,7 +620,8 @@ func TestReplayHandler(t *testing.T) {
service.On("GetReplayByID", ctx, replayID).Return(replayWithRun, nil)
service.On("CancelReplay", ctx, replayWithRun).Return(nil)

replayHandler := v1beta1.NewReplayHandler(logger, service)
validator := new(mockReplayValidator)
replayHandler := v1beta1.NewReplayHandler(logger, service, validator)

req := &pb.CancelReplayRequest{
ProjectName: projectName,
Expand Down Expand Up @@ -785,3 +812,21 @@ func (_m *mockReplayService) GetReplayConfig(ctx context.Context, projectName te
}
return args.Get(0).(map[string]string), args.Error(1)
}

// ReplayValidator is an autogenerated mock type for the ReplayValidator type
type mockReplayValidator struct {
mock.Mock
}

// ValidateDateRange provides a mock function with given fields: ctx, replayRequest
func (_m *mockReplayValidator) ValidateDateRange(ctx context.Context, replayRequest *scheduler.Replay) error {
ret := _m.Called(ctx, replayRequest)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *scheduler.Replay) error); ok {
r0 = rf(ctx, replayRequest)
} else {
r0 = ret.Error(0)
}

return r0
}
13 changes: 11 additions & 2 deletions core/scheduler/service/replay_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func NewValidator(replayRepository ReplayRepository, scheduler ReplayScheduler,
}

func (v Validator) Validate(ctx context.Context, replayRequest *scheduler.Replay, jobCron *cron.ScheduleSpec) error {
if err := v.validateDateRange(ctx, replayRequest, jobCron); err != nil {
if err := v.ValidateDateRange(ctx, replayRequest); err != nil {
return err
}

Expand All @@ -33,11 +33,20 @@ func (v Validator) Validate(ctx context.Context, replayRequest *scheduler.Replay
return v.validateConflictedRun(ctx, replayRequest, jobCron)
}

func (v Validator) validateDateRange(ctx context.Context, replayRequest *scheduler.Replay, jobCron *cron.ScheduleSpec) error {
func (v Validator) ValidateDateRange(ctx context.Context, replayRequest *scheduler.Replay) error {
jobSpec, err := v.jobRepo.GetJobDetails(ctx, replayRequest.Tenant().ProjectName(), replayRequest.JobName())
if err != nil {
return err
}

if jobSpec.Schedule.Interval == "" {
return errors.NewError(errors.ErrInternalError, scheduler.EntityReplay, "job schedule interval is empty")
}
jobCron, err := cron.ParseCronSchedule(jobSpec.Schedule.Interval)
if err != nil {
return err
}

replayStartDate := replayRequest.Config().StartTime.UTC()
replayEndDate := replayRequest.Config().EndTime.UTC()
jobLogicalStartDate := jobSpec.Schedule.StartDate.UTC()
Expand Down
Loading
Loading