Skip to content

Commit adaafec

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
TorchScript HSTU sparse + dense for C++ deployment
Reviewed By: LinjianMa Differential Revision: D102661041 fbshipit-source-id: 188566e15775df8e4400366010e1d3ecf2ef8797
1 parent fb64fd4 commit adaafec

16 files changed

Lines changed: 1524 additions & 19 deletions
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
//
3+
// End-to-end runner for the HSTU torch.jit / torch.package artifacts produced
4+
// by generative_recommenders/dlrm_v3/inference/packager.py and exercised by
5+
// :end_to_end_test.
6+
//
7+
// CLI:
8+
// hstu_runner <sparse.pt> <dense.pt> <inputs.pt> <output.pt>
9+
//
10+
// Where:
11+
// sparse.pt ScriptModule whose forward(uih, candidates) returns
12+
// Tuple[Dict[str,Tensor], Dict[str,Tensor],
13+
// Dict[str,Tensor], Tensor, Tensor]
14+
// dense.pt ScriptModule (cuda:0, bf16) whose forward(...) returns
15+
// Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
16+
// inputs.pt ScriptModule whose forward() returns
17+
// Tuple[KeyedJaggedTensor, KeyedJaggedTensor]
18+
// output.pt torch::pickle_save destination for the predictions tensor;
19+
// readable from Python as ``torch.load(output.pt)``.
20+
21+
#include <fstream>
22+
#include <iostream>
23+
#include <stdexcept>
24+
#include <string>
25+
#include <vector>
26+
27+
#include <torch/csrc/jit/serialization/import.h>
28+
#include <torch/script.h>
29+
30+
namespace {
31+
32+
torch::jit::Module loadModule(const std::string& path) {
33+
auto m = torch::jit::load(path);
34+
m.eval();
35+
return m;
36+
}
37+
38+
// Walk a Dict<str, Tensor> and replace every value with .to(device) (and
39+
// optionally .to(bfloat16)). C++ analog of move_sparse_output_to_device.
40+
void moveDictToDevice(
41+
c10::impl::GenericDict& d,
42+
const torch::Device& device,
43+
bool toBfloat16) {
44+
for (auto& kv : d) {
45+
auto t = kv.value().toTensor().to(device);
46+
if (toBfloat16) {
47+
t = t.to(torch::kBFloat16);
48+
}
49+
d.insert_or_assign(kv.key(), t);
50+
}
51+
}
52+
53+
void writePickle(const torch::Tensor& t, const std::string& path) {
54+
// torch::pickle_save returns a byte buffer in the same wire format as
55+
// ``torch.save(tensor, ...)``, so the Python side can read it with
56+
// ``torch.load(path)``.
57+
const auto data = torch::jit::pickle_save(c10::IValue(t));
58+
std::ofstream out(path, std::ios::binary);
59+
if (!out) {
60+
throw std::runtime_error("failed to open output: " + path);
61+
}
62+
out.write(data.data(), static_cast<std::streamsize>(data.size()));
63+
}
64+
65+
} // namespace
66+
67+
int main(int argc, char** argv) {
68+
if (argc < 5) {
69+
std::cerr << "Usage: hstu_runner <sparse.pt> <dense.pt> <inputs.pt> "
70+
"<output.pt>\n";
71+
return 1;
72+
}
73+
const std::string sparsePath{argv[1]};
74+
const std::string densePath{argv[2]};
75+
const std::string inputsPath{argv[3]};
76+
const std::string outputPath{argv[4]};
77+
78+
// Log to a file next to the output so we can inspect even if
79+
// buck2 swallows stderr.
80+
const std::string logPath = outputPath + ".log";
81+
std::ofstream logFile(logPath);
82+
auto log = [&](const std::string& msg) {
83+
logFile << msg << std::endl;
84+
logFile.flush();
85+
std::cerr << msg << std::endl;
86+
};
87+
88+
try {
89+
log("[runner] step 1: loading sparse module from " + sparsePath);
90+
auto sparse = loadModule(sparsePath);
91+
92+
log("[runner] step 2: loading dense module from " + densePath);
93+
auto dense = loadModule(densePath);
94+
95+
log("[runner] step 3: loading inputs module from " + inputsPath);
96+
auto inputs = loadModule(inputsPath);
97+
98+
log("[runner] step 4: running inputs.forward()");
99+
auto inputsTuple = inputs.forward({}).toTuple();
100+
auto uihLengths = inputsTuple->elements()[0];
101+
auto uihValues = inputsTuple->elements()[1];
102+
auto candidatesLengths = inputsTuple->elements()[2];
103+
auto candidatesValues = inputsTuple->elements()[3];
104+
log("[runner] step 4 done: got 4 input tensors");
105+
106+
log("[runner] step 5: running sparse.forward()");
107+
std::vector<c10::IValue> sparseInputs{
108+
uihLengths, uihValues, candidatesLengths, candidatesValues};
109+
auto sparseOut = sparse.forward(sparseInputs).toTuple();
110+
log("[runner] step 5 done: sparse forward returned " +
111+
std::to_string(sparseOut->elements().size()) + " elements");
112+
113+
log("[runner] step 6: unpacking sparse output dicts");
114+
auto seqEmbValues = sparseOut->elements()[0].toGenericDict();
115+
auto seqEmbLengths = sparseOut->elements()[1].toGenericDict();
116+
auto payloadFeatures = sparseOut->elements()[2].toGenericDict();
117+
auto uihSeqLengths = sparseOut->elements()[3].toTensor();
118+
auto numCandidates = sparseOut->elements()[4].toTensor();
119+
log("[runner] step 6 done: unpacked dicts");
120+
121+
log("[runner] step 7: moving dicts to cuda:0");
122+
const auto device = torch::Device(torch::kCUDA, 0);
123+
moveDictToDevice(seqEmbValues, device, /*toBfloat16=*/true);
124+
log("[runner] step 7a: seqEmbValues moved");
125+
moveDictToDevice(seqEmbLengths, device, /*toBfloat16=*/false);
126+
log("[runner] step 7b: seqEmbLengths moved");
127+
moveDictToDevice(payloadFeatures, device, /*toBfloat16=*/false);
128+
log("[runner] step 7c: payloadFeatures moved");
129+
uihSeqLengths = uihSeqLengths.to(device);
130+
numCandidates = numCandidates.to(device);
131+
log("[runner] step 7 done: all on cuda:0");
132+
133+
log("[runner] step 8: running dense.forward()");
134+
std::vector<c10::IValue> denseInputs{
135+
seqEmbValues,
136+
seqEmbLengths,
137+
payloadFeatures,
138+
uihSeqLengths,
139+
numCandidates,
140+
};
141+
auto denseOut = dense.forward(denseInputs);
142+
log("[runner] step 8 done: dense forward returned");
143+
144+
auto preds = denseOut.toTensor().detach().cpu();
145+
log("[runner] step 9: preds on cpu");
146+
147+
std::cout << "preds shape: " << preds.sizes() << '\n';
148+
std::cout << "preds sum: "
149+
<< preds.to(torch::kFloat32).sum().item<float>() << '\n';
150+
151+
writePickle(preds, outputPath);
152+
std::cout << "wrote " << outputPath << '\n';
153+
log("[runner] step 10: done, wrote output");
154+
return 0;
155+
} catch (const std::exception& e) {
156+
log(std::string("hstu_runner FAILED: ") + e.what());
157+
return 1;
158+
}
159+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
"""
18+
TorchScript-friendly wrapper for the HSTU dense path (GPU transformer).
19+
20+
``HSTUDenseScriptModule`` accepts the *flattened* sparse-output dicts produced
21+
by :class:`HSTUSparseScriptModule`, reconstructs ``Dict[str,
22+
SequenceEmbedding]`` for the existing :meth:`DlrmHSTU.main_forward` and
23+
returns a 3-tuple of ``(preds, labels, weights)`` -- the only fields the
24+
predictor actually consumes.
25+
"""
26+
27+
from typing import Dict
28+
29+
import torch
30+
from generative_recommenders.dlrm_v3.inference.inference_modules import get_hstu_model
31+
from generative_recommenders.dlrm_v3.inference.ts_types import (
32+
SeqEmbLengths,
33+
SeqEmbValues,
34+
unflatten_seq_embeddings,
35+
)
36+
from generative_recommenders.modules.dlrm_hstu import DlrmHSTU, DlrmHSTUConfig
37+
from torchrec.modules.embedding_configs import EmbeddingConfig
38+
39+
40+
class HSTUDenseScriptModule(torch.nn.Module):
41+
"""Script-friendly dense module.
42+
43+
The wrapper owns a dense-only :class:`DlrmHSTU` (no
44+
``_embedding_collection``) and delegates to ``main_forward`` after
45+
reconstructing the ``SequenceEmbedding`` NamedTuple form.
46+
"""
47+
48+
def __init__(
49+
self,
50+
hstu_config: DlrmHSTUConfig,
51+
table_config: Dict[str, EmbeddingConfig],
52+
) -> None:
53+
super().__init__()
54+
self._hstu_model: DlrmHSTU = get_hstu_model(
55+
table_config=table_config,
56+
hstu_config=hstu_config,
57+
table_device="cpu",
58+
is_dense=True,
59+
)
60+
61+
def forward(
62+
self,
63+
seq_emb_values: SeqEmbValues,
64+
seq_emb_lengths: SeqEmbLengths,
65+
payload_features: Dict[str, torch.Tensor],
66+
uih_seq_lengths: torch.Tensor,
67+
num_candidates: torch.Tensor,
68+
) -> torch.Tensor:
69+
# TorchScript supports ``int(tensor.item())`` on a 0-d tensor.
70+
max_uih_len: int = int(uih_seq_lengths.max().item())
71+
max_num_candidates: int = int(num_candidates.max().item())
72+
73+
seq_embeddings = unflatten_seq_embeddings(seq_emb_values, seq_emb_lengths)
74+
75+
(
76+
_,
77+
_,
78+
_,
79+
mt_target_preds,
80+
_mt_target_labels,
81+
_mt_target_weights,
82+
) = self._hstu_model.main_forward(
83+
seq_embeddings=seq_embeddings,
84+
payload_features=payload_features,
85+
max_uih_len=max_uih_len,
86+
uih_seq_lengths=uih_seq_lengths,
87+
max_num_candidates=max_num_candidates,
88+
num_candidates=num_candidates,
89+
)
90+
assert mt_target_preds is not None
91+
# Return just the predictions tensor; labels/weights are unused by
92+
# the predictor at inference time and would force ``Optional[Tensor]``
93+
# in the return type, which torch.jit.trace rejects ("Only tensors,
94+
# lists, tuples of tensors, or dictionary of tensors can be output
95+
# from traced functions").
96+
return mt_target_preds

0 commit comments

Comments
 (0)