|
12 | 12 | class LlamaTest(test_base.TestCase):
|
13 | 13 |
|
14 | 14 | def test_can_run(self):
|
15 |
| - sample_args = ( |
16 |
| - torch.randint(0, 32000, (1, 2048)), |
17 |
| - torch.arange(0, 2048), |
18 |
| - ) |
19 |
| - sample_args = pytree.tree_map(tensor.t2j, sample_args) |
| 15 | + with torch_xla2.default_env(): |
| 16 | + sample_args = ( |
| 17 | + torch.randint(0, 32000, (1, 2048), device='jax:0'), |
| 18 | + torch.arange(0, 2048, device='jax:0'), |
| 19 | + ) |
20 | 20 |
|
21 |
| - model_args = llama_model.ModelArgs( |
22 |
| - block_size=2048, |
23 |
| - vocab_size=32000, |
24 |
| - n_layer=2, |
25 |
| - n_head=4, |
26 |
| - dim=256, |
27 |
| - ) |
28 |
| - m = llama_model.Transformer(model_args) |
29 |
| - m.to(torch.bfloat16) |
30 |
| - m.setup_caches(1, 2048) |
| 21 | + model_args = llama_model.ModelArgs( |
| 22 | + block_size=2048, |
| 23 | + vocab_size=32000, |
| 24 | + n_layer=2, |
| 25 | + n_head=4, |
| 26 | + dim=256, |
| 27 | + ) |
| 28 | + m = llama_model.Transformer(model_args) |
| 29 | + m.to(torch.bfloat16) |
| 30 | + m.setup_caches(1, 2048) |
| 31 | + m = m.to('jax') |
| 32 | + |
| 33 | + print(m(*sample_args)) |
31 | 34 |
|
32 |
| - # NOTE: this API does NOT use torch export |
33 |
| - weights, jax_func = torch_xla2.extract_jax(m) |
34 |
| - print(jax_func(weights, sample_args)) |
35 | 35 |
|
36 | 36 | def test_can_run_exportable(self):
|
37 |
| - model_args = model_exportable.ModelArgs( |
38 |
| - vocab_size=32000, |
39 |
| - n_layers=2, |
40 |
| - n_heads=4, |
41 |
| - dim=256, |
42 |
| - ) |
43 |
| - m = model_exportable.Transformer(model_args) |
44 |
| - context_length = 2048 |
45 |
| - input_shape_prefill = (1, context_length) |
46 |
| - input_shape_decode = (1, 1) |
| 37 | + model_args = model_exportable.ModelArgs( |
| 38 | + vocab_size=32000, |
| 39 | + n_layers=2, |
| 40 | + n_heads=4, |
| 41 | + dim=256, |
| 42 | + ) |
| 43 | + m = model_exportable.Transformer(model_args) |
| 44 | + context_length = 2048 |
| 45 | + input_shape_prefill = (1, context_length) |
| 46 | + input_shape_decode = (1, 1) |
47 | 47 |
|
48 |
| - def make_cache(args, batch_size): |
49 |
| - n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads |
50 |
| - n_local_heads = args.n_heads |
51 |
| - n_local_kv_heads = n_kv_heads |
52 |
| - n_rep = n_local_heads // n_local_kv_heads |
53 |
| - head_dim = args.dim // args.n_heads |
54 |
| - res = [] |
55 |
| - for i in range(args.n_layers): |
56 |
| - if batch_size is None: |
57 |
| - size = ( |
58 |
| - args.max_seq_len, |
59 |
| - n_local_kv_heads, |
60 |
| - head_dim, |
61 |
| - ) |
62 |
| - else: |
63 |
| - size = ( |
64 |
| - batch_size, |
65 |
| - args.max_seq_len, |
66 |
| - n_local_kv_heads, |
67 |
| - head_dim, |
68 |
| - ) |
69 |
| - res.append( |
70 |
| - (torch.zeros( |
71 |
| - size, |
72 |
| - dtype=torch.bfloat16 if args.bf16_enable else torch.float), |
73 |
| - torch.zeros( |
74 |
| - size, |
75 |
| - dtype=torch.bfloat16 if args.bf16_enable else torch.float))) |
76 |
| - return res |
| 48 | + def make_cache(args, batch_size): |
| 49 | + n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads |
| 50 | + n_local_heads = args.n_heads |
| 51 | + n_local_kv_heads = n_kv_heads |
| 52 | + n_rep = n_local_heads // n_local_kv_heads |
| 53 | + head_dim = args.dim // args.n_heads |
| 54 | + res = [] |
| 55 | + for i in range(args.n_layers): |
| 56 | + if batch_size is None: |
| 57 | + size = ( |
| 58 | + args.max_seq_len, |
| 59 | + n_local_kv_heads, |
| 60 | + head_dim, |
| 61 | + ) |
| 62 | + else: |
| 63 | + size = ( |
| 64 | + batch_size, |
| 65 | + args.max_seq_len, |
| 66 | + n_local_kv_heads, |
| 67 | + head_dim, |
| 68 | + ) |
| 69 | + res.append( |
| 70 | + (torch.zeros( |
| 71 | + size, |
| 72 | + dtype=torch.bfloat16 if args.bf16_enable else torch.float), |
| 73 | + torch.zeros( |
| 74 | + size, |
| 75 | + dtype=torch.bfloat16 if args.bf16_enable else torch.float))) |
| 76 | + return res |
77 | 77 |
|
78 |
| - prefill_caches = make_cache(model_args, 1) |
| 78 | + prefill_caches = make_cache(model_args, 1) |
79 | 79 |
|
80 |
| - sample_input_prefill = ( |
81 |
| - torch.randint(0, 1000, input_shape_prefill, |
82 |
| - dtype=torch.int32), # len seq length |
83 |
| - torch.arange(0, context_length, dtype=torch.int32), # input indexes |
84 |
| - torch.arange(0, context_length, dtype=torch.int32), # context indexes |
85 |
| - prefill_caches, |
86 |
| - True, # prefil |
87 |
| - ) |
88 |
| - with torch.no_grad(): |
89 |
| - m_prefill = torch.export.export(m, sample_input_prefill) |
| 80 | + sample_input_prefill = ( |
| 81 | + torch.randint(0, 1000, input_shape_prefill, |
| 82 | + dtype=torch.int32), # len seq length |
| 83 | + torch.arange(0, context_length, dtype=torch.int32), # input indexes |
| 84 | + torch.arange(0, context_length, dtype=torch.int32), # context indexes |
| 85 | + prefill_caches, |
| 86 | + True, # prefil |
| 87 | + ) |
| 88 | + with torch.no_grad(): |
| 89 | + m_prefill = torch.export.export(m, sample_input_prefill) |
90 | 90 |
|
91 |
| - weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) |
92 |
| - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, |
93 |
| - sample_input_prefill) |
94 |
| - print('Prefill', mj_prefill(weights, sample_inputs)) |
| 91 | + weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) |
| 92 | + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, |
| 93 | + sample_input_prefill) |
| 94 | + print('Prefill', mj_prefill(weights, sample_inputs)) |
95 | 95 |
|
96 |
| - sample_input_decode = ( |
97 |
| - torch.randint(0, 1000, input_shape_decode, |
98 |
| - dtype=torch.int32), # len = 1 |
99 |
| - torch.tensor([0], dtype=torch.int32), |
100 |
| - torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), |
101 |
| - prefill_caches, |
102 |
| - False # prefill |
103 |
| - ) |
104 |
| - with torch.no_grad(): |
105 |
| - m_decode = torch.export.export(m, sample_input_decode) |
106 |
| - weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) |
107 |
| - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, |
108 |
| - sample_input_decode) |
109 |
| - print('Decode', mj_decode(weights, sample_inputs)) |
| 96 | + sample_input_decode = ( |
| 97 | + torch.randint(0, 1000, input_shape_decode, |
| 98 | + dtype=torch.int32), # len = 1 |
| 99 | + torch.tensor([0], dtype=torch.int32), |
| 100 | + torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), |
| 101 | + prefill_caches, |
| 102 | + False # prefill |
| 103 | + ) |
| 104 | + with torch.no_grad(): |
| 105 | + m_decode = torch.export.export(m, sample_input_decode) |
| 106 | + weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) |
| 107 | + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, |
| 108 | + sample_input_decode) |
| 109 | + print('Decode', mj_decode(weights, sample_inputs)) |
110 | 110 |
|
111 | 111 |
|
112 | 112 | if __name__ == "__main__":
|
|
0 commit comments