-
Notifications
You must be signed in to change notification settings - Fork 136
Description
Feature Request Type
Usability
Have you searched existing issues?
Yes
Is your feature request related to a problem?
Copyright 2023 Ant Group Co., Ltd.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Start nodes.
> bazel run -c opt //examples/python/utils:nodectl -- --config pwd
/examples/python/conf/2pc.json up
Run this example script.
> bazel run -c opt //examples/python/ml/flax_bert -- --config pwd
/examples/python/conf/2pc.json
import argparse
import json
import os
import time
from contextlib import contextmanager
import flax.linen as fnn
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax import Array
from typing import Optional, Tuple, Union
from datasets import load_dataset
from transformers import AutoTokenizer, FlaxBertForSequenceClassification
import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
import spu.utils.distributed as ppd
copts = spu_pb2.CompilerOptions()
enable x / broadcast(y) -> x * broadcast(1/y) which accelerate the softmax
copts.enable_optimize_denominator_with_broadcast = True
parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="examples/python/ml/flax_bert/2pc.json")
args = parser.parse_args()
with open(args.config, 'r') as file:
conf = json.load(file)
ppd.init(conf["nodes"], conf["devices"])
def _gelu(x: Array):
order = 2
left, right = -1.8, 0.5
zero_mask = x < left
right_mask = x > right
# poly_mask = jnp.logical_not(zero_mask) & jnp.logical_not(right_mask)
poly_mask = zero_mask ^ right_mask ^ True
if order == 2:
coeff = jnp.array([0.17025471961795552, 0.32421075145645123, -0.006946478426716862])
# tmp1 = x * coeff[0] + coeff[1]
# computed_v = tmp1 * x + coeff[2]
x2 = jnp.square(x)
computed_v = coeff[0] * x2 + coeff[1] * x + coeff[2]
else:
raise ValueError("error")
x = right_mask * x + poly_mask * computed_v # + zero_mask * 0
return x
def ours_fake_exp(x: Array) -> Array:
coeff = jnp.array([0.0291786417536755, 0.26332572942924737, 0.8318065569430768, 0.9708981428645548])
tmp1 = x * coeff[0] + coeff[1]
tmp2 = x * tmp1 + coeff[2]
computed_v = x * tmp2 + coeff[3]
# x2 = jnp.square(x)
# x3 = x2 * x
# computed_v = coeff[0] * x3 + coeff[1] * x2 + coeff[2] * x + coeff[3]
return computed_v
def _softmax(
x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
) -> Array:
threshold = -3.9
x = x - jnp.max(x, axis=-1, keepdims=True)
# exp on large negative is clipped to zero
compute_mask = x > threshold
computed_v = ours_fake_exp(x)
x = computed_v * compute_mask
divisor = jnp.sum(x, axis, keepdims=True)
return x / divisor
@contextmanager
def hijack(enabled=True):
if not enabled:
yield
return
# hijack some target functions
jnn_gelu = jnn.gelu
fnn_gelu = fnn.gelu
jnn_sm = jnn.softmax
fnn_sm = fnn.softmax
jnn.gelu = _gelu
fnn.gelu = _gelu
jnn.softmax = _softmax
fnn.softmax = _softmax
yield
# recover back
jnn.gelu = jnn_gelu
fnn.gelu = fnn_gelu
jnn.softmax = jnn_sm
fnn.softmax = fnn_sm
def run_on_cpu(model, input_ids, attention_masks, labels):
print(f"Running on CPU ...")
params = model.params
def eval(params, input_ids, attention_masks):
logits = model(input_ids, attention_masks, params=params)[0]
return logits
start = time.time()
logits = eval(params, input_ids, attention_masks)
end = time.time()
print(f"CPU runtime: {(end - start)}s\noutput logits: {logits}")
def run_on_spu(model, input_ids, attention_masks, labels):
print(f"Running on SPU ...")
params = model.params
def eval(params, input_ids, attention_masks):
with hijack(enabled=True):
logits = model(input_ids, attention_masks, params=params)[0]
return logits
spu_input_ids = ppd.device("P1")(lambda x: x)(input_ids)
spu_attention_masks = ppd.device("P1")(lambda x: x)(attention_masks)
spu_params = ppd.device("P2")(lambda x: x)(params)
start = time.time()
logits_spu = ppd.device("SPU")(eval, copts=copts)(
spu_params, spu_input_ids, spu_attention_masks
)
end = time.time()
logits_spu = ppd.get(logits_spu)
print(f"SPU runtime: {(end - start)}s\noutput logits: {logits_spu}")
def main(tokenizer_func, model_func, checkpoint):
dataset = load_dataset(
"parquet",
data_files={"test": "/home/xiaoqi/pycharm/spu-nimbus/datasets/cola/test-00000-of-00001.parquet"},
split="test",
)
model = model_func.from_pretrained(checkpoint)
tokenizer = tokenizer_func.from_pretrained(checkpoint)
for dummy_input in dataset:
features, labels = dummy_input["sentence"], dummy_input["label"]
input_ids, attention_masks = (
tokenizer(
features,
return_tensors="jax",
)["input_ids"],
tokenizer(
features,
return_tensors="jax",
)["attention_mask"],
)
run_on_cpu(model, input_ids, attention_masks, labels)
run_on_spu(model, input_ids, attention_masks, labels)
break # just test one sentense
if name == "main":
checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = FlaxBertForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
#tokenizer = BertTokenizerFast
#model = FlaxBertForSequenceClassification
#checkpoint = "/home/xiaoqi/pycharm/spu-nimbus/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e"
main(tokenizer, model, checkpoint)这是的测试代码,它会产生[2025-10-13 09:03:16.056] [info] [cheetah_dot.cc:394] empty@1x28996x768 => 1x1x1 Compute 268540.992 ms
[2025-10-13 09:03:18.425] [warning] [matmat_prot.cc:136] skiped 0, total 22268928, ratio 0
[2025-10-13 09:03:18.493] [info] [cheetah_dot.cc:481] 1@1x28996x768 => 1x1x1 Recv 5114.813 MiB, Response 0.243 MiB Pack 64.315 ms
[2025-10-13 09:03:18,497] [Process-2] Traceback (most recent call last):
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py", line 324, in Run
ret_objs = fn(self, *args, **kwargs)
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py", line 581, in builtin_spu_run
rt.run(spu_exec)
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 45, in run
return self._vm.Run(executable.SerializeToString())
RuntimeError: what:
[external/yacl/yacl/link/transport/channel.cc:351] Get data timeout, key=root:P2P-4:0->1
Stacktrace:
#0 yacl::link::Context::RecvInternal()+0x7ff740b0e2fb
#1 yacl::link::Context::Recv()+0x7ff740b108c6
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x7ff73ff21c1e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x7ff73ff28253
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x7ff73ff28b21
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x7ff73ff28c7a
#6 spu::mpc::cheetah::MatMulAA::proc()+0x7ff73ff01481
#7 spu::mpc::MatmulKernel::evaluate()+0x7ff740713040
#8 spu::dynDispatch<>()+0x7ff740762408
#9 spu::mpc::mmul_aa()+0x7ff740779e47
#10 spu::mpc::mmul_ss()+0x7ff74076b519
#11 spu::kernel::hal::_mmul_ss()+0x7ff7407549bd
#12 spu::kernel::hal::_mmul_impl()+0x7ff74073f0c2
#13 spu::kernel::hal::_mmul()+0x7ff740746537
#14 spu::kernel::hal::mixed_mmul()+0x7ff7406bb35e
#15 spu::kernel::hal::matmul()+0x7ff7406bf14d
stacktrace:
#0 yacl::link::Context::RecvInternal()+0x7ff740b0e2fb
#1 yacl::link::Context::Recv()+0x7ff740b108c6
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x7ff73ff21c1e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x7ff73ff28253
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x7ff73ff28b21
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x7ff73ff28c7a
#6 spu::mpc::cheetah::MatMulAA::proc()+0x7ff73ff01481
#7 spu::mpc::MatmulKernel::evaluate()+0x7ff740713040
#8 spu::dynDispatch<>()+0x7ff740762408
#9 spu::mpc::mmul_aa()+0x7ff740779e47
#10 spu::mpc::mmul_ss()+0x7ff74076b519
#11 spu::kernel::hal::_mmul_ss()+0x7ff7407549bd
#12 spu::kernel::hal::_mmul_impl()+0x7ff74073f0c2
#13 spu::kernel::hal::_mmul()+0x7ff740746537
#14 spu::kernel::hal::mixed_mmul()+0x7ff7406bb35e
#15 spu::kernel::hal::matmul()+0x7ff7406bf14d
Describe features you want to add to SPU
How to avoid connection timeouts on the SPU
Describe features you want to add to SPU
How to avoid connection timeouts on the SPU