Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions protos/Crane.proto
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ message StreamCforedRequest {
TASK_REQUEST = 1;
TASK_COMPLETION_REQUEST = 2;
CFORED_GRACEFUL_EXIT = 3;
TASK_META_REQUEST = 4;
}
CforedRequestType type = 1;

Expand All @@ -620,11 +621,18 @@ message StreamCforedRequest {
string cfored_name = 1;
}

message TaskMetaReq {
uint32 uid = 1;
uint32 task_id = 2;
int32 cattach_pid = 3;
}

oneof payload {
CforedReg payload_cfored_reg = 2;
TaskReq payload_task_req = 3;
TaskCompleteReq payload_task_complete_req = 4;
GracefulExitReq payload_graceful_exit_req = 5;
TaskMetaReq payload_task_meta_req = 6;
}
}

Expand All @@ -636,6 +644,7 @@ message StreamCtldReply {
TASK_COMPLETION_ACK_REPLY = 3;
CFORED_REGISTRATION_ACK = 4;
CFORED_GRACEFUL_EXIT_ACK = 5;
TASK_META_REPLY = 6;
}

message TaskIdReply {
Expand Down Expand Up @@ -670,6 +679,13 @@ message StreamCtldReply {
bool ok = 1;
}

message TaskMetaReply {
bool ok = 1;
string failure_reason = 2;
TaskToCtld task = 3;
int32 cattach_pid = 4;
}

CtldReplyType type = 1;

oneof payload {
Expand All @@ -679,6 +695,7 @@ message StreamCtldReply {
TaskCompletionAckReply payload_task_completion_ack = 5;
TaskIdReply payload_task_id_reply = 6;
CforedGracefulExitAck payload_graceful_exit_ack = 7;
TaskMetaReply payload_task_meta_reply = 8;
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/CraneCtld/CtldPublicDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,10 @@ struct TaskInCtld {
crane::grpc::TaskToD GetTaskToD(const CranedId& craned_id) const;

crane::grpc::JobToD GetJobToD(const CranedId& craned_id) const;

const std::string& GetAllocatedCranedsRegex() const {
return allocated_craneds_regex;
}
};

struct Qos {
Expand Down
102 changes: 77 additions & 25 deletions src/CraneCtld/RpcService/CtldGrpcServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
kWaitRegReq = 0,
kWaitMsg,
kCleanData,
kWaitReConnect,
};

bool ok;

StreamCforedRequest cfored_request;

auto stream_writer = std::make_shared<CforedStreamWriter>(stream);
std::weak_ptr<CforedStreamWriter> writer_weak_ptr(stream_writer);
std::string cfored_name;
std::weak_ptr<StreamWriterProxy> proxy_weak_ptr;

CRANE_TRACE("CforedStream from {} created.", context->peer());

Expand All @@ -210,6 +210,21 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
}

cfored_name = cfored_request.payload_cfored_reg().cfored_name();

m_ctld_server_->m_stream_proxy_mtx_.Lock();
auto iter =
m_ctld_server_->m_cfored_stream_proxy_map_.find(cfored_name);
if (iter != m_ctld_server_->m_cfored_stream_proxy_map_.end()) {
iter->second->SetWriter(stream_writer);
proxy_weak_ptr = iter->second;
} else {
auto proxy = std::make_shared<StreamWriterProxy>();
proxy->SetWriter(stream_writer);
m_ctld_server_->m_cfored_stream_proxy_map_[cfored_name] = proxy;
proxy_weak_ptr = proxy;
}
m_ctld_server_->m_stream_proxy_mtx_.Unlock();

CRANE_INFO("Cfored {} registered.", cfored_name);

ok = stream_writer->WriteCforedRegistrationAck({});
Expand All @@ -220,11 +235,11 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
"Failed to send msg to cfored {}. Connection is broken. "
"Exiting...",
cfored_name);
state = StreamState::kCleanData;
state = StreamState::kWaitReConnect;
}

} else {
state = StreamState::kCleanData;
state = StreamState::kWaitReConnect;
}

break;
Expand All @@ -239,36 +254,43 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
task->SetFieldsByTaskToCtld(payload.task());

auto &meta = std::get<InteractiveMetaInTask>(task->meta);

meta.cb_task_res_allocated =
[writer_weak_ptr](task_id_t task_id,
std::string const &allocated_craned_regex,
std::list<std::string> const &craned_ids) {
if (auto writer = writer_weak_ptr.lock(); writer)
writer->WriteTaskResAllocReply(
task_id,
{std::make_pair(allocated_craned_regex, craned_ids)});
[proxy_weak_ptr](task_id_t task_id,
std::string const &allocated_craned_regex,
std::list<std::string> const &craned_ids) {
if (auto proxy = proxy_weak_ptr.lock(); proxy) {
proxy->WithWriter([&](CforedStreamWriter &writer) {
writer.WriteTaskResAllocReply(
task_id,
{std::make_pair(allocated_craned_regex, craned_ids)});
});
}
};

meta.cb_task_cancel = [writer_weak_ptr](task_id_t task_id) {
meta.cb_task_cancel = [proxy_weak_ptr](task_id_t task_id) {
CRANE_TRACE("Sending TaskCancelRequest in task_cancel", task_id);
if (auto writer = writer_weak_ptr.lock(); writer)
writer->WriteTaskCancelRequest(task_id);
if (auto proxy = proxy_weak_ptr.lock(); proxy) {
proxy->WithWriter([&](CforedStreamWriter &writer) {
writer.WriteTaskCancelRequest(task_id);
});
}
};

meta.cb_task_completed = [this, cfored_name, writer_weak_ptr](
meta.cb_task_completed = [this, cfored_name, proxy_weak_ptr](
task_id_t task_id,
bool send_completion_ack) {
CRANE_TRACE("The completion callback of task #{} has been called.",
task_id);
if (auto writer = writer_weak_ptr.lock(); writer) {
if (auto proxy = proxy_weak_ptr.lock(); proxy) {
if (send_completion_ack)
writer->WriteTaskCompletionAckReply(task_id);
} else {
CRANE_ERROR(
"Stream writer of ia task #{} has been destroyed. "
"TaskCompletionAckReply will not be sent.",
task_id);
proxy->WithWriter([&](CforedStreamWriter &writer) {
writer.WriteTaskCompletionAckReply(task_id);
});
else
CRANE_ERROR(
"Stream writer of ia task #{} has been destroyed. "
"TaskCompletionAckReply will not be sent.",
task_id);
}

m_ctld_server_->m_mtx_.Lock();
Expand Down Expand Up @@ -301,7 +323,7 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
"Failed to send msg to cfored {}. Connection is broken. "
"Exiting...",
cfored_name);
state = StreamState::kCleanData;
state = StreamState::kWaitReConnect;
} else {
if (result.has_value()) {
m_ctld_server_->m_mtx_.Lock();
Expand All @@ -312,6 +334,26 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
}
} break;

case StreamCforedRequest::TASK_META_REQUEST: {
auto const &payload = cfored_request.payload_task_meta_req();
CRANE_TRACE("Recv TaskMetaReq of Task #{}", payload.task_id());
std::string failure_reason;
bool ok = true;
crane::grpc::TaskToCtld task;
if (!g_task_scheduler->QueryTaskUseId(payload.task_id(), &task)) {
ok = false;
failure_reason = "Task not found";
} else {
if (payload.uid() != task.uid() &&
!g_account_manager->CheckUidIsAdmin(payload.uid())) {
ok = false;
failure_reason = "permission denied";
}
}
stream_writer->WriteTaskMetaReply(ok, failure_reason, task,
payload.cattach_pid());
} break;

case StreamCforedRequest::TASK_COMPLETION_REQUEST: {
auto const &payload = cfored_request.payload_task_complete_req();
CRANE_TRACE("Recv TaskCompletionReq of Task #{}", payload.task_id());
Expand Down Expand Up @@ -339,10 +381,16 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
return Status::CANCELLED;
}
} else {
state = StreamState::kCleanData;
state = StreamState::kWaitReConnect;
}
} break;

case StreamState::kWaitReConnect: {
CRANE_INFO("Cfored {} unexpectedly disconnected. Wait for restart.....",
cfored_name);
stream_writer->Invalidate();
return Status::OK;
}
case StreamState::kCleanData: {
CRANE_INFO("Cfored {} disconnected. Cleaning its data...", cfored_name);
stream_writer->Invalidate();
Expand All @@ -359,6 +407,10 @@ grpc::Status CtldForInternalServiceImpl::CforedStream(
g_task_scheduler->TerminateRunningTask(task_id);
}

m_ctld_server_->m_stream_proxy_mtx_.Lock();
m_ctld_server_->m_cfored_stream_proxy_map_.erase(cfored_name);
m_ctld_server_->m_stream_proxy_mtx_.Unlock();

return Status::OK;
}
}
Expand Down
43 changes: 43 additions & 0 deletions src/CraneCtld/RpcService/CtldGrpcServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,22 @@ class CforedStreamWriter {
return m_stream_->Write(reply);
}

bool WriteTaskMetaReply(bool ok, const std::string &failure_reason,
const crane::grpc::TaskToCtld &task, int32_t pid) {
LockGuard guard(&m_stream_mtx_);
if (!m_valid_) return false;

StreamCtldReply reply;
reply.set_type(StreamCtldReply::TASK_META_REPLY);
auto *task_meta_reply = reply.mutable_payload_task_meta_reply();
task_meta_reply->set_ok(ok);
task_meta_reply->set_failure_reason(failure_reason);
task_meta_reply->set_cattach_pid(pid);
task_meta_reply->mutable_task()->CopyFrom(task);

return m_stream_->Write(reply);
}

void Invalidate() {
LockGuard guard(&m_stream_mtx_);
m_valid_ = false;
Expand All @@ -165,6 +181,29 @@ class CforedStreamWriter {
ABSL_GUARDED_BY(m_stream_mtx_);
};

class StreamWriterProxy {
public:
void SetWriter(std::shared_ptr<CforedStreamWriter> writer) {
std::lock_guard<std::mutex> lock(mtx_);
writer_ = std::move(writer);
}

std::shared_ptr<CforedStreamWriter> GetWriter() {
std::lock_guard<std::mutex> lock(mtx_);
return writer_;
}

template <typename Func>
void WithWriter(Func &&func) {
std::shared_ptr<CforedStreamWriter> writer = GetWriter();
if (writer) func(*writer);
}

private:
std::mutex mtx_;
std::shared_ptr<CforedStreamWriter> writer_;
};

class CtldServer;

class CtldForInternalServiceImpl final
Expand Down Expand Up @@ -392,6 +431,10 @@ class CtldServer {
HashMap<std::string /* cfored_name */, HashSet<task_id_t>>
m_cfored_running_tasks_ ABSL_GUARDED_BY(m_mtx_);

Mutex m_stream_proxy_mtx_;
HashMap<std::string /* cfored_name */, std::shared_ptr<StreamWriterProxy>>
m_cfored_stream_proxy_map_ ABSL_GUARDED_BY(m_stream_proxy_mtx_);

std::unique_ptr<CtldForInternalServiceImpl> m_internal_service_impl_;
std::unique_ptr<Server> m_internal_server_;

Expand Down
11 changes: 11 additions & 0 deletions src/CraneCtld/TaskScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,17 @@ void TaskScheduler::QueryTasksInRam(
ranges::for_each(filtered_rng, append_fn);
}

bool TaskScheduler::QueryTaskUseId(task_id_t task_id,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QueryTaskRegex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QueryTaskNodeRegex

crane::grpc::TaskToCtld* task) {
LockGuard running_guard(&m_running_task_map_mtx_);
auto iter = m_running_task_map_.find(task_id);
if (iter == m_running_task_map_.end()) return false;

*task = iter->second->TaskToCtld();
task->set_nodelist(iter->second->GetAllocatedCranedsRegex());
return true;
}

void TaskScheduler::QueryRnJobOnCtldForNodeConfig(
const CranedId& craned_id, crane::grpc::ConfigureCranedRequest* req) {
LockGuard running_job_guard(&m_running_task_map_mtx_);
Expand Down
2 changes: 2 additions & 0 deletions src/CraneCtld/TaskScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ class TaskScheduler {
void QueryTasksInRam(const crane::grpc::QueryTasksInfoRequest* request,
crane::grpc::QueryTasksInfoReply* response);

bool QueryTaskUseId(task_id_t task_id, crane::grpc::TaskToCtld* task);

void QueryRnJobOnCtldForNodeConfig(const CranedId& craned_id,
crane::grpc::ConfigureCranedRequest* req);

Expand Down
Loading