Skip to content

Commit 9a287a3

Browse files
committed
Rework how files being exported are tracked
1 parent d6a1e49 commit 9a287a3

3 files changed

Lines changed: 137 additions & 25 deletions

File tree

Source/santad/Logs/EndpointSecurity/Logger.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class Logger : public Timer<Logger> {
5858

5959
virtual ~Logger() = default;
6060

61+
Logger(Logger &&) = default;
62+
Logger &operator=(Logger &&rhs) = default;
63+
Logger(Logger &) = default;
64+
Logger &operator=(Logger &rhs) = default;
65+
6166
virtual void Log(std::unique_ptr<santa::EnrichedMessage> msg);
6267

6368
void LogAllowlist(const santa::Message &msg, const std::string_view hash);
@@ -88,13 +93,60 @@ class Logger : public Timer<Logger> {
8893
friend class santa::LoggerPeer;
8994

9095
private:
96+
class ExportTracker {
97+
public:
98+
static ExportTracker Create() {
99+
dispatch_queue_t q = dispatch_queue_create("com.northpolesec.santa.daemon.export_tracker",
100+
DISPATCH_QUEUE_SERIAL_WITH_AUTORELEASE_POOL);
101+
return ExportTracker(q);
102+
}
103+
104+
ExportTracker(dispatch_queue_t q) : q_(q) {}
105+
ExportTracker(ExportTracker &&) = default;
106+
ExportTracker &operator=(ExportTracker &&rhs) = default;
107+
ExportTracker(ExportTracker &) = default;
108+
ExportTracker &operator=(ExportTracker &rhs) = default;
109+
110+
/// Track a new key. If the key isn't yet tracked, its value will be set
111+
/// to false. If the key is already tracked, its value will not be changed.
112+
void Track(std::string file_path) {
113+
dispatch_sync(q_, ^{
114+
file_state_.try_emplace(std::move(file_path), false);
115+
});
116+
}
117+
118+
/// Mark the given key as completed. If the key doesn't previously exist,
119+
/// it will automatically start being tracked.
120+
void AckCompleted(std::string file_path) {
121+
dispatch_sync(q_, ^{
122+
file_state_.insert_or_assign(std::move(file_path), true);
123+
});
124+
}
125+
126+
/// Empty the map and return the previous state
127+
absl::flat_hash_map<std::string, bool> Drain() {
128+
__block absl::flat_hash_map<std::string, bool> return_state;
129+
dispatch_sync(q_, ^{
130+
std::swap(return_state, file_state_);
131+
});
132+
return return_state;
133+
}
134+
135+
friend class santa::LoggerPeer;
136+
137+
private:
138+
absl::flat_hash_map<std::string, bool> file_state_;
139+
dispatch_queue_t q_;
140+
};
141+
91142
void ExportTelemetrySerialized();
92143

93144
SNTSyncdQueue *syncd_queue_;
94145
GetExportConfigBlock get_export_config_block_;
95146
TelemetryEvent telemetry_mask_;
96147
std::shared_ptr<santa::Serializer> serializer_;
97148
std::shared_ptr<santa::Writer> writer_;
149+
ExportTracker tracker_;
98150
dispatch_queue_t export_queue_;
99151
};
100152

Source/santad/Logs/EndpointSecurity/Logger.mm

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@
105105
get_export_config_block_(get_export_config_block),
106106
telemetry_mask_(telemetry_mask),
107107
serializer_(std::move(serializer)),
108-
writer_(std::move(writer)) {
108+
writer_(std::move(writer)),
109+
tracker_(ExportTracker::Create()) {
109110
// Provide a default block instead of leaving nil
110111
if (get_export_config_block_ == nil) {
111112
get_export_config_block_ = ^SNTExportConfiguration *() {
@@ -136,44 +137,46 @@
136137

137138
// Get a copy of the current export config to be used for the entire export
138139
SNTExportConfiguration *export_config = get_export_config_block_();
139-
140-
// Track which files have been provided by the writer for exporting as well
141-
// as their export status so the writer can know which ones to clean up.
142-
__block absl::flat_hash_map<std::string, bool> files_exported;
140+
if (!export_config) {
141+
LOGW(@"Telemetry export enabled, but no export configuration is set.");
142+
return;
143+
}
143144

144145
while (std::optional<std::string> file_to_export = writer_->NextFileToExport()) {
145146
NSString *path = @((*file_to_export).c_str());
146147

147148
NSFileHandle *handle = [NSFileHandle fileHandleForReadingAtPath:path];
148149
if (!handle) {
149150
LOGW(@"Failed to get a file handle for telemetry file to export: %@", path);
150-
files_exported.insert_or_assign(*file_to_export, true);
151+
tracker_.AckCompleted(*file_to_export);
151152
continue;
152153
}
153154

154155
struct stat sb;
155156
if (fstat(handle.fileDescriptor, &sb) != 0) {
156157
LOGW(@"Failed to stat telemetry file to export: %@", path);
157-
files_exported.insert_or_assign(*file_to_export, true);
158+
tracker_.AckCompleted(*file_to_export);
158159
continue;
159160
}
160161

161162
if (!S_ISREG(sb.st_mode)) {
162163
LOGW(@"Telemetry file to export is not a regular file: %@", path);
163-
files_exported.insert_or_assign(*file_to_export, true);
164+
tracker_.AckCompleted(*file_to_export);
164165
continue;
165166
}
166167

167168
// Track all files as initially unsuccessfully processed
168169
// in case the export times out.
169-
files_exported.insert_or_assign(*file_to_export, false);
170+
tracker_.Track(*file_to_export);
170171

171172
dispatch_group_enter(group);
172173
[syncd_queue_ exportTelemetryFile:handle
173174
config:export_config
174175
completionHandler:^(BOOL success) {
175176
[handle closeFile];
176-
files_exported.insert_or_assign(*file_to_export, success);
177+
if (success) {
178+
tracker_.AckCompleted(*file_to_export);
179+
}
177180
dispatch_group_leave(group);
178181
}];
179182
}
@@ -183,7 +186,7 @@
183186
LOGW(@"Timed out waiting for telemetry to export.");
184187
}
185188

186-
writer_->FilesExported(files_exported);
189+
writer_->FilesExported(tracker_.Drain());
187190
}
188191

189192
void Logger::Log(std::unique_ptr<EnrichedMessage> msg) {

Source/santad/Logs/EndpointSecurity/LoggerTest.mm

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@
6060

6161
class LoggerPeer : public Logger {
6262
public:
63-
// Make base class constructors visible
63+
// Make base class constructors and members visible
6464
using Logger::Logger;
65+
using Logger::serializer_;
66+
using Logger::tracker_;
67+
using Logger::writer_;
6568

6669
LoggerPeer(std::unique_ptr<Logger> l)
6770
: Logger(nil, nil, TelemetryEvent::kEverything, l->serializer_, l->writer_) {}
6871

69-
std::shared_ptr<santa::Serializer> Serializer() { return serializer_; }
70-
71-
std::shared_ptr<santa::Writer> Writer() { return writer_; }
72+
absl::flat_hash_map<std::string, bool> TrackerState() { return tracker_.file_state_; }
7273
};
7374

7475
} // namespace santa
@@ -120,32 +121,32 @@ - (void)testCreate {
120121
LoggerPeer logger(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
121122
SNTEventLogTypeFilelog, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1,
122123
1, 1));
123-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.Serializer()));
124-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<File>(logger.Writer()));
124+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.serializer_));
125+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<File>(logger.writer_));
125126

126127
logger = LoggerPeer(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
127128
SNTEventLogTypeSyslog, nil, @"/tmp/temppy", @"/tmp/spool", 1,
128129
1, 1, 1));
129-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.Serializer()));
130-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Syslog>(logger.Writer()));
130+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.serializer_));
131+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Syslog>(logger.writer_));
131132

132133
logger = LoggerPeer(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
133134
SNTEventLogTypeNull, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1,
134135
1, 1));
135-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Empty>(logger.Serializer()));
136-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Null>(logger.Writer()));
136+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Empty>(logger.serializer_));
137+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Null>(logger.writer_));
137138

138139
logger = LoggerPeer(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
139140
SNTEventLogTypeProtobuf, nil, @"/tmp/temppy", @"/tmp/spool", 1,
140141
1, 1, 1));
141-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Protobuf>(logger.Serializer()));
142-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Spool>(logger.Writer()));
142+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Protobuf>(logger.serializer_));
143+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Spool>(logger.writer_));
143144

144145
logger = LoggerPeer(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
145146
SNTEventLogTypeJSON, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1,
146147
1, 1));
147-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Protobuf>(logger.Serializer()));
148-
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<File>(logger.Writer()));
148+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Protobuf>(logger.serializer_));
149+
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<File>(logger.writer_));
149150
}
150151

151152
- (void)testLog {
@@ -263,4 +264,60 @@ - (void)testLogFileAccess {
263264
XCTBubbleMockVerifyAndClearExpectations(mockWriter.get());
264265
}
265266

267+
- (void)testExportTracker {
268+
auto mockESApi = std::make_shared<MockEndpointSecurityAPI>();
269+
LoggerPeer logger(Logger::Create(mockESApi, nil, nil, TelemetryEvent::kEverything,
270+
SNTEventLogTypeNull, nil, @"", @"", 1, 1, 1, 1));
271+
272+
// Nothing in the map initially
273+
auto map = logger.tracker_.Drain();
274+
XCTAssertEqual(logger.TrackerState().size(), 0);
275+
XCTAssertEqual(map.size(), 0);
276+
277+
// Start tracking a couple of keys
278+
logger.tracker_.Track("foo");
279+
XCTAssertEqual(logger.TrackerState().size(), 1);
280+
XCTAssertEqual(logger.TrackerState().at("foo"), false);
281+
282+
logger.tracker_.Track("bar");
283+
XCTAssertEqual(logger.TrackerState().size(), 2);
284+
XCTAssertEqual(logger.TrackerState().at("bar"), false);
285+
286+
// Change state of an existing key
287+
logger.tracker_.AckCompleted("bar");
288+
XCTAssertEqual(logger.TrackerState().at("bar"), true);
289+
290+
// Change state of a non-existing key, it should be created
291+
logger.tracker_.AckCompleted("cake");
292+
XCTAssertEqual(logger.TrackerState().at("cake"), true);
293+
294+
// Drain the tracker
295+
map = logger.tracker_.Drain();
296+
XCTAssertEqual(logger.TrackerState().size(), 0);
297+
XCTAssertEqual(map.size(), 3);
298+
XCTAssertEqual(map.at("foo"), false);
299+
XCTAssertEqual(map.at("bar"), true);
300+
XCTAssertEqual(map.at("cake"), true);
301+
302+
// Add some more keys after draining
303+
logger.tracker_.Track("baz");
304+
logger.tracker_.AckCompleted("qaz");
305+
XCTAssertEqual(logger.TrackerState().size(), 2);
306+
XCTAssertEqual(logger.TrackerState().at("baz"), false);
307+
XCTAssertEqual(logger.TrackerState().at("qaz"), true);
308+
309+
// Track something already ack'd, ensure value doesn't change
310+
logger.tracker_.Track("qaz");
311+
XCTAssertEqual(logger.TrackerState().size(), 2);
312+
XCTAssertEqual(logger.TrackerState().at("baz"), false);
313+
XCTAssertEqual(logger.TrackerState().at("qaz"), true);
314+
315+
// One last drain for fun
316+
map = logger.tracker_.Drain();
317+
XCTAssertEqual(logger.TrackerState().size(), 0);
318+
XCTAssertEqual(map.size(), 2);
319+
XCTAssertEqual(map.at("baz"), false);
320+
XCTAssertEqual(map.at("qaz"), true);
321+
}
322+
266323
@end

0 commit comments

Comments
 (0)