Potential PyTorch Tracing InterfaceUserUser./bindings/python./bindings/pythontorchtorch./lib/ffi./lib/ffi./lib/pcg./lib/pcg./lib/runtime./lib/runtime./lib/compiler./lib/compilerRuntimeBackingRuntimeBackinglegionlegion Example source code (adapted fromthe pytorch tutorial) 1 import flexflow.torch as torch2 import flexflow.torch.nn as nn34 class NeuralNetwork(nn.Module):5   def __init__(self):6     super(NeuralNetwork, self).__init__()7     self.flatten = nn.Flatten()8     self.linear_relu_stack = nn.Sequential(9       nn.Linear(28*28, 512),10       nn.ReLU(),11       nn.Linear(512, 512),12       nn.ReLU(),13       nn.Linear(512, 10),14     )1516     def forward(self, x):17       x = self.flatten(x)18       logits = self.linear_relu_stack(x)19       return logits2021 def top_level_task():22   model = NeuralNetwork()2324   dataloader = ...2526   if tracing_mechanism == 'symbolic_trace':27     compiled_model = model.compile(28       algorithm=...,29       optimizer=...30     )31   elif tracing_mechanism == 'dynamo':32     compiled_model = torch.compile(33       model,34       backend='flexflow',35       options=dict(36         algorithm=...,37         optimizer=...,38       )39     )4041   for batch_id, (X, y) in enumerate(dataloader):42     pred = compiled_model(X)43     loss = loss_fn(pred, y)44     loss.backward()45     optimizer.step()46     optimizer.zero_grad()4748     if batch_id % 100 == 0:49       loss, current = loss.item(), (batch_num + 1) * len(X)50       print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") import flexflow.torch as torchimport flexflow.torch.nn as nn # borrowed fromthe pytorch tutorialclass NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)    def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits def top_level_task():model = NeuralNetwork()User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionalt[fx] compiled_model = model.compile(algorithm=...,optimizer=...)from torch.fx import symbolic_trace symbolic_traced  = symbolic_trace(model) model.forward(<tracing tensor>)symbolic_traced : torch.fx.GraphModulerefcompiled_model = compilation(symbolic_traced) # see belowcompiled_model: CompiledModel[dynamo] compiled_model = torch.compile(model,backend='flexflow',options=dict(algorithm=...,optimizer=...))def flexflow_compiler(gm: torch.fx.GraphModule,example_inputs: List[torch.Tensor])-> CompiledModelrefcompiled_model = compilation(gm) # see belowcompiled_model: CompiledModelcompiled_model: CompiledModelUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegioncompilation[def compilation(g: torch.fx.GraphModule) -> CompiledModel]g: torch.fx.GraphModuleff_model = flexflow.torch.from_fx(symbolic_traced)flexflow.torch.from_fx[def from_fx(g: torch.fx.GraphModule) -> ComputationGraph]./bindings/pythontorch./lib/ffi./lib/pcgg: torch.fx.GraphModuleflexflow_computation_graph_create(...)typedef struct {ComputationGraphBuilder *ptr;} flexflow_computation_graph_builder_t;flexflow_computation_graph_add_op_flat(...)ComputationGraphBuilder::flat(...);struct Tensor { ... };typedef struct {Tensor *ptr;} flexflow_tensor_t;flexflow_computation_graph_add_op_dense(...)ComputationGraphBuilder::dense(...);Tensorflexflow_tensor_tflexflow_computation_graph_add_op_relu(...)ComputationGraphbuilder::relu(...);Tensorflexflow_tensor_t..., etc.comp_graph: ComputationGraphoptimization[def optimization(comp_graph: ComputationGraph) -> CompiledModel]./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilercomp_graph: ComputationGraph flexflow_error_tflexflow_computation_graph_compile(flexflow_computation_graph_t,flexflow_optimizer_t,flexflow_compilation_algorithm_t,flexflow_model_compilation_result *out);ModelCompilationResult optimize(ComputationGraph const &,AlgorithmConfig const &);SearchResult optimize(ComputationGraph const &,MachineSpecification const &,CostEstimator const &,AlgorithmConfig const &);struct SearchResult {ParallelComputationGraph pcg;TensorMapping tensor_mapping;SearchSolution solution;CostValues cost_values;};struct ModelCompilationResult {ComputationGraph computation_graph;ParallelComputationGraph pcg;TensorMapping tensor_mapping;};typedef struct {ModelCompilationResult *ptr;} model_compilation_result_t;compiled_model: CompiledModelcompiled_model : CompiledModelUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionserializationmodel_json = compiled_model.as_json()with open('compiled.json', 'w') as f:compiled_model.dump(f)enddeserializationUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionTraining Starts ...41   for batch_id, (X, y) in enumerate(dataloader):42     pred = compiled_model(X)43     loss = loss_fn(pred, y)44     loss.backward()45     optimizer.step()46     optimizer.zero_grad()4748     if batch_id % 100 == 0:49       loss, current = loss.item(), (batch_num + 1) * len(X)50       print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")... loop[training loop]User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionopt[reading tensor elements]get_tensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionopt[writing to tensor elements]set_tensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionfwd ...42     pred = compiled_model(X)43     loss = loss_fn(pred, y)... pred = compiled_model(batch)opt[if first iteration]flexflow_error_tflexflow_start_training(flexflow_model_compilation_result_t,flexflow_model_compilation_result_t *out);typedef struct {ModelTrainingInstance *ptr;} flexflow_model_training_instance_t;model.training_instance = ...pred: TensorFutureloss = loss_fn(pred, label)flexflow_error_tflexflow_model_training_instance_forward(flexflow_model_training_instance_t);forward(ModelTrainingInstance const &);loopexecute(OpTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();flexflow_tensor_tloss: LossTensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionref[optional] reading tensor elementsref[optional] writing to tensor elementsUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionbwd ...44     loss.backward()... loss.backward()flexflow_error_tflexflow_model_training_instance_backward(flexflow_model_training_instance_t);backward(ModelTrainingInstance const &);loopexecute(OpTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionref[optional] reading tensor elementsref[optional] writing to tensor elementsUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionupdate ...45     optimizer.step()46     optimizer.zero_grad()... User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionoptimizer.step()optimizer.zero_grad()flexflow_error_tflexflow_model_training_instance_update(flexflow_model_training_instance_t);update(ModelTrainingInstance const &);loopexecute(IndexTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionTraining Stops<compiled_model goes out of scope>flexflow_error_tflexflow_stop_training(flexflow_model_training_instance_t);