2024-12-17 15:47:10 +08:00
|
|
|
#include "coroutine.hpp"
|
|
|
|
#include "coctx.h"
|
2024-12-20 21:17:21 +08:00
|
|
|
#include "logger.hpp"
|
|
|
|
// #include <atomic>
|
|
|
|
#include <cstddef>
|
2024-12-17 15:47:10 +08:00
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
namespace tinyrpc {
|
|
|
|
static thread_local Coroutine* t_main_coroutine = nullptr; // thread_local: 每个线程有一个主协程
|
|
|
|
static thread_local Coroutine* t_curr_coroutine = nullptr;
|
2024-12-20 21:17:21 +08:00
|
|
|
// static std::atomic_int t_coroutine_count {0};
|
2024-12-17 15:47:10 +08:00
|
|
|
|
2024-12-25 19:40:27 +08:00
|
|
|
Coroutine* Coroutine::getCurrCoroutine() {
|
|
|
|
if(t_main_coroutine == nullptr) {
|
|
|
|
t_main_coroutine = new Coroutine();
|
|
|
|
}
|
|
|
|
return t_curr_coroutine;
|
|
|
|
}
|
|
|
|
|
|
|
|
Coroutine* Coroutine::getMainCoroutine() {
|
|
|
|
if(t_main_coroutine == nullptr) {
|
|
|
|
t_main_coroutine = new Coroutine();
|
|
|
|
}
|
|
|
|
return t_main_coroutine;
|
|
|
|
}
|
|
|
|
|
2024-12-17 15:47:10 +08:00
|
|
|
void coFunction(Coroutine* co) {
|
|
|
|
if (co != nullptr) {
|
|
|
|
co->m_is_in_cofunc = true;
|
|
|
|
(*co)();
|
|
|
|
co->m_is_in_cofunc = false;
|
|
|
|
}
|
|
|
|
Coroutine::yeild();
|
|
|
|
}
|
|
|
|
|
|
|
|
Coroutine::Coroutine() { // 构造主协程
|
2024-12-20 21:17:21 +08:00
|
|
|
// m_cor_id = t_coroutine_count++;
|
2024-12-17 15:47:10 +08:00
|
|
|
// t_main_coroutine = this;
|
2024-12-20 21:17:21 +08:00
|
|
|
t_curr_coroutine = this;
|
2024-12-17 15:47:10 +08:00
|
|
|
|
|
|
|
logger() << "main coroutine has built";
|
|
|
|
}
|
|
|
|
|
2025-01-10 15:00:50 +08:00
|
|
|
Coroutine::Coroutine( std::function<void()> cb, std::size_t stack_size /* = 1 * 1024 * 1024 *//* , char* stack_sp */):
|
2024-12-20 21:17:21 +08:00
|
|
|
m_stack_sp(static_cast<char*>(malloc(stack_size))),
|
2024-12-17 15:47:10 +08:00
|
|
|
m_stack_size(stack_size),
|
|
|
|
m_callback(cb)
|
|
|
|
|
|
|
|
{ // 构造协程
|
2024-12-20 21:17:21 +08:00
|
|
|
// m_cor_id = t_coroutine_count++;
|
2024-12-17 15:47:10 +08:00
|
|
|
|
|
|
|
if (t_main_coroutine == nullptr) {
|
|
|
|
t_main_coroutine = new Coroutine();
|
|
|
|
}
|
|
|
|
|
2024-12-20 21:17:21 +08:00
|
|
|
char* top = m_stack_sp + stack_size;
|
2024-12-17 15:47:10 +08:00
|
|
|
top = reinterpret_cast<char*>((reinterpret_cast<unsigned long long >(top) & (~0xfull))); // 8字节对齐
|
|
|
|
|
|
|
|
m_ctx.regs[reg::kRBP] = top;
|
|
|
|
m_ctx.regs[reg::kRSP] = top;
|
|
|
|
m_ctx.regs[reg::kRDI] = this;
|
|
|
|
m_ctx.regs[reg::kRETAddr] = reinterpret_cast<char*>(&coFunction);
|
|
|
|
m_ctx.regs[reg::kRDI] = reinterpret_cast<char*>(this);
|
|
|
|
logger() << "user coroutine has built";
|
|
|
|
}
|
|
|
|
|
|
|
|
void Coroutine::yeild() {
|
|
|
|
|
2024-12-20 21:17:21 +08:00
|
|
|
if (t_main_coroutine == nullptr) {
|
|
|
|
logger() << "main coroutine is null";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (t_curr_coroutine == nullptr) {
|
|
|
|
logger() << "current coroutine is null";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2024-12-17 15:47:10 +08:00
|
|
|
if (t_curr_coroutine == t_main_coroutine) {
|
|
|
|
logger() << "current coroutine is main coroutine !";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
Coroutine* cur = t_curr_coroutine;
|
|
|
|
t_curr_coroutine = t_main_coroutine;
|
|
|
|
coctx_swap(&(cur->m_ctx), &(t_main_coroutine->m_ctx));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
void Coroutine::resume() {
|
|
|
|
if (t_curr_coroutine != t_main_coroutine) {
|
|
|
|
logger() << "swap error, current coroutine must be main coroutine !";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
t_curr_coroutine = this;
|
|
|
|
coctx_swap(&(t_main_coroutine->m_ctx), &(this->m_ctx));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Coroutine::~Coroutine() {
|
2025-01-14 15:27:15 +08:00
|
|
|
if(m_stack_sp) free(m_stack_sp);
|
2024-12-17 15:47:10 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|