|
| 1 | +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | +// |
| 3 | +// End-to-end runner for the HSTU torch.jit / torch.package artifacts produced |
| 4 | +// by generative_recommenders/dlrm_v3/inference/packager.py and exercised by |
| 5 | +// :end_to_end_test. |
| 6 | +// |
| 7 | +// CLI: |
| 8 | +// hstu_runner <sparse.pt> <dense.pt> <inputs.pt> <output.pt> |
| 9 | +// |
| 10 | +// Where: |
| 11 | +// sparse.pt ScriptModule whose forward(uih, candidates) returns |
| 12 | +// Tuple[Dict[str,Tensor], Dict[str,Tensor], |
| 13 | +// Dict[str,Tensor], Tensor, Tensor] |
| 14 | +// dense.pt ScriptModule (cuda:0, bf16) whose forward(...) returns |
| 15 | +// Tuple[Tensor, Optional[Tensor], Optional[Tensor]] |
| 16 | +// inputs.pt ScriptModule whose forward() returns |
| 17 | +// Tuple[KeyedJaggedTensor, KeyedJaggedTensor] |
| 18 | +// output.pt torch::pickle_save destination for the predictions tensor; |
| 19 | +// readable from Python as ``torch.load(output.pt)``. |
| 20 | + |
| 21 | +#include <fstream> |
| 22 | +#include <iostream> |
| 23 | +#include <stdexcept> |
| 24 | +#include <string> |
| 25 | +#include <vector> |
| 26 | + |
| 27 | +#include <torch/csrc/jit/serialization/import.h> |
| 28 | +#include <torch/script.h> |
| 29 | + |
| 30 | +namespace { |
| 31 | + |
| 32 | +torch::jit::Module loadModule(const std::string& path) { |
| 33 | + auto m = torch::jit::load(path); |
| 34 | + m.eval(); |
| 35 | + return m; |
| 36 | +} |
| 37 | + |
| 38 | +// Walk a Dict<str, Tensor> and replace every value with .to(device) (and |
| 39 | +// optionally .to(bfloat16)). C++ analog of move_sparse_output_to_device. |
| 40 | +void moveDictToDevice( |
| 41 | + c10::impl::GenericDict& d, |
| 42 | + const torch::Device& device, |
| 43 | + bool toBfloat16) { |
| 44 | + for (auto& kv : d) { |
| 45 | + auto t = kv.value().toTensor().to(device); |
| 46 | + if (toBfloat16) { |
| 47 | + t = t.to(torch::kBFloat16); |
| 48 | + } |
| 49 | + d.insert_or_assign(kv.key(), t); |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +void writePickle(const torch::Tensor& t, const std::string& path) { |
| 54 | + // torch::pickle_save returns a byte buffer in the same wire format as |
| 55 | + // ``torch.save(tensor, ...)``, so the Python side can read it with |
| 56 | + // ``torch.load(path)``. |
| 57 | + const auto data = torch::jit::pickle_save(c10::IValue(t)); |
| 58 | + std::ofstream out(path, std::ios::binary); |
| 59 | + if (!out) { |
| 60 | + throw std::runtime_error("failed to open output: " + path); |
| 61 | + } |
| 62 | + out.write(data.data(), static_cast<std::streamsize>(data.size())); |
| 63 | +} |
| 64 | + |
| 65 | +} // namespace |
| 66 | + |
| 67 | +int main(int argc, char** argv) { |
| 68 | + if (argc < 5) { |
| 69 | + std::cerr << "Usage: hstu_runner <sparse.pt> <dense.pt> <inputs.pt> " |
| 70 | + "<output.pt>\n"; |
| 71 | + return 1; |
| 72 | + } |
| 73 | + const std::string sparsePath{argv[1]}; |
| 74 | + const std::string densePath{argv[2]}; |
| 75 | + const std::string inputsPath{argv[3]}; |
| 76 | + const std::string outputPath{argv[4]}; |
| 77 | + |
| 78 | + // Log to a file next to the output so we can inspect even if |
| 79 | + // buck2 swallows stderr. |
| 80 | + const std::string logPath = outputPath + ".log"; |
| 81 | + std::ofstream logFile(logPath); |
| 82 | + auto log = [&](const std::string& msg) { |
| 83 | + logFile << msg << std::endl; |
| 84 | + logFile.flush(); |
| 85 | + std::cerr << msg << std::endl; |
| 86 | + }; |
| 87 | + |
| 88 | + try { |
| 89 | + log("[runner] step 1: loading sparse module from " + sparsePath); |
| 90 | + auto sparse = loadModule(sparsePath); |
| 91 | + |
| 92 | + log("[runner] step 2: loading dense module from " + densePath); |
| 93 | + auto dense = loadModule(densePath); |
| 94 | + |
| 95 | + log("[runner] step 3: loading inputs module from " + inputsPath); |
| 96 | + auto inputs = loadModule(inputsPath); |
| 97 | + |
| 98 | + log("[runner] step 4: running inputs.forward()"); |
| 99 | + auto inputsTuple = inputs.forward({}).toTuple(); |
| 100 | + auto uihLengths = inputsTuple->elements()[0]; |
| 101 | + auto uihValues = inputsTuple->elements()[1]; |
| 102 | + auto candidatesLengths = inputsTuple->elements()[2]; |
| 103 | + auto candidatesValues = inputsTuple->elements()[3]; |
| 104 | + log("[runner] step 4 done: got 4 input tensors"); |
| 105 | + |
| 106 | + log("[runner] step 5: running sparse.forward()"); |
| 107 | + std::vector<c10::IValue> sparseInputs{ |
| 108 | + uihLengths, uihValues, candidatesLengths, candidatesValues}; |
| 109 | + auto sparseOut = sparse.forward(sparseInputs).toTuple(); |
| 110 | + log("[runner] step 5 done: sparse forward returned " + |
| 111 | + std::to_string(sparseOut->elements().size()) + " elements"); |
| 112 | + |
| 113 | + log("[runner] step 6: unpacking sparse output dicts"); |
| 114 | + auto seqEmbValues = sparseOut->elements()[0].toGenericDict(); |
| 115 | + auto seqEmbLengths = sparseOut->elements()[1].toGenericDict(); |
| 116 | + auto payloadFeatures = sparseOut->elements()[2].toGenericDict(); |
| 117 | + auto uihSeqLengths = sparseOut->elements()[3].toTensor(); |
| 118 | + auto numCandidates = sparseOut->elements()[4].toTensor(); |
| 119 | + log("[runner] step 6 done: unpacked dicts"); |
| 120 | + |
| 121 | + log("[runner] step 7: moving dicts to cuda:0"); |
| 122 | + const auto device = torch::Device(torch::kCUDA, 0); |
| 123 | + moveDictToDevice(seqEmbValues, device, /*toBfloat16=*/true); |
| 124 | + log("[runner] step 7a: seqEmbValues moved"); |
| 125 | + moveDictToDevice(seqEmbLengths, device, /*toBfloat16=*/false); |
| 126 | + log("[runner] step 7b: seqEmbLengths moved"); |
| 127 | + moveDictToDevice(payloadFeatures, device, /*toBfloat16=*/false); |
| 128 | + log("[runner] step 7c: payloadFeatures moved"); |
| 129 | + uihSeqLengths = uihSeqLengths.to(device); |
| 130 | + numCandidates = numCandidates.to(device); |
| 131 | + log("[runner] step 7 done: all on cuda:0"); |
| 132 | + |
| 133 | + log("[runner] step 8: running dense.forward()"); |
| 134 | + std::vector<c10::IValue> denseInputs{ |
| 135 | + seqEmbValues, |
| 136 | + seqEmbLengths, |
| 137 | + payloadFeatures, |
| 138 | + uihSeqLengths, |
| 139 | + numCandidates, |
| 140 | + }; |
| 141 | + auto denseOut = dense.forward(denseInputs); |
| 142 | + log("[runner] step 8 done: dense forward returned"); |
| 143 | + |
| 144 | + auto preds = denseOut.toTensor().detach().cpu(); |
| 145 | + log("[runner] step 9: preds on cpu"); |
| 146 | + |
| 147 | + std::cout << "preds shape: " << preds.sizes() << '\n'; |
| 148 | + std::cout << "preds sum: " |
| 149 | + << preds.to(torch::kFloat32).sum().item<float>() << '\n'; |
| 150 | + |
| 151 | + writePickle(preds, outputPath); |
| 152 | + std::cout << "wrote " << outputPath << '\n'; |
| 153 | + log("[runner] step 10: done, wrote output"); |
| 154 | + return 0; |
| 155 | + } catch (const std::exception& e) { |
| 156 | + log(std::string("hstu_runner FAILED: ") + e.what()); |
| 157 | + return 1; |
| 158 | + } |
| 159 | +} |
0 commit comments