Potential PyTorch Tracing InterfaceUserUser./bindings/python./bindings/pythontorchtorch./lib/ffi./lib/ffi./lib/pcg./lib/pcg./lib/runtime./lib/runtime./lib/compiler./lib/compilerRuntimeBackingRuntimeBackinglegionlegion Example source code (adapted fromthe pytorch tutorial) 1 import flexflow.torch as torch2 import flexflow.torch.nn as nn34 class NeuralNetwork(nn.Module):5 def __init__(self):6 super(NeuralNetwork, self).__init__()7 self.flatten = nn.Flatten()8 self.linear_relu_stack = nn.Sequential(9 nn.Linear(28*28, 512),10 nn.ReLU(),11 nn.Linear(512, 512),12 nn.ReLU(),13 nn.Linear(512, 10),14 )1516 def forward(self, x):17 x = self.flatten(x)18 logits = self.linear_relu_stack(x)19 return logits2021 def top_level_task():22 model = NeuralNetwork()2324 dataloader = ...2526 if tracing_mechanism == 'symbolic_trace':27 compiled_model = model.compile(28 algorithm=...,29 optimizer=...30 )31 elif tracing_mechanism == 'dynamo':32 compiled_model = torch.compile(33 model,34 backend='flexflow',35 options=dict(36 algorithm=...,37 optimizer=...,38 )39 )4041 for batch_id, (X, y) in enumerate(dataloader):42 pred = compiled_model(X)43 loss = loss_fn(pred, y)44 loss.backward()45 optimizer.step()46 optimizer.zero_grad()4748 if batch_id % 100 == 0:49 loss, current = loss.item(), (batch_num + 1) * len(X)50 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") import flexflow.torch as torchimport flexflow.torch.nn as nn # borrowed fromthe pytorch tutorialclass NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),) def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits def top_level_task():model = NeuralNetwork()User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionalt[fx] compiled_model = model.compile(algorithm=...,optimizer=...)from torch.fx import symbolic_trace symbolic_traced = symbolic_trace(model) model.forward(<tracing tensor>)symbolic_traced : torch.fx.GraphModulerefcompiled_model = compilation(symbolic_traced) # see belowcompiled_model: CompiledModel[dynamo] compiled_model = torch.compile(model,backend='flexflow',options=dict(algorithm=...,optimizer=...))def flexflow_compiler(gm: torch.fx.GraphModule,example_inputs: List[torch.Tensor])-> CompiledModelrefcompiled_model = compilation(gm) # see belowcompiled_model: CompiledModelcompiled_model: CompiledModelUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegioncompilation[def compilation(g: torch.fx.GraphModule) -> CompiledModel]g: torch.fx.GraphModuleff_model = flexflow.torch.from_fx(symbolic_traced)flexflow.torch.from_fx[def from_fx(g: torch.fx.GraphModule) -> ComputationGraph]./bindings/pythontorch./lib/ffi./lib/pcgg: torch.fx.GraphModuleflexflow_computation_graph_create(...)typedef struct {ComputationGraphBuilder *ptr;} flexflow_computation_graph_builder_t;flexflow_computation_graph_add_op_flat(...)ComputationGraphBuilder::flat(...);struct Tensor { ... };typedef struct {Tensor *ptr;} flexflow_tensor_t;flexflow_computation_graph_add_op_dense(...)ComputationGraphBuilder::dense(...);Tensorflexflow_tensor_tflexflow_computation_graph_add_op_relu(...)ComputationGraphbuilder::relu(...);Tensorflexflow_tensor_t..., etc.comp_graph: ComputationGraphoptimization[def optimization(comp_graph: ComputationGraph) -> CompiledModel]./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilercomp_graph: ComputationGraph flexflow_error_tflexflow_computation_graph_compile(flexflow_computation_graph_t,flexflow_optimizer_t,flexflow_compilation_algorithm_t,flexflow_model_compilation_result *out);ModelCompilationResult optimize(ComputationGraph const &,AlgorithmConfig const &);SearchResult optimize(ComputationGraph const &,MachineSpecification const &,CostEstimator const &,AlgorithmConfig const &);struct SearchResult {ParallelComputationGraph pcg;TensorMapping tensor_mapping;SearchSolution solution;CostValues cost_values;};struct ModelCompilationResult {ComputationGraph computation_graph;ParallelComputationGraph pcg;TensorMapping tensor_mapping;};typedef struct {ModelCompilationResult *ptr;} model_compilation_result_t;compiled_model: CompiledModelcompiled_model : CompiledModelUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionserializationmodel_json = compiled_model.as_json()with open('compiled.json', 'w') as f:compiled_model.dump(f)enddeserializationUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionTraining Starts ...41 for batch_id, (X, y) in enumerate(dataloader):42 pred = compiled_model(X)43 loss = loss_fn(pred, y)44 loss.backward()45 optimizer.step()46 optimizer.zero_grad()4748 if batch_id % 100 == 0:49 loss, current = loss.item(), (batch_num + 1) * len(X)50 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")... loop[training loop]User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionopt[reading tensor elements]get_tensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionopt[writing to tensor elements]set_tensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionfwd ...42 pred = compiled_model(X)43 loss = loss_fn(pred, y)... pred = compiled_model(batch)opt[if first iteration]flexflow_error_tflexflow_start_training(flexflow_model_compilation_result_t,flexflow_model_compilation_result_t *out);typedef struct {ModelTrainingInstance *ptr;} flexflow_model_training_instance_t;model.training_instance = ...pred: TensorFutureloss = loss_fn(pred, label)flexflow_error_tflexflow_model_training_instance_forward(flexflow_model_training_instance_t);forward(ModelTrainingInstance const &);loopexecute(OpTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();flexflow_tensor_tloss: LossTensorUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionref[optional] reading tensor elementsref[optional] writing to tensor elementsUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionbwd ...44 loss.backward()... loss.backward()flexflow_error_tflexflow_model_training_instance_backward(flexflow_model_training_instance_t);backward(ModelTrainingInstance const &);loopexecute(OpTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionref[optional] reading tensor elementsref[optional] writing to tensor elementsUser./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionupdate ...45 optimizer.step()46 optimizer.zero_grad()... User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionoptimizer.step()optimizer.zero_grad()flexflow_error_tflexflow_model_training_instance_update(flexflow_model_training_instance_t);update(ModelTrainingInstance const &);loopexecute(IndexTaskInvocation const &);IndexLauncher launcher;...runtime->execute_index_space(ctx, launcher); TaskReturnAccessorTaskReturnAcessor ret_acc = ...;ret_acc.wait();User./bindings/pythontorch./lib/ffi./lib/pcg./lib/runtime./lib/compilerRuntimeBackinglegionTraining Stops<compiled_model goes out of scope>flexflow_error_tflexflow_stop_training(flexflow_model_training_instance_t);