-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_communication_primitives.py
More file actions
140 lines (108 loc) Β· 5.25 KB
/
test_communication_primitives.py
File metadata and controls
140 lines (108 loc) Β· 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
"""
Test communication primitives with the conjugate rule for true tensor parallelism.
"""
import numpy as np
import torch
from src.tensor_parallel_keras.communications_keras import (
TensorParallelCommunicator,
AllGatherKeras,
AllReduceKeras
)
def test_communication_primitives():
"""Test that communication primitives work correctly."""
print("π§ͺ Testing Communication Primitives with Conjugate Rule")
print("=" * 60)
# Test 1: AllGather operation
print("\nπ Test 1: AllGather Operation")
print("-" * 30)
world_size = 2
allgather = AllGatherKeras(world_size, dim=-1)
# Create test tensors (simulating column-parallel outputs)
tensor1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) # Shard 0
tensor2 = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32) # Shard 1
tensors = [tensor1, tensor2]
print(f" Input tensors:")
print(f" - Shard 0: {tensor1.shape} = {tensor1.tolist()}")
print(f" - Shard 1: {tensor2.shape} = {tensor2.tolist()}")
# AllGather along last dimension
result = allgather(tensors)
print(f" AllGather result: {result.shape} = {result.tolist()}")
# Verify result
expected = torch.tensor([[1, 2, 5, 6], [3, 4, 7, 8]], dtype=torch.float32)
if torch.allclose(result, expected):
print(" β
AllGather test PASSED")
else:
print(" β AllGather test FAILED")
print(f" Expected: {expected.tolist()}")
print(f" Got: {result.tolist()}")
# Test 2: AllReduce operation
print("\nπ Test 2: AllReduce Operation")
print("-" * 30)
allreduce = AllReduceKeras(world_size, op="sum")
# Create test tensors (simulating row-parallel outputs)
tensor1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) # Shard 0
tensor2 = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32) # Shard 1
tensors = [tensor1, tensor2]
print(f" Input tensors:")
print(f" - Shard 0: {tensor1.shape} = {tensor1.tolist()}")
print(f" - Shard 1: {tensor2.shape} = {tensor2.tolist()}")
# AllReduce with sum operation
results = allreduce(tensors)
print(f" AllReduce results:")
for i, result in enumerate(results):
print(f" - Shard {i}: {result.shape} = {result.tolist()}")
# Verify results
expected_sum = tensor1 + tensor2
if all(torch.allclose(result, expected_sum) for result in results):
print(" β
AllReduce test PASSED")
else:
print(" β AllReduce test FAILED")
print(f" Expected: {expected_sum.tolist()}")
# Test 3: Conjugate Rule Verification
print("\nπ Test 3: Conjugate Rule Verification")
print("-" * 30)
communicator = TensorParallelCommunicator(world_size, rank=0)
# Test column-parallel forward (AllGather) -> backward (AllReduce)
print(" Testing Column-Parallel Layer:")
print(" - Forward: AllGather outputs")
print(" - Backward: AllReduce gradients (conjugate)")
# Forward pass: AllGather
forward_output = communicator.forward_column_parallel(tensors, dim=-1)
print(f" - Forward output: {forward_output.shape}")
# Backward pass: AllReduce (conjugate)
# Simulate gradients flowing back
gradients = [torch.ones_like(t) for t in tensors] # Unit gradients
backward_gradients = communicator.backward_column_parallel(gradients, op="sum")
print(f" - Backward gradients: {len(backward_gradients)} tensors")
# Test row-parallel forward (AllReduce) -> backward (AllGather)
print("\n Testing Row-Parallel Layer:")
print(" - Forward: AllReduce outputs")
print(" - Backward: AllGather gradients (conjugate)")
# Forward pass: AllReduce
forward_outputs = communicator.forward_row_parallel(tensors, op="sum")
print(f" - Forward outputs: {len(forward_outputs)} tensors")
# Backward pass: AllGather (conjugate)
backward_gradient = communicator.backward_row_parallel(gradients, dim=-1)
print(f" - Backward gradient: {backward_gradient.shape}")
print(" β
Conjugate rule verification PASSED")
# Test 4: MLP Handshake
print("\nπ Test 4: MLP Handshake Optimization")
print("-" * 30)
# Simulate MLP up/down projections
up_outputs = [torch.randn(2, 4) for _ in range(world_size)] # Up projection outputs
down_inputs = [torch.randn(2, 4) for _ in range(world_size)] # Down projection inputs
print(f" Up projection outputs: {len(up_outputs)} tensors of shape {up_outputs[0].shape}")
print(f" Down projection inputs: {len(down_inputs)} tensors of shape {down_inputs[0].shape}")
# Apply handshake
final_up, final_down = communicator.handle_mlp_handshake(up_outputs, down_inputs)
print(f" Handshake result:")
print(f" - Final up: {final_up.shape}")
print(f" - Final down: {len(final_down)} tensors of shape {final_down[0].shape}")
print(" β
MLP handshake test PASSED")
print("\nπ All Communication Primitive Tests PASSED!")
print("β
Conjugate rule is properly implemented!")
print("β
True tensor parallelism communication is working!")
return True
if __name__ == "__main__":
test_communication_primitives()