Open
Description
Hi team. I am working on adding this privacy engine in a customized GRU(GRU with attention layer) and when it run loss.backward() It returns this error: RuntimeError: output with shape [1, 256] doesn't match the broadcast shape [238, 1, 256]. Also, I originally set it as 256 batch size but it reduces some , I think it's because I am using possion_sampling. My question is, why there is different broadcast shape? I did set batch_first as True. Hope someone could help me with that. Thanks a lot!
num_train_samples = 20000 # Number of training samples
num_test_samples = 1000 # Number of test samples
embedding_size = 500 # Size of each embedding
num_classes = 4 # Number of output classes
# Generate random embeddings for training and testing
X_ori = np.random.rand(num_train_samples, embedding_size).astype(np.float32) # Training embeddings
X_sec = np.random.rand(num_test_samples, embedding_size).astype(np.float32) # Testing embeddings
# Generate random labels for training and testing
y_ori = np.random.randint(0, num_classes, size=num_train_samples).astype(np.int64) # Training labels
y_sec = np.random.randint(0, num_classes, size=num_test_samples).astype(np.int64) # Testing labels
# Step 1: Define the Dataset
class EmbeddingDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
embedding, label = self.data[idx]
return torch.tensor(embedding, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
class AttentionLayer(nn.Module):
def __init__(self, hidden_size):
super(AttentionLayer, self).__init__()
self.attention_weights = nn.Linear(hidden_size,1)
def forward(self, gru_outputs):
print("GRU Outputs Shape:", gru_outputs.shape)
scores = self.attention_weights(gru_outputs)
if scores.shape[1] == 1:
scores = scores.view(scores.size(0), -1)
else:
scores = scores.squeeze(-1)
weights = F.softmax(scores, dim = -1)
context_vector = torch.sum(weights.unsqueeze(-1)*gru_outputs, dim = 1)
return context_vector, weights
class AdvancedGRUClassifier2(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout = 0.3 ):
super(AdvancedGRUClassifier2, self).__init__()
self.gru = nn.GRU(input_size, hidden_size, num_layers = num_layers, batch_first = True, dropout = dropout, bidirectional = True)
self.attention = AttentionLayer(hidden_size * 2)
self.fc = nn.Sequential(nn.Dropout(0.2),
nn.Linear(hidden_size*2, output_size))
def forward(self, x):
# if x.dim() == 2: # If input is (batch_size, feature_size)
# x = x.unsqueeze(1) # Shape becomes (batch_size, seq_length=1, feature_size)
x = x.unsqueeze(1)
gru_outputs, _ = self.gru(x)
context_vector, attn_weights = self.attention(gru_outputs)
out = self.fc(context_vector)
return F.log_softmax(out, dim = 1)
data = [(embedding, label) for embedding, label in zip(X_ori, y_ori)
# Add more data as needed
]
dataset = EmbeddingDataset(data)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
test_data = [(embedding, label) for embedding, label in zip(X_sec,y_sec)]
test_dataset = EmbeddingDataset(test_data)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Instantiate the model
model_3 = AdvancedGRUClassifier2(
input_size=500, # Replace with the size of your embedding
hidden_size=256, # First hidden layer dimension
output_size=4, # Number of classes
num_layers = 6,
dropout=0.2 # Dropout rate to prevent overfitting
)
# Step 3: Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_3.parameters(), lr=0.001)
# Define parameters for early stopping
max_epochs = 75 # Maximum number of epochs
train_losses = []
test_losses = []
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model_3, optimizer, dataloader = privacy_engine.make_private(
module=model_3,
optimizer=optimizer,
data_loader=dataloader,
noise_multiplier=1.1,
max_grad_norm=1.0,
# poisson_sampling=False
)
for epoch in range(max_epochs):
# Training
model_3.train()
epoch_train_loss = 0
for embeddings, labels in dataloader:
optimizer.zero_grad()
outputs = model_3(embeddings)
loss = criterion(outputs, labels)
print("Model output shape:", outputs.shape) # Expect: (batch_size, num_classes)
print("Labels shape:", labels.shape) # Expect: (batch_size,)
print("loss is ", loss)
for name, param in model_3.named_parameters():
if param.grad is None:
print(f"No gradients for {name}")
else:
print(f"Gradients for {name} have shape: {param.grad.shape}")
# torch.autograd.set_detect_anomaly(True)
# Perform backward pass
loss.backward()
# loss.backward()
optimizer.step()
epoch_train_loss += loss.item()
avg_train_loss = epoch_train_loss / len(dataloader)
train_losses.append(avg_train_loss)
# Testing
model_3.eval()
epoch_test_loss = 0
with torch.no_grad():
for embeddings, labels in test_dataloader:
outputs = model_3(embeddings)
loss = criterion(outputs, labels)
epoch_test_loss += loss.item()
avg_test_loss = epoch_test_loss / len(test_dataloader)
test_losses.append(avg_test_loss)
# Print progress
print(f'Epoch [{epoch+1}/{max_epochs}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}') #how to do cprint?
# Early stopping condition
if avg_test_loss > avg_train_loss:
print(f"Stopping early at epoch {epoch+1}: Test Loss ({avg_test_loss:.4f}) > Train Loss ({avg_train_loss:.4f})")
break
# Optional: Visualize Loss Trends
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Test Loss Over Epochs')
plt.show()
Metadata
Assignees
Labels
No labels