Skip to content

[Feature]: When testing the Bert model on the Nimbus framework, running it on the SPU may result in communication timeouts. #1267

@qqqqzt

Description

@qqqqzt

Feature Request Type

Usability

Have you searched existing issues?

Yes

Is your feature request related to a problem?

Copyright 2023 Ant Group Co., Ltd.

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

Start nodes.

> bazel run -c opt //examples/python/utils:nodectl -- --config pwd/examples/python/conf/2pc.json up

Run this example script.

> bazel run -c opt //examples/python/ml/flax_bert -- --config pwd/examples/python/conf/2pc.json

import argparse
import json
import os
import time
from contextlib import contextmanager

import flax.linen as fnn
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax import Array
from typing import Optional, Tuple, Union
from datasets import load_dataset
from transformers import AutoTokenizer, FlaxBertForSequenceClassification

import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
import spu.utils.distributed as ppd

copts = spu_pb2.CompilerOptions()

enable x / broadcast(y) -> x * broadcast(1/y) which accelerate the softmax

copts.enable_optimize_denominator_with_broadcast = True

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="examples/python/ml/flax_bert/2pc.json")
args = parser.parse_args()

with open(args.config, 'r') as file:
conf = json.load(file)

ppd.init(conf["nodes"], conf["devices"])

def _gelu(x: Array):
order = 2
left, right = -1.8, 0.5

zero_mask = x < left
right_mask = x > right
# poly_mask = jnp.logical_not(zero_mask) & jnp.logical_not(right_mask)
poly_mask = zero_mask ^ right_mask ^ True

if order == 2:
    coeff = jnp.array([0.17025471961795552, 0.32421075145645123, -0.006946478426716862])
    
    # tmp1 = x * coeff[0] + coeff[1]
    # computed_v = tmp1 * x + coeff[2]
    x2 = jnp.square(x)
    computed_v = coeff[0] * x2 + coeff[1] * x + coeff[2]
else:
    raise ValueError("error")
x = right_mask * x + poly_mask * computed_v  # + zero_mask * 0

return x

def ours_fake_exp(x: Array) -> Array:
coeff = jnp.array([0.0291786417536755, 0.26332572942924737, 0.8318065569430768, 0.9708981428645548])

tmp1 = x * coeff[0] + coeff[1]
tmp2 = x * tmp1 + coeff[2]
computed_v = x * tmp2 + coeff[3]

# x2 = jnp.square(x)
# x3 = x2 * x
# computed_v = coeff[0] * x3 + coeff[1] * x2 + coeff[2] * x + coeff[3]
return computed_v

def _softmax(
x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
) -> Array:
threshold = -3.9

x = x - jnp.max(x, axis=-1, keepdims=True)
# exp on large negative is clipped to zero
compute_mask = x > threshold

computed_v = ours_fake_exp(x)

x = computed_v * compute_mask

divisor = jnp.sum(x, axis, keepdims=True)
return x / divisor

@contextmanager
def hijack(enabled=True):
if not enabled:
yield
return
# hijack some target functions
jnn_gelu = jnn.gelu
fnn_gelu = fnn.gelu
jnn_sm = jnn.softmax
fnn_sm = fnn.softmax

jnn.gelu = _gelu
fnn.gelu = _gelu
jnn.softmax = _softmax
fnn.softmax = _softmax

yield
# recover back
jnn.gelu = jnn_gelu
fnn.gelu = fnn_gelu
jnn.softmax = jnn_sm
fnn.softmax = fnn_sm

def run_on_cpu(model, input_ids, attention_masks, labels):
print(f"Running on CPU ...")
params = model.params

def eval(params, input_ids, attention_masks):
    logits = model(input_ids, attention_masks, params=params)[0]
    return logits

start = time.time()
logits = eval(params, input_ids, attention_masks)
end = time.time()
print(f"CPU runtime: {(end - start)}s\noutput logits: {logits}")

def run_on_spu(model, input_ids, attention_masks, labels):
print(f"Running on SPU ...")
params = model.params

def eval(params, input_ids, attention_masks):
    with hijack(enabled=True):
        logits = model(input_ids, attention_masks, params=params)[0]
    return logits

spu_input_ids = ppd.device("P1")(lambda x: x)(input_ids)
spu_attention_masks = ppd.device("P1")(lambda x: x)(attention_masks)
spu_params = ppd.device("P2")(lambda x: x)(params)
start = time.time()
logits_spu = ppd.device("SPU")(eval, copts=copts)(
    spu_params, spu_input_ids, spu_attention_masks
)
end = time.time()
logits_spu = ppd.get(logits_spu)
print(f"SPU runtime: {(end - start)}s\noutput logits: {logits_spu}")

def main(tokenizer_func, model_func, checkpoint):
dataset = load_dataset(
"parquet",
data_files={"test": "/home/xiaoqi/pycharm/spu-nimbus/datasets/cola/test-00000-of-00001.parquet"},
split="test",
)
model = model_func.from_pretrained(checkpoint)
tokenizer = tokenizer_func.from_pretrained(checkpoint)

for dummy_input in dataset:
    features, labels = dummy_input["sentence"], dummy_input["label"]

    input_ids, attention_masks = (
        tokenizer(
            features,
            return_tensors="jax",
        )["input_ids"],
        tokenizer(
            features,
            return_tensors="jax",
        )["attention_mask"],
    )

    run_on_cpu(model, input_ids, attention_masks, labels)
    run_on_spu(model, input_ids, attention_masks, labels)
    break  # just test one sentense

if name == "main":
checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = FlaxBertForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

#tokenizer = BertTokenizerFast
#model = FlaxBertForSequenceClassification
#checkpoint = "/home/xiaoqi/pycharm/spu-nimbus/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e"
main(tokenizer, model, checkpoint)这是的测试代码,它会产生[2025-10-13 09:03:16.056] [info] [cheetah_dot.cc:394] empty@1x28996x768 => 1x1x1 Compute 268540.992 ms

[2025-10-13 09:03:18.425] [warning] [matmat_prot.cc:136] skiped 0, total 22268928, ratio 0
[2025-10-13 09:03:18.493] [info] [cheetah_dot.cc:481] 1@1x28996x768 => 1x1x1 Recv 5114.813 MiB, Response 0.243 MiB Pack 64.315 ms
[2025-10-13 09:03:18,497] [Process-2] Traceback (most recent call last):
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py", line 324, in Run
ret_objs = fn(self, *args, **kwargs)
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py", line 581, in builtin_spu_run
rt.run(spu_exec)
File "/data1/bazel_cache/bazel/_bazel_xiaoqi/test0/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 45, in run
return self._vm.Run(executable.SerializeToString())
RuntimeError: what:
[external/yacl/yacl/link/transport/channel.cc:351] Get data timeout, key=root:P2P-4:0->1
Stacktrace:
#0 yacl::link::Context::RecvInternal()+0x7ff740b0e2fb
#1 yacl::link::Context::Recv()+0x7ff740b108c6
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x7ff73ff21c1e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x7ff73ff28253
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x7ff73ff28b21
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x7ff73ff28c7a
#6 spu::mpc::cheetah::MatMulAA::proc()+0x7ff73ff01481
#7 spu::mpc::MatmulKernel::evaluate()+0x7ff740713040
#8 spu::dynDispatch<>()+0x7ff740762408
#9 spu::mpc::mmul_aa()+0x7ff740779e47
#10 spu::mpc::mmul_ss()+0x7ff74076b519
#11 spu::kernel::hal::_mmul_ss()+0x7ff7407549bd
#12 spu::kernel::hal::_mmul_impl()+0x7ff74073f0c2
#13 spu::kernel::hal::_mmul()+0x7ff740746537
#14 spu::kernel::hal::mixed_mmul()+0x7ff7406bb35e
#15 spu::kernel::hal::matmul()+0x7ff7406bf14d

stacktrace:
#0 yacl::link::Context::RecvInternal()+0x7ff740b0e2fb
#1 yacl::link::Context::Recv()+0x7ff740b108c6
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x7ff73ff21c1e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x7ff73ff28253
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x7ff73ff28b21
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x7ff73ff28c7a
#6 spu::mpc::cheetah::MatMulAA::proc()+0x7ff73ff01481
#7 spu::mpc::MatmulKernel::evaluate()+0x7ff740713040
#8 spu::dynDispatch<>()+0x7ff740762408
#9 spu::mpc::mmul_aa()+0x7ff740779e47
#10 spu::mpc::mmul_ss()+0x7ff74076b519
#11 spu::kernel::hal::_mmul_ss()+0x7ff7407549bd
#12 spu::kernel::hal::_mmul_impl()+0x7ff74073f0c2
#13 spu::kernel::hal::_mmul()+0x7ff740746537
#14 spu::kernel::hal::mixed_mmul()+0x7ff7406bb35e
#15 spu::kernel::hal::matmul()+0x7ff7406bf14d

Describe features you want to add to SPU

How to avoid connection timeouts on the SPU

Describe features you want to add to SPU

How to avoid connection timeouts on the SPU

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions