Skip to content

WATonomous/troubleshoot-heterogenous-distributed-operations

 
 

Repository files navigation

Troubleshooting Collective Operations on a Heterogeneous Cluster with UCC and UCX

This repo is a proof-of-concept setup to run workloads across ROCm and CUDA. This is a fork of the original (excellent) work by Rafał Siwek (https://github.com/RafalSiwek/troubleshoot-heterogenous-distributed-operations) with some bug fixes and additional experiments.

Getting started

Start the containers:

docker compose up --build

Extract the SSH keys from the containers:

(docker compose exec rocm1 cat /root/.ssh/id_ed25519.pub && docker compose exec cuda1 cat /root/.ssh/id_ed25519.pub) > tmp/ssh-pub-keys.txt

Copy the public keys to each worker container by pasting it into the /root/.ssh/authorized_keys file:

cat tmp/ssh-pub-keys.txt | docker compose exec --no-TTY rocm1 tee /root/.ssh/authorized_keys
cat tmp/ssh-pub-keys.txt | docker compose exec --no-TTY cuda1 tee /root/.ssh/authorized_keys

Check that each machine can ssh into the other:

# should print "cuda1"
docker compose exec rocm1 ssh root@cuda1 hostname

# should print "rocm1"
docker compose exec cuda1 ssh root@rocm1 hostname

Run tests

send_recv

Compile the test:

docker compose exec rocm1 bash

# in the container
cp -r tests/send_recv /tmp/
pushd /tmp/send_recv
hipcc test_send_recv_rocm.cpp -lmpi
echo $?
popd
docker compose exec cuda1 bash

# in the container
cp -r tests/send_recv /tmp/
pushd /tmp/send_recv
nvcc test_send_recv_cuda.cpp -lmpi
echo $?
popd

Run the test (in one of the containers):

docker compose exec rocm1 bash

mpirun --allow-run-as-root -np 2 -H rocm1,cuda1 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/send_recv/a.out

# make sure the command exited successfully. Should print "0"
echo $?

bidirectional_send_recv

Compile the test:

docker compose exec rocm1 bash -c "rm -rf /tmp/bidirectional_send_recv && cp -r tests/bidirectional_send_recv /tmp/ && cd /tmp/bidirectional_send_recv && hipcc test_bidirectional_send_recv_rocm.cpp -lmpi && echo $?"
docker compose exec cuda1 bash -c "rm -rf /tmp/bidirectional_send_recv && cp -r tests/bidirectional_send_recv /tmp/ && cd /tmp/bidirectional_send_recv && nvcc test_bidirectional_send_recv_cuda.cpp -lmpi && echo $?"

Run the test (cuda rank 0, rocm rank 1):

docker compose exec cuda1 bash

mpirun --allow-run-as-root -np 2 -H cuda1,rocm1 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/bidirectional_send_recv/a.out

# make sure the command exited successfully. Should print "0"
echo $?

Run the test (rocm rank 0, cuda rank 1 ):

docker compose exec rocm1 bash

mpirun --allow-run-as-root -np 2 -H rocm1,cuda1 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/bidirectional_send_recv/a.out

# make sure the command exited successfully. Should print "0"
echo $?

allreduce

Compile the test:

docker compose exec rocm1 bash -c "rm -rf /tmp/allreduce && cp -r tests/allreduce /tmp/ && cd /tmp/allreduce && hipcc test_allreduce_rocm.cpp -lmpi && echo $?"
docker compose exec cuda1 bash -c "rm -rf /tmp/allreduce && cp -r tests/allreduce /tmp/ && cd /tmp/allreduce && nvcc test_allreduce_cuda.cpp -lmpi && echo $?"

Run the test (in one of the containers):

docker compose exec rocm1 bash

UCX_TLS=tcp UCX_NET_DEVICES=eth0 UCX_WARN_UNUSED_ENV_VARS=n \
UCX_ROCM_COPY_D2H_THRESH=0 UCX_ROCM_COPY_H2D_THRESH=0 UCC_EC_ROCM_REDUCE_HOST_LIMIT=0 UCC_EC_ROCM_COPY_HOST_LIMIT=0 OMPI_MCA_mpi_accelerator_rocm_memcpyD2H_limit=0 OMPI_MCA_mpi_accelerator_rocm_memcpyH2D_limit=0 \
mpirun --allow-run-as-root -np 2 -H rocm1,cuda1 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
-mca pml_ucx_tls tcp -mca pml_ucx_devices eth0 \
-x UCC_TL_UCP_TUNE=inf -x UCX_LOG_LEVEL=DEBUG \
/tmp/allreduce/a.out

# make sure the command exited successfully. Should print "0"
echo $?

CUDA-only setup

To test communication between CUDA nodes. Use the docker-compose.cuda-only.yml:

docker compose -f docker-compose.cuda-only.yml up -d --build

(docker compose exec cuda1 cat /root/.ssh/id_ed25519.pub && docker compose exec cuda2 cat /root/.ssh/id_ed25519.pub) > tmp/ssh-pub-keys.txt

cat tmp/ssh-pub-keys.txt | docker compose exec --no-TTY cuda1 tee /root/.ssh/authorized_keys
cat tmp/ssh-pub-keys.txt | docker compose exec --no-TTY cuda2 tee /root/.ssh/authorized_keys

docker compose exec cuda1 ssh root@cuda2 hostname
docker compose exec cuda2 ssh root@cuda1 hostname

Run tests

send_recv

Compile the test:

docker compose exec cuda1 bash -c "rm -rf /tmp/send_recv && cp -r tests/send_recv /tmp/ && cd /tmp/send_recv && nvcc test_send_recv_cuda.cpp -lmpi && echo $?"
docker compose exec cuda2 bash -c "rm -rf /tmp/send_recv && cp -r tests/send_recv /tmp/ && cd /tmp/send_recv && nvcc test_send_recv_cuda.cpp -lmpi && echo $?"

Run the test (in one of the containers):

docker compose exec cuda1 bash

mpirun --allow-run-as-root -np 2 -H cuda1,cuda2 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/send_recv/a.out

# make sure the command exited successfully. Should print "0"
echo $?

bidirectional_send_recv

Compile the test:

docker compose exec cuda1 bash -c "rm -rf /tmp/bidirectional_send_recv && cp -r tests/bidirectional_send_recv /tmp/ && cd /tmp/bidirectional_send_recv && nvcc test_bidirectional_send_recv_cuda.cpp -lmpi && echo $?"
docker compose exec cuda2 bash -c "rm -rf /tmp/bidirectional_send_recv && cp -r tests/bidirectional_send_recv /tmp/ && cd /tmp/bidirectional_send_recv && nvcc test_bidirectional_send_recv_cuda.cpp -lmpi && echo $?"

Run the test (in one of the containers):

docker compose exec cuda1 bash

mpirun --allow-run-as-root -np 2 -H cuda1,cuda2 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/bidirectional_send_recv/a.out

# make sure the command exited successfully. Should print "0"
echo $?

allreduce

Compile the test:

docker compose exec cuda1 bash -c "rm -rf /tmp/allreduce && cp -r tests/allreduce /tmp/ && cd /tmp/allreduce && nvcc test_allreduce_cuda.cpp -lmpi && echo $?"
docker compose exec cuda2 bash -c "rm -rf /tmp/allreduce && cp -r tests/allreduce /tmp/ && cd /tmp/allreduce && nvcc test_allreduce_cuda.cpp -lmpi && echo $?"

Run the test (in one of the containers):

docker compose exec cuda1 bash

mpirun --allow-run-as-root -np 2 -H cuda1,cuda2 \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/tmp/allreduce/a.out

# make sure the command exited successfully. Should print "0"
echo $?

Custom UCC

CUDA

# Derived from install_ucc.sh

# Initial build
docker compose exec cuda1 bash -c "mkdir -p /tmp/ucc-overlay /tmp/ucc-overlay-workdir /tmp/ucc-overlay-upper && mount -t overlay overlay -o lowerdir=/opt/ucc,upperdir=/tmp/ucc-overlay-upper,workdir=/tmp/ucc-overlay-workdir /tmp/ucc-overlay"
docker compose exec cuda1 bash -c "cd /tmp/ucc-overlay && ./autogen.sh && ./configure --prefix=/usr --with-ucx=/usr --with-cuda=/usr/local/cuda --with-nvcc-gencode="-gencode=arch=compute_75,code=sm_75" --with-tls=ucp --enable-debug && time make -j && make install"

docker compose exec cuda2 bash -c "mkdir -p /tmp/ucc-overlay /tmp/ucc-overlay-workdir /tmp/ucc-overlay-upper && mount -t overlay overlay -o lowerdir=/opt/ucc,upperdir=/tmp/ucc-overlay-upper,workdir=/tmp/ucc-overlay-workdir /tmp/ucc-overlay"
docker compose exec cuda2 bash -c "cd /tmp/ucc-overlay && ./autogen.sh && ./configure --prefix=/usr --with-ucx=/usr --with-cuda=/usr/local/cuda --with-nvcc-gencode="-gencode=arch=compute_75,code=sm_75" --with-tls=ucp --enable-debug && time make -j && make install"

# Subsequent builds
docker compose exec cuda1 bash -c "cd /tmp/ucc-overlay && time make -j && make install"
docker compose exec cuda2 bash -c "cd /tmp/ucc-overlay && time make -j && make install"

# Verify with:
docker compose exec cuda1 bash -c "ucc_info -v"
docker compose exec cuda2 bash -c "ucc_info -v"

PyTorch

Install PyTorch:

docker compose exec rocm1 bash -c "python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4/"
docker compose exec cuda1 bash -c "python -m pip install torch torchvision torchaudio"

Run PyTorch tests:

docker compose exec rocm1 bash

mpirun --allow-run-as-root -np 2 -H cuda1,rocm1 \
-x MASTER_ADDR=rocm1 -x MASTER_PORT=1234 \
-mca pml ucx \
-mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
tests/pytorch/test_bidirectional_send_recv.py

# mpirun --allow-run-as-root -np 2 -H <cuda_ip>,<rocm_ip> -x MASTER_ADDR=<rocm_ip> -x MASTER_PORT=1234 -mca pml ucx -x UCX_ROCM_COPY_D2H_THRESH=0 -x UCX_ROCM_COPY_H2D_THRESH=0 -x OMPI_MCA_mpi_accelerator_rocm_memcpyD2H_limit=0 -x OMPI_MCA_mpi_accelerator_rocm_memcpyH2D_limit=0 /opt/conda/envs/py_3.12/bin/python /test_allreduce.py

MASTER_ADDR=rocm1 MASTER_PORT=1234 \
torchrun \
  --nproc_per_node=1 \
  --nnodes=2 \
  --node_rank=0 \
  --rdzv_id=123 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=rocm1:1234 \
  tests/pytorch/test_bidirectional_send_recv_torchrun.py

MASTER_ADDR=rocm1 MASTER_PORT=1234 \
torchrun \
  --nproc_per_node=1 \
  --nnodes=2 \
  --node_rank=1 \
  --rdzv_id=123 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=rocm1:1234 \
  tests/pytorch/test_bidirectional_send_recv_torchrun.py


MASTER_ADDR=rocm1 MASTER_PORT=12345 LOCAL_RANK=0 WORLD_SIZE=2 RANK=0 python tests/pytorch/test_bidirectional_send_recv_torchrun.py
MASTER_ADDR=rocm1 MASTER_PORT=12345 LOCAL_RANK=0 WORLD_SIZE=2 RANK=1 python tests/pytorch/test_bidirectional_send_recv_torchrun.py

CUDA-only:

MASTER_ADDR=cuda1 MASTER_PORT=12345 LOCAL_RANK=0 WORLD_SIZE=2 RANK=0 python tests/pytorch/test_bidirectional_send_recv_torchrun.py
MASTER_ADDR=cuda1 MASTER_PORT=12345 LOCAL_RANK=0 WORLD_SIZE=2 RANK=1 python tests/pytorch/test_bidirectional_send_recv_torchrun.py

MASTER_ADDR=cuda1 MASTER_PORT=1234 \
torchrun \
  --nproc_per_node=1 \
  --nnodes=2 \
  --node_rank=0 \
  --rdzv_id=123 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=cuda1:1234 \
  tests/pytorch/test_bidirectional_send_recv_torchrun.py

MASTER_ADDR=cuda1 MASTER_PORT=1234 \
torchrun \
  --nproc_per_node=1 \
  --nnodes=2 \
  --node_rank=1 \
  --rdzv_id=123 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=cuda1:1234 \
  tests/pytorch/test_bidirectional_send_recv_torchrun.py

Building PyTorch

Building PyTorch from source is needed to enable MPI support. Instructions are available here.

git clone https://github.com/pytorch/pytorch --branch v2.7.0 --depth 1 --recursive
cd pytorch
conda install -y cmake ninja
pip install -r requirements.txt

ROCm:

python tools/amd_build/build_amd.py

All:

export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}"
python setup.py develop

Old README

This repository provides a proof-of-concept setup to run HPC workloads on AWS public cloud instances with different GPU accelerators. It also documents the issues encountered during deployment.

Infrastructure Setup

This workload is configured for two types of g4 instances, each with distinct GPU types:

  • g4ad.xlarge: One AMD Radeon Pro V520 GPU (gfx1011) using ROCm 6.2.2.
  • g4dn.xlarge: One NVIDIA T4 GPU using CUDA 12.4.

These instances are running Ubuntu 22.04 (ubuntu-eks/k8s_1.30/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-* AMI images). Since g4ad instances don’t support EFA, networking was configured for standard performance with bandwidth up to 10 Gbps and security groups allowing traffic on all ports.

Collective Operations Environment

Collective operations run as distributed PyTorch jobs in OCI containers.

Separate Docker images were built for each instance type and published to Docker Hub:

Additional tests were conducted using PyTorch distributed backends for the torchrun setup, but with no successful outcomes:

Following guidance from the UCX team (response here), we switched to using MPI with UCC and UCX.

Running Collective Operations

To test collective operations, start three containers using the Docker images with host network configurations (pre-installed PyTorch images are available):

  • g4ad:

    docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
    --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host \
    --shm-size 8G --user root --rm --network host \
    rafalsiwek/g4ad_distributed_ml:1.0_pytorch bash
  • g4dn:

    docker run --gpus all -it --rm --user root --cap-add=SYS_PTRACE \
    --security-opt seccomp=unconfined --ipc=host --shm-size 8G \
    --network host --group-add video \
    rafalsiwek/g4dn_distributed_ml:1.0_pytorch bash
  • t3

    docker run -it --rm --network host rafalsiwek/opmpi_ucx_simple:latest

These MPI worker containers are set up for passwordless SSH. Follow these steps:

  1. In the "master container," generate an SSH key:

    ssh-keygen -t rsa
  2. Copy the public key to each "worker container" by pasting it into the ~/.ssh/authorized_keys file.

  3. Update the SSH daemon (sshd) port in each worker container to a port not used by the host:

    vi /etc/ssh/sshd_config
  4. Change the SSH port in the "master container":

    vi /etc/ssh/ssh_config
  5. Start the SSH server in each worker container:

    /usr/sbin/sshd -D

Tests Run

The tests directory contains scripts and log outputs for the tests conducted.

This test sends and receives a GPU memory buffer between nodes. The CUDA node generated a buffer with cudaMemcpy, while the ROCm node used hipMemcpy.

After compiling with nvcc on the CUDA node and hipcc on the ROCm node, the MPI job was triggered with:

mpirun --allow-run-as-root -np 2 -H <rocm_ip>,<cuda_ip> \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/test_send_recv

This job completed successfully (log output). The same job was also run with -x UCC_LOG_LEVEL=DEBUG (log output).

This test runs a bidirectional simple send and receive operation where Rank 0 (CUDA) sends and recvs data from Rank1 (ROCM)

The job was triggered wuth:

mpirun --allow-run-as-root -np 1 -host <cuda_ip> \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-x UCC_LOG_LEVEL=DEBUG -x UCC_COLL_TRACE=DEBUG \
/test_bidirectional_send_recv : \
-np 1 -host <rocm_ip> \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-x UCC_LOG_LEVEL=DEBUG -x UCC_COLL_TRACE=DEBUG \
/test_bidirectional_send_recv

An error occurred on the ROCm node/rank with the ucp_mem_type_unpack method (log output with backtrace).

This test involves running an allreduce operation in a heterogeneous ring. As in the send_recv test, the CUDA node generated a buffer with cudaMemcpy, while the ROCm node used hipMemcpy.

The job was run with:

mpirun --allow-run-as-root -np 2 -H <rocm_ip>,<cuda_ip> \
-mca pml ucx -mca coll_ucc_enable 1 -mca coll_ucc_priority 100 \
-mca coll_ucc_verbose 3 -mca pml_ucx_verbose 3 \
/test_allreduce

An error occurred on the ROCm node/rank with the ucp_mem_type_unpack method (log output with backtrace).

Suspecting UCX was using different TLs, the job was re-run with the -mca pml_ucx_tls=tcp option to force TCP and -mca pml_ucx_devices ens4 to specify the ens4 network device. However, the failure persisted (log output).

Following recommendations in UCC Issue #1039, the job was run with -x UCC_TL_UCP_TUNE=inf to adjust the UCP transport layer for UCC. Although the ROCm node/rank failed, the CUDA node/rank completed the allreduce job (log output).

Still suspecting integration issues between UCC and UCX, debug logging for UCX was enabled with -x UCX_LOG_LEVEL=DEBUG (log output). The send_recv example was also re-run with UCX debug logs (log output), though the logs did not provide clear conclusions for me.

Additional tests were run for the allreduce collective on a homogeneous setup with CUDA-only and ROCm-only environments. The CUDA-only ring tests were successful (UCC debug and trace logs here). However, the ROCm-only ring encountered errors with the uct_rocm_copy_ep_put_short function in the ucx/libuct_rocm library, which is consistent with the errors seen in previous tests.

To gather more details, I ran additional jobs to capture UCX debug logs (UCX logs with -x UCX_LOG_LEVEL=DEBUG here and logs with -mca pml_ucx_verbose 100 here).

To verify that the issue is specific to ROCm-only ring communication, I also tested simple send_recv operations. These showed the same error (logs with -x UCX_LOG_LEVEL=DEBUG here and logs with -mca pml_ucx_verbose 100 here). In contrast, the send_recv test on the CUDA-only ring completed without issues.

As a result the root of the uct_rocm_copy_ep_put_short issue with ROCm ranks was related to the fact that AWS EC2 g4ad machines do not supprt (Large BAR setting)[https://github.com/openucx/ucx/wiki/Build-and-run-ROCM-UCX-OpenMPI#sanity-check-for-large-bar-setting] and can be circumvent with the following variables (for more context look here):

UCX_ROCM_COPY_D2H_THRESH=0
UCX_ROCM_COPY_H2D_THRESH=0
UCC_EC_ROCM_REDUCE_HOST_LIMIT=0
UCC_EC_ROCM_COPY_HOST_LIMIT=0
OMPI_MCA_mpi_accelerator_rocm_memcpyD2H_limit=0
OMPI_MCA_mpi_accelerator_rocm_memcpyH2D_limit=0

To test ML workflows, I built PyTorch (v2.5.1) from source with MPI support following this guide. Experiments were run using mpirun with UCX PML and UCC collective configurations.

Running a bi-directional send/recv test was successful, confirming basic communication across GPUs (logs available here).

However, testing collective operations with UCC, such as the allreduce test, and also with the official PyTorch distributed examples, led to a failure on the ucp_tag_send_nbx operation for both ranks (logs and backtrace available here).

Running the allreduce test with UCC collectives disabled for MPI yielded partial success, where the operation completed successfully on the CUDA rank but failed on the ROCm rank (logs and backtrace here).

The appeared to be related to the UCX configuration as the worker fails at ucs_async_check_owner_thread(&(ep->worker)->async) assertion. UCX at configuration time had enabled multi-threading support with --enable-mt flag, which this setup might not fully support. After rebuilding UCX with multi-threading disabled, everything seems to be working correctly.

To verify the UCX portion of the software stack, OSU Micro-Benchmarks 7.4 were conducted. Separate builds of the benchmark were created for each type of rank to match the specific GPU hardware: one with CUDA for NVIDIA and one with ROCm for AMD. The benchmarks were run using the GPU device, and here are the test results:

  • P2P Bi-Directional Bandwidth Benchmark (results and logs here) - This test completed successfully for all data types available, validating both the benchmark and the setup. The results reflect the bandwidth limitations of the setup.

  • Collective AllGather Latency Benchmark (results and logs here) - The first collective test ran successfully across all available data types, showing expected operation latencies without any errors.

  • Collective AllReduce Latency Benchmark - This is the primary and most extensive collective test. Running the benchmark with validation disabled was successful (results and logs here). However, enabling validation led to failures across all tested data types, with logs showing significant discrepancies between expected and actual values (results and logs here). It is worth noting that, since version 7.2, the OSU benchmark suite updated its validation methods for reduction tests (see changelog). When re-running the benchmark with version 7.2, validation passed without issues (results and logs here).

After these tests, it appears that this heterogeneous setup can handle collective communication with some inherent bandwidth and latency limitations. This raises the question of why collective operations in PyTorch fail on the ucp_tag_send_nbx operation, while basic bi-directional send_recv operations complete without issue. Could this be due to the way PyTorch implements collective operations? It seems possible, even though PyTorch uses similar functions to those in my allreduce test.

Solution

This excercise allowed to implement two main solutions to the problem:

  • To enable collective operations with a machine not supporting Large BAR, in this case the ROCm Rank node with AMD ROCm, following variables have to be passed with mpirun:
    UCX_ROCM_COPY_D2H_THRESH=0
    UCX_ROCM_COPY_H2D_THRESH=0
    UCC_EC_ROCM_REDUCE_HOST_LIMIT=0
    UCC_EC_ROCM_COPY_HOST_LIMIT=0
    OMPI_MCA_mpi_accelerator_rocm_memcpyD2H_limit=0
    OMPI_MCA_mpi_accelerator_rocm_memcpyH2D_limit=0
  • To avoid failing PyTorch jobs due to possible lack of UCX multi-thread support, UCX has to be rebuilt without the --enable-mt flag.

Thanks for help to:

About

Testing distributed training on heterogeneous (CUDA+ROCm) clusters

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages

  • C++ 52.4%
  • Shell 33.1%
  • Dockerfile 9.8%
  • Python 4.3%
  • C 0.4%