tinyrpc/src/coroutine/coroutine.cc

78 lines
2.3 KiB
C++
Raw Normal View History

2024-12-17 15:47:10 +08:00
#include "coroutine.hpp"
#include "coctx.h"
#include "logger.h"
#include <atomic>
#include <functional>
namespace tinyrpc {
static thread_local Coroutine* t_main_coroutine = nullptr; // thread_local: 每个线程有一个主协程
static thread_local Coroutine* t_curr_coroutine = nullptr;
static std::atomic_int t_coroutine_count {0};
void coFunction(Coroutine* co) {
if (co != nullptr) {
co->m_is_in_cofunc = true;
(*co)();
co->m_is_in_cofunc = false;
}
Coroutine::yeild();
}
Coroutine::Coroutine() { // 构造主协程
m_cor_id = t_coroutine_count++;
// t_main_coroutine = this;
t_main_coroutine = t_curr_coroutine = this;
logger() << "main coroutine has built";
}
Coroutine::Coroutine(std::size_t stack_size, char* stack_sp, std::function<void()> cb) :
m_stack_sp(stack_sp),
m_stack_size(stack_size),
m_callback(cb)
{ // 构造协程
m_cor_id = t_coroutine_count++;
if (t_main_coroutine == nullptr) {
t_main_coroutine = new Coroutine();
}
char* top = stack_sp + stack_size;
top = reinterpret_cast<char*>((reinterpret_cast<unsigned long long >(top) & (~0xfull))); // 8字节对齐
m_ctx.regs[reg::kRBP] = top;
m_ctx.regs[reg::kRSP] = top;
m_ctx.regs[reg::kRDI] = this;
m_ctx.regs[reg::kRETAddr] = reinterpret_cast<char*>(&coFunction);
m_ctx.regs[reg::kRDI] = reinterpret_cast<char*>(this);
logger() << "user coroutine has built";
}
void Coroutine::yeild() {
if (t_curr_coroutine == t_main_coroutine) {
logger() << "current coroutine is main coroutine !";
return;
}
Coroutine* cur = t_curr_coroutine;
t_curr_coroutine = t_main_coroutine;
coctx_swap(&(cur->m_ctx), &(t_main_coroutine->m_ctx));
}
void Coroutine::resume() {
if (t_curr_coroutine != t_main_coroutine) {
logger() << "swap error, current coroutine must be main coroutine !";
return;
}
t_curr_coroutine = this;
coctx_swap(&(t_main_coroutine->m_ctx), &(this->m_ctx));
}
Coroutine::~Coroutine() {
free(m_stack_sp);
}
}