-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_mlp.py
More file actions
31 lines (24 loc) · 900 Bytes
/
basic_mlp.py
File metadata and controls
31 lines (24 loc) · 900 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return torch.softmax(x, dim=1)
# 定义输入、隐藏层和输出层的大小
input_size = 3
hidden_size = 6
output_size = 2 # 二分类,输出大小为2
# 创建MLP模型实例
mlp = MLP(input_size, hidden_size, output_size)
# 示例输入
input_data = torch.tensor([[0.1, 0.2, 0.3]]) # 示例输入数据,大小为(batch_size, input_size)
# 使用模型进行前向传播
output = mlp(input_data)
print("Output probabilities:", output)