diff --git a/AUTHORS b/AUTHORS index d86fd74..23a03fe 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1 +1,5 @@ Patrick Stuedi +Adrian Schuepbach +Jonas Pfefferle +Animesh Trivedi + diff --git a/pom.xml b/pom.xml index cca234f..6b8f28a 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ com.ibm.darpc darpc jar - 1.3 + 1.4 darpc DaRPC (Data Center RPC) is a Java library for low latency Remote Procedure Call (RPC) http://github.com/zrlio/darpc @@ -30,7 +30,11 @@ Animesh Trivedi atr@zurich.ibm.com - + + Adrian Schuepbach + dri@zurich.ibm.com + + scm:git:git://github.com/zrlio/darpc.git diff --git a/src/main/java/com/ibm/darpc/DaRPCClientGroup.java b/src/main/java/com/ibm/darpc/DaRPCClientGroup.java index aa4652b..e06e54a 100644 --- a/src/main/java/com/ibm/darpc/DaRPCClientGroup.java +++ b/src/main/java/com/ibm/darpc/DaRPCClientGroup.java @@ -9,17 +9,17 @@ import com.ibm.disni.rdma.verbs.RdmaCmId; public class DaRPCClientGroup extends DaRPCEndpointGroup, R, T> { - public static DaRPCClientGroup createClientGroup(DaRPCProtocol protocol, int timeout, int maxinline, int recvQueue, int sendQueue) throws Exception { - DaRPCClientGroup group = new DaRPCClientGroup(protocol, timeout, maxinline, recvQueue, sendQueue); + public static DaRPCClientGroup createClientGroup(DaRPCProtocol protocol, DaRPCMemPool, R, T> memPool, int timeout, int maxinline, int recvQueue, int sendQueue) throws Exception { + DaRPCClientGroup group = new DaRPCClientGroup(protocol, memPool, timeout, maxinline, recvQueue, sendQueue); group.init(new RpcClientFactory(group)); return group; - } - - private DaRPCClientGroup(DaRPCProtocol protocol, int timeout, int maxinline, int recvQueue, int sendQueue) + } + + private DaRPCClientGroup(DaRPCProtocol protocol, DaRPCMemPool, R, T> memPool, int timeout, int maxinline, int recvQueue, int sendQueue) throws Exception { - super(protocol, timeout, maxinline, recvQueue, sendQueue); + super(protocol, memPool, timeout, maxinline, recvQueue, sendQueue); } - + @Override public void allocateResources(DaRPCClientEndpoint endpoint) throws Exception { @@ -38,18 +38,18 @@ public IbvQP createQpProvider(DaRPCClientEndpoint endpoint) throws IOExcep IbvQP qp = this.createQP(endpoint.getIdPriv(), endpoint.getPd(), cq); return qp; } - + public static class RpcClientFactory implements RdmaEndpointFactory> { private DaRPCClientGroup group; - + public RpcClientFactory(DaRPCClientGroup group){ this.group = group; } - + @Override public DaRPCClientEndpoint createEndpoint(RdmaCmId id, boolean serverSide) throws IOException { return new DaRPCClientEndpoint(group, id, serverSide); } - } + } } diff --git a/src/main/java/com/ibm/darpc/DaRPCEndpoint.java b/src/main/java/com/ibm/darpc/DaRPCEndpoint.java index eabb4e7..01ed96c 100644 --- a/src/main/java/com/ibm/darpc/DaRPCEndpoint.java +++ b/src/main/java/com/ibm/darpc/DaRPCEndpoint.java @@ -27,7 +27,6 @@ import java.util.LinkedList; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; @@ -39,16 +38,12 @@ public abstract class DaRPCEndpoint extends RdmaEndpoint { private static final Logger logger = LoggerFactory.getLogger("com.ibm.darpc"); - private static final int headerSize = 4; - + static final int HEADERSIZE = 4; //size of the ticket + public abstract void dispatchReceive(ByteBuffer buffer, int ticket, int recvIndex) throws IOException; public abstract void dispatchSend(int ticket) throws IOException; - + private DaRPCEndpointGroup, R, T> rpcGroup; - private ByteBuffer dataBuffer; - private IbvMr dataMr; - private ByteBuffer receiveBuffer; - private ByteBuffer sendBuffer; private ByteBuffer[] recvBufs; private ByteBuffer[] sendBufs; private SVCPostRecv[] recvCall; @@ -56,96 +51,90 @@ public abstract class DaRPCEndpoint pendingPostSend; private ArrayBlockingQueue freePostSend; private AtomicLong ticketCount; - private int pipelineLength; + final private int sendPipelineLength; + final private int recvPipelineLength; private int payloadSize; private int rawBufferSize; private int maxinline; private AtomicLong messagesSent; private AtomicLong messagesReceived; - - + + public DaRPCEndpoint(DaRPCEndpointGroup, R, T> endpointGroup, RdmaCmId idPriv, boolean serverSide) throws IOException { super(endpointGroup, idPriv, serverSide); this.rpcGroup = endpointGroup; this.maxinline = rpcGroup.getMaxInline(); this.payloadSize = rpcGroup.getBufferSize(); - this.rawBufferSize = headerSize + this.payloadSize; - this.pipelineLength = rpcGroup.recvQueueSize(); - this.freePostSend = new ArrayBlockingQueue(pipelineLength); + this.rawBufferSize = HEADERSIZE + this.payloadSize; + this.sendPipelineLength = rpcGroup.sendQueueSize(); + this.recvPipelineLength = rpcGroup.recvQueueSize(); + this.freePostSend = new ArrayBlockingQueue(sendPipelineLength); this.pendingPostSend = new ConcurrentHashMap(); - this.recvBufs = new ByteBuffer[pipelineLength]; - this.sendBufs = new ByteBuffer[pipelineLength]; - this.recvCall = new SVCPostRecv[pipelineLength]; - this.sendCall = new SVCPostSend[pipelineLength]; + this.recvBufs = new ByteBuffer[recvPipelineLength]; + this.sendBufs = new ByteBuffer[sendPipelineLength]; + this.recvCall = new SVCPostRecv[recvPipelineLength]; + this.sendCall = new SVCPostSend[sendPipelineLength]; this.ticketCount = new AtomicLong(0); this.messagesSent = new AtomicLong(0); this.messagesReceived = new AtomicLong(0); - logger.info("RPC client endpoint, with payload buffer size = " + payloadSize + ", pipeline " + pipelineLength); + logger.info("RPC client endpoint, with payload buffer size = " + payloadSize + ", send pipeline " + + sendPipelineLength + ", receive pipeline " + recvPipelineLength); } - + public void init() throws IOException { - int sendBufferOffset = pipelineLength * rawBufferSize; - - /* Main data buffer for sends and receives. Will be split into two regions, - * one for sends and one for receives. - */ - dataBuffer = ByteBuffer.allocateDirect(pipelineLength * rawBufferSize * 2); - /* Only do one memory registration with the IB card. */ - dataMr = registerMemory(dataBuffer).execute().free().getMr(); - - /* Receive memory region is the first half of the main buffer. */ - dataBuffer.limit(dataBuffer.position() + sendBufferOffset); - receiveBuffer = dataBuffer.slice(); - - /* Send memory region is the second half of the main buffer. */ - dataBuffer.position(sendBufferOffset); - dataBuffer.limit(dataBuffer.position() + sendBufferOffset); - sendBuffer = dataBuffer.slice(); - - for(int i = 0; i < pipelineLength; i++) { - /* Create single receive buffers within the receive region in form of slices. */ - receiveBuffer.position(i * rawBufferSize); - receiveBuffer.limit(receiveBuffer.position() + rawBufferSize); - recvBufs[i] = receiveBuffer.slice(); - - /* Create single send buffers within the send region in form of slices. */ - sendBuffer.position(i * rawBufferSize); - sendBuffer.limit(sendBuffer.position() + rawBufferSize); - sendBufs[i] = sendBuffer.slice(); + for(int i = 0; i < sendPipelineLength; i++) { + try { + sendBufs[i] = rpcGroup.getWRBuffer(this); + } catch (Exception e) { + throw new IOException(e); + } - this.recvCall[i] = setupRecvTask(i); this.sendCall[i] = setupSendTask(i); freePostSend.add(sendCall[i]); + } + for(int i = 0; i < recvPipelineLength; i++) { + try { + recvBufs[i] = rpcGroup.getWRBuffer(this); + } catch (Exception e) { + throw new IOException(e); + } + + this.recvCall[i] = setupRecvTask(i); recvCall[i].execute(); } } @Override public synchronized void close() throws IOException, InterruptedException { + for(int i = 0; i < sendPipelineLength; i++) { + rpcGroup.freeBuffer(this, sendBufs[i]); + } + for(int i = 0; i < recvPipelineLength; i++) { + rpcGroup.freeBuffer(this, recvBufs[i]); + } super.close(); - deregisterMemory(dataMr); - } - + } + public long getMessagesSent() { return messagesSent.get(); } - + public long getMessagesReceived() { return messagesReceived.get(); } - + protected boolean sendMessage(DaRPCMessage message, int ticket) throws IOException { SVCPostSend postSend = freePostSend.poll(); if (postSend != null){ int index = (int) postSend.getWrMod(0).getWr_id(); sendBufs[index].putInt(0, ticket); - sendBufs[index].position(4); - int written = 4 + message.write(sendBufs[index]); + sendBufs[index].position(HEADERSIZE); + int written = HEADERSIZE + message.write(sendBufs[index]); postSend.getWrMod(0).getSgeMod(0).setLength(written); postSend.getWrMod(0).setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); if (written <= maxinline) { postSend.getWrMod(0).setSend_flags(postSend.getWrMod(0).getSend_flags() | IbvSendWR.IBV_SEND_INLINE); - } + } pendingPostSend.put(ticket, postSend); postSend.execute(); messagesSent.incrementAndGet(); @@ -154,33 +143,33 @@ protected boolean sendMessage(DaRPCMessage message, int ticket) throws IOExcepti return false; } } - + protected void postRecv(int index) throws IOException { recvCall[index].execute(); - } - + } + public void freeSend(int ticket) throws IOException { SVCPostSend sendOperation = pendingPostSend.remove(ticket); if (sendOperation == null) { throw new IOException("no pending ticket " + ticket + ", current ticket count " + ticketCount.get()); } this.freePostSend.add(sendOperation); - } - + } + public void dispatchCqEvent(IbvWC wc) throws IOException { if (wc.getStatus() == 5){ //flush return; } else if (wc.getStatus() != 0){ throw new IOException("Faulty operation! wc.status " + wc.getStatus()); - } - + } + if (wc.getOpcode() == 128){ //receiving a message int index = (int) wc.getWr_id(); ByteBuffer recvBuffer = recvBufs[index]; int ticket = recvBuffer.getInt(0); - recvBuffer.position(4); + recvBuffer.position(HEADERSIZE); dispatchReceive(recvBuffer, ticket, index); } else if (wc.getOpcode() == 0) { //send completion @@ -190,9 +179,9 @@ public void dispatchCqEvent(IbvWC wc) throws IOException { dispatchSend(ticket); } else { throw new IOException("Unkown opcode " + wc.getOpcode()); - } - } - + } + } + private SVCPostSend setupSendTask(int wrid) throws IOException { ArrayList sendWRs = new ArrayList(1); LinkedList sgeList = new LinkedList(); @@ -200,7 +189,7 @@ private SVCPostSend setupSendTask(int wrid) throws IOException { IbvSge sge = new IbvSge(); sge.setAddr(MemoryUtils.getAddress(sendBufs[wrid])); sge.setLength(rawBufferSize); - sge.setLkey(dataMr.getLkey()); + sge.setLkey(rpcGroup.getLKey(sendBufs[wrid])); sgeList.add(sge); IbvSendWR sendWR = new IbvSendWR(); @@ -220,7 +209,7 @@ private SVCPostRecv setupRecvTask(int wrid) throws IOException { IbvSge sge = new IbvSge(); sge.setAddr(MemoryUtils.getAddress(recvBufs[wrid])); sge.setLength(rawBufferSize); - sge.setLkey(dataMr.getLkey()); + sge.setLkey(rpcGroup.getLKey(recvBufs[wrid])); sgeList.add(sge); IbvRecvWR recvWR = new IbvRecvWR(); diff --git a/src/main/java/com/ibm/darpc/DaRPCEndpointGroup.java b/src/main/java/com/ibm/darpc/DaRPCEndpointGroup.java index d95f746..acff4b7 100644 --- a/src/main/java/com/ibm/darpc/DaRPCEndpointGroup.java +++ b/src/main/java/com/ibm/darpc/DaRPCEndpointGroup.java @@ -24,7 +24,7 @@ import java.io.IOException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import java.nio.ByteBuffer; import com.ibm.disni.rdma.verbs.*; import com.ibm.disni.rdma.*; @@ -32,26 +32,34 @@ public abstract class DaRPCEndpointGroup, R extends DaRPCMessage, T extends DaRPCMessage> extends RdmaEndpointGroup { private static final Logger logger = LoggerFactory.getLogger("com.ibm.darpc"); private static int DARPC_VERSION = 50; - + private int recvQueueSize; private int sendQueueSize; private int timeout; private int bufferSize; private int maxInline; - + private DaRPCMemPool memPool; + public static int getVersion(){ return DARPC_VERSION; - } - - protected DaRPCEndpointGroup(DaRPCProtocol protocol, int timeout, int maxinline, int recvQueue, int sendQueue) throws Exception { + } + + protected DaRPCEndpointGroup(DaRPCProtocol protocol, DaRPCMemPool memPool, int timeout, int maxinline, int recvQueue, int sendQueue) throws Exception { super(timeout); this.recvQueueSize = recvQueue; - this.sendQueueSize = Math.max(recvQueue, sendQueue); + this.sendQueueSize = sendQueue; this.timeout = timeout; this.bufferSize = Math.max(protocol.createRequest().size(), protocol.createResponse().size()); this.maxInline = maxinline; - } - + this.memPool = memPool; + } + + @Override + public void init(RdmaEndpointFactory factory) { + super.init(factory); + memPool.init(this); + } + protected synchronized IbvQP createQP(RdmaCmId id, IbvPd pd, IbvCQ cq) throws IOException{ IbvQPInitAttr attr = new IbvQPInitAttr(); attr.cap().setMax_recv_wr(recvQueueSize); @@ -61,33 +69,47 @@ protected synchronized IbvQP createQP(RdmaCmId id, IbvPd pd, IbvCQ cq) throws IO attr.cap().setMax_inline_data(maxInline); attr.setQp_type(IbvQP.IBV_QPT_RC); attr.setRecv_cq(cq); - attr.setSend_cq(cq); + attr.setSend_cq(cq); IbvQP qp = id.createQP(pd, attr); return qp; } - + public int getTimeout() { return timeout; } - + public int getBufferSize() { return bufferSize; - } + } + public void close() throws IOException, InterruptedException { super.close(); + memPool.close(); logger.info("rpc group down"); - } - + } + public int recvQueueSize() { return recvQueueSize; } - + public int sendQueueSize() { return sendQueueSize; - } - + } + public int getMaxInline() { return maxInline; } + + ByteBuffer getWRBuffer(RdmaEndpoint endpoint) throws Exception { + return memPool.getBuffer(endpoint); + } + + void freeBuffer(RdmaEndpoint endpoint, ByteBuffer b) throws IOException { + memPool.freeBuffer(endpoint, b); + } + + int getLKey(ByteBuffer b) { + return memPool.getLKey(b); + } } diff --git a/src/main/java/com/ibm/darpc/DaRPCMemPool.java b/src/main/java/com/ibm/darpc/DaRPCMemPool.java new file mode 100644 index 0000000..9d30255 --- /dev/null +++ b/src/main/java/com/ibm/darpc/DaRPCMemPool.java @@ -0,0 +1,15 @@ +package com.ibm.darpc; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.NoSuchElementException; + +import com.ibm.disni.rdma.RdmaEndpoint; + +public interface DaRPCMemPool, R extends DaRPCMessage, T extends DaRPCMessage> { + void init(DaRPCEndpointGroup endpointGroup); + void close() throws IOException; + ByteBuffer getBuffer(RdmaEndpoint endpoint) throws IOException, NoSuchElementException; + void freeBuffer(RdmaEndpoint endpoint, ByteBuffer buffer) throws IOException; + public int getLKey(ByteBuffer b) throws IllegalArgumentException; +} diff --git a/src/main/java/com/ibm/darpc/DaRPCMemPoolImpl.java b/src/main/java/com/ibm/darpc/DaRPCMemPoolImpl.java new file mode 100644 index 0000000..20aa5c4 --- /dev/null +++ b/src/main/java/com/ibm/darpc/DaRPCMemPoolImpl.java @@ -0,0 +1,202 @@ +package com.ibm.darpc; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import java.util.LinkedList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.ibm.disni.rdma.RdmaEndpoint; +import com.ibm.disni.rdma.verbs.IbvMr; +import com.ibm.disni.rdma.verbs.IbvPd; +import com.ibm.disni.util.MemoryUtils; + +public class DaRPCMemPoolImpl, R extends DaRPCMessage, T extends DaRPCMessage> implements DaRPCMemPool { + private static final Logger logger = LoggerFactory.getLogger("com.ibm.darpc"); + private static final int defaultAllocationSize = 16 * 1024 * 1024; // 16MB + private static final String hugePageFileName = "/darpcmempoolimpl.mem"; + private final int allocationSize; + private final int alignmentSize; + private final int allocationLimit; + private int currentAllocationSize; + private final String hugePagePath; + private ConcurrentHashMap memoryRegions; + private int access; + private DaRPCEndpointGroup endpointGroup; + private ConcurrentHashMap> pdMap; + private List mrs; + private List hugePageFiles; + + public DaRPCMemPoolImpl(String hugePagePath, int allocationSize, int alignmentSize, int allocationLimit) throws IllegalArgumentException { + if (hugePagePath == null) { + logger.error("Hugepage path must be set"); + throw new IllegalArgumentException("Hugepage path must be set"); + } + this.hugePagePath = hugePagePath; + this.allocationSize = allocationSize; + this.alignmentSize = alignmentSize; + this.allocationLimit = allocationLimit; + this.currentAllocationSize = 0; + this.access = IbvMr.IBV_ACCESS_LOCAL_WRITE | IbvMr.IBV_ACCESS_REMOTE_WRITE | IbvMr.IBV_ACCESS_REMOTE_READ; + this.pdMap = new ConcurrentHashMap>(); + this.mrs = new LinkedList(); + memoryRegions = new ConcurrentHashMap(); + hugePageFiles = new LinkedList(); + } + + public DaRPCMemPoolImpl(String hugePagePath) throws IllegalArgumentException { + this(hugePagePath, defaultAllocationSize, 0, 16 * defaultAllocationSize); + } + + @Override + public void init(DaRPCEndpointGroup endpointGroup) { + this.endpointGroup = endpointGroup; + } + + @Override + public void close() throws IOException { + synchronized(this) { + for (IbvMr m : mrs) { + try { + m.deregMr().execute().free(); + } catch (IOException e) { + logger.error("Could not unregister memory region."); + e.printStackTrace(); + } + } + mrs = null; + for (String fileName : hugePageFiles) { + File f = new File(fileName); + f.delete(); + } + hugePageFiles = null; + } + } + + @Override + public ByteBuffer getBuffer(RdmaEndpoint endpoint) throws IOException, NoSuchElementException { + LinkedBlockingQueue freeList = pdMap.get(endpoint.getPd()); + + if (freeList == null) { + synchronized(this) { + freeList = pdMap.get(endpoint.getPd()); + if (freeList == null) { + freeList = new LinkedBlockingQueue(); + pdMap.put(endpoint.getPd(), freeList); + } + } + } + + ByteBuffer r = freeList.poll(); + + if (r == null) { + synchronized(this) { + r = freeList.poll(); + if (r == null) { + allocateHugePageBuffer(freeList, endpoint.getPd()); + } + r = freeList.poll(); + if (r == null) { + logger.error("Failed to allocate more buffers."); + throw new NoSuchElementException("Failed to allocate more buffers."); + } + } + } + r.clear(); + return r; + } + + @Override + public void freeBuffer(RdmaEndpoint endpoint, ByteBuffer buffer) { + LinkedBlockingQueue freeList = pdMap.get(endpoint.getPd()); + freeList.add(buffer); + } + + @Override + public int getLKey(ByteBuffer buffer) throws IllegalArgumentException { + return memoryRegions.get(MemoryUtils.getAddress(buffer)).getLkey(); + } + + // allocate a buffer from hugepages + private void allocateHugePageBuffer(LinkedBlockingQueue freeList, IbvPd pd) throws IOException { + int totalAllocationSize = allocationSize + alignmentSize; + if ((currentAllocationSize + totalAllocationSize) > allocationLimit) { + logger.error("Out of memory. Cannot allocate more buffers from hugepages. " + + "allocationSize = " + allocationSize + + ", alignmentSize = " + alignmentSize + + ", currentAllocationSize = " + currentAllocationSize + + ", allocationLimit = " + allocationLimit); + throw new IOException("Out of memory. Cannot allocate more buffers from hugepages." + + "allocationSize = " + allocationSize + + ", alignmentSize = " + alignmentSize + + ", currentAllocationSize = " + currentAllocationSize + + ", allocationLimit = " + allocationLimit); + + } + String newFile = this.hugePagePath + hugePageFileName + System.currentTimeMillis(); + RandomAccessFile randomFile = null; + try { + randomFile = new RandomAccessFile(newFile, "rw"); + } catch (FileNotFoundException e) { + logger.error("Path " + newFile + " to huge page path/file cannot be accessed."); + throw e; + } + hugePageFiles.add(newFile); + try { + randomFile.setLength(totalAllocationSize); + } catch (IOException e) { + logger.error("Could not set allocation length of mapped random access file on huge page directory."); + logger.error("allocaiton size = " + allocationSize + " , alignment size = " + alignmentSize); + logger.error("allocation size and alignment must be a multiple of the hugepage size."); + randomFile.close(); + throw e; + } + FileChannel channel = randomFile.getChannel(); + MappedByteBuffer mappedBuffer = null; + try { + mappedBuffer = channel.map(MapMode.READ_WRITE, 0, + totalAllocationSize); + } catch (IOException e) { + logger.error("Could not map the huge page file on path " + newFile); + randomFile.close(); + throw e; + } + randomFile.close(); + + currentAllocationSize += totalAllocationSize; + + long rawBufferAddress = MemoryUtils.getAddress(mappedBuffer); + if (alignmentSize > 0) { + long alignmentOffset = rawBufferAddress % alignmentSize; + if (alignmentOffset != 0) { + mappedBuffer.position(alignmentSize - (int)alignmentOffset); + } + } + + ByteBuffer alignedBuffer = mappedBuffer.slice(); + + IbvMr mr = pd.regMr(alignedBuffer, access).execute().free().getMr(); + mrs.add(mr); + int sliceSize = endpointGroup.getBufferSize() + DaRPCEndpoint.HEADERSIZE; + int i = 0; + while ((i * sliceSize + sliceSize) < alignedBuffer.capacity()) { + alignedBuffer.position(i * sliceSize); + alignedBuffer.limit(i * sliceSize + sliceSize); + ByteBuffer buffer = alignedBuffer.slice(); + freeList.add(buffer); + memoryRegions.put(MemoryUtils.getAddress(buffer), mr); + i++; + } + } +} diff --git a/src/main/java/com/ibm/darpc/DaRPCServerEndpoint.java b/src/main/java/com/ibm/darpc/DaRPCServerEndpoint.java index 06fd8e6..14ea773 100644 --- a/src/main/java/com/ibm/darpc/DaRPCServerEndpoint.java +++ b/src/main/java/com/ibm/darpc/DaRPCServerEndpoint.java @@ -12,37 +12,42 @@ public class DaRPCServerEndpoint extends DaRPCEndpoint { private static final Logger logger = LoggerFactory.getLogger("com.ibm.darpc"); - + private DaRPCServerGroup group; + final private int eventPoolSize; private ArrayBlockingQueue> eventPool; private ArrayBlockingQueue> lazyEvents; private int getClusterId; - + public DaRPCServerEndpoint(DaRPCServerGroup group, RdmaCmId idPriv, boolean serverSide) throws IOException { super(group, idPriv, serverSide); this.group = group; this.getClusterId = group.newClusterId(); + this.eventPoolSize = Math.max(group.recvQueueSize(), group.sendQueueSize()); this.eventPool = new ArrayBlockingQueue>(group.recvQueueSize()); this.lazyEvents = new ArrayBlockingQueue>(group.recvQueueSize()); + } + public void init() throws IOException { super.init(); - for(int i = 0; i < group.recvQueueSize(); i++){ + for(int i = 0; i < this.eventPoolSize; i++){ DaRPCServerEvent event = new DaRPCServerEvent(this, group.createRequest(), group.createResponse()); this.eventPool.add(event); - + } } - + void sendResponse(DaRPCServerEvent event) throws IOException { if (sendMessage(event.getSendMessage(), event.getTicket())){ eventPool.add(event); } else { lazyEvents.add(event); } - } - + } + + public synchronized void dispatchCmEvent(RdmaCmEvent cmEvent) throws IOException { super.dispatchCmEvent(cmEvent); try { @@ -53,16 +58,17 @@ public synchronized void dispatchCmEvent(RdmaCmEvent cmEvent) throws IOException } else if (eventType == RdmaCmEvent.EventType.RDMA_CM_EVENT_DISCONNECTED.ordinal()) { logger.info("RPC disconnection, eid " + this.getEndpointId()); group.close(this); - } + } } catch (Exception e) { e.printStackTrace(); } - } + } public int clusterId() { return getClusterId; } - + + public void dispatchReceive(ByteBuffer recvBuffer, int ticket, int recvIndex) throws IOException { DaRPCServerEvent event = eventPool.poll(); if (event == null){ @@ -72,11 +78,12 @@ public void dispatchReceive(ByteBuffer recvBuffer, int ticket, int recvIndex) th event.getReceiveMessage().update(recvBuffer); event.stamp(ticket); postRecv(recvIndex); - group.processServerEvent(event); + group.processServerEvent(event); } - + + public void dispatchSend(int ticket) throws IOException { - freeSend(ticket); + freeSend(ticket); DaRPCServerEvent event = lazyEvents.poll(); if (event != null){ sendResponse(event); diff --git a/src/main/java/com/ibm/darpc/DaRPCServerGroup.java b/src/main/java/com/ibm/darpc/DaRPCServerGroup.java index 8b77a52..47dca76 100644 --- a/src/main/java/com/ibm/darpc/DaRPCServerGroup.java +++ b/src/main/java/com/ibm/darpc/DaRPCServerGroup.java @@ -15,7 +15,7 @@ public class DaRPCServerGroup extends DaRPCEndpointGroup, R, T> { private static final Logger logger = LoggerFactory.getLogger("com.ibm.darpc"); - + private ConcurrentHashMap> deviceInstance; private DaRPCResourceManager resourceManager; private long[] computeAffinities; @@ -26,20 +26,23 @@ public class DaRPCServerGroup ex private boolean polling; private int pollSize; private int clusterSize; - - public static DaRPCServerGroup createServerGroup(DaRPCService rpcService, long[] clusterAffinities, int timeout, int maxinline, boolean polling, int recvQueue, int sendQueue, int pollSize, int clusterSize) throws Exception { - DaRPCServerGroup group = new DaRPCServerGroup(rpcService, clusterAffinities, timeout, maxinline, polling, recvQueue, sendQueue, pollSize, clusterSize); + + public static DaRPCServerGroup createServerGroup(DaRPCService rpcService, DaRPCMemPool, R, T> memPool, long[] clusterAffinities, int timeout, int maxinline, boolean polling, + int recvQueue, int sendQueue, int pollSize, int clusterSize) throws Exception { + DaRPCServerGroup group = new DaRPCServerGroup(rpcService, memPool, clusterAffinities, timeout, maxinline, polling, + recvQueue, sendQueue, pollSize, clusterSize); group.init(new RpcServerFactory(group)); return group; } - private DaRPCServerGroup(DaRPCService rpcService, long[] clusterAffinities, int timeout, int maxinline, boolean polling, int recvQueue, int sendQueue, int pollSize, int clusterSize) throws Exception { - super(rpcService, timeout, maxinline, recvQueue, sendQueue); - + private DaRPCServerGroup(DaRPCService rpcService, DaRPCMemPool, R, T> memPool,long[] clusterAffinities, int timeout, int maxinline, + boolean polling, int recvQueue, int sendQueue, int pollSize, int clusterSize) throws Exception { + super(rpcService, memPool, timeout, maxinline, recvQueue, sendQueue); + this.rpcService = rpcService; deviceInstance = new ConcurrentHashMap>(); this.computeAffinities = clusterAffinities; - this.resourceAffinities = clusterAffinities; + this.resourceAffinities = clusterAffinities; this.nbrOfClusters = computeAffinities.length; this.currentCluster = 0; resourceManager = new DaRPCResourceManager(resourceAffinities, timeout); @@ -47,7 +50,7 @@ private DaRPCServerGroup(DaRPCService rpcService, long[] clusterAffinities this.pollSize = pollSize; this.clusterSize = clusterSize; } - + public RdmaCqProvider createCqProvider(DaRPCServerEndpoint endpoint) throws IOException { logger.info("setting up cq processor (multicore)"); IbvContext context = endpoint.getIdPriv().getVerbs(); @@ -64,27 +67,27 @@ public RdmaCqProvider createCqProvider(DaRPCServerEndpoint endpoint) throws rpcInstance = deviceInstance.get(context.getCmd_fd()); DaRPCCluster cqProcessor = rpcInstance.getProcessor(endpoint.clusterId()); return cqProcessor; - } - + } + public IbvQP createQpProvider(DaRPCServerEndpoint endpoint) throws IOException{ logger.info("setting up QP"); DaRPCCluster cqProcessor = this.lookupCqProcessor(endpoint); IbvCQ cq = cqProcessor.getCQ(); - IbvQP qp = this.createQP(endpoint.getIdPriv(), endpoint.getPd(), cq); + IbvQP qp = this.createQP(endpoint.getIdPriv(), endpoint.getPd(), cq); cqProcessor.registerQP(qp.getQp_num(), endpoint); return qp; - } - + } + public void allocateResources(DaRPCServerEndpoint endpoint) throws Exception { resourceManager.allocateResources(endpoint); } - + synchronized int newClusterId() { int newClusterId = currentCluster; currentCluster = (currentCluster + 1) % nbrOfClusters; return newClusterId; } - + protected synchronized DaRPCCluster lookupCqProcessor(DaRPCServerEndpoint endpoint) throws IOException{ IbvContext context = endpoint.getIdPriv().getVerbs(); if (context == null) { @@ -99,17 +102,17 @@ protected synchronized DaRPCCluster lookupCqProcessor(DaRPCServerEndpoint cqProcessor = rpcInstance.getProcessor(endpoint.clusterId()); return cqProcessor; } - } - + } + public void close() throws IOException, InterruptedException { super.close(); for (DaRPCInstance rpcInstance : deviceInstance.values()){ rpcInstance.close(); - } + } resourceManager.close(); logger.info("rpc group down"); - } - + } + public R createRequest() { return rpcService.createRequest(); } @@ -117,33 +120,33 @@ public R createRequest() { public T createResponse() { return rpcService.createResponse(); } - + public void processServerEvent(DaRPCServerEvent event) throws IOException { rpcService.processServerEvent(event); } - + public void open(DaRPCServerEndpoint endpoint){ rpcService.open(endpoint); - } - + } + public void close(DaRPCServerEndpoint endpoint){ rpcService.close(endpoint); } - + public DaRPCService getRpcService() { return rpcService; } - + public static class RpcServerFactory implements RdmaEndpointFactory> { private DaRPCServerGroup group; - + public RpcServerFactory(DaRPCServerGroup group){ this.group = group; } - + @Override public DaRPCServerEndpoint createEndpoint(RdmaCmId id, boolean serverSide) throws IOException { return new DaRPCServerEndpoint(group, id, serverSide); } - } + } } diff --git a/src/test/java/com/ibm/darpc/examples/client/DaRPCClient.java b/src/test/java/com/ibm/darpc/examples/client/DaRPCClient.java index fb5a8c3..283aa2f 100644 --- a/src/test/java/com/ibm/darpc/examples/client/DaRPCClient.java +++ b/src/test/java/com/ibm/darpc/examples/client/DaRPCClient.java @@ -22,8 +22,6 @@ package com.ibm.darpc.examples.client; import java.io.FileOutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; @@ -42,6 +40,8 @@ import com.ibm.darpc.DaRPCClientGroup; import com.ibm.darpc.DaRPCEndpoint; import com.ibm.darpc.DaRPCFuture; +import com.ibm.darpc.DaRPCMemPool; +import com.ibm.darpc.DaRPCMemPoolImpl; import com.ibm.darpc.DaRPCStream; import com.ibm.darpc.examples.protocol.RdmaRpcProtocol; import com.ibm.darpc.examples.protocol.RdmaRpcRequest; @@ -51,8 +51,8 @@ public class DaRPCClient { public static enum BenchmarkType { UNDEFINED - }; - + }; + public static class ClientThread implements Runnable { public static final int FUTURE_POLL = 0; public static final int STREAM_POLL = 1; @@ -60,19 +60,19 @@ public static class ClientThread implements Runnable { public static final int STREAM_TAKE = 3; public static final int BATCH_STREAM_TAKE = 4; public static final int BATCH_STREAM_POLL = 5; - + private DaRPCClientEndpoint clientEp; private int loop; private int queryMode; private int clienttimeout; private ArrayBlockingQueue freeResponses; - + protected double throughput; protected double latency; protected double readOps; protected double writeOps; - protected double errorOps; - + protected double errorOps; + public ClientThread(DaRPCClientEndpoint clientEp, int loop, URI uri, int mode, int rpcpipeline, int clienttimeout){ this.clientEp = clientEp; this.loop = loop; @@ -82,9 +82,9 @@ public ClientThread(DaRPCClientEndpoint clientE for (int i = 0; i < rpcpipeline; i++){ RdmaRpcResponse response = new RdmaRpcResponse(); freeResponses.add(response); - } + } } - + @Override public void run() { try { @@ -97,15 +97,15 @@ public void run() { while(freeResponses.isEmpty()){ DaRPCFuture future = stream.poll(); if (future != null){ - freeResponses.add(future.getReceiveMessage()); + freeResponses.add(future.getReceiveMessage()); consumed++; } } - + request.setParam(issued); RdmaRpcResponse response = freeResponses.poll(); DaRPCFuture future = stream.request(request, response, streamMode); - + switch (queryMode) { case FUTURE_POLL: while (!future.isDone()) { @@ -121,12 +121,12 @@ public void run() { } consumed++; freeResponses.add(future.getReceiveMessage()); - break; + break; case FUTURE_TAKE: future.get(clienttimeout, TimeUnit.MILLISECONDS); consumed++; freeResponses.add(future.getReceiveMessage()); - break; + break; case STREAM_TAKE: future = stream.take(clienttimeout); consumed++; @@ -135,7 +135,7 @@ public void run() { case BATCH_STREAM_TAKE: break; case BATCH_STREAM_POLL: - break; + break; } } while (consumed < issued){ @@ -152,7 +152,7 @@ public void run() { public void close() throws Exception { clientEp.close(); } - + public double getThroughput() { return throughput; } @@ -171,15 +171,15 @@ public double getWriteOps() { public double getErrorOps() { return this.errorOps; - } - + } + public double getOps(){ return loop; - } + } } - + public void launch(String[] args) throws Exception { - String ipAddress = ""; + String ipAddress = ""; int size = 24; int loop = 100; int threadCount = 1; @@ -190,6 +190,7 @@ public void launch(String[] args) throws Exception { int maxinline = 0; int recvQueue = batchSize; int sendQueue = batchSize; + String hugePagePath = null; Option addressOption = Option.builder("a").required().desc("server address").hasArg().build(); Option loopOption = Option.builder("k").desc("loop count").hasArg().build(); @@ -202,6 +203,7 @@ public void launch(String[] args) throws Exception { Option sendQueueOption = Option.builder("s").desc("send queue").hasArg().build(); Option recvQueueOption = Option.builder("r").desc("receive queue").hasArg().build(); Option serializedSizeOption = Option.builder("l").desc("serialized size").hasArg().build(); + Option hugepagePathOption = Option.builder("h").required().desc("memory pool hugepage path").hasArg().build(); Options options = new Options(); options.addOption(addressOption); options.addOption(loopOption); @@ -214,12 +216,15 @@ public void launch(String[] args) throws Exception { options.addOption(sendQueueOption); options.addOption(recvQueueOption); options.addOption(serializedSizeOption); + options.addOption(hugepagePathOption); CommandLineParser parser = new DefaultParser(); - + try { CommandLine line = parser.parse(options, args); ipAddress = line.getOptionValue(addressOption.getOpt()); + hugePagePath = line.getOptionValue(hugepagePathOption.getOpt()); + if (line.hasOption(loopOption.getOpt())) { loop = Integer.parseInt(line.getOptionValue(loopOption.getOpt())); } @@ -273,15 +278,16 @@ public void launch(String[] args) throws Exception { if ((threadCount % connections) != 0){ throw new Exception("thread count needs to be a multiple of connections"); } - + int threadsperconnection = threadCount / connections; DaRPCEndpoint[] rpcConnections = new DaRPCEndpoint[connections]; Thread[] workers = new Thread[threadCount]; ClientThread[] benchmarkTask = new ClientThread[threadCount]; - + RdmaRpcProtocol rpcProtocol = new RdmaRpcProtocol(); + DaRPCMemPool, RdmaRpcRequest, RdmaRpcResponse> memPool = new DaRPCMemPoolImpl, RdmaRpcRequest, RdmaRpcResponse>(hugePagePath); System.out.println("starting.. threads " + threadCount + ", connections " + connections + ", server " + ipAddress + ", recvQueue " + recvQueue + ", sendQueue" + sendQueue + ", batchSize " + batchSize + ", mode " + mode); - DaRPCClientGroup group = DaRPCClientGroup.createClientGroup(rpcProtocol, 100, maxinline, recvQueue, sendQueue); + DaRPCClientGroup group = DaRPCClientGroup.createClientGroup(rpcProtocol, memPool, 100, maxinline, recvQueue, sendQueue); URI uri = URI.create("rdma://" + ipAddress + ":" + 1919); int k = 0; for (int i = 0; i < rpcConnections.length; i++){ @@ -295,7 +301,7 @@ public void launch(String[] args) throws Exception { } StopWatch stopWatchThroughput = new StopWatch(); - stopWatchThroughput.start(); + stopWatchThroughput.start(); for(int i = 0; i < threadCount;i++){ workers[i] = new Thread(benchmarkTask[i]); workers[i].start(); @@ -319,7 +325,7 @@ public void launch(String[] args) throws Exception { double throughputperclient = throughput / _threadcount; double norm = 1.0; latency = norm / throughputperclient * 1000000.0; - } + } System.out.println("throughput " + throughput); String dataFilename = "datalog-client.dat"; @@ -336,18 +342,18 @@ public void launch(String[] args) throws Exception { + "\n"; ByteBuffer buffer = ByteBuffer.wrap(logdata.getBytes()); dataChannel.write(buffer); - dataChannel.close(); + dataChannel.close(); dataStream.close(); - + for (int i = 0; i < rpcConnections.length; i++){ rpcConnections[i].close(); } group.close(); } - - public static void main(String[] args) throws Exception { + + public static void main(String[] args) throws Exception { DaRPCClient rpcClient = new DaRPCClient(); - rpcClient.launch(args); + rpcClient.launch(args); System.exit(0); } } diff --git a/src/test/java/com/ibm/darpc/examples/server/DaRPCServer.java b/src/test/java/com/ibm/darpc/examples/server/DaRPCServer.java index 57efef4..58dc4a8 100644 --- a/src/test/java/com/ibm/darpc/examples/server/DaRPCServer.java +++ b/src/test/java/com/ibm/darpc/examples/server/DaRPCServer.java @@ -21,8 +21,6 @@ package com.ibm.darpc.examples.server; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import org.apache.commons.cli.CommandLine; @@ -33,6 +31,8 @@ import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; +import com.ibm.darpc.DaRPCMemPool; +import com.ibm.darpc.DaRPCMemPoolImpl; import com.ibm.darpc.DaRPCServerEndpoint; import com.ibm.darpc.DaRPCServerGroup; import com.ibm.darpc.examples.protocol.RdmaRpcRequest; @@ -40,7 +40,7 @@ import com.ibm.disni.rdma.*; public class DaRPCServer { - private String ipAddress; + private String ipAddress; private int poolsize = 3; private int recvQueue = 16; private int sendQueue = 16; @@ -49,7 +49,8 @@ public class DaRPCServer { private boolean polling = false; private int maxinline = 0; private int connections = 16; - + String hugePagePath = null; + public void run() throws Exception{ long[] clusterAffinities = new long[poolsize]; for (int i = 0; i < poolsize; i++){ @@ -58,15 +59,16 @@ public void run() throws Exception{ } System.out.println("running...server " + ipAddress + ", poolsize " + poolsize + ", maxinline " + maxinline + ", polling " + polling + ", recvQueue " + recvQueue + ", sendQueue " + sendQueue + ", wqSize " + wqSize + ", rpcservice-timeout " + servicetimeout); RdmaRpcService rpcService = new RdmaRpcService(servicetimeout); - DaRPCServerGroup group = DaRPCServerGroup.createServerGroup(rpcService, clusterAffinities, -1, maxinline, polling, recvQueue, sendQueue, wqSize, 32); + DaRPCMemPool, RdmaRpcRequest, RdmaRpcResponse> memPool = new DaRPCMemPoolImpl, RdmaRpcRequest, RdmaRpcResponse>(hugePagePath); + DaRPCServerGroup group = DaRPCServerGroup.createServerGroup(rpcService, memPool, clusterAffinities, -1, maxinline, polling, recvQueue, sendQueue, wqSize, 32); RdmaServerEndpoint> serverEp = group.createServerEndpoint(); URI uri = URI.create("rdma://" + ipAddress + ":" + 1919); serverEp.bind(uri); while(true){ serverEp.accept(); - } + } } - + public void launch(String[] args) throws Exception { Option addressOption = Option.builder("a").required().desc("server address").hasArg().build(); Option poolsizeOption = Option.builder("p").desc("pool size").hasArg().build(); @@ -78,6 +80,7 @@ public void launch(String[] args) throws Exception { Option recvQueueOption = Option.builder("r").desc("receive queue").hasArg().build(); Option sendQueueOption = Option.builder("s").desc("send queue").hasArg().build(); Option serializedSizeOption = Option.builder("l").desc("serialized size").hasArg().build(); + Option hugepagePathOption = Option.builder("h").required().desc("memory pool hugepage path").hasArg().build(); Options options = new Options(); options.addOption(addressOption); options.addOption(poolsizeOption); @@ -89,12 +92,15 @@ public void launch(String[] args) throws Exception { options.addOption(recvQueueOption); options.addOption(sendQueueOption); options.addOption(serializedSizeOption); + options.addOption(hugepagePathOption); CommandLineParser parser = new DefaultParser(); try { CommandLine line = parser.parse(options, args); ipAddress = line.getOptionValue(addressOption.getOpt()); + hugePagePath = line.getOptionValue(hugepagePathOption.getOpt()); + if (line.hasOption(poolsizeOption.getOpt())) { poolsize = Integer.parseInt(line.getOptionValue(poolsizeOption.getOpt())); } @@ -130,9 +136,9 @@ public void launch(String[] args) throws Exception { } this.run(); } - - public static void main(String[] args) throws Exception { + + public static void main(String[] args) throws Exception { DaRPCServer rpcServer = new DaRPCServer(); - rpcServer.launch(args); - } + rpcServer.launch(args); + } }