From ef6df24f1c6603ce40ed78cf0837c2bee5ba64e7 Mon Sep 17 00:00:00 2001 From: yf225 Date: Sun, 12 Jun 2022 12:26:27 -0700 Subject: [PATCH] Add pytorch dependencies to CI Dockerfile and enable test (#503) --- docker/unittest.Dockerfile | 12 ++++++++++++ docs/install.rst | 2 +- tests/{torch => }/test_torch_dict_input.py | 0 tests/{torch => }/test_torch_reshape.py | 0 tests/{torch => }/test_torch_simple.py | 0 tests/{torch => }/test_torch_zhen.py | 0 6 files changed, 13 insertions(+), 1 deletion(-) rename tests/{torch => }/test_torch_dict_input.py (100%) rename tests/{torch => }/test_torch_reshape.py (100%) rename tests/{torch => }/test_torch_simple.py (100%) rename tests/{torch => }/test_torch_zhen.py (100%) diff --git a/docker/unittest.Dockerfile b/docker/unittest.Dockerfile index f16b18398..49e19da5d 100644 --- a/docker/unittest.Dockerfile +++ b/docker/unittest.Dockerfile @@ -23,6 +23,18 @@ RUN source python3.9-env/bin/activate && pip install --upgrade pip \ tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ pybind11 ray[default] matplotlib +# Install PyTorch dependencies +RUN git clone https://github.com/pytorch/functorch /functorch +RUN source python3.7-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd +RUN source python3.8-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd +RUN source python3.9-env/bin/activate \ + && pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + && pushd /functorch && python setup.py install && popd + # We determine the CUDA version at `docker build ...` phase ARG JAX_CUDA_VERSION=11.1 COPY scripts/install_cuda.sh /install_cuda.sh diff --git a/docs/install.rst b/docs/install.rst index 97e337b07..63e4e9b96 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -144,7 +144,7 @@ To enable Alpa for PyTorch, install the following dependencies: # Install nightly version of torch and torchdistx pip3 uninstall -y torch torchdistx - pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu + pip install torch torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu # Build functorch from source git clone https://github.com/pytorch/functorch diff --git a/tests/torch/test_torch_dict_input.py b/tests/test_torch_dict_input.py similarity index 100% rename from tests/torch/test_torch_dict_input.py rename to tests/test_torch_dict_input.py diff --git a/tests/torch/test_torch_reshape.py b/tests/test_torch_reshape.py similarity index 100% rename from tests/torch/test_torch_reshape.py rename to tests/test_torch_reshape.py diff --git a/tests/torch/test_torch_simple.py b/tests/test_torch_simple.py similarity index 100% rename from tests/torch/test_torch_simple.py rename to tests/test_torch_simple.py diff --git a/tests/torch/test_torch_zhen.py b/tests/test_torch_zhen.py similarity index 100% rename from tests/torch/test_torch_zhen.py rename to tests/test_torch_zhen.py