Skip to content

Commit 11010ee

Browse files
committed
Fix schema yaml parsing, check for required.
1 parent 3bb7d0f commit 11010ee

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

runprompt

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def main():
229229
if input_schema:
230230
first_key = list(input_schema.keys())[0]
231231
variables[first_key] = args_str
232+
validate_required_inputs(meta, variables)
232233
prompt = render_template(template, variables)
233234
log("Rendered prompt: %s" % prompt)
234235
output_config = meta.get("output", {})
@@ -533,7 +534,6 @@ def parse_yaml(s):
533534
result = {}
534535
stack = [(result, -1)]
535536
current_list = None
536-
current_list_key = None
537537
current_list_indent = -1
538538
for line in s.split("\n"):
539539
if not line.strip() or line.strip().startswith("#"):
@@ -543,14 +543,14 @@ def parse_yaml(s):
543543
list_match = re.match(r"^(\s*)-\s*(.*)", line)
544544
if list_match:
545545
item_value = list_match.group(2).strip()
546-
if current_list is not None and indent > current_list_indent:
546+
if current_list is not None and indent >= current_list_indent:
547547
current_list.append(parse_yaml_value(item_value) if item_value else item_value)
548548
continue
549-
# Not a list item or different context - reset list tracking
549+
# Not a list item - reset list tracking if we've dedented
550550
if current_list is not None and indent <= current_list_indent:
551551
current_list = None
552-
current_list_key = None
553552
current_list_indent = -1
553+
# Pop stack for dedented lines
554554
while stack and indent <= stack[-1][1]:
555555
stack.pop()
556556
if not stack:
@@ -563,23 +563,33 @@ def parse_yaml(s):
563563
parent = stack[-1][0]
564564
if value:
565565
parent[key] = parse_yaml_value(value)
566+
# Reset list tracking when we have a value
567+
current_list = None
568+
current_list_indent = -1
566569
else:
567-
parent[key] = {}
568-
stack.append((parent[key], indent))
569-
# Set up for potential list children
570-
current_list_key = key
571-
current_list = []
572-
parent[key] = current_list
573-
current_list_indent = indent
574-
# Convert empty lists back to empty dicts
575-
def fix_empty(obj):
576-
if isinstance(obj, dict):
577-
for k, v in obj.items():
578-
if isinstance(v, list) and len(v) == 0:
579-
obj[k] = {}
580-
else:
581-
fix_empty(v)
582-
fix_empty(result)
570+
# Check if next non-empty line is a list item
571+
remaining_lines = s.split("\n")
572+
line_idx = remaining_lines.index(line) if line in remaining_lines else -1
573+
is_list_parent = False
574+
if line_idx >= 0:
575+
for next_line in remaining_lines[line_idx + 1:]:
576+
if not next_line.strip() or next_line.strip().startswith("#"):
577+
continue
578+
next_indent = len(next_line) - len(next_line.lstrip())
579+
if next_indent <= indent:
580+
break
581+
if re.match(r"^\s*-\s*", next_line):
582+
is_list_parent = True
583+
break
584+
if is_list_parent:
585+
parent[key] = []
586+
current_list = parent[key]
587+
current_list_indent = indent
588+
else:
589+
parent[key] = {}
590+
stack.append((parent[key], indent))
591+
current_list = None
592+
current_list_indent = -1
583593
return result
584594

585595

@@ -606,6 +616,38 @@ def parse_yaml_value(s):
606616
return s
607617

608618

619+
def get_required_input_fields(meta):
620+
"""Extract required field names from input schema (fields without ? suffix)."""
621+
input_schema = meta.get("input", {}).get("schema", {})
622+
required = []
623+
for key in input_schema:
624+
if not key.endswith("?"):
625+
required.append(key)
626+
return required
627+
628+
629+
def validate_required_inputs(meta, variables):
630+
"""Check that all required input fields are present. Exit with error if not."""
631+
log("DEBUG validate_required_inputs meta: %s" % meta)
632+
log("DEBUG validate_required_inputs variables: %s" % variables)
633+
required = get_required_input_fields(meta)
634+
log("DEBUG required fields: %s" % required)
635+
missing = []
636+
for field in required:
637+
if field not in variables or variables[field] == "":
638+
missing.append(field)
639+
if missing:
640+
print("%sError: Missing required input field(s): %s%s" %
641+
(RED, ", ".join(missing), RESET), file=sys.stderr)
642+
input_schema = meta.get("input", {}).get("schema", {})
643+
print("Expected input schema:", file=sys.stderr)
644+
for key, value in input_schema.items():
645+
opt = " (optional)" if key.endswith("?") else " (required)"
646+
clean_key = key.rstrip("?")
647+
print(" %s: %s%s" % (clean_key, value, opt), file=sys.stderr)
648+
sys.exit(1)
649+
650+
609651
def apply_overrides(meta):
610652
"""Apply RUNPROMPT_* env vars to prompt metadata (for prompt-specific overrides)."""
611653
for key in CONFIG["env"]:

tests/runtests

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ run_test "cache tests" python3 tests/test-cache.py
3434

3535
run_test "tools tests" python3 tests/test-tools.py
3636

37+
run_test "required inputs tests" python3 tests/test-required-inputs.py
38+
3739
echo ""
3840
echo "Passed: $pass, Failed: $fail"
3941

0 commit comments

Comments
 (0)