Skip to content

Commit 9eec9e1

Browse files
authored
Burn lstm (#351)
* add rust burn lstm * pre-convert lstm weights to burn format * remove duplicate uv install * rename rust lstm
1 parent 97617af commit 9eec9e1

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

docker/Dockerfile

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ RUN dnf update -y && \
1414
vim libgfortran sqlite \
1515
bzip2 expat udunits2 zlib \
1616
mpich hdf5 netcdf netcdf-fortran netcdf-cxx netcdf-cxx4-mpich
17+
# Install UV and setup cargo bin for rust tools (uv and cargo)
18+
ENV PATH="/root/.cargo/bin:${PATH}"
19+
ENV UV_INSTALL_DIR=/root/.cargo/bin
20+
ENV UV_COMPILE_BYTECODE=1
21+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
22+
RUN uv self update
1723

1824
FROM base AS build_base
1925
# no dnf update to keep devel packages consistent with versions installed in base
@@ -33,7 +39,7 @@ FROM build_base AS boost_build
3339
RUN wget https://archives.boost.io/release/1.79.0/source/boost_1_79_0.tar.gz
3440
RUN tar -xzf boost_1_79_0.tar.gz
3541
WORKDIR /boost_1_79_0
36-
RUN ./bootstrap.sh && ./b2 && ./b2 headers
42+
RUN ./bootstrap.sh && ./b2 headers
3743
ENV BOOST_ROOT=/boost_1_79_0
3844

3945

@@ -46,7 +52,7 @@ ENV FC=gfortran NETCDF=/usr/lib64/gfortran/modules/
4652
RUN ln -s /usr/bin/python3 /usr/bin/python
4753

4854
WORKDIR /ngen/
49-
RUN pip3 install uv && uv venv
55+
RUN uv venv
5056
ENV PATH="/ngen/.venv/bin:$PATH"
5157
## make sure clone isn't cached if repo is updated
5258
ADD https://api.github.com/repos/${TROUTE_REPO}/git/refs/heads/${TROUTE_BRANCH} /tmp/version.json
@@ -161,11 +167,25 @@ ENTRYPOINT ["./HelloNGEN.sh"]
161167

162168
FROM build_base AS lstm_weights
163169
RUN git clone --depth=1 --branch example_weights https://github.com/ciroh-ua/lstm.git /lstm_weights
170+
# add the rust weight conversion
171+
RUN uv run --with pyyaml --with numpy --with torch --extra-index-url https://download.pytorch.org/whl/cpu \
172+
https://raw.githubusercontent.com/CIROH-UA/rust-lstm-1025/refs/tags/v0.1.0/scripts/convert.py \
173+
all /lstm_weights/trained_neuralhydrology_models/
164174
# replace the relative path with the absolute path in the model config files
165175
RUN shopt -s globstar
166176
RUN sed -i 's|\.\.|/ngen/ngen/extern/lstm|g' /lstm_weights/trained_neuralhydrology_models/**/config.yml
167177

168178

179+
FROM build_base AS burn_lstm
180+
RUN dnf install -y clang
181+
WORKDIR /build
182+
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
183+
RUN echo 'source $HOME/.cargo/env' >> $HOME/.bashrc
184+
RUN git clone --depth=1 https://github.com/aaraney/bmi-rs
185+
RUN git clone --depth=1 --branch v0.1.0 https://github.com/ciroh-ua/rust-lstm-1025
186+
WORKDIR /build/rust-lstm-1025
187+
RUN cargo build --release
188+
169189
FROM base AS final
170190

171191
WORKDIR /ngen
@@ -177,12 +197,7 @@ COPY --from=troute_build /ngen/t-route/src/troute-*/dist/*.whl /tmp/
177197

178198
RUN ln -s /dmod/bin/ngen /usr/local/bin/ngen
179199

180-
ENV UV_INSTALL_DIR=/root/.cargo/bin
181-
ENV UV_COMPILE_BYTECODE=1
182-
183-
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
184-
ENV PATH="/root/.cargo/bin:${PATH}"
185-
RUN uv self update && uv venv && \
200+
RUN uv venv && \
186201
uv pip install --no-cache-dir /tmp/*.whl netCDF4==1.6.3
187202
# Clean up some stuff, this doesn't make the image any smaller
188203
RUN rm -rf /tmp/*.whl
@@ -211,6 +226,9 @@ RUN uv pip install --no-cache-dir /ngen/ngen/extern/lstm --extra-index-url https
211226
RUN rm -rf /ngen/ngen/extern/lstm/trained_neuralhydrology_models
212227
COPY --from=lstm_weights /lstm_weights/trained_neuralhydrology_models /ngen/ngen/extern/lstm/trained_neuralhydrology_models
213228

229+
# Copy the rust version of the lstm over.
230+
COPY --from=burn_lstm /build/rust-lstm-1025/target/release/librust_lstm_1025.so /dmod/shared_libs/librust_lstm_1025.so
231+
214232
## add some metadata to the image
215233
COPY --from=troute_build /tmp/troute_url /ngen/troute_url
216234
COPY --from=ngen_build /tmp/ngen_url /ngen/ngen_url

0 commit comments

Comments
 (0)