Skip to content

Commit ddd4db7

Browse files
authored
introduce torch_xla2.compile API, make sdxl to use it (#8269)
1 parent fa311ec commit ddd4db7

File tree

6 files changed

+121
-3
lines changed

6 files changed

+121
-3
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# How to run:
2+
3+
```
4+
python sdxl.py
5+
```
Loading
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import time
2+
import functools
3+
import jax
4+
import torch
5+
import torch_xla2
6+
from torch_xla2 import interop
7+
from torch_xla2.interop import JittableModule
8+
9+
from transformers.modeling_outputs import BaseModelOutputWithPooling
10+
11+
from jax.tree_util import register_pytree_node
12+
import jax
13+
14+
def base_model_output_with_pooling_flatten(v):
15+
return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None
16+
17+
def base_model_output_with_pooling_unflatten(aux_data, children):
18+
return BaseModelOutputWithPooling(*children)
19+
20+
register_pytree_node(
21+
BaseModelOutputWithPooling,
22+
base_model_output_with_pooling_flatten,
23+
base_model_output_with_pooling_unflatten
24+
)
25+
26+
27+
from diffusers import StableDiffusionPipeline
28+
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
29+
30+
prompt = "a photograph of an astronaut riding a horse"
31+
# image = pipe(prompt).images[0]
32+
33+
34+
env = torch_xla2.default_env()
35+
jax.config.update('jax_enable_x64', False)
36+
37+
def move_scheduler(scheduler):
38+
for k, v in scheduler.__dict__.items():
39+
if isinstance(v, torch.Tensor):
40+
setattr(scheduler, k, v.to('jax'))
41+
42+
43+
with env:
44+
pipe.to('jax:1')
45+
move_scheduler(pipe.scheduler)
46+
pipe.unet = torch_xla2.compile(
47+
pipe.unet, torch_xla2.CompileOptions(
48+
jax_jit_kwargs={'static_argnames': ('return_dict',)}
49+
)
50+
)
51+
import pdb; pdb.set_trace()
52+
pipe.text_encoder = torch_xla2.compile(pipe.text_encoder)
53+
54+
BS = 4
55+
prompt = [prompt] * BS
56+
pipe.vae = torch_xla2.compile(
57+
pipe.vae, torch_xla2.CompileOptions(
58+
jax_jit_kwargs={'static_argnames': ('return_dict',)},
59+
methods_to_compile=['decode'],
60+
)
61+
)
62+
image = pipe(prompt).images[0]
63+
64+
jax.profiler.start_trace('/tmp/sdxl')
65+
start = time.perf_counter()
66+
image = pipe(prompt, num_inference_steps=20).images[0]
67+
end = time.perf_counter()
68+
jax.profiler.stop_trace()
69+
print('Total time is ', end - start, 'bs = ', BS)
70+
image.save(f"astronaut_rides_horse.png")
71+
72+
73+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
from diffusers import StableDiffusionPipeline
3+
4+
import torch_xla2
5+
env = torch_xla2.default_env()
6+
7+
# this is now contains torhc.Tensor
8+
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
9+
10+
with env:
11+
pipe.to('jax')
12+
prompt = "a photograph of an astronaut riding a horse"
13+
image = pipe(prompt, num_inference_steps=10).images[0]
14+
image.save(f"astronaut_rides_horse_orig.png")

experimental/torch_xla2/torch_xla2/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Dict, Any, Optional
2+
import dataclasses
13
import jax
24
import os
35
import torch
@@ -91,4 +93,30 @@ def enable_accuracy_mode():
9193
def enable_performance_mode():
9294
jax.config.update('jax_enable_x64', False)
9395
jax.config.update('jax_default_matmul_precision', 'default')
94-
default_env().config.internal_respect_torch_return_dtypes = False
96+
default_env().config.internal_respect_torch_return_dtypes = False
97+
98+
99+
100+
@dataclasses.dataclass
101+
class CompileOptions:
102+
# only valid if compiling nn.Module
103+
methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
104+
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
105+
mode: str = 'jax' # or dynamo or export
106+
107+
108+
def compile(fn, options: Optional[CompileOptions] = None):
109+
options = options or CompileOptions()
110+
if options.mode == 'jax':
111+
from torch_xla2 import interop
112+
if isinstance(fn, torch.nn.Module):
113+
module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
114+
for n in options.methods_to_compile:
115+
module.make_jitted(n)
116+
return module
117+
else:
118+
return interop.jax_jit(fn)
119+
elif options.mode == 'dynamo':
120+
raise RuntimeError('dynamo mode is not supported yet')
121+
elif options.mode == 'export':
122+
raise RuntimeError('export mode is not supported yet')

experimental/torch_xla2/torch_xla2/interop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ def set_one(module, prefix):
4949

5050
class JittableModule(torch.nn.Module):
5151

52-
# TODO: add statedict loading hook
53-
5452
def __init__(self, m: torch.nn.Module, extra_jit_args={}):
5553
super().__init__()
5654
self.params, self.buffers = extract_all_buffers(m)

0 commit comments

Comments
 (0)