Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,43 @@ type PersistentVolumeClaimManaged struct {
Size string `json:"size,omitempty"`
}

// MLFlowExportType defines what to export to MLFlow
// +kubebuilder:validation:Enum=metrics;artifacts
type MLFlowExportType string

const (
// Export evaluation metrics to MLFlow
MLFlowMetricsExport MLFlowExportType = "metrics"
// Export evaluation artifacts to MLFlow
MLFlowArtifactsExport MLFlowExportType = "artifacts"
)

// MLFlowOutput defines the configuration for MLFlow output
type MLFlowOutput struct {
// TrackingUri is the MLFlow tracking server URI
// +kubebuilder:validation:Pattern=`^https?://[a-zA-Z0-9.-]+(:[0-9]+)?(/.*)?$`
TrackingUri string `json:"trackingUri"`
// ExperimentName is the name of the MLFlow experiment
// +optional
ExperimentName *string `json:"experimentName,omitempty"`
// RunId is the specific MLFlow run ID to use (optional)
// +optional
RunId *string `json:"runId,omitempty"`
// Export defines what to export to MLFlow (metrics, artifacts, or both)
// +optional
Export []MLFlowExportType `json:"export,omitempty"`
}

type Outputs struct {
// Use an existing PVC to store the outputs
// +optional
PersistentVolumeClaimName *string `json:"pvcName,omitempty"`
// Create an operator managed PVC
// +optional
PersistentVolumeClaimManaged *PersistentVolumeClaimManaged `json:"pvcManaged,omitempty"`
// Export results to MLFlow tracking server
// +optional
MLFlow *MLFlowOutput `json:"mlflow,omitempty"`
}

func (c *LMEvalContainer) GetSecurityContext() *corev1.SecurityContext {
Expand Down Expand Up @@ -604,6 +634,37 @@ func (o *Outputs) HasExistingPVC() bool {
return o.PersistentVolumeClaimName != nil
}

// HasMLFlow returns whether the outputs define MLFlow configuration
func (o *Outputs) HasMLFlow() bool {
return o.MLFlow != nil
}

// HasMLFlowMetrics returns whether MLFlow export includes metrics
func (m *MLFlowOutput) HasMLFlowMetrics() bool {
if m == nil || len(m.Export) == 0 {
return false
}
for _, export := range m.Export {
if export == MLFlowMetricsExport {
return true
}
}
return false
}

// HasMLFlowArtifacts returns whether MLFlow export includes artifacts
func (m *MLFlowOutput) HasMLFlowArtifacts() bool {
if m == nil || len(m.Export) == 0 {
return false
}
for _, export := range m.Export {
if export == MLFlowArtifactsExport {
return true
}
}
return false
}

// LMEvalJobStatus defines the observed state of LMEvalJob
type LMEvalJobStatus struct {
// Important: Run "make" to regenerate code after modifying this file
Expand Down
35 changes: 35 additions & 0 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 46 additions & 31 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,36 @@ func (t *strArrayArg) String() string {
}

var (
taskRecipes strArrayArg
customArtifactArgs strArrayArg
taskNames strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
customTaskGitURL = flag.String("custom-task-git-url", "", "Git repository URL for custom tasks")
customTaskGitBranch = flag.String("custom-task-git-branch", "", "Git repository branch for custom tasks")
customTaskGitCommit = flag.String("custom-task-git-commit", "", "Git commit for custom tasks")
customTaskGitPath = flag.String("custom-task-git-path", "", "Custom task path")
allowOnline = flag.Bool("allow-online", false, "Allow LMEval online access")
driverLog = ctrl.Log.WithName("driver")
taskRecipes strArrayArg
customArtifactArgs strArrayArg
taskNames strArrayArg
mlflowExportTypes strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
customTaskGitURL = flag.String("custom-task-git-url", "", "Git repository URL for custom tasks")
customTaskGitBranch = flag.String("custom-task-git-branch", "", "Git repository branch for custom tasks")
customTaskGitCommit = flag.String("custom-task-git-commit", "", "Git commit for custom tasks")
customTaskGitPath = flag.String("custom-task-git-path", "", "Custom task path")
allowOnline = flag.Bool("allow-online", false, "Allow LMEval online access")
mlflowTrackingUri = flag.String("mlflow-tracking-uri", "", "MLFlow tracking server URI")
mlflowExperimentName = flag.String("mlflow-experiment-name", "", "MLFlow experiment name")
mlflowRunId = flag.String("mlflow-run-id", "", "MLFlow run ID")
mlflowSourceName = flag.String("mlflow-source-name", "", "Value for mlflow.source.name tag (e.g. CR name)")
mlflowSourceType = flag.String("mlflow-source-type", "", "Value for mlflow.source.type tag")
mlflowParamsJSON = flag.String("mlflow-params-json", "", "JSON-encoded parameters to send to MLFlow")
driverLog = ctrl.Log.WithName("driver")
)

func init() {
flag.Var(&taskRecipes, "task-recipe", "task recipe")
flag.Var(&customArtifactArgs, "custom-artifact", "A string contains an artifact's type, name and value. Use | as separator")
flag.Var(&taskNames, "task-name", "A task name for custom tasks")
flag.Var(&mlflowExportTypes, "mlflow-export-type", "MLFlow export type (metrics, artifacts)")
}

func main() {
Expand Down Expand Up @@ -120,21 +128,28 @@ func main() {
}

driverOpt := driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomArtifacts: customArtifacts,
Args: args,
CommPort: *commPort,
DownloadAssetsS3: *downloadAssetsS3,
CustomTaskGitURL: *customTaskGitURL,
CustomTaskGitBranch: *customTaskGitBranch,
CustomTaskGitCommit: *customTaskGitCommit,
CustomTaskGitPath: *customTaskGitPath,
TaskNames: taskNames,
AllowOnline: *allowOnline,
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomArtifacts: customArtifacts,
Args: args,
CommPort: *commPort,
DownloadAssetsS3: *downloadAssetsS3,
CustomTaskGitURL: *customTaskGitURL,
CustomTaskGitBranch: *customTaskGitBranch,
CustomTaskGitCommit: *customTaskGitCommit,
CustomTaskGitPath: *customTaskGitPath,
TaskNames: taskNames,
AllowOnline: *allowOnline,
MLFlowTrackingUri: *mlflowTrackingUri,
MLFlowExperimentName: *mlflowExperimentName,
MLFlowRunId: *mlflowRunId,
MLFlowExportTypes: mlflowExportTypes,
MLFlowSourceName: *mlflowSourceName,
MLFlowSourceType: *mlflowSourceType,
MLFlowParamsJSON: *mlflowParamsJSON,
}

driver, err := driver.NewDriver(&driverOpt)
Expand Down
27 changes: 27 additions & 0 deletions config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,33 @@ spec:
outputs:
description: Outputs specifies storage for evaluation results
properties:
mlflow:
description: Export results to MLFlow tracking server
properties:
experimentName:
description: ExperimentName is the name of the MLFlow experiment
type: string
export:
description: Export defines what to export to MLFlow (metrics,
artifacts, or both)
items:
description: MLFlowExportType defines what to export to
MLFlow
enum:
- metrics
- artifacts
type: string
type: array
runId:
description: RunId is the specific MLFlow run ID to use (optional)
type: string
trackingUri:
description: TrackingUri is the MLFlow tracking server URI
pattern: ^https?://[a-zA-Z0-9.-]+(:[0-9]+)?(/.*)?$
type: string
required:
- trackingUri
type: object
pvcManaged:
description: Create an operator managed PVC
properties:
Expand Down
108 changes: 108 additions & 0 deletions controllers/lmes/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ type DriverOption struct {
CustomTaskGitPath string
TaskNames []string
AllowOnline bool
// MLFlow configuration
MLFlowTrackingUri string
MLFlowExperimentName string
MLFlowRunId string
MLFlowExportTypes []string
MLFlowSourceName string
MLFlowSourceType string
MLFlowParamsJSON string
}

type ArtifactType string
Expand Down Expand Up @@ -443,6 +451,15 @@ func (d *driverImpl) updateCompleteStatus(err error) {
var results string
results, err = d.getResults()
d.status.Results = results

// Export to MLFlow if configured and no errors
if err == nil {
mlflowErr := d.exportMLFlow()
if mlflowErr != nil {
d.Option.Logger.Error(mlflowErr, "failed to export results to MLFlow")
// Don't fail the job for MLFlow export errors, just log them
}
}
}

if err != nil {
Expand Down Expand Up @@ -710,3 +727,94 @@ func (d *driverImpl) fetchGitCustomTasks() error {

return nil
}

// exportMLFlow exports evaluation results to MLFlow tracking server if configured
func (d *driverImpl) exportMLFlow() error {
// Check if MLFlow is configured
if d.Option.MLFlowTrackingUri == "" || len(d.Option.MLFlowExportTypes) == 0 {
return nil
}

d.Option.Logger.Info("Exporting results to MLFlow",
"trackingUri", d.Option.MLFlowTrackingUri,
"experimentName", d.Option.MLFlowExperimentName,
"runId", d.Option.MLFlowRunId,
"exportTypes", d.Option.MLFlowExportTypes,
"sourceName", d.Option.MLFlowSourceName,
"sourceType", d.Option.MLFlowSourceType)

// Build command arguments
args := []string{
"/opt/app-root/src/scripts/mlflow_export.py",
"--output-dir", d.Option.OutputPath,
"--tracking-uri", d.Option.MLFlowTrackingUri,
}

if d.Option.MLFlowExperimentName != "" {
args = append(args, "--experiment-name", d.Option.MLFlowExperimentName)
}

if d.Option.MLFlowRunId != "" {
args = append(args, "--run-id", d.Option.MLFlowRunId)
}

if len(d.Option.MLFlowExportTypes) > 0 {
args = append(args, "--export-types")
args = append(args, d.Option.MLFlowExportTypes...)
}

if d.Option.MLFlowSourceName != "" {
args = append(args, "--source-name", d.Option.MLFlowSourceName)
}

if d.Option.MLFlowSourceType != "" {
args = append(args, "--source-type", d.Option.MLFlowSourceType)
}

if d.Option.MLFlowParamsJSON != "" {
args = append(args, "--params-json", d.Option.MLFlowParamsJSON)
}

// Execute MLFlow export script
cmd := exec.Command("python", args...)
Comment on lines +778 to +779
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Running the MLFlow export script without a timeout can cause hangs.

Because exec.Command isn’t tied to a context, a stuck mlflow_export.py (e.g., due to tracking server issues) could block updateCompleteStatus indefinitely and delay job completion. Consider using exec.CommandContext with d.Option.Context (or a derived context with a timeout) so the export can be cancelled cleanly and doesn’t block the driver.


// Set environment variables for the script
env := append(os.Environ(),
"MLFLOW_TRACKING_URI="+d.Option.MLFlowTrackingUri)

if d.Option.MLFlowExperimentName != "" {
env = append(env, "MLFLOW_EXPERIMENT_NAME="+d.Option.MLFlowExperimentName)
}

if d.Option.MLFlowRunId != "" {
env = append(env, "MLFLOW_RUN_ID="+d.Option.MLFlowRunId)
}

if len(d.Option.MLFlowExportTypes) > 0 {
env = append(env, "MLFLOW_EXPORT_TYPES="+strings.Join(d.Option.MLFlowExportTypes, ","))
}

if d.Option.MLFlowSourceName != "" {
env = append(env, "MLFLOW_SOURCE_NAME="+d.Option.MLFlowSourceName)
}

if d.Option.MLFlowSourceType != "" {
env = append(env, "MLFLOW_SOURCE_TYPE="+d.Option.MLFlowSourceType)
}

if d.Option.MLFlowParamsJSON != "" {
env = append(env, "MLFLOW_PARAMS_JSON="+d.Option.MLFlowParamsJSON)
}

cmd.Env = env

// Capture output and errors
output, err := cmd.CombinedOutput()
if err != nil {
d.Option.Logger.Error(err, "MLFlow export script failed", "output", string(output))
return fmt.Errorf("MLFlow export failed: %v, output: %s", err, string(output))
}

d.Option.Logger.Info("MLFlow export completed successfully", "output", string(output))
return nil
}
Loading
Loading