From 10e47248de3bf5abe245a6baa268d5ce81e9fd32 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Sun, 29 Dec 2024 02:48:05 -0800 Subject: [PATCH] Use finer grain sqlite3 APIs --- native/src/core/bootstages.cpp | 3 +- native/src/core/db.cpp | 215 +++++++++++++++++++++------------ native/src/core/include/db.hpp | 1 + 3 files changed, 138 insertions(+), 81 deletions(-) diff --git a/native/src/core/bootstages.cpp b/native/src/core/bootstages.cpp index b055cbb44..6b54296ce 100644 --- a/native/src/core/bootstages.cpp +++ b/native/src/core/bootstages.cpp @@ -154,7 +154,6 @@ bool MagiskD::post_fs_data() const noexcept { LOGI("** post-fs-data mode running\n"); preserve_stub_apk(); - prune_su_access(); bool safe_mode = false; @@ -168,6 +167,8 @@ bool MagiskD::post_fs_data() const noexcept { } } + prune_su_access(); + if (!magisk_env()) { LOGE("* Magisk environment incomplete, abort\n"); safe_mode = true; diff --git a/native/src/core/db.cpp b/native/src/core/db.cpp index 5911eb01b..631d9b351 100644 --- a/native/src/core/db.cpp +++ b/native/src/core/db.cpp @@ -12,6 +12,7 @@ using namespace std; struct sqlite3; +struct sqlite3_stmt; static sqlite3 *mDB = nullptr; @@ -24,12 +25,22 @@ static sqlite3 *mDB = nullptr; #define SQLITE_OPEN_CREATE 0x00000004 /* Ok for sqlite3_open_v2() */ #define SQLITE_OPEN_FULLMUTEX 0x00010000 /* Ok for sqlite3_open_v2() */ -using sqlite3_callback = int (*)(void*, int, char**, char**); +#define SQLITE_OK 0 /* Successful result */ +#define SQLITE_ROW 100 /* sqlite3_step() has another row ready */ +#define SQLITE_DONE 101 /* sqlite3_step() has finished executing */ + static int (*sqlite3_open_v2)(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs); -static const char *(*sqlite3_errmsg)(sqlite3 *db); static int (*sqlite3_close)(sqlite3 *db); -static void (*sqlite3_free)(void *v); -static int (*sqlite3_exec)(sqlite3 *db, const char *sql, sqlite3_callback fn, void *v, char **errmsg); +static int (*sqlite3_prepare_v2)(sqlite3 *db, const char *zSql, int nByte, sqlite3_stmt **ppStmt, const char **pzTail); +static int (*sqlite3_bind_parameter_count)(sqlite3_stmt*); +static int (*sqlite3_bind_int)(sqlite3_stmt*, int, int); +static int (*sqlite3_bind_text)(sqlite3_stmt*,int,const char*,int,void(*)(void*)); +static int (*sqlite3_column_count)(sqlite3_stmt *pStmt); +static const char *(*sqlite3_column_name)(sqlite3_stmt*, int N); +static const char *(*sqlite3_column_text)(sqlite3_stmt*, int iCol); +static int (*sqlite3_step)(sqlite3_stmt*); +static int (*sqlite3_finalize)(sqlite3_stmt *pStmt); +static const char *(*sqlite3_errstr)(int); // Internal Android linker APIs @@ -53,9 +64,8 @@ constexpr char apex_path[] = "/apex/com.android.runtime/lib64:/apex/com.android. constexpr char apex_path[] = "/apex/com.android.runtime/lib:/apex/com.android.art/lib:/apex/com.android.i18n/lib:"; #endif -static int dl_init = 0; - static bool dload_sqlite() { + static int dl_init = 0; if (dl_init) return dl_init > 0; dl_init = -1; @@ -83,10 +93,17 @@ static bool dload_sqlite() { DLERR(sqlite); DLOAD(sqlite, sqlite3_open_v2); - DLOAD(sqlite, sqlite3_errmsg); DLOAD(sqlite, sqlite3_close); - DLOAD(sqlite, sqlite3_exec); - DLOAD(sqlite, sqlite3_free); + DLOAD(sqlite, sqlite3_prepare_v2); + DLOAD(sqlite, sqlite3_bind_parameter_count); + DLOAD(sqlite, sqlite3_bind_int); + DLOAD(sqlite, sqlite3_bind_text); + DLOAD(sqlite, sqlite3_step); + DLOAD(sqlite, sqlite3_column_count); + DLOAD(sqlite, sqlite3_column_name); + DLOAD(sqlite, sqlite3_column_text); + DLOAD(sqlite, sqlite3_finalize); + DLOAD(sqlite, sqlite3_errstr); dl_init = 1; return true; @@ -122,74 +139,118 @@ int db_settings::get_idx(string_view key) const { return idx; } -static int ver_cb(void *ver, int, char **data, char **) { - *((int *) ver) = parse_int(data[0]); - return 0; +static void ver_cb(void *ver, auto, rust::Slice data) { + *((int *) ver) = parse_int(data[0].c_str()); } +db_result::db_result(int code) : err(code == SQLITE_OK ? "" : (sqlite3_errstr(code) ?: "")) {} + bool db_result::check_err() { if (!err.empty()) { - LOGE("sqlite3_exec: %s\n", err.data()); + LOGE("sqlite3: %s\n", err.data()); return true; } return false; } -static db_result sql_exec(sqlite3 *db, const char *sql, sqlite3_callback fn, void *v) { - char *err = nullptr; - sqlite3_exec(db, sql, fn, v, &err); - if (err) { - db_result r = err; - sqlite3_free(err); - return r; +using StringVec = rust::Vec; +using StringSlice = rust::Slice; +using StrSlice = rust::Slice; +using sqlite_row_callback = void(*)(void*, StringSlice, StringSlice); + +#define fn_run_ret(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc + +static int sql_exec(sqlite3 *db, rust::Str zSql, StrSlice args, sqlite_row_callback callback, void *v) { + const char *sql = zSql.begin(); + auto arg_it = args.begin(); + unique_ptr stmt(nullptr, sqlite3_finalize); + + while (sql != zSql.end()) { + // Step 1: prepare statement + { + sqlite3_stmt *st = nullptr; + fn_run_ret(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql); + if (st == nullptr) continue; + stmt.reset(st); + } + + // Step 2: bind arguments + if (int count = sqlite3_bind_parameter_count(stmt.get())) { + for (int i = 1; i <= count && arg_it != args.end(); ++i, ++arg_it) { + fn_run_ret(sqlite3_bind_text, stmt.get(), i, arg_it->data(), arg_it->size(), nullptr); + } + } + + // Step 3: execute + bool first = true; + StringVec columns; + for (;;) { + int rc = sqlite3_step(stmt.get()); + if (rc == SQLITE_DONE) break; + if (rc != SQLITE_ROW) return rc; + if (callback == nullptr) continue; + if (first) { + int count = sqlite3_column_count(stmt.get()); + for (int i = 0; i < count; ++i) { + columns.emplace_back(sqlite3_column_name(stmt.get(), i)); + } + first = false; + } + StringVec data; + for (int i = 0; i < columns.size(); ++i) { + data.emplace_back(sqlite3_column_text(stmt.get(), i)); + } + callback(v, StringSlice(columns), StringSlice(data)); + } } - return {}; + + return SQLITE_OK; } -#define sql_exe_ret(...) if (auto r = sql_exec(__VA_ARGS__); !r) return r -#define fn_run_ret(fn) if (auto r = fn(); !r) return r +static int sql_exec(sqlite3 *db, const char *sql, sqlite_row_callback callback = nullptr, void *v = nullptr) { + return sql_exec(db, sql, {}, callback, v); +} -static db_result open_and_init_db(sqlite3 *&db) { +static db_result open_and_init_db() { if (!dload_sqlite()) return "Cannot load libsqlite.so"; - int ret = sqlite3_open_v2(MAGISKDB, &db, - SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr); - if (ret) - return sqlite3_errmsg(db); + unique_ptr db(nullptr, sqlite3_close); + { + sqlite3 *sql; + fn_run_ret(sqlite3_open_v2, MAGISKDB, &sql, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr); + db.reset(sql); + } + int ver = 0; bool upgrade = false; - sql_exe_ret(db, "PRAGMA user_version", ver_cb, &ver); + fn_run_ret(sql_exec, db.get(), "PRAGMA user_version", ver_cb, &ver); if (ver > DB_VERSION) { // Don't support downgrading database - sqlite3_close(db); return "Downgrading database is not supported"; } auto create_policy = [&] { - return sql_exec(db, + return sql_exec(db.get(), "CREATE TABLE IF NOT EXISTS policies " "(uid INT, policy INT, until INT, logging INT, " - "notification INT, PRIMARY KEY(uid))", - nullptr, nullptr); + "notification INT, PRIMARY KEY(uid))"); }; auto create_settings = [&] { - return sql_exec(db, + return sql_exec(db.get(), "CREATE TABLE IF NOT EXISTS settings " - "(key TEXT, value INT, PRIMARY KEY(key))", - nullptr, nullptr); + "(key TEXT, value INT, PRIMARY KEY(key))"); }; auto create_strings = [&] { - return sql_exec(db, + return sql_exec(db.get(), "CREATE TABLE IF NOT EXISTS strings " - "(key TEXT, value TEXT, PRIMARY KEY(key))", - nullptr, nullptr); + "(key TEXT, value TEXT, PRIMARY KEY(key))"); }; auto create_denylist = [&] { - return sql_exec(db, + return sql_exec(db.get(), "CREATE TABLE IF NOT EXISTS denylist " - "(package_name TEXT, process TEXT, PRIMARY KEY(package_name, process))", - nullptr, nullptr); + "(package_name TEXT, process TEXT, PRIMARY KEY(package_name, process))"); }; // Database changelog: @@ -214,48 +275,45 @@ static db_result open_and_init_db(sqlite3 *&db) { upgrade = true; } if (ver == 7) { - sql_exe_ret(db, + fn_run_ret(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;" "CREATE TABLE IF NOT EXISTS hidelist " "(package_name TEXT, process TEXT, PRIMARY KEY(package_name, process));" "INSERT INTO hidelist SELECT process as package_name, process FROM hidelist_tmp;" "DROP TABLE hidelist_tmp;" - "COMMIT;", - nullptr, nullptr); + "COMMIT;"); // Directly jump to version 9 ver = 9; upgrade = true; } if (ver == 8) { - sql_exe_ret(db, + fn_run_ret(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;" "CREATE TABLE IF NOT EXISTS hidelist " "(package_name TEXT, process TEXT, PRIMARY KEY(package_name, process));" "INSERT INTO hidelist SELECT * FROM hidelist_tmp;" "DROP TABLE hidelist_tmp;" - "COMMIT;", - nullptr, nullptr); + "COMMIT;"); ver = 9; upgrade = true; } if (ver == 9) { - sql_exe_ret(db, "DROP TABLE IF EXISTS logs", nullptr, nullptr); + fn_run_ret(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr); ver = 10; upgrade = true; } if (ver == 10) { - sql_exe_ret(db, + fn_run_ret(sql_exec, db.get(), "DROP TABLE IF EXISTS hidelist;" - "DELETE FROM settings WHERE key='magiskhide';", - nullptr, nullptr); + "DELETE FROM settings WHERE key='magiskhide';"); fn_run_ret(create_denylist); ver = 11; upgrade = true; } if (ver == 11) { - sql_exe_ret(db, + fn_run_ret(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE policies RENAME TO policies_tmp;" "CREATE TABLE IF NOT EXISTS policies " @@ -264,8 +322,7 @@ static db_result open_and_init_db(sqlite3 *&db) { "INSERT INTO policies " "SELECT uid, policy, until, logging, notification FROM policies_tmp;" "DROP TABLE policies_tmp;" - "COMMIT;", - nullptr, nullptr); + "COMMIT;"); ver = 12; upgrade = true; } @@ -274,47 +331,45 @@ static db_result open_and_init_db(sqlite3 *&db) { // Set version char query[32]; sprintf(query, "PRAGMA user_version=%d", ver); - sql_exe_ret(db, query, nullptr, nullptr); + fn_run_ret(sql_exec, db.get(), query); + } + mDB = db.release(); + return {}; +} + +static db_result ensure_db() { + if (mDB == nullptr) { + auto res = open_and_init_db(); + if (res.check_err()) { + // Open fails, remove and reconstruct + unlink(MAGISKDB); + res = open_and_init_db(); + if (!res) return res; + } } return {}; } db_result db_exec(const char *sql) { - if (mDB == nullptr) { - auto res = open_and_init_db(mDB); - if (res.check_err()) { - // Open fails, remove and reconstruct - unlink(MAGISKDB); - res = open_and_init_db(mDB); - if (!res) return res; - } - } + if (auto res = ensure_db(); !res) return res; if (mDB) { - sql_exe_ret(mDB, sql, nullptr, nullptr); + return sql_exec(mDB, sql); } return {}; } -static int sqlite_db_row_callback(void *cb, int col_num, char **data, char **col_name) { +static void row_to_db_row(void *cb, rust::Slice columns, rust::Slice data) { auto &func = *static_cast(cb); db_row row; - for (int i = 0; i < col_num; ++i) - row[col_name[i]] = data[i]; - return func(row) ? 0 : 1; + for (int i = 0; i < columns.size(); ++i) + row[columns[i].c_str()] = data[i].c_str(); + func(row); } db_result db_exec(const char *sql, const db_row_cb &fn) { - if (mDB == nullptr) { - auto res = open_and_init_db(mDB); - if (res.check_err()) { - // Open fails, remove and reconstruct - unlink(MAGISKDB); - res = open_and_init_db(mDB); - if (!res) return res; - } - } + if (auto res = ensure_db(); !res) return res; if (mDB) { - sql_exe_ret(mDB, sql, sqlite_db_row_callback, (void *) &fn); + return sql_exec(mDB, sql, row_to_db_row, (void *) &fn); } return {}; } diff --git a/native/src/core/include/db.hpp b/native/src/core/include/db.hpp index 9198f915a..34b6320e3 100644 --- a/native/src/core/include/db.hpp +++ b/native/src/core/include/db.hpp @@ -129,6 +129,7 @@ struct owned_fd; struct db_result { db_result() = default; db_result(const char *s) : err(s) {} + db_result(int code); bool check_err(); operator bool() { return err.empty(); } private: