Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
c00a57c
implement attn softcap in vanilla attention
dvruette Jul 31, 2025
db3a35e
add attention logits to attention output
dvruette Jul 31, 2025
40f9ec3
change check_interval type to int or float
dvruette Jul 31, 2025
f2cd8e5
[fix] make sure sharding gets applied to every array in batch
dvruette Aug 2, 2025
f81e3e3
prettier memory summary printing
dvruette Aug 2, 2025
0ec9171
improve memory monitoring display
dvruette Aug 2, 2025
42c6764
graceful handling of failing to save state
dvruette Aug 4, 2025
a261472
Merge branch 'erfanzar:main' into main
dvruette Aug 5, 2025
f8bf6e4
keep track of dataloading time
dvruette Aug 6, 2025
c86da8e
implement metrics aggregation
dvruette Aug 12, 2025
3116af9
make wandb tags configurable
dvruette Aug 12, 2025
38d8bf7
fix metrics aggregation
dvruette Aug 12, 2025
678a6fb
annotate profiler steps
dvruette Aug 19, 2025
b8b5424
fix training step annotation
dvruette Aug 19, 2025
8efe3e4
Merge branch 'main' of github.com:erfanzar/EasyDeL
dvruette Aug 19, 2025
6546900
fix vanilla attention
dvruette Aug 19, 2025
17acf24
add profiling based on env variable
dvruette Aug 21, 2025
408e565
fix tpu_setup script for v6e
dvruette Aug 21, 2025
b4b1fbd
Merge branch 'erfanzar:main' into main
dvruette Aug 21, 2025
bb13fa0
revert to_state function
dvruette Aug 21, 2025
b3663af
try fix for to_state function
dvruette Aug 21, 2025
ecf61be
debugging to_state method
dvruette Aug 21, 2025
7655aa9
debugging to_state method
dvruette Aug 21, 2025
a8dc4e1
fix to_state method
dvruette Aug 21, 2025
390f020
add logs for starting/stopping profiler
dvruette Aug 21, 2025
69403ea
disable profiling options
dvruette Aug 21, 2025
ccb05fe
minor logging change
dvruette Aug 22, 2025
ce11645
refactor wandb init to allow for arbitrary kwargs
dvruette Aug 22, 2025
ddb9e7f
fix case where `entity` is provided as wandb kwarg
dvruette Aug 22, 2025
1f13fd3
fix bug when resuming wandb run
dvruette Aug 22, 2025
16b03fe
fix resuming training with step_start_point > 0
dvruette Aug 22, 2025
eea7602
fix typo on state replacement
dvruette Aug 22, 2025
68b54c4
don't save last checkpoint if we're still before `step_start_point`
dvruette Aug 22, 2025
2d92321
fix resuming from checkpoint
dvruette Aug 22, 2025
5e20ced
print traceback
dvruette Aug 22, 2025
c71f1e2
try fix for multi-host checkpointing
dvruette Aug 22, 2025
b690d9d
only save files on process 0
dvruette Aug 22, 2025
a9c52c0
don't write index file to /dev/null
dvruette Aug 22, 2025
61c1a85
don't write to /dev/null
dvruette Aug 22, 2025
ac7c8ef
only display progress bar on process 0
dvruette Aug 22, 2025
971f19b
fix typo in _save_sharded_to_local
dvruette Aug 23, 2025
b16cc32
fix shard gathering for saving checkpoints
dvruette Aug 23, 2025
af94431
fix loading optimizer state
dvruette Aug 23, 2025
9406945
update dockerfile entry to bash
dvruette Aug 26, 2025
90ab11f
update docker entrypoint
dvruette Aug 27, 2025
3146c0e
update entrypoint and specify cmd
dvruette Aug 27, 2025
8df364c
install rsync in docker image for ray
dvruette Aug 27, 2025
f305382
add docker images for ray
dvruette Aug 27, 2025
15247e9
use root to build image
dvruette Aug 27, 2025
a1a48b6
install deps into venv
dvruette Aug 27, 2025
78e0068
fix default python env
dvruette Aug 27, 2025
9bad3dc
extract building ray docker images into separate workflow
dvruette Aug 27, 2025
673629c
update name of workflow
dvruette Aug 27, 2025
55b3365
use job matrix for building
dvruette Aug 27, 2025
ea92e72
preinstall torch on gpu
dvruette Aug 27, 2025
451a9e9
preinstall torch everywhere & don't fail build on python version check
dvruette Aug 27, 2025
c3cfcc6
don't free disk space because it takes too long
dvruette Aug 27, 2025
e533c67
add gcloud cli to ray docker image
dvruette Aug 27, 2025
1eebb3c
fix gcloud installation
dvruette Aug 27, 2025
b9c2968
add `cryptography` python dependency
dvruette Aug 27, 2025
30ac0d0
add fixes and patches from marin
dvruette Aug 27, 2025
ebb9ce9
fix patch location
dvruette Aug 27, 2025
1e3285a
fix libtpu installation (hopefully)
dvruette Aug 27, 2025
a7ccc6e
Merge branch 'main' of github.com:erfanzar/EasyDeL
dvruette Sep 1, 2025
916582e
add step skipping for resuming from checkpoint
dvruette Sep 1, 2025
4edfcfd
install eformer from dvruette/eformer
dvruette Sep 2, 2025
f187f79
fix checkpointing
dvruette Sep 2, 2025
903b632
add ray docker image to build
dvruette Sep 2, 2025
fa10e6e
update ray image
dvruette Sep 3, 2025
3c067f4
update ray
dvruette Sep 3, 2025
8ef48ce
update lockfile
dvruette Sep 3, 2025
827a355
Merge branch 'main' of github.com:erfanzar/EasyDeL into merge-upstream
dvruette Sep 3, 2025
d194cd8
remove old fix that is no longer necessary
dvruette Sep 3, 2025
d740552
update lockfile
dvruette Sep 3, 2025
9ccf74d
add support for uploading profiler traces to gcs
dvruette Sep 6, 2025
c8ffb7a
add orbax saving
dvruette Sep 6, 2025
a0c04af
fix orbax checkpointing
dvruette Sep 6, 2025
b080e20
load from orbax checkpoint
dvruette Sep 6, 2025
acf93b2
checkpointing fixes
dvruette Sep 7, 2025
5708b7c
more checkpointing safety
dvruette Sep 7, 2025
2bb8f83
fix type hint
dvruette Sep 11, 2025
0904c88
various minor fixes and improvements to training loop/checkpointing
dvruette Sep 11, 2025
b441979
remove barrier after checkpointing
dvruette Sep 11, 2025
85472a5
update ray to latest version
dvruette Sep 11, 2025
73eebea
fix continuing from checkpoint, fix unnecessary comms
dvruette Sep 13, 2025
3840afc
minor changes
dvruette Oct 6, 2025
38b2380
fix easydel auto flags
dvruette Oct 6, 2025
69f5617
make training loop compatible with new data sampler
dvruette Oct 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions .github/workflows/docker-image_ray.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Docker Image CI/CD
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch:

jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
target:
- {file: Dockerfile.ray, tag: ray-cpu, tagsuf: "-ray-cpu", args: "HARDWARE_TYPE=cpu RAY_IMAGE=rayproject/ray:2.47.1-py311"}
- {file: Dockerfile.ray, tag: ray-gpu, tagsuf: "-ray-gpu", args: "HARDWARE_TYPE=gpu RAY_IMAGE=rayproject/ray:2.47.1-py311-gpu"}
- {file: Dockerfile.ray, tag: ray-tpu, tagsuf: "-ray-tpu", args: "HARDWARE_TYPE=tpu RAY_IMAGE=rayproject/ray:2.47.1-py311"}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- run: pip install toml
- name: Extract version
id: meta
run: |
VERSION=$(grep '^version = ' pyproject.toml | sed 's/version = "\(.*\)"/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT
- name: Log in to GHCR
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
file: ./${{ matrix.target.file }}
push: ${{ github.event_name != 'pull_request' }}
tags: |
ghcr.io/dvruette/easydel:${{ steps.meta.outputs.version }}${{ matrix.target.tagsuf }}
ghcr.io/dvruette/easydel:latest${{ matrix.target.tagsuf }}
ghcr.io/dvruette/easydel:${{ steps.meta.outputs.date }}${{ matrix.target.tagsuf }}
build-args: |
VERSION=${{ steps.meta.outputs.version }}
${{ matrix.target.args }}
cache-from: type=registry,ref=ghcr.io/dvruette/easydel:buildcache
cache-to: type=registry,ref=ghcr.io/dvruette/easydel:buildcache,mode=max
- name: Prune build cache
if: always()
run: |
docker buildx prune -af || true
docker system prune -af || true
43 changes: 17 additions & 26 deletions .github/workflows/docker-image_tpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,38 @@ on:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allow manual triggers
workflow_dispatch:

jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write

strategy:
fail-fast: false
matrix:
target:
- {file: Dockerfile, tag: cpu, tagsuf: "", args: "HARDWARE_TYPE=cpu"}
- {file: Dockerfile, tag: gpu, tagsuf: "-gpu", args: "HARDWARE_TYPE=gpu"}
- {file: Dockerfile, tag: tpu, tagsuf: "-tpu", args: "HARDWARE_TYPE=tpu"}
- {file: Dockerfile.ray, tag: ray-cpu, tagsuf: "-ray-cpu", args: "HARDWARE_TYPE=cpu RAY_IMAGE=rayproject/ray:2.36.0-py311"}
- {file: Dockerfile.ray, tag: ray-gpu, tagsuf: "-ray-gpu", args: "HARDWARE_TYPE=gpu RAY_IMAGE=rayproject/ray:2.36.0-py311-gpu"}
- {file: Dockerfile.ray, tag: ray-tpu, tagsuf: "-ray-tpu", args: "HARDWARE_TYPE=tpu RAY_IMAGE=rayproject/ray:2.36.0-py311"}
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install toml package
run: pip install toml

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Cache Docker layers
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-

- name: Extract version from pyproject.toml
- run: pip install toml
- name: Extract version
id: meta
run: |
VERSION=$(grep '^version = ' pyproject.toml | sed 's/version = "\(.*\)"/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT

- name: Log in to GitHub Container Registry
if: github.event_name != 'pull_request' # Avoid pushing on PRs
- name: Log in to GHCR
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ ENV PYTHONPATH=/app:/app/easydel:. \
USER easydel
WORKDIR /app

CMD ["bash"]
CMD ["bash"]
60 changes: 60 additions & 0 deletions Dockerfile.ray
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# syntax=docker/dockerfile:1
ARG RAY_IMAGE=rayproject/ray:2.49.1-py311
FROM ${RAY_IMAGE}

WORKDIR /home/ray

RUN sudo apt-get update && sudo apt-get install -y --no-install-recommends libgomp1 rsync ca-certificates curl git build-essential && sudo rm -rf /var/lib/apt/lists/*

RUN python3.11 -m pip install uv

COPY pyproject.toml uv.lock* ./

ARG HARDWARE_TYPE=cpu
RUN --mount=type=cache,target=/root/.cache/uv uv venv /home/ray/.venv && \
if [ "$HARDWARE_TYPE" = "gpu" ]; then uv sync --frozen --no-dev --no-install-project --extra gpu; \
elif [ "$HARDWARE_TYPE" = "tpu" ]; then uv sync --frozen --no-dev --no-install-project --extra tpu; \
else uv sync --frozen --no-dev --no-install-project; fi

ENV PATH="/home/ray/.venv/bin:$PATH" \
VIRTUAL_ENV="/home/ray/.venv" \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1

RUN sed -i \
-e '/^# >>> conda initialize >>>/,/^# <<< conda initialize <<</d' \
-e '/^export \+PATH=\$HOME\/anaconda3\/bin:\$PATH$/d' \
/home/ray/.bashrc && \
printf 'export VIRTUAL_ENV=/home/ray/.venv\nexport PATH="/home/ray/.venv/bin:$PATH"\n' >> /home/ray/.bashrc

COPY . .

RUN uv pip install -e . --no-deps
RUN uv pip install \
google-api-python-client \
google-auth-httplib2 \
google-auth-oauthlib \
cryptography

RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz && \
tar -xf google-cloud-cli-linux-x86_64.tar.gz && \
printf "N\nY\n\n" | script -qec ./google-cloud-sdk/install.sh && \
rm google-cloud-cli-linux-x86_64.tar.gz

ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1

COPY ./preemptible-fix.patch .
# apply the patch to fix preemptible GCP in Ray in the installed ray package
RUN patch -d /home/ray/.venv/lib/python3.11/site-packages/ -p2 < preemptible-fix.patch

ARG VERSION
ENV VERSION=$VERSION \
HARDWARE_TYPE=$HARDWARE_TYPE

LABEL org.opencontainers.image.version=$VERSION \
org.opencontainers.image.description="EasyDeL on Ray base image" \
org.opencontainers.image.source="https://github.com/dvruette/EasyDeL"

ENTRYPOINT []
CMD ["bash"]
34 changes: 17 additions & 17 deletions easydel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@
if _check_bool_flag("EASYDEL_AUTO", True):
_sys.setrecursionlimit(10000)

# Tell jax xla bridge to stay quiet and only yied warnings or errors.
_getlogger("jax._src.xla_bridge").setLevel(30)
_getlogger("jax._src.mesh_utils").setLevel(30)
_getlogger("jax._src.distributed").setLevel(30)
# these people talk too much
_getlogger("eray-executor").setLevel(30)
_getlogger("absl").setLevel(30)
_getlogger("datasets").setLevel(30)
# # Tell jax xla bridge to stay quiet and only yied warnings or errors.
# _getlogger("jax._src.xla_bridge").setLevel(30)
# _getlogger("jax._src.mesh_utils").setLevel(30)
# _getlogger("jax._src.distributed").setLevel(30)
# # these people talk too much
# _getlogger("eray-executor").setLevel(30)
# _getlogger("absl").setLevel(30)
# # _getlogger("datasets").setLevel(30)

_os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
_os.environ["KMP_AFFINITY"] = "noverbose"
_os.environ["GRPC_VERBOSITY"] = "3"
_os.environ["GLOG_minloglevel"] = "3"
_os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
_os.environ["CACHE_TRITON_KERNELS"] = "1"
_os.environ["TPU_MIN_LOG_LEVEL"] = "2"
_os.environ["TPU_STDERR_LOG_LEVEL"] = "2"
# _os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# _os.environ["KMP_AFFINITY"] = "noverbose"
# _os.environ["GRPC_VERBOSITY"] = "3"
# _os.environ["GLOG_minloglevel"] = "3"
# _os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# _os.environ["CACHE_TRITON_KERNELS"] = "1"
# _os.environ["TPU_MIN_LOG_LEVEL"] = "2"
# _os.environ["TPU_STDERR_LOG_LEVEL"] = "2"
_os.environ["XLA_FLAGS"] = (
_os.getenv("XLA_FLAGS", "") + " "
"--xla_gpu_triton_gemm_any=true "
Expand All @@ -70,7 +70,7 @@
"--xla_gpu_force_compilation_parallelism=4 "
"--xla_gpu_enable_shared_constants=true "
"--xla_gpu_enable_triton_gemm=true "
"--xla_gpu_graph_level=3 "
# "--xla_gpu_graph_level=3 "
"--xla_gpu_enable_command_buffer= "
)
_os.environ["LIBTPU_INIT_ARGS"] = (
Expand Down
2 changes: 2 additions & 0 deletions easydel/infra/base_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def save_optimizer(self, save_directory: str | ePathLike, float_dtype: jnp.dtype
)
except Exception as e:
logger.error(f"Optimizer save failed: {e!s}")
import traceback
traceback.print_exc()
raise
else:
logger.info("Current State don't contain any Optimizer.")
Expand Down
20 changes: 20 additions & 0 deletions easydel/infra/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __reduce__(self):
class AttentionLayerOutput(ModelOutput):
attention_output: chex.Array
attention_weight: chex.Array | None = None
attention_logits: chex.Array | None = None
cache_view: TransformerCacheView | None = None


Expand All @@ -199,6 +200,7 @@ class EncoderLayerOutput(ModelOutput):
hidden_states: chex.Array
residual_states: chex.Array | None = None
attention_weight: chex.Array | None = None
attention_logits: chex.Array | None = None


@auto_pytree
Expand All @@ -207,6 +209,7 @@ class DecoderLayerOutput(ModelOutput):
residual_states: chex.Array | None = None
cross_attention: chex.Array | None = None
attention_weight: chex.Array | None = None
attention_logits: chex.Array | None = None
router_logits: chex.Array | None = None
gate_loss: chex.Array | None = None
cache_view: TransformerCacheView | None = None
Expand Down Expand Up @@ -236,6 +239,7 @@ class BaseModelOutput(ModelOutput):
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
past_key_values: dict[str, chex.Array] | None = None
loss: chex.Array | None = None

Expand Down Expand Up @@ -329,6 +333,7 @@ class BaseModelOutputWithPast(ModelOutput):
past_key_values: dict[str, chex.Array] | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -361,6 +366,7 @@ class BaseModelOutputWithPooling(ModelOutput):
pooler_output: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -410,7 +416,9 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
hidden_states: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
cross_attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -457,7 +465,9 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
past_key_values: TransformerCache | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
cross_attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -516,10 +526,13 @@ class Seq2SeqModelOutput(ModelOutput):
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
decoder_attention_logits: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
cross_attention_logits: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
encoder_attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -561,7 +574,9 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
past_key_values: TransformerCache | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
cross_attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -590,6 +605,7 @@ class MaskedLMOutput(ModelOutput):
hidden_states: tuple[chex.Array] | None = None
last_hidden_state: chex.Array | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
loss: chex.Array | None = None

Expand Down Expand Up @@ -648,10 +664,13 @@ class Seq2SeqLMOutput(ModelOutput):
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
decoder_attention_logits: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
cross_attention_logits: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
encoder_attention_logits: tuple[chex.Array] | None = None
loss: chex.Array | None = None


Expand Down Expand Up @@ -946,6 +965,7 @@ class MoeModelOutput(ModelOutput):
hidden_states: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
attentions: tuple[chex.Array] | None = None
attention_logits: tuple[chex.Array] | None = None
router_logits: tuple[chex.Array] | None = None
all_router_losses: tuple[chex.Array] | None = None
logits: chex.Array = None
Expand Down
Loading