Skip to content

Commit 43169dd

Browse files
Python: Handle PrepareRecipe, Visit, and Generate (#6551)
1 parent d890c71 commit 43169dd

File tree

3 files changed

+452
-0
lines changed

3 files changed

+452
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ dist/
3131
.pytest_cache/
3232
.mypy_cache/
3333
.ruff_cache/
34+
.venv/

rewrite-python/rewrite/src/rewrite/rpc/server.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ def handle_reset(params: dict) -> bool:
408408
local_objects.clear()
409409
remote_objects.clear()
410410
remote_refs.clear()
411+
_prepared_recipes.clear()
412+
_execution_contexts.clear()
413+
_recipe_accumulators.clear()
414+
_recipe_phases.clear()
411415
logger.info("Reset: cleared all cached state")
412416
return True
413417

@@ -521,6 +525,281 @@ def _serialize_value(value) -> Any:
521525
return str(value)
522526

523527

528+
# Prepared recipes storage - maps recipe IDs to recipe instances
529+
_prepared_recipes: Dict[str, Any] = {}
530+
# Execution contexts storage - maps context IDs to ExecutionContext instances
531+
_execution_contexts: Dict[str, Any] = {}
532+
# Accumulator storage for ScanningRecipes - maps recipe IDs to accumulators
533+
_recipe_accumulators: Dict[str, Any] = {}
534+
# Phase tracking for recipes - maps recipe IDs to 'scan' or 'edit'
535+
_recipe_phases: Dict[str, str] = {}
536+
537+
538+
def handle_prepare_recipe(params: dict) -> dict:
539+
"""Handle a PrepareRecipe RPC request.
540+
541+
Prepares a recipe for execution by:
542+
1. Looking up the recipe in the marketplace
543+
2. Instantiating it with the provided options
544+
3. Storing it with a unique ID
545+
4. Returning the descriptor and visitor info
546+
547+
Args:
548+
params: dict with 'id' (recipe name) and optional 'options'
549+
550+
Returns:
551+
dict with 'id', 'descriptor', 'editVisitor', and precondition info
552+
"""
553+
recipe_name = params.get('id')
554+
if recipe_name is None:
555+
raise ValueError("Recipe 'id' is required")
556+
options = params.get('options', {})
557+
558+
logger.info(f"PrepareRecipe: id={recipe_name}, options={options}")
559+
560+
marketplace = _get_marketplace()
561+
562+
# Look up the recipe - returns (RecipeDescriptor, Type[Recipe]) tuple
563+
recipe_info = marketplace.find_recipe(recipe_name)
564+
if recipe_info is None:
565+
raise ValueError(f"Recipe not found: {recipe_name}")
566+
567+
_descriptor, recipe_class = recipe_info
568+
if recipe_class is None:
569+
raise ValueError(f"Recipe class not found for: {recipe_name}")
570+
571+
# Instantiate the recipe with options
572+
if options:
573+
recipe = recipe_class(**options)
574+
else:
575+
recipe = recipe_class()
576+
577+
# Generate a unique ID for this prepared recipe
578+
prepared_id = generate_id()
579+
_prepared_recipes[prepared_id] = recipe
580+
581+
# Build the response
582+
descriptor = recipe.descriptor()
583+
584+
# Determine if this is a scanning recipe
585+
from rewrite.recipe import ScanningRecipe
586+
is_scanning = isinstance(recipe, ScanningRecipe)
587+
588+
response = {
589+
'id': prepared_id,
590+
'descriptor': _recipe_descriptor_to_dict(descriptor),
591+
'editVisitor': f'edit:{prepared_id}',
592+
'editPreconditions': _get_preconditions(recipe, 'edit'),
593+
'scanVisitor': f'scan:{prepared_id}' if is_scanning else None,
594+
'scanPreconditions': _get_preconditions(recipe, 'scan') if is_scanning else [],
595+
}
596+
597+
logger.info(f"PrepareRecipe response: {response}")
598+
return response
599+
600+
601+
def _get_preconditions(recipe, phase: str) -> List[dict]:
602+
"""Get preconditions for a recipe phase.
603+
604+
For now, we add a type precondition to ensure only Python files are visited.
605+
"""
606+
# Add precondition to only visit Python source files
607+
return [{
608+
'visitorName': 'org.openrewrite.rpc.internal.FindTreesOfType',
609+
'visitorOptions': {'type': 'org.openrewrite.python.tree.Py'}
610+
}]
611+
612+
613+
def handle_visit(params: dict) -> dict:
614+
"""Handle a Visit RPC request.
615+
616+
Applies a visitor to a tree and returns whether it was modified.
617+
618+
Args:
619+
params: dict with 'visitor', 'sourceFileType', 'treeId', 'p' (context id), 'cursor'
620+
621+
Returns:
622+
dict with 'modified' boolean
623+
"""
624+
visitor_name = params.get('visitor')
625+
source_file_type = params.get('sourceFileType')
626+
tree_id = params.get('treeId')
627+
p_id = params.get('p')
628+
cursor_ids = params.get('cursor')
629+
630+
if visitor_name is None:
631+
raise ValueError("'visitor' is required")
632+
if tree_id is None:
633+
raise ValueError("'treeId' is required")
634+
635+
logger.info(f"Visit: visitor={visitor_name}, treeId={tree_id}, p={p_id}")
636+
637+
# Get or create execution context
638+
if p_id and p_id in _execution_contexts:
639+
ctx = _execution_contexts[p_id]
640+
else:
641+
from rewrite import InMemoryExecutionContext
642+
ctx = InMemoryExecutionContext()
643+
if p_id:
644+
_execution_contexts[p_id] = ctx
645+
646+
# Get the tree - fetch from Java if we don't have it locally
647+
tree = local_objects.get(tree_id)
648+
if tree is None:
649+
tree = get_object_from_java(tree_id, source_file_type)
650+
651+
if tree is None:
652+
raise ValueError(f"Tree not found: {tree_id}")
653+
654+
# Instantiate the visitor
655+
visitor = _instantiate_visitor(visitor_name, ctx)
656+
657+
# Apply the visitor
658+
from rewrite.visitor import Cursor
659+
cursor = Cursor(None, Cursor.ROOT_VALUE)
660+
661+
before = tree
662+
after = visitor.visit(tree, ctx, cursor)
663+
664+
# Update local objects with the result and determine if modified
665+
# Use referential equality (identity comparison) to detect modifications
666+
if after is None:
667+
# Tree was deleted
668+
if tree_id in local_objects:
669+
del local_objects[tree_id]
670+
modified = True
671+
elif after is not before:
672+
# Tree object changed - update both the tree_id entry and the new id entry
673+
local_objects[tree_id] = after
674+
if str(after.id) != tree_id:
675+
local_objects[str(after.id)] = after
676+
modified = True
677+
else:
678+
modified = False
679+
680+
logger.info(f"Visit result: modified={modified}, tree_id={tree_id}, before.id={before.id}, after.id={after.id if after else None}")
681+
return {'modified': modified}
682+
683+
684+
def _instantiate_visitor(visitor_name: str, ctx):
685+
"""Instantiate a visitor from its name.
686+
687+
Visitor names can be:
688+
- 'edit:<recipe_id>' - get the editor from a prepared recipe
689+
- 'scan:<recipe_id>' - get the scanner from a prepared scanning recipe
690+
691+
For ScanningRecipes, the accumulator is persisted across calls so that
692+
data collected during the scan phase is available during the edit and
693+
generate phases.
694+
"""
695+
if visitor_name.startswith('edit:'):
696+
recipe_id = visitor_name[5:]
697+
recipe = _prepared_recipes.get(recipe_id)
698+
if recipe is None:
699+
raise ValueError(f"Prepared recipe not found: {recipe_id}")
700+
701+
# Track phase transition
702+
_recipe_phases[recipe_id] = 'edit'
703+
704+
# For ScanningRecipe, use the accumulated data from scan phase
705+
from rewrite.recipe import ScanningRecipe
706+
if isinstance(recipe, ScanningRecipe):
707+
# Get existing accumulator or create new one
708+
if recipe_id not in _recipe_accumulators:
709+
_recipe_accumulators[recipe_id] = recipe.initial_value(ctx)
710+
acc = _recipe_accumulators[recipe_id]
711+
return recipe.editor_with_data(acc)
712+
713+
return recipe.editor()
714+
715+
elif visitor_name.startswith('scan:'):
716+
recipe_id = visitor_name[5:]
717+
recipe = _prepared_recipes.get(recipe_id)
718+
if recipe is None:
719+
raise ValueError(f"Prepared recipe not found: {recipe_id}")
720+
from rewrite.recipe import ScanningRecipe
721+
if not isinstance(recipe, ScanningRecipe):
722+
raise ValueError(f"Recipe is not a scanning recipe: {recipe_id}")
723+
724+
# Check for phase transition (edit -> scan = new cycle)
725+
# If we're transitioning from edit back to scan, clear the accumulator
726+
if _recipe_phases.get(recipe_id) == 'edit':
727+
_recipe_accumulators.pop(recipe_id, None)
728+
_recipe_phases[recipe_id] = 'scan'
729+
730+
# Get existing accumulator or create new one
731+
if recipe_id not in _recipe_accumulators:
732+
_recipe_accumulators[recipe_id] = recipe.initial_value(ctx)
733+
acc = _recipe_accumulators[recipe_id]
734+
735+
return recipe.scanner(acc)
736+
737+
else:
738+
raise ValueError(f"Unknown visitor name format: {visitor_name}")
739+
740+
741+
def handle_generate(params: dict) -> dict:
742+
"""Handle a Generate RPC request.
743+
744+
Called by the recipe run cycle to generate new source files from scanning recipes.
745+
For non-scanning recipes, returns an empty list.
746+
747+
The accumulator used here is the same one that was populated during the scan phase,
748+
allowing recipes to generate files based on data collected across all source files.
749+
750+
Args:
751+
params: dict with 'id' (prepared recipe id) and 'p' (context id)
752+
753+
Returns:
754+
dict with 'ids' list and 'sourceFileTypes' list
755+
"""
756+
recipe_id = params.get('id')
757+
p_id = params.get('p')
758+
759+
if recipe_id is None:
760+
raise ValueError("'id' is required")
761+
762+
logger.info(f"Generate: id={recipe_id}, p={p_id}")
763+
764+
recipe = _prepared_recipes.get(recipe_id)
765+
if recipe is None:
766+
raise ValueError(f"Prepared recipe not found: {recipe_id}")
767+
768+
# Get or create execution context
769+
if p_id and p_id in _execution_contexts:
770+
ctx = _execution_contexts[p_id]
771+
else:
772+
from rewrite import InMemoryExecutionContext
773+
ctx = InMemoryExecutionContext()
774+
if p_id:
775+
_execution_contexts[p_id] = ctx
776+
777+
# Only scanning recipes can generate files
778+
from rewrite.recipe import ScanningRecipe
779+
if isinstance(recipe, ScanningRecipe):
780+
# Use the persisted accumulator from the scan phase, or create new one if not available
781+
if recipe_id in _recipe_accumulators:
782+
acc = _recipe_accumulators[recipe_id]
783+
else:
784+
acc = recipe.initial_value(ctx)
785+
_recipe_accumulators[recipe_id] = acc
786+
787+
generated = recipe.generate(acc, ctx)
788+
789+
ids = []
790+
source_file_types = []
791+
for sf in generated:
792+
sf_id = str(sf.id)
793+
local_objects[sf_id] = sf
794+
ids.append(sf_id)
795+
source_file_types.append(sf.__class__.__module__ + '.' + sf.__class__.__name__)
796+
797+
return {'ids': ids, 'sourceFileTypes': source_file_types}
798+
799+
# Non-scanning recipes don't generate files
800+
return {'ids': [], 'sourceFileTypes': []}
801+
802+
524803
def handle_request(method: str, params: dict) -> Any:
525804
"""Handle an RPC request."""
526805
handlers = {
@@ -531,6 +810,9 @@ def handle_request(method: str, params: dict) -> Any:
531810
'Print': handle_print,
532811
'Reset': handle_reset,
533812
'GetMarketplace': handle_get_marketplace,
813+
'PrepareRecipe': handle_prepare_recipe,
814+
'Visit': handle_visit,
815+
'Generate': handle_generate,
534816
}
535817

536818
handler = handlers.get(method)

0 commit comments

Comments
 (0)