tinyrpc/src/coroutine/coroutine_hook.cc
2024-12-25 19:40:27 +08:00

113 lines
3.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "coroutine_hook.hpp"
#include "coroutine.hpp"
#include "logger.hpp"
#include "fd_event.hpp"
#include "reactor.hpp"
#include <dlfcn.h>
#define HOOK_SYSTEM_FUN(name) name##_fun_ptr_t g_sys_##name##_fun = (name##_fun_ptr_t)dlsym(RTLD_NEXT, #name)
HOOK_SYSTEM_FUN(read);
HOOK_SYSTEM_FUN(write);
HOOK_SYSTEM_FUN(accept);
namespace tinyrpc {
static bool isEnableHook = false;
void enableHook() {
isEnableHook = true;
}
void disableHook() {
isEnableHook = false;
}
ssize_t read_hook(int fd, void *buf, size_t count) {
logger() << "read_hook is calling";
FdEvent fe(fd);
fe.addListenEvent(IOEvent::READ);
Coroutine* curCoro = Coroutine::getCurrCoroutine();
fe.setReadCallback([curCoro] () -> void{
curCoro->resume();
});
// fd 设置为 nonblock
fe.setNonblock();
// 尝试一下系统read 返回值大于0直接返回
int ret = g_sys_read_fun(fd, buf, count);
if(ret > 0) return ret;
// fd 添加到 epoll 中
Reactor::getReactor()->addFdEvent(&fe);
Coroutine::yeild(); // yeild
Reactor::getReactor()->delFdEvent(&fe);
// 调用系统 read 返回
return g_sys_read_fun(fd, buf, count);
}
ssize_t write_hook(int fd, const void *buf, size_t count) {
logger() << "write_hook is calling";
FdEvent fe(fd);
fe.addListenEvent(IOEvent::WRITE);
Coroutine* curCoro = Coroutine::getCurrCoroutine();
fe.setWriteCallback([curCoro] () -> void{
curCoro->resume();
});
// fd 设置为 nonblock
fe.setNonblock();
// 尝试一下系统 write 返回值大于0直接返回
int ret = g_sys_write_fun(fd, buf, count);
if(ret > 0) return ret;
// fd 添加到 epoll 中
Reactor::getReactor()->addFdEvent(&fe);
Coroutine::yeild(); // yeild
Reactor::getReactor()->delFdEvent(&fe);
// 调用系统 write 返回
return g_sys_write_fun(fd, buf, count);
}
int accept_hook(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
logger() << "accept_hook is calling";
FdEvent fe(sockfd);
fe.addListenEvent(IOEvent::READ);
Coroutine* curCoro = Coroutine::getCurrCoroutine();
fe.setReadCallback([curCoro] () -> void{
curCoro->resume();
});
// logger() << "accept_hook fd = " << fe.getFd();
// fd 设置为 nonblock
fe.setNonblock();
// 尝试一下系统 accept 返回值大于 0 直接返回
int ret = g_sys_accept_fun(sockfd, addr, addrlen);
if(ret >= 0) return ret;
// fd 添加到 epoll 中
Reactor::getReactor()->addFdEvent(&fe);
logger() << "accept_hook cor yeild";
Coroutine::yeild(); // yeild
logger() << "accept_hook cor resume then call g_sys_accept_fun";
Reactor::getReactor()->delFdEvent(&fe);
// 调用系统 write 返回
return g_sys_accept_fun(sockfd, addr, addrlen);
}
}
ssize_t read(int fd, void *buf, size_t count) {
if (tinyrpc::isEnableHook == false) {
return g_sys_read_fun(fd, buf, count); // 没有启用 hook 直接转发到系统调用
}
return tinyrpc::read_hook(fd, buf, count);
}
ssize_t write(int fd, const void *buf, size_t count) {
if (tinyrpc::isEnableHook == false) {
return g_sys_write_fun(fd, buf, count); // 没有启用 hook 直接转发到系统调用
}
return tinyrpc::write_hook(fd, buf, count);
}
int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
if (tinyrpc::isEnableHook == false) {
return g_sys_accept_fun(sockfd, addr, addrlen); // 没有启用 hook 直接转发到系统调用
}
return tinyrpc::accept_hook(sockfd, addr, addrlen);
}