|
14 | 14 |
|
15 | 15 | nvidia-smi -i 0 -lgc <min_clock>,<max_clock> |
16 | 16 | """ |
| 17 | + |
17 | 18 | import time |
18 | 19 |
|
19 | 20 | import matplotlib.pyplot as plt |
|
41 | 42 |
|
42 | 43 | model_name = "meta-llama/Llama-3.2-1B" |
43 | 44 | config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
44 | | -tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| 45 | +tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 46 | + model_name, trust_remote_code=True |
| 47 | +) |
45 | 48 | tokenizer.pad_token = tokenizer.eos_token |
46 | | -model = LlamaForCausalLM(config).to(device).train() # set norm layers to training mode |
| 49 | +model = LlamaForCausalLM(config).to(device).train() # set norm layers to training mode |
47 | 50 | loss_fct = torch.nn.CrossEntropyLoss() |
48 | 51 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) |
49 | 52 |
|
50 | 53 | # Configuration matrix to test |
51 | 54 | batch_seqlen = [(24, 768), (12, 1024), (6, 2048), (3, 4096), (2, 8192), (1, 16384)] |
52 | | -backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION] |
| 55 | +backends = [ |
| 56 | + SDPBackend.CUDNN_ATTENTION, |
| 57 | + SDPBackend.EFFICIENT_ATTENTION, |
| 58 | + SDPBackend.FLASH_ATTENTION, |
| 59 | +] |
53 | 60 |
|
54 | 61 | # Run timing experiments |
55 | 62 | warmup_iterations = 5 # num of training iterations to run for warmup |
56 | 63 | measure_iterations = 100 # num of training iterations to run to measure for timing |
57 | 64 | data = [] |
58 | 65 | for batch_size, seq_len in batch_seqlen: |
59 | | - assert seq_len < tokenizer.model_max_length, "seqlen must be less than the model max length" |
| 66 | + assert ( |
| 67 | + seq_len < tokenizer.model_max_length |
| 68 | + ), "seqlen must be less than the model max length" |
60 | 69 | # create random tensors |
61 | 70 | # - input embedding tensor to simulate a batch of input token sequences converted into embeddings |
62 | 71 | # - attention mask of all ones for full attention |
63 | 72 | # - random target to compute cross entropy loss in training loop |
64 | 73 | shape = (batch_size, seq_len, config.hidden_size) |
65 | 74 | inputs_embeds = torch.randn(*shape, dtype=dtype, device=device) |
66 | 75 | attention_mask = torch.ones(*shape[:2], dtype=torch.int64, device=device) |
67 | | - target = torch.randint(2, config.vocab_size-2, shape[:2], dtype=torch.int64, device=device) |
| 76 | + target = torch.randint( |
| 77 | + 2, config.vocab_size - 2, shape[:2], dtype=torch.int64, device=device |
| 78 | + ) |
68 | 79 | for backend in backends: |
69 | 80 | backend_name = str(backend).split(".")[-1] |
70 | | - print(f"Timing {backend_name} with batch_size={batch_size} and seq_len={seq_len}") |
| 81 | + print( |
| 82 | + f"Timing {backend_name} with batch_size={batch_size} and seq_len={seq_len}" |
| 83 | + ) |
71 | 84 | with sdpa_kernel(backends=[backend]): |
72 | 85 | # warmup iterations: to minimize the effect of system cache |
73 | 86 | for _ in range(warmup_iterations): |
74 | | - output = model.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
75 | | - loss = loss_fct(output.logits.view(-1, config.vocab_size), target.view(-1)) |
| 87 | + output = model.forward( |
| 88 | + inputs_embeds=inputs_embeds, attention_mask=attention_mask |
| 89 | + ) |
| 90 | + loss = loss_fct( |
| 91 | + output.logits.view(-1, config.vocab_size), target.view(-1) |
| 92 | + ) |
76 | 93 | optimizer.zero_grad() |
77 | 94 | loss.backward() |
78 | 95 | optimizer.step() |
|
81 | 98 | start = time.time() |
82 | 99 | # measure iterations: per-iteration time obtained by averaging |
83 | 100 | for _ in range(measure_iterations): |
84 | | - output = model.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
85 | | - loss = loss_fct(output.logits.view(-1, config.vocab_size), target.view(-1)) |
| 101 | + output = model.forward( |
| 102 | + inputs_embeds=inputs_embeds, attention_mask=attention_mask |
| 103 | + ) |
| 104 | + loss = loss_fct( |
| 105 | + output.logits.view(-1, config.vocab_size), target.view(-1) |
| 106 | + ) |
86 | 107 | optimizer.zero_grad() |
87 | 108 | loss.backward() |
88 | 109 | optimizer.step() |
89 | | - torch.cuda.synchronize() # wait for all kernels to finish for accurate timing |
| 110 | + torch.cuda.synchronize() # wait for all kernels to finish for accurate timing |
90 | 111 | duration = time.time() - start |
91 | | - data.append((backend_name, batch_size, seq_len, duration/measure_iterations)) |
| 112 | + data.append( |
| 113 | + (backend_name, batch_size, seq_len, duration / measure_iterations) |
| 114 | + ) |
92 | 115 |
|
93 | 116 | # Process stats |
94 | 117 | df = pd.DataFrame(data, columns=["backend", "batch_size", "seq_len", "time"]) |
95 | 118 | df["label"] = "BS=" + df["batch_size"].astype(str) + " SL=" + df["seq_len"].astype(str) |
96 | 119 | # compute the speedup w.r.t. CUDNN_ATTENTION |
97 | 120 | df["speedup_label"] = df["backend"] + " vs EFFICIENT_ATTENTION" |
98 | 121 | df["speedup"] = df.apply( |
99 | | - lambda row: df.loc[(df["backend"] == "EFFICIENT_ATTENTION") & (df["batch_size"] == row["batch_size"]) & (df["seq_len"] == row["seq_len"]), "time"].values[0] / row["time"], |
100 | | - axis=1) |
| 122 | + lambda row: df.loc[ |
| 123 | + (df["backend"] == "EFFICIENT_ATTENTION") |
| 124 | + & (df["batch_size"] == row["batch_size"]) |
| 125 | + & (df["seq_len"] == row["seq_len"]), |
| 126 | + "time", |
| 127 | + ].values[0] |
| 128 | + / row["time"], |
| 129 | + axis=1, |
| 130 | +) |
101 | 131 | df.to_csv("training_timing.csv", index=False) |
102 | 132 |
|
103 | 133 | # Create plots |
104 | 134 | label_order = [f"BS={b} SL={s}" for b, s in batch_seqlen] # x-axis order |
105 | 135 | hue_order = ["CUDNN_ATTENTION", "FLASH_ATTENTION", "EFFICIENT_ATTENTION"] |
106 | | -g = sns.barplot(data=df, x="label", y="time", hue="backend", |
107 | | - palette=["#76B900", "orchid", "royalblue"], order=label_order, hue_order=hue_order) |
| 136 | +g = sns.barplot( |
| 137 | + data=df, |
| 138 | + x="label", |
| 139 | + y="time", |
| 140 | + hue="backend", |
| 141 | + palette=["#76B900", "orchid", "royalblue"], |
| 142 | + order=label_order, |
| 143 | + hue_order=hue_order, |
| 144 | +) |
108 | 145 | g.set_title("\nTraining Iteration Time") |
109 | | -g.set(xlabel="Batch size and sequence length", ylabel="Mean iteration time (s), lower is better") |
| 146 | +g.set( |
| 147 | + xlabel="Batch size and sequence length", |
| 148 | + ylabel="Mean iteration time (s), lower is better", |
| 149 | +) |
110 | 150 | g.get_legend().set_title("") |
111 | 151 | plt.legend(fontsize=8) |
112 | 152 | plt.xticks(rotation=10, size=8) |
113 | 153 | plt.tight_layout() |
114 | 154 | plt.savefig("iteration_time.png", dpi=300) |
115 | 155 |
|
116 | 156 | plt.clf() |
117 | | -hue_order = ["CUDNN_ATTENTION vs EFFICIENT_ATTENTION", "FLASH_ATTENTION vs EFFICIENT_ATTENTION"] |
118 | | -g = sns.barplot(data=df[df["speedup_label"]!="EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION"], |
119 | | - x="label", y="speedup", hue="speedup_label", |
120 | | - palette=["#76B900", "orchid"], order=label_order, hue_order=hue_order) |
| 157 | +hue_order = [ |
| 158 | + "CUDNN_ATTENTION vs EFFICIENT_ATTENTION", |
| 159 | + "FLASH_ATTENTION vs EFFICIENT_ATTENTION", |
| 160 | +] |
| 161 | +g = sns.barplot( |
| 162 | + data=df[df["speedup_label"] != "EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION"], |
| 163 | + x="label", |
| 164 | + y="speedup", |
| 165 | + hue="speedup_label", |
| 166 | + palette=["#76B900", "orchid"], |
| 167 | + order=label_order, |
| 168 | + hue_order=hue_order, |
| 169 | +) |
121 | 170 | for container in g.containers: |
122 | 171 | g.bar_label(container, fmt="%.2f", fontsize=6) |
123 | | -g.set_title("Per-iteration Speed-up of\ncuDNN/Flash Attention Backend vs Efficient Attention") |
124 | | -g.set(xlabel="Batch size and sequence length", ylabel="Speed-up ratio, higher is better") |
| 172 | +g.set_title( |
| 173 | + "Per-iteration Speed-up of\ncuDNN/Flash Attention Backend vs Efficient Attention" |
| 174 | +) |
| 175 | +g.set( |
| 176 | + xlabel="Batch size and sequence length", ylabel="Speed-up ratio, higher is better" |
| 177 | +) |
125 | 178 | g.get_legend().set_title("") |
126 | 179 | plt.legend(fontsize=8) |
127 | 180 | plt.xticks(rotation=10, size=8) |
|
0 commit comments