Skip to content

Commit 49e10e1

Browse files
committed
Fix issue
1 parent 14f8568 commit 49e10e1

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2024 The Kubeflow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Enhanced hello world pipeline demonstrating multi-task workflow with inputs and outputs.
15+
16+
This pipeline improves test coverage by:
17+
- Accepting pipeline-level inputs
18+
- Passing data between multiple tasks
19+
- Producing pipeline-level outputs
20+
- Using both container and Python components
21+
"""
22+
23+
from kfp import compiler
24+
from kfp import dsl
25+
26+
27+
@dsl.component(base_image="public.ecr.aws/docker/library/python:3.12")
28+
def generate_greeting(name: str, greeting_type: str = "formal") -> str:
29+
"""Generate a personalized greeting message.
30+
31+
Args:
32+
name: Name of the recipient
33+
greeting_type: Type of greeting ("formal" or "casual")
34+
35+
Returns:
36+
Formatted greeting message
37+
"""
38+
if greeting_type == "formal":
39+
return f"Dear {name}, welcome to Kubeflow Pipelines!"
40+
else:
41+
return f"Hey {name}! Welcome to KFP!"
42+
43+
44+
@dsl.container_component
45+
def display_greeting(message: dsl.Input[str], output_file: dsl.Output[dsl.Dataset]):
46+
"""Display greeting and save to output file.
47+
48+
Args:
49+
message: The greeting message to display
50+
output_file: Output artifact containing the greeting
51+
"""
52+
return dsl.ContainerSpec(
53+
image='registry.access.redhat.com/ubi9/python-311:latest',
54+
command=['sh', '-c'],
55+
args=[
56+
f'echo "Processing greeting..." && '
57+
f'echo "{message}" | tee {output_file.path} && '
58+
f'echo "Greeting saved to {output_file.path}"'
59+
]
60+
)
61+
62+
63+
@dsl.component(base_image="public.ecr.aws/docker/library/python:3.12")
64+
def count_words(text: str) -> int:
65+
"""Count words in the given text.
66+
67+
Args:
68+
text: Input text to analyze
69+
70+
Returns:
71+
Number of words in the text
72+
"""
73+
word_count = len(text.split())
74+
print(f"Word count: {word_count}")
75+
return word_count
76+
77+
78+
@dsl.pipeline(
79+
name='hello-world-multi-task-io',
80+
description='Enhanced hello world with multiple tasks, inputs, and outputs'
81+
)
82+
def hello_world_pipeline(
83+
recipient_name: str = 'Kubeflow User',
84+
greeting_style: str = 'formal'
85+
) -> str:
86+
"""Multi-task pipeline demonstrating data flow between components.
87+
88+
Args:
89+
recipient_name: Name of the person to greet
90+
greeting_style: Style of greeting ("formal" or "casual")
91+
92+
Returns:
93+
Final greeting message
94+
"""
95+
# Task 1: Generate greeting message
96+
greeting_task = generate_greeting(
97+
name=recipient_name,
98+
greeting_type=greeting_style
99+
)
100+
101+
# Task 2: Display and save greeting (depends on Task 1)
102+
display_task = display_greeting(
103+
message=greeting_task.output
104+
)
105+
106+
# Task 3: Count words in greeting (depends on Task 1)
107+
count_task = count_words(
108+
text=greeting_task.output
109+
)
110+
111+
# Return the final greeting as pipeline output
112+
return greeting_task.output
113+
114+
115+
if __name__ == '__main__':
116+
compiler.Compiler().compile(
117+
pipeline_func=hello_world_pipeline,
118+
package_path=__file__.replace('.py', '.yaml')
119+
)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2024 The Kubeflow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Pipeline demonstrating Pydantic-based input validation.
15+
16+
This example shows how to use Pydantic models for robust input validation
17+
in KFP components, improving reliability and providing clear error messages.
18+
"""
19+
20+
from kfp import compiler
21+
from kfp import dsl
22+
23+
24+
@dsl.component(
25+
base_image="public.ecr.aws/docker/library/python:3.12",
26+
packages_to_install=["pydantic>=2.0,<3"]
27+
)
28+
def validate_training_config(
29+
learning_rate: float,
30+
epochs: int,
31+
batch_size: int,
32+
model_type: str
33+
) -> str:
34+
"""Validate training configuration using Pydantic.
35+
36+
Args:
37+
learning_rate: Learning rate for training (0 < lr <= 1)
38+
epochs: Number of training epochs (1-1000)
39+
batch_size: Batch size (must be power of 2)
40+
model_type: Type of model (cnn, rnn, or transformer)
41+
42+
Returns:
43+
Validation status message
44+
"""
45+
from pydantic import BaseModel, Field, field_validator
46+
47+
class TrainingConfig(BaseModel):
48+
"""Validated training configuration."""
49+
learning_rate: float = Field(gt=0, le=1, description="Learning rate")
50+
epochs: int = Field(gt=0, le=1000, description="Number of epochs")
51+
batch_size: int = Field(gt=0, description="Batch size")
52+
model_type: str = Field(pattern="^(cnn|rnn|transformer)$", description="Model architecture")
53+
54+
@field_validator('batch_size')
55+
@classmethod
56+
def batch_size_power_of_two(cls, v):
57+
"""Ensure batch_size is a power of two."""
58+
if v & (v - 1) != 0:
59+
raise ValueError(f'batch_size must be power of 2, got {v}')
60+
return v
61+
62+
# Validate configuration
63+
config = TrainingConfig(
64+
learning_rate=learning_rate,
65+
epochs=epochs,
66+
batch_size=batch_size,
67+
model_type=model_type
68+
)
69+
70+
status = f"✓ Configuration valid: lr={config.learning_rate}, epochs={config.epochs}, batch_size={config.batch_size}, model={config.model_type}"
71+
print(status)
72+
return status
73+
74+
75+
@dsl.component(
76+
base_image="public.ecr.aws/docker/library/python:3.12",
77+
packages_to_install=["pydantic>=2.0,<3"]
78+
)
79+
def validate_data_paths(
80+
train_path: str,
81+
valid_path: str,
82+
test_path: str
83+
) -> str:
84+
"""Validate data file paths.
85+
86+
Args:
87+
train_path: Path to training data
88+
valid_path: Path to validation data
89+
test_path: Path to test data
90+
91+
Returns:
92+
Validation status
93+
"""
94+
from pydantic import BaseModel, Field
95+
96+
class DataPaths(BaseModel):
97+
"""Validated data paths."""
98+
train_path: str = Field(min_length=1, description="Training data path")
99+
valid_path: str = Field(min_length=1, description="Validation data path")
100+
test_path: str = Field(min_length=1, description="Test data path")
101+
102+
paths = DataPaths(
103+
train_path=train_path,
104+
valid_path=valid_path,
105+
test_path=test_path
106+
)
107+
108+
status = f"✓ Paths validated: train={paths.train_path}, valid={paths.valid_path}, test={paths.test_path}"
109+
print(status)
110+
return status
111+
112+
113+
@dsl.component(base_image="public.ecr.aws/docker/library/python:3.12")
114+
def mock_training(config_status: str, paths_status: str) -> str:
115+
"""Mock training task that uses validated inputs.
116+
117+
Args:
118+
config_status: Configuration validation status
119+
paths_status: Paths validation status
120+
121+
Returns:
122+
Training completion message
123+
"""
124+
print("Validation results:")
125+
print(config_status)
126+
print(paths_status)
127+
print("Starting model training with validated configuration...")
128+
return "Training completed successfully!"
129+
130+
131+
@dsl.pipeline(
132+
name='pipeline-with-pydantic-validation',
133+
description='Demonstrates Pydantic validation for pipeline inputs'
134+
)
135+
def validated_pipeline(
136+
learning_rate: float = 0.001,
137+
epochs: int = 10,
138+
batch_size: int = 32,
139+
model_type: str = "cnn",
140+
train_data_path: str = "/data/train.csv",
141+
valid_data_path: str = "/data/valid.csv",
142+
test_data_path: str = "/data/test.csv"
143+
) -> str:
144+
"""Pipeline with Pydantic-validated inputs.
145+
146+
Args:
147+
learning_rate: Learning rate (0 < lr <= 1)
148+
epochs: Number of epochs (1-1000)
149+
batch_size: Batch size (power of 2)
150+
model_type: Model type (cnn/rnn/transformer)
151+
train_data_path: Path to training data
152+
valid_data_path: Path to validation data
153+
test_data_path: Path to test data
154+
155+
Returns:
156+
Training result message
157+
"""
158+
# Validate training configuration
159+
config_validation = validate_training_config(
160+
learning_rate=learning_rate,
161+
epochs=epochs,
162+
batch_size=batch_size,
163+
model_type=model_type
164+
)
165+
166+
# Validate data paths
167+
paths_validation = validate_data_paths(
168+
train_path=train_data_path,
169+
valid_path=valid_data_path,
170+
test_path=test_data_path
171+
)
172+
173+
# Run training with validated inputs
174+
training_result = mock_training(
175+
config_status=config_validation.output,
176+
paths_status=paths_validation.output
177+
)
178+
179+
return training_result.output
180+
181+
182+
if __name__ == '__main__':
183+
compiler.Compiler().compile(
184+
pipeline_func=validated_pipeline,
185+
package_path=__file__.replace('.py', '.yaml')
186+
)

0 commit comments

Comments
 (0)