-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathclient_test.go
206 lines (171 loc) · 6.55 KB
/
client_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
package main
import (
"bytes"
"encoding/binary"
"os"
"syscall"
"testing"
"github.com/pantheon-systems/pauditd/pkg/slog"
"github.com/stretchr/testify/assert"
)
func TestNetlinkClient_KeepConnection(t *testing.T) {
n := makeNelinkClient(t)
n.KeepConnection()
msg, err := n.Receive()
if err != nil {
t.Fatal("Did not expect an error", err)
}
expectedData := []byte{4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint32(expectedData[12:16], uint32(os.Getpid()))
assert.Equal(t, uint16(1001), msg.Header.Type, "Header.Type mismatch")
assert.Equal(t, uint16(5), msg.Header.Flags, "Header.Flags mismatch")
assert.Equal(t, uint32(1), msg.Header.Seq, "Header.Seq mismatch")
assert.Equal(t, uint32(56), msg.Header.Len, "Packet size is wrong - this test is brittle though")
assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong")
// Make sure we get errors printed
lb, elb := hookLogger()
defer resetLogger()
if err := syscall.Close(n.fd); err != nil {
t.Errorf("Failed to close syscall fd: %v", err)
}
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error")
}
func TestNetlinkClient_SendReceive(t *testing.T) {
var err error
var msg *syscall.NetlinkMessage
// Build our client
n := makeNelinkClient(t)
defer func() {
if err := syscall.Close(n.fd); err != nil {
t.Errorf("Failed to close syscall fd: %v", err)
}
}()
// Make sure we can encode/decode properly
payload := &AuditStatusPayload{
Mask: 4,
Enabled: 1,
Pid: uint32(1006),
}
packet := &NetlinkPacket{
Type: uint16(1001),
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK,
Pid: uint32(1006),
}
// Send and receive the packet
msg = sendReceive(t, n, packet, payload)
// Validate the header fields that are not directly encoded to the AuditStatusPayload
assert.Equal(t, packet.Type, msg.Header.Type, "Header.Type mismatch")
assert.Equal(t, packet.Flags, msg.Header.Flags, "Header.Flags mismatch")
assert.Equal(t, uint32(1), msg.Header.Seq, "Header.Seq mismatch")
assert.Equal(t, uint32(56), msg.Header.Len, "Packet size is wrong - this test is brittle though")
// Extract the meaningful portion of the data
meaningfulData := msg.Data[:40]
// Deserialize syscall.NetlinkMessage{Data} into an AuditStatusPayload
// AuditStatusPayload is a custom struct that represents the logical
// structure of the payload one expects to send or receive in a Netlink message.
// This struct is used to encode or decode the Data field of a syscall.NetlinkMessage.
var receivedPayload AuditStatusPayload
dataReader := bytes.NewReader(meaningfulData)
err = binary.Read(dataReader, binary.LittleEndian, &receivedPayload)
if err != nil {
t.Fatalf("Failed to deserialize payload: %v", err)
}
// Compare the deserialized payload with the expected payload
assert.Equal(t, payload.Mask, receivedPayload.Mask, "Payload.Mask mismatch")
assert.Equal(t, payload.Enabled, receivedPayload.Enabled, "Payload.Enabled mismatch")
assert.Equal(t, payload.Pid, receivedPayload.Pid, "Payload.Pid mismatch")
// Make sure sequences numbers increment on our side
msg = sendReceive(t, n, packet, payload)
assert.Equal(t, uint32(2), msg.Header.Seq, "Header.Seq did not increment")
// Make sure 0-length packets result in an error
if err := syscall.Sendto(n.fd, []byte{}, 0, n.address); err != nil {
t.Errorf("Failed to send data: %v", err)
}
_, err = n.Receive()
assert.Equal(t, "got a 0 length packet", err.Error(), "Error was incorrect")
// Make sure we get errors from sendto back
if err := syscall.Close(n.fd); err != nil {
t.Errorf("Failed to close syscall fd: %v", err)
}
err = n.Send(packet, payload)
assert.Equal(t, "bad file descriptor", err.Error(), "Error was incorrect")
// Make sure we get errors from recvfrom back
n.fd = 0
_, err = n.Receive()
assert.Equal(t, "socket operation on non-socket", err.Error(), "Error was incorrect")
}
func TestNewNetlinkClient(t *testing.T) {
// Hook loggers to capture output
lb, elb := hookLogger()
defer resetLogger()
// Create a new NetlinkClient
n, err := NewNetlinkClient(1024)
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}
t.Logf("Received file descriptor: %d", n.fd)
defer n.Close()
// Verify the NetlinkClient is properly initialized
assert.NotNil(t, n, "Expected a netlink client but got nil")
// In Linux (and UNIX-like systems), file descriptors are:
// 0 → stdin
// 1 → stdout
// 2 → stderr
// 3+ → other open files/sockets
assert.GreaterOrEqual(t, n.fd, 0, "Invalid file descriptor")
assert.NotNil(t, n.address, "Address was nil")
assert.Equal(t, uint32(0), n.seq, "Seq should start at 0")
assert.True(t, MaxAuditMessageLength >= len(n.buf), "Client buffer is too small")
// Verify log output
assert.Contains(t, lb.String(), "Socket receive buffer size:", "Expected log lines for socket buffer size")
assert.Equal(t, "", elb.String(), "Did not expect any error messages")
}
// Helper to make a client listening on a unix socket
func makeNelinkClient(t *testing.T) *NetlinkClient {
if err := os.Remove("pauditd.test.sock"); err != nil && !os.IsNotExist(err) {
t.Errorf("Failed to remove test socket: %v", err)
}
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_RAW, 0)
if err != nil {
t.Fatal("Could not create a socket:", err)
}
n := &NetlinkClient{
fd: fd,
address: &syscall.SockaddrUnix{Name: "pauditd.test.sock"},
buf: make([]byte, MaxAuditMessageLength),
}
if err = syscall.Bind(fd, n.address); err != nil {
if err := syscall.Close(fd); err != nil {
t.Errorf("Failed to close socket fd after bind error: %v", err)
}
t.Fatal("Could not bind to netlink socket:", err)
}
return n
}
// Helper to send and then receive a message with the netlink client
func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *syscall.NetlinkMessage {
err := n.Send(packet, payload)
if err != nil {
t.Fatal("Failed to send:", err)
}
msg, err := n.Receive()
if err != nil {
t.Fatal("Failed to receive:", err)
}
return msg
}
// Resets global loggers
func resetLogger() {
slog.Info.SetOutput(os.Stdout)
slog.Error.SetOutput(os.Stderr)
}
// Hooks the global loggers writers so you can assert their contents
func hookLogger() (lb *bytes.Buffer, elb *bytes.Buffer) {
lb = &bytes.Buffer{}
slog.Info.SetOutput(lb)
elb = &bytes.Buffer{}
slog.Error.SetOutput(elb)
return
}