-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
Coming from pytorch/xla#8388, I tried running THE SD2 perf optimized implementation and ran into the following error
[worker:] + sudo groupadd docker
[worker:] groupadd: group 'docker' already exists
[worker:] + sudo usermod -aG docker kojoe
[worker:] + newgrp docker
[worker:] WARNING: Your config file at [/home/kojoe/.docker/config.json] contains these credential helper entries:
[worker:]
[worker:] {
[worker:] "credHelpers": {
[worker:] "us-central1-docker.pkg.dev": "gcloud"
[worker:] }
[worker:] }
[worker:] Adding credentials for: us-central1-docker.pkg.dev
[worker:] gcloud credential helpers already registered correctly.
[worker:] Error response from daemon: Cannot kill container: kojoe-test: No such container: kojoe-test
[worker:] v4: Pulling from deeplearning-images/reproducibility/pytorch-tpu-diffusers
[worker:] Digest: sha256:38a7ced3c82f8288fa6c8a33f8c8c7ef9dfa403e4a76bbcab5ea453c0eced862
[worker:] Status: Image is up to date for us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-diffusers:v4
[worker:] us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-diffusers:v4
[worker:] WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
[worker:] Repo card metadata block was not found. Setting CardData to empty.
[worker:] WARNING:huggingface_hub.repocard:Repo card metadata block was not found. Setting CardData to empty.
Generating train split: 100%|██████████| 1221/1221 [00:02<00:00, 447.38 examples/s]
[worker:] ***** Running training *****
[worker:] Instantaneous batch size per device = 16
[worker:] Total train batch size (w. parallel, distributed & accumulation) = 128
[worker:] Total optimization steps = 50
[worker:] Traceback (most recent call last):
[worker:] File "/workspace/diffusers/examples/text_to_image/train_text_to_image_xla.py", line 579, in <module>
[worker:] Bad StatusOr access: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'tpu.mask_cast'
[worker:]
[worker:] at location: loc("/dot_general"(callsite("body"("/usr/local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":1211:0) at callsite("run"("/usr/local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":1302:0) at callsite("_flash_attention_dq_kernel"("/usr/local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":1301:0) at "_flash_attention_bwd_dq"("/usr/local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":1484:0))))))
[worker:]
[worker:] The MLIR operation involved:
[worker:] %8069 = "tpu.mask_cast"(%8068) : (vector<8x128xi1>) -> vector<8x128x2xi1>
[worker:]
[worker:] Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
[worker:]
[worker:] main(args)
[worker:] File "/workspace/diffusers/examples/text_to_image/train_text_to_image_xla.py", line 561, in main
[worker:] trainer.start_training()
[worker:] File "/workspace/diffusers/examples/text_to_image/train_text_to_image_xla.py", line 86, in start_training
[worker:] xm.mark_step()
[worker:] File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1046, in mark_step
[worker:] torch_xla._XLAC._xla_step_marker(
[worker:] RuntimeError: torch_xla/csrc/xla_graph_executor.cpp:689 : Check failed: tensor_data
[worker:] *** Begin stack trace ***
[worker:] tsl::CurrentStackTrace[abi:cxx11]()
[worker:] torch_xla::XLAGraphExecutor::CollectSyncTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&)
[worker:] torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
[worker:] torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
[worker:] torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, bool)
[worker:]
[worker:]
[worker:]
[worker:]
[worker:] _PyObject_MakeTpCall
[worker:] _PyEval_EvalFrameDefault
[worker:]
[worker:] _PyEval_EvalFrameDefault
[worker:]
[worker:] _PyEval_EvalFrameDefault
[worker:]
[worker:] _PyEval_EvalFrameDefault
[worker:]
[worker:] PyEval_EvalCode
[worker:]
[worker:]
[worker:]
[worker:] _PyRun_SimpleFileObject
[worker:] _PyRun_AnyFileObject
[worker:] Py_RunMain
[worker:] Py_BytesMain
[worker:] __libc_start_main
[worker:] _start
[worker:] *** End stack trace ***
[worker:]
Metadata
Metadata
Assignees
Labels
No labels