Skip to content

concord-server: refactor WebSocketChannelManager, allow message sources in plugins #1056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import com.walmartlabs.concord.server.boot.servlets.FormServletHolder;
import com.walmartlabs.concord.server.boot.statics.StaticResourcesConfigurator;
import com.walmartlabs.concord.server.boot.validation.ValidationModule;
import com.walmartlabs.concord.server.websocket.ConcordWebSocketServlet;
import com.walmartlabs.concord.server.agent.websocket.ConcordWebSocketServlet;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.web.mgt.WebSecurityManager;
import org.eclipse.jetty.ee8.servlet.FilterHolder;
Expand Down Expand Up @@ -60,10 +60,6 @@ public void configure(Binder binder) {

newSetBinder(binder, FilterHolder.class).addBinding().to(ShiroFilterHolder.class).in(SINGLETON);

// HttpServlet

newSetBinder(binder, HttpServlet.class).addBinding().to(ConcordWebSocketServlet.class).in(SINGLETON);

// ServletHolder

bindServletHolder(binder, FormServletHolder.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.inject.Module;
import com.walmartlabs.concord.server.boot.BackgroundTasks;
import com.walmartlabs.concord.server.boot.HttpServer;
import com.walmartlabs.concord.server.message.MessageChannelManager;

import javax.inject.Inject;
import java.util.Collection;
Expand All @@ -43,6 +44,9 @@ public final class ConcordServer {
@Inject
private HttpServer server;

@Inject
private MessageChannelManager messageChannelManager;

private final Lock controlMutex = new ReentrantLock();

public static ConcordServer withModules(Module... modules) throws Exception {
Expand Down Expand Up @@ -73,6 +77,11 @@ public ConcordServer start() throws Exception {
public void stop() throws Exception {
controlMutex.lock();
try {
if (messageChannelManager != null) {
messageChannelManager.shutdown();
messageChannelManager = null;
}

if (server != null) {
server.stop();
server = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import com.walmartlabs.concord.db.DatabaseModule;
import com.walmartlabs.concord.dependencymanager.DependencyManagerConfiguration;
import com.walmartlabs.concord.server.agent.AgentModule;
import com.walmartlabs.concord.server.agent.websocket.WebSocketModule;
import com.walmartlabs.concord.server.audit.AuditLogModule;
import com.walmartlabs.concord.server.boot.BackgroundTasks;
import com.walmartlabs.concord.server.cfg.ConfigurationModule;
import com.walmartlabs.concord.server.cfg.DatabaseConfigurationModule;
import com.walmartlabs.concord.server.console.ConsoleModule;
import com.walmartlabs.concord.server.events.EventModule;
import com.walmartlabs.concord.server.message.MessageChannelManager;
import com.walmartlabs.concord.server.metrics.MetricModule;
import com.walmartlabs.concord.server.org.OrganizationModule;
import com.walmartlabs.concord.server.policy.PolicyModule;
Expand All @@ -46,7 +48,6 @@
import com.walmartlabs.concord.server.task.TaskSchedulerModule;
import com.walmartlabs.concord.server.template.TemplateModule;
import com.walmartlabs.concord.server.user.UserModule;
import com.walmartlabs.concord.server.websocket.WebSocketModule;

import javax.inject.Named;
import java.security.SecureRandom;
Expand Down Expand Up @@ -88,10 +89,11 @@ public void configure(Binder binder) {
binder.bind(Listeners.class).in(SINGLETON);
binder.bind(SecureRandom.class).toProvider(SecureRandomProvider.class);

binder.bind(MessageChannelManager.class).in(SINGLETON);

binder.bind(DependencyManagerConfiguration.class).toProvider(DependencyManagerConfigurationProvider.class);

binder.install(new ApiServerModule());

binder.install(new AgentModule());
binder.install(new ApiKeyModule());
binder.install(new AuditLogModule());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import com.walmartlabs.concord.server.boot.BackgroundTasks;
import com.walmartlabs.concord.server.sdk.rest.Resource;
import com.walmartlabs.concord.server.task.TaskScheduler;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import com.walmartlabs.concord.server.message.MessageChannelManager;
import org.jooq.Configuration;

import javax.inject.Inject;
Expand All @@ -41,18 +41,18 @@ public class ServerResource implements Resource {

private final TaskScheduler taskScheduler;
private final BackgroundTasks backgroundTasks;
private final WebSocketChannelManager webSocketChannelManager;
private final MessageChannelManager messageChannelManager;
private final PingDao pingDao;

@Inject
public ServerResource(TaskScheduler taskScheduler,
BackgroundTasks backgroundTasks,
WebSocketChannelManager webSocketChannelManager,
MessageChannelManager messageChannelManager,
PingDao pingDao) {

this.taskScheduler = taskScheduler;
this.backgroundTasks = backgroundTasks;
this.webSocketChannelManager = webSocketChannelManager;
this.messageChannelManager = messageChannelManager;
this.pingDao = pingDao;
}

Expand All @@ -75,9 +75,8 @@ public VersionResponse version() {
@POST
@Path("/maintenance-mode")
public void maintenanceMode() {
messageChannelManager.shutdown();
backgroundTasks.stop();

webSocketChannelManager.shutdown();
taskScheduler.stop();
}

Expand All @@ -88,7 +87,8 @@ public TestBean test() {
return new TestBean(OffsetDateTime.now());
}

static class PingDao extends AbstractDao {
@SuppressWarnings("resource")
public static class PingDao extends AbstractDao {

@Inject
public PingDao(@MainDB Configuration cfg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
* =====
*/

import com.walmartlabs.concord.server.message.MessageChannel;
import com.walmartlabs.concord.server.message.MessageChannelManager;
import com.walmartlabs.concord.server.agent.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.UuidGenerator;
import com.walmartlabs.concord.server.queueclient.message.MessageType;
import com.walmartlabs.concord.server.queueclient.message.ProcessRequest;
import com.walmartlabs.concord.server.sdk.ProcessKey;
import com.walmartlabs.concord.server.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import org.jooq.DSLContext;

import javax.inject.Inject;
Expand All @@ -40,25 +41,27 @@
public class AgentManager {

private final AgentCommandsDao commandQueue;
private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;
private final UuidGenerator uuidGenerator;

@Inject
public AgentManager(AgentCommandsDao commandQueue,
WebSocketChannelManager channelManager, UuidGenerator uuidGenerator) {
MessageChannelManager channelManager,
UuidGenerator uuidGenerator) {

this.commandQueue = requireNonNull(commandQueue);
this.channelManager = requireNonNull(channelManager);
this.uuidGenerator = requireNonNull(uuidGenerator);
}

public Collection<AgentWorkerEntry> getAvailableAgents() {
Map<WebSocketChannel, ProcessRequest> reqs = channelManager.getRequests(MessageType.PROCESS_REQUEST);
Map<MessageChannel, ProcessRequest> reqs = channelManager.getRequests(MessageType.PROCESS_REQUEST);
return reqs.entrySet().stream()
.filter(r -> r.getKey() instanceof WebSocketChannel) // TODO a better way
.map(r -> AgentWorkerEntry.builder()
.channelId(r.getKey().getChannelId())
.agentId(r.getKey().getAgentId())
.userAgent(r.getKey().getUserAgent())
.userAgent(((WebSocketChannel) r.getKey()).getUserAgent())
.capabilities(r.getValue().getCapabilities())
.build())
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@

import javax.annotation.Nullable;
import java.util.Map;
import java.util.UUID;

@Value.Immutable
@JsonInclude(JsonInclude.Include.NON_EMPTY)
@JsonSerialize(as = ImmutableAgentWorkerEntry.class)
@JsonDeserialize(as = ImmutableAgentWorkerEntry.class)
public interface AgentWorkerEntry {

UUID channelId();
String channelId();

@Nullable // backward-compatibility with old agent versions
String agentId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import com.walmartlabs.concord.server.PeriodicTask;
import com.walmartlabs.concord.server.agent.AgentCommand;
import com.walmartlabs.concord.server.agent.Commands;
import com.walmartlabs.concord.server.message.MessageChannel;
import com.walmartlabs.concord.server.message.MessageChannelManager;
import com.walmartlabs.concord.server.cfg.AgentConfiguration;
import com.walmartlabs.concord.server.jooq.tables.records.AgentCommandsRecord;
import com.walmartlabs.concord.server.queueclient.message.CommandRequest;
import com.walmartlabs.concord.server.queueclient.message.CommandResponse;
import com.walmartlabs.concord.server.queueclient.message.MessageType;
import com.walmartlabs.concord.server.sdk.metrics.WithTimer;
import com.walmartlabs.concord.server.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import org.jooq.Configuration;
import org.jooq.DSLContext;
import org.slf4j.Logger;
Expand All @@ -61,12 +61,12 @@ public class Dispatcher extends PeriodicTask {
private static final int BATCH_SIZE = 10;

private final DispatcherDao dao;
private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;

@Inject
public Dispatcher(AgentConfiguration cfg,
DispatcherDao dao,
WebSocketChannelManager channelManager) {
MessageChannelManager channelManager) {

super(cfg.getCommandPollDelay().toMillis(), ERROR_DELAY);
this.dao = dao;
Expand All @@ -75,7 +75,7 @@ public Dispatcher(AgentConfiguration cfg,

@Override
protected boolean performTask() {
Map<WebSocketChannel, CommandRequest> requests = this.channelManager.getRequests(MessageType.COMMAND_REQUEST);
Map<MessageChannel, CommandRequest> requests = this.channelManager.getRequests(MessageType.COMMAND_REQUEST);
if (requests.isEmpty()) {
return false;
}
Expand Down Expand Up @@ -148,7 +148,6 @@ private AgentCommand findCandidate(CommandRequest req, List<AgentCommand> candid
}

private void sendResponse(Match match, AgentCommand response) {
WebSocketChannel channel = match.request.channel;
long correlationId = match.request.request.getCorrelationId();

CommandType type = CommandType.valueOf((String) response.getData().remove(Commands.TYPE_KEY));
Expand All @@ -157,7 +156,8 @@ private void sendResponse(Match match, AgentCommand response) {
payload.put("type", type.toString());
payload.putAll(response.getData());

boolean success = channelManager.sendResponse(channel.getChannelId(), CommandResponse.cancel(correlationId, payload));
MessageChannel channel = match.request.channel;
boolean success = channelManager.sendMessage(channel.getChannelId(), CommandResponse.cancel(correlationId, payload));
if (success) {
log.info("sendResponse ['{}'] -> done", correlationId);
} else {
Expand Down Expand Up @@ -223,25 +223,10 @@ private AgentCommand convert(AgentCommandsRecord r) {
}
}

private static final class Match {
private record Match(Request request, AgentCommand command) {

private final Request request;
private final AgentCommand command;

private Match(Request request, AgentCommand command) {
this.request = request;
this.command = command;
}
}

private static final class Request {

private final WebSocketChannel channel;
private final CommandRequest request;

private Request(WebSocketChannel channel, CommandRequest request) {
this.channel = channel;
this.request = request;
}
private record Request(MessageChannel channel, CommandRequest request) {
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.walmartlabs.concord.server.websocket;
package com.walmartlabs.concord.server.agent.websocket;

/*-
* *****
Expand All @@ -20,11 +20,11 @@
* =====
*/

import com.walmartlabs.concord.server.message.MessageChannelManager;
import com.walmartlabs.concord.server.security.apikey.ApiKeyDao;
import org.eclipse.jetty.ee8.websocket.server.JettyWebSocketServlet;
import org.eclipse.jetty.ee8.websocket.server.JettyWebSocketServletFactory;


import javax.inject.Inject;
import javax.servlet.annotation.WebServlet;

Expand All @@ -33,11 +33,11 @@ public class ConcordWebSocketServlet extends JettyWebSocketServlet {

private static final long serialVersionUID = 1L;

private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;
private final ApiKeyDao apiKeyDao;

@Inject
public ConcordWebSocketServlet(WebSocketChannelManager channelManager, ApiKeyDao apiKeyDao) {
public ConcordWebSocketServlet(MessageChannelManager channelManager, ApiKeyDao apiKeyDao) {
this.channelManager = channelManager;
this.apiKeyDao = apiKeyDao;
}
Expand Down
Loading