forked from BlackSamorez/tensor_parallel
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_realistic_memory_savings.py
More file actions
231 lines (186 loc) · 8.56 KB
/
test_realistic_memory_savings.py
File metadata and controls
231 lines (186 loc) · 8.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python3
"""
Test suite for realistic memory savings with tensor parallelism.
"""
import time
import logging
import numpy as np
import pytest
# Import required modules
try:
import keras
from keras import layers
from src.tensor_parallel_keras.coordinated_optimizer import CoordinatedOptimizer
print("✅ Required modules imported successfully")
except ImportError as e:
print(f"❌ Import failed: {e}")
pytest.skip(f"Required modules not available: {e}")
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def create_large_model():
"""Create a large model to demonstrate memory savings."""
import keras
from keras import layers
# Create a large model with many parameters
model = keras.Sequential([
layers.Input(shape=(1000,)),
layers.Dense(2048, activation='relu'),
layers.Dense(2048, activation='relu'),
layers.Dense(1024, activation='relu'),
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='softmax')
])
return model
def get_optimizer_memory_info(optimizer, world_size, enable_sharding=True):
"""Helper function to get memory usage for a given optimizer and world size."""
import keras
from src.tensor_parallel_keras.coordinated_optimizer import CoordinatedOptimizer
coord_opt = CoordinatedOptimizer(
base_optimizer=optimizer,
world_size=world_size,
distributed_backend_type='fallback',
shard_optimizer_states=enable_sharding
)
return coord_opt.get_memory_usage()
def test_realistic_memory_savings():
"""Test realistic memory savings with large models."""
print("🚀 Testing Realistic Memory Savings")
print("=" * 40)
start_time = time.time()
print(f"⏱️ {time.time() - start_time:.2f}s: Starting realistic memory test...")
# Import required modules
try:
import keras
from keras import layers
print(f"✅ {time.time() - start_time:.2f}s: Modules imported successfully")
except ImportError as e:
pytest.skip(f"Required modules not available: {e}")
print(f"⏱️ {time.time() - start_time:.2f}s: Creating large model...")
# Create a large model for realistic testing
model = keras.Sequential([
layers.Input(shape=(784,)),
layers.Dense(2048, activation='relu'),
layers.Dense(4096, activation='relu'),
layers.Dense(2048, activation='relu'),
layers.Dense(1024, activation='relu'),
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='softmax')
])
print(f"✅ {time.time() - start_time:.2f}s: Model created with {model.count_params():,} parameters")
# Test different world sizes
world_sizes = [2, 4, 8]
print("\n🔄 Testing Adam Optimizer")
print("-" * 30)
for world_size in world_sizes:
print(f" World Size: {world_size}")
# Test without sharding
optimizer = keras.optimizers.Adam(learning_rate=0.001)
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=False)
print(f" No sharding: {memory_info}")
# Test with sharding
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=True)
print(f" With sharding: {memory_info}")
if memory_info['sharding_enabled']:
savings = memory_info['memory_savings']
theoretical_max = f"{(1 - 1/world_size) * 100:.1f}%"
print(f" 💾 Memory savings: {savings}")
print(f" 📊 Theoretical max savings: {theoretical_max}")
print("\n🔄 Testing SGD Optimizer")
print("-" * 30)
for world_size in world_sizes:
print(f" World Size: {world_size}")
# Test without sharding
optimizer = keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=False)
print(f" No sharding: {memory_info}")
# Test with sharding
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=True)
print(f" With sharding: {memory_info}")
if memory_info['sharding_enabled']:
savings = memory_info['memory_savings']
theoretical_max = f"{(1 - 1/world_size) * 100:.1f}%"
print(f" 💾 Memory savings: {savings}")
print(f" 📊 Theoretical max savings: {theoretical_max}")
print("\n🔄 Testing RMSprop Optimizer")
print("-" * 30)
for world_size in world_sizes:
print(f" World Size: {world_size}")
# Test without sharding
optimizer = keras.optimizers.RMSprop(learning_rate=0.001)
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=False)
print(f" No sharding: {memory_info}")
# Test with sharding
memory_info = get_optimizer_memory_info(optimizer, world_size, enable_sharding=True)
print(f" With sharding: {memory_info}")
if memory_info['sharding_enabled']:
savings = memory_info['memory_savings']
theoretical_max = f"{(1 - 1/world_size) * 100:.1f}%"
print(f" 💾 Memory savings: {savings}")
print(f" 📊 Theoretical max savings: {theoretical_max}")
print(f"✅ Realistic memory test completed in {time.time() - start_time:.2f}s")
def test_optimizer_state_partitioning():
"""Test optimizer state partitioning across devices."""
print("🔧 Testing Optimizer State Partitioning")
print("=" * 40)
start_time = time.time()
print(f"⏱️ {time.time() - start_time:.2f}s: Starting partitioning test...")
# Create a simple model
model = keras.Sequential([
layers.Input(shape=(10,)),
layers.Dense(100, activation='relu'),
layers.Dense(50, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
# Create coordinated optimizer with sharded states
optimizer = CoordinatedOptimizer(
base_optimizer=keras.optimizers.Adam(learning_rate=0.001),
world_size=4,
shard_optimizer_states=True
)
print(f"✅ {time.time() - start_time:.2f}s: Coordinated optimizer created")
# Get sharded states structure
sharded_states = optimizer._get_sharded_states_structure()
print(f" Sharded states structure:")
for state_name, state_info in sharded_states.items():
if isinstance(state_info, dict):
print(f" {state_name}:")
for var_name, var_info in state_info.items():
if isinstance(var_info, dict) and 'num_shards' in var_info:
print(f" {var_name}: {var_info['num_shards']} shards")
for i, shape in enumerate(var_info['shard_shapes']):
print(f" Shard {i}: {shape}")
else:
print(f" {var_name}: {var_info}")
else:
print(f" {state_name}: {state_info}")
print(f"✅ Partitioning test completed in {time.time() - start_time:.2f}s")
if __name__ == "__main__":
print("🎯 REALISTIC MEMORY SAVINGS TEST")
print("=" * 40)
# Test 1: Realistic memory savings
test1_success = test_realistic_memory_savings()
# Test 2: State partitioning
test2_success = test_optimizer_state_partitioning()
print("\n" + "=" * 40)
print("🎉 TESTING COMPLETED!")
print(f"\n📋 RESULTS:")
print(f" - Realistic Memory: {'✅' if test1_success else '❌'}")
print(f" - State Partitioning: {'✅' if test2_success else '❌'}")
if all([test1_success, test2_success]):
print("\n🚀 SUCCESS: All realistic memory tests passed!")
print("\n💡 KEY BENEFITS:")
print(" ✅ Significant memory savings with large models")
print(" ✅ Efficient optimizer state partitioning")
print(" ✅ Scalable to any number of devices")
print(" ✅ Production-ready implementation")
else:
print("\n⚠️ WARNING: Some tests failed.")