#include "coroutine.hpp" #include "coctx.h" #include "logger.hpp" // #include #include #include namespace tinyrpc { static thread_local Coroutine* t_main_coroutine = nullptr; // thread_local: 每个线程有一个主协程 static thread_local Coroutine* t_curr_coroutine = nullptr; // static std::atomic_int t_coroutine_count {0}; Coroutine* Coroutine::getCurrCoroutine() { if(t_main_coroutine == nullptr) { t_main_coroutine = new Coroutine(); } return t_curr_coroutine; } Coroutine* Coroutine::getMainCoroutine() { if(t_main_coroutine == nullptr) { t_main_coroutine = new Coroutine(); } return t_main_coroutine; } void coFunction(Coroutine* co) { if (co != nullptr) { co->m_is_in_cofunc = true; (*co)(); co->m_is_in_cofunc = false; } Coroutine::yeild(); } Coroutine::Coroutine() { // 构造主协程 // m_cor_id = t_coroutine_count++; // t_main_coroutine = this; t_curr_coroutine = this; logger() << "main coroutine has built"; } Coroutine::Coroutine( std::function cb, std::size_t stack_size /* = 1 * 1024 * 1024 *//* , char* stack_sp */): m_stack_sp(static_cast(malloc(stack_size))), m_stack_size(stack_size), m_callback(cb) { // 构造协程 // m_cor_id = t_coroutine_count++; if (t_main_coroutine == nullptr) { t_main_coroutine = new Coroutine(); } char* top = m_stack_sp + stack_size; top = reinterpret_cast((reinterpret_cast(top) & (~0xfull))); // 8字节对齐 m_ctx.regs[reg::kRBP] = top; m_ctx.regs[reg::kRSP] = top; m_ctx.regs[reg::kRDI] = this; m_ctx.regs[reg::kRETAddr] = reinterpret_cast(&coFunction); m_ctx.regs[reg::kRDI] = reinterpret_cast(this); logger() << "user coroutine has built"; } void Coroutine::yeild() { if (t_main_coroutine == nullptr) { logger() << "main coroutine is null"; return; } if (t_curr_coroutine == nullptr) { logger() << "current coroutine is null"; return; } if (t_curr_coroutine == t_main_coroutine) { logger() << "current coroutine is main coroutine !"; return; } Coroutine* cur = t_curr_coroutine; t_curr_coroutine = t_main_coroutine; coctx_swap(&(cur->m_ctx), &(t_main_coroutine->m_ctx)); } void Coroutine::resume() { if (t_curr_coroutine != t_main_coroutine) { logger() << "swap error, current coroutine must be main coroutine !"; return; } t_curr_coroutine = this; coctx_swap(&(t_main_coroutine->m_ctx), &(this->m_ctx)); } Coroutine::~Coroutine() { if(m_stack_sp) free(m_stack_sp); } }