fix: zygiskd companion, companion responses, write fd function and early client close

This commit fixes numerous issues in zygiskd code: The zygiskd companion code not loading the right entry, the companion not sending the correct responses, the write fd function not working properly and early client close when connecting to the companion.
This commit is contained in:
ThePedroo
2024-09-08 15:53:47 -03:00
parent a549f0e5ae
commit c2abef8826
12 changed files with 721 additions and 603 deletions

View File

@@ -1,221 +1,263 @@
#include <linux/un.h>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <dirent.h>
#include <fcntl.h>
#include "daemon.h"
#include "dl.h"
#include "socket_utils.h"
namespace zygiskd {
static std::string TMP_PATH;
void Init(const char *path) {
TMP_PATH = path;
static std::string TMP_PATH;
void Init(const char *path) {
TMP_PATH = path;
}
std::string GetTmpPath() {
return TMP_PATH;
}
int Connect(uint8_t retry) {
retry++;
int fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
struct sockaddr_un addr = {
.sun_family = AF_UNIX,
.sun_path = { 0 }
};
auto socket_path = TMP_PATH + kCPSocketName;
strcpy(addr.sun_path, socket_path.c_str());
socklen_t socklen = sizeof(addr);
while (--retry) {
int r = connect(fd, (struct sockaddr *)&addr, socklen);
if (r == 0) return fd;
if (retry) {
PLOGE("Retrying to connect to zygiskd, sleep 1s");
sleep(1);
}
}
std::string GetTmpPath() {
return TMP_PATH;
close(fd);
return -1;
}
bool PingHeartbeat() {
int fd = Connect(5);
if (fd == -1) {
PLOGE("Connect to zygiskd");
return false;
}
int Connect(uint8_t retry) {
int fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
struct sockaddr_un addr = {
.sun_family = AF_UNIX,
.sun_path = { 0 },
};
socket_utils::write_u8(fd, (uint8_t) SocketAction::PingHeartBeat);
auto socket_path = TMP_PATH + kCPSocketName;
strcpy(addr.sun_path, socket_path.c_str());
socklen_t socklen = sizeof(addr);
close(fd);
while (retry--) {
int r = connect(fd, reinterpret_cast<struct sockaddr*>(&addr), socklen);
if (r == 0) return fd;
if (retry) {
PLOGE("Retrying to connect to zygiskd, sleep 1s");
sleep(1);
}
}
return true;
}
int RequestLogcatFd() {
int fd = Connect(1);
if (fd == -1) {
PLOGE("RequestLogcatFd");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::RequestLogcatFd);
return fd;
}
uint32_t GetProcessFlags(uid_t uid) {
int fd = Connect(1);
if (fd == -1) {
PLOGE("GetProcessFlags");
return 0;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetProcessFlags);
socket_utils::write_u32(fd, uid);
uint32_t res = socket_utils::read_u32(fd);
close(fd);
return res;
}
std::vector<Module> ReadModules() {
std::vector<Module> modules;
int fd = Connect(1);
if (fd == -1) {
PLOGE("ReadModules");
return modules;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::ReadModules);
size_t len = socket_utils::read_usize(fd);
for (size_t i = 0; i < len; i++) {
std::string name = socket_utils::read_string(fd);
int module_fd = socket_utils::recv_fd(fd);
modules.emplace_back(name, module_fd);
}
close(fd);
return modules;
}
int ConnectCompanion(size_t index) {
int fd = Connect(1);
if (fd == -1) {
PLOGE("ConnectCompanion");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::RequestCompanionSocket);
socket_utils::write_usize(fd, index);
uint8_t res = socket_utils::read_u8(fd);
if (res == 1) return fd;
else {
close(fd);
return -1;
}
}
int GetModuleDir(size_t index) {
int fd = Connect(1);
if (fd == -1) {
PLOGE("GetModuleDir");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetModuleDir);
socket_utils::write_usize(fd, index);
int nfd = socket_utils::recv_fd(fd);
close(fd);
return nfd;
}
void ZygoteRestart() {
int fd = Connect(1);
if (fd == -1) {
if (errno == ENOENT) LOGD("Could not notify ZygoteRestart (maybe it hasn't been created)");
else PLOGE("Could not notify ZygoteRestart");
return;
}
if (!socket_utils::write_u8(fd, (uint8_t) SocketAction::ZygoteRestart))
PLOGE("Failed to request ZygoteRestart");
close(fd);
}
void SystemServerStarted() {
int fd = Connect(1);
if (fd == -1) PLOGE("Failed to report system server started");
else {
if (!socket_utils::write_u8(fd, (uint8_t) SocketAction::SystemServerStarted))
PLOGE("Failed to report system server started");
}
close(fd);
}
void GetInfo(struct zygote_info *info) {
/* TODO: Optimize and avoid re-connect twice here */
int fd = Connect(1);
if (fd != -1) {
info->running = true;
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetInfo);
int flags = socket_utils::read_u32(fd);
if (flags & (1 << 27)) {
info->root_impl = ZYGOTE_ROOT_IMPL_APATCH;
} else if (flags & (1 << 29)) {
info->root_impl = ZYGOTE_ROOT_IMPL_KERNELSU;
} else if (flags & (1 << 30)) {
info->root_impl = ZYGOTE_ROOT_IMPL_MAGISK;
} else {
info->root_impl = ZYGOTE_ROOT_IMPL_NONE;
}
info->pid = socket_utils::read_u32(fd);
info->modules = (struct zygote_modules *)malloc(sizeof(struct zygote_modules));
if (info->modules == NULL) {
info->modules->modules_count = 0;
close(fd);
return -1;
}
bool PingHeartbeat() {
UniqueFd fd = Connect(5);
if (fd == -1) {
PLOGE("Connect to zygiskd");
return false;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::PingHeartBeat);
return true;
}
return;
}
int RequestLogcatFd() {
int fd = Connect(1);
if (fd == -1) {
PLOGE("RequestLogcatFd");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::RequestLogcatFd);
return fd;
}
info->modules->modules_count = socket_utils::read_usize(fd);
uint32_t GetProcessFlags(uid_t uid) {
UniqueFd fd = Connect(1);
if (fd == -1) {
PLOGE("GetProcessFlags");
return 0;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetProcessFlags);
socket_utils::write_u32(fd, uid);
if (info->modules->modules_count == 0) {
info->modules->modules = NULL;
return socket_utils::read_u32(fd);
}
close(fd);
std::vector<Module> ReadModules() {
std::vector<Module> modules;
UniqueFd fd = Connect(1);
if (fd == -1) {
PLOGE("ReadModules");
return modules;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::ReadModules);
size_t len = socket_utils::read_usize(fd);
for (size_t i = 0; i < len; i++) {
std::string name = socket_utils::read_string(fd);
int module_fd = socket_utils::recv_fd(fd);
modules.emplace_back(name, module_fd);
}
return modules;
}
return;
}
int ConnectCompanion(size_t index) {
int fd = Connect(1);
if (fd == -1) {
PLOGE("ConnectCompanion");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::RequestCompanionSocket);
socket_utils::write_usize(fd, index);
if (socket_utils::read_u8(fd) == 1) {
return fd;
info->modules->modules = (char **)malloc(sizeof(char *) * info->modules->modules_count);
if (info->modules->modules == NULL) {
free(info->modules);
info->modules = NULL;
info->modules->modules_count = 0;
close(fd);
return;
}
for (size_t i = 0; i < info->modules->modules_count; i++) {
/* INFO by ThePedroo: Ugly solution to read with std::string existance (temporary) */
std::string name = socket_utils::read_string(fd);
char module_path[PATH_MAX];
snprintf(module_path, sizeof(module_path), "/data/adb/modules/%s/module.prop", name.c_str());
FILE *module_prop = fopen(module_path, "r");
if (module_prop == NULL) {
info->modules->modules[i] = strdup(name.c_str());
} else {
close(fd);
return -1;
}
}
char line[1024];
while (fgets(line, sizeof(line), module_prop) != NULL) {
if (strncmp(line, "name=", 5) == 0) {
info->modules->modules[i] = strndup(line + 5, strlen(line) - 6);
int GetModuleDir(size_t index) {
UniqueFd fd = Connect(1);
if (fd == -1) {
PLOGE("GetModuleDir");
return -1;
}
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetModuleDir);
socket_utils::write_usize(fd, index);
return socket_utils::recv_fd(fd);
}
void ZygoteRestart() {
UniqueFd fd = Connect(1);
if (fd == -1) {
if (errno == ENOENT) {
LOGD("Could not notify ZygoteRestart (maybe it hasn't been created)");
} else {
PLOGE("Could not notify ZygoteRestart");
break;
}
return;
}
if (!socket_utils::write_u8(fd, (uint8_t) SocketAction::ZygoteRestart)) {
PLOGE("Failed to request ZygoteRestart");
}
}
void SystemServerStarted() {
UniqueFd fd = Connect(1);
if (fd == -1) {
PLOGE("Failed to report system server started");
} else {
if (!socket_utils::write_u8(fd, (uint8_t) SocketAction::SystemServerStarted)) {
PLOGE("Failed to report system server started");
}
}
}
void GetInfo(struct zygote_info *info) {
/* TODO: Optimize and avoid re-connect twice here */
int fd = Connect(1);
if (fd != -1) {
socket_utils::write_u8(fd, (uint8_t) SocketAction::GetInfo);
int flags = socket_utils::read_u32(fd);
if (flags & (1 << 27)) {
info->root_impl = ZYGOTE_ROOT_IMPL_APATCH;
} else if (flags & (1 << 29)) {
info->root_impl = ZYGOTE_ROOT_IMPL_KERNELSU;
} else if (flags & (1 << 30)) {
info->root_impl = ZYGOTE_ROOT_IMPL_MAGISK;
} else {
info->root_impl = ZYGOTE_ROOT_IMPL_NONE;
}
info->modules = (struct zygote_modules *)malloc(sizeof(struct zygote_modules));
if (info->modules == NULL) {
info->modules->modules_count = 0;
close(fd);
return;
}
info->modules->modules_count = socket_utils::read_usize(fd);
if (info->modules->modules_count == 0) {
info->modules->modules = NULL;
close(fd);
return;
}
info->modules->modules = (char **)malloc(sizeof(char *) * info->modules->modules_count);
if (info->modules->modules == NULL) {
free(info->modules);
info->modules = NULL;
info->modules->modules_count = 0;
close(fd);
return;
}
for (size_t i = 0; i < info->modules->modules_count; i++) {
/* INFO by ThePedroo: Ugly solution to read with std::string existance (temporary) */
std::string name = socket_utils::read_string(fd);
char module_path[PATH_MAX];
snprintf(module_path, sizeof(module_path), "/data/adb/modules/%s/module.prop", name.c_str());
FILE *module_prop = fopen(module_path, "r");
if (module_prop == NULL) {
info->modules->modules[i] = strdup(name.c_str());
} else {
char line[1024];
while (fgets(line, sizeof(line), module_prop) != NULL) {
if (strncmp(line, "name=", 5) == 0) {
info->modules->modules[i] = strndup(line + 5, strlen(line) - 6);
break;
}
}
fclose(module_prop);
}
}
close(fd);
} else info->running = false;
}
fclose(module_prop);
}
}
close(fd);
} else info->running = false;
}
}

View File

@@ -5,131 +5,133 @@
#include "socket_utils.h"
namespace socket_utils {
ssize_t xread(int fd, void* buf, size_t count) {
size_t read_sz = 0;
ssize_t ret;
do {
ret = read(fd, (std::byte*) buf + read_sz, count - read_sz);
if (ret < 0) {
if (errno == EINTR) continue;
PLOGE("read");
return ret;
}
read_sz += ret;
} while (read_sz != count && ret != 0);
if (read_sz != count) {
PLOGE("read (%zu != %zu)", count, read_sz);
}
return read_sz;
}
ssize_t xread(int fd, void* buf, size_t count) {
size_t read_sz = 0;
ssize_t ret;
do {
ret = read(fd, (std::byte*) buf + read_sz, count - read_sz);
if (ret < 0) {
if (errno == EINTR) continue;
PLOGE("read");
return ret;
}
read_sz += ret;
} while (read_sz != count && ret != 0);
if (read_sz != count) {
PLOGE("read (%zu != %zu)", count, read_sz);
size_t xwrite(int fd, const void* buf, size_t count) {
size_t write_sz = 0;
ssize_t ret;
do {
ret = write(fd, (std::byte*) buf + write_sz, count - write_sz);
if (ret < 0) {
if (errno == EINTR) continue;
PLOGE("write");
return write_sz;
}
return read_sz;
write_sz += ret;
} while (write_sz != count && ret != 0);
if (write_sz != count) {
PLOGE("write (%zu != %zu)", count, write_sz);
}
return write_sz;
}
size_t xwrite(int fd, const void* buf, size_t count) {
size_t write_sz = 0;
ssize_t ret;
do {
ret = write(fd, (std::byte*) buf + write_sz, count - write_sz);
if (ret < 0) {
if (errno == EINTR) continue;
PLOGE("write");
return write_sz;
}
write_sz += ret;
} while (write_sz != count && ret != 0);
if (write_sz != count) {
PLOGE("write (%zu != %zu)", count, write_sz);
}
return write_sz;
}
ssize_t xrecvmsg(int sockfd, struct msghdr* msg, int flags) {
int rec = recvmsg(sockfd, msg, flags);
if (rec < 0) PLOGE("recvmsg");
return rec;
}
ssize_t xrecvmsg(int sockfd, struct msghdr* msg, int flags) {
int rec = recvmsg(sockfd, msg, flags);
if (rec < 0) PLOGE("recvmsg");
return rec;
}
template<typename T>
inline T read_exact_or(int fd, T fail) {
T res;
return sizeof(T) == xread(fd, &res, sizeof(T)) ? res : fail;
}
void* recv_fds(int sockfd, char* cmsgbuf, size_t bufsz, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cmsgbuf,
.msg_controllen = bufsz
};
template<typename T>
inline bool write_exact(int fd, T val) {
return sizeof(T) == xwrite(fd, &val, sizeof(T));
}
xrecvmsg(sockfd, &msg, MSG_WAITALL);
cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
uint8_t read_u8(int fd) {
return read_exact_or<uint8_t>(fd, 0);
}
if (msg.msg_controllen != bufsz ||
cmsg == nullptr ||
// TODO: pass from rust: 20, expected: 16
// cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt) ||
cmsg->cmsg_level != SOL_SOCKET ||
cmsg->cmsg_type != SCM_RIGHTS) {
return nullptr;
}
uint32_t read_u32(int fd) {
return read_exact_or<uint32_t>(fd, 0);
}
return CMSG_DATA(cmsg);
}
size_t read_usize(int fd) {
return read_exact_or<size_t>(fd, 0);
}
template<typename T>
inline T read_exact_or(int fd, T fail) {
T res;
return sizeof(T) == xread(fd, &res, sizeof(T)) ? res : fail;
}
bool write_usize(int fd, size_t val) {
return write_exact<size_t>(fd, val);
}
template<typename T>
inline bool write_exact(int fd, T val) {
return sizeof(T) == xwrite(fd, &val, sizeof(T));
}
std::string read_string(int fd) {
size_t len = read_usize(fd);
uint8_t read_u8(int fd) {
return read_exact_or<uint8_t>(fd, 0);
}
char buf[len + 1];
xread(fd, buf, len);
uint32_t read_u32(int fd) {
return read_exact_or<uint32_t>(fd, 0);
}
buf[len] = '\0';
size_t read_usize(int fd) {
return read_exact_or<size_t>(fd, 0);
}
return buf;
}
bool write_usize(int fd, size_t val) {
return write_exact<size_t>(fd, val);
}
bool write_u8(int fd, uint8_t val) {
return write_exact<uint8_t>(fd, val);
}
std::string read_string(int fd) {
auto len = read_usize(fd);
char buf[len + 1];
buf[len] = '\0';
xread(fd, buf, len);
return buf;
}
void* recv_fds(int sockfd, char* cmsgbuf, size_t bufsz, int cnt) {
iovec iov = {
.iov_base = &cnt,
.iov_len = sizeof(cnt),
};
msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cmsgbuf,
.msg_controllen = bufsz
};
bool write_u8(int fd, uint8_t val) {
return write_exact<uint8_t>(fd, val);
}
xrecvmsg(sockfd, &msg, MSG_WAITALL);
cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
bool write_u32(int fd, uint32_t val) {
return write_exact<uint32_t>(fd, val);
}
if (msg.msg_controllen != bufsz ||
cmsg == nullptr ||
// TODO: pass from rust: 20, expected: 16
// cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt) ||
cmsg->cmsg_level != SOL_SOCKET ||
cmsg->cmsg_type != SCM_RIGHTS) {
return nullptr;
}
bool write_string(int fd, std::string_view str) {
return write_usize(fd, str.size()) && str.size() == xwrite(fd, str.data(), str.size());
}
return CMSG_DATA(cmsg);
}
int recv_fd(int sockfd) {
char cmsgbuf[CMSG_SPACE(sizeof(int))];
char cmsgbuf[CMSG_SPACE(sizeof(int))];
void* data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1);
if (data == nullptr) return -1;
void* data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1);
if (data == nullptr) return -1;
int result;
memcpy(&result, data, sizeof(int));
return result;
}
int result;
memcpy(&result, data, sizeof(int));
return result;
}
bool write_u32(int fd, uint32_t val) {
return write_exact<uint32_t>(fd, val);
}
bool write_string(int fd, std::string_view str) {
return write_usize(fd, str.size()) && str.size() == xwrite(fd, str.data(), str.size());
}
}