Cleanup and fix compress/decompress command

This commit is contained in:
topjohnwu
2025-08-22 22:21:30 -07:00
committed by John Wu
parent c1491b8d2b
commit 7b706bb0cb
5 changed files with 94 additions and 83 deletions

View File

@@ -116,6 +116,26 @@ impl<T: Write> WriteExt for T {
} }
} }
pub enum FileOrStd {
StdIn,
StdOut,
StdErr,
File(File),
}
impl FileOrStd {
pub fn as_file(&self) -> &File {
let raw_fd_ref: &'static RawFd = match self {
FileOrStd::StdIn => &0,
FileOrStd::StdOut => &1,
FileOrStd::StdErr => &2,
FileOrStd::File(file) => return file,
};
// SAFETY: File is guaranteed to have the same ABI as RawFd
unsafe { mem::transmute(raw_fd_ref) }
}
}
fn open_fd(path: &Utf8CStr, flags: i32, mode: mode_t) -> OsResult<OwnedFd> { fn open_fd(path: &Utf8CStr, flags: i32, mode: mode_t) -> OsResult<OwnedFd> {
unsafe { unsafe {
let fd = libc::open(path.as_ptr(), flags, mode as c_uint).as_os_result( let fd = libc::open(path.as_ptr(), flags, mode as c_uint).as_os_result(

View File

@@ -1,4 +1,4 @@
use crate::compress::{compress, decompress}; use crate::compress::{compress_cmd, decompress_cmd};
use crate::cpio::{cpio_commands, print_cpio_usage}; use crate::cpio::{cpio_commands, print_cpio_usage};
use crate::dtb::{DtbAction, dtb_commands, print_dtb_usage}; use crate::dtb::{DtbAction, dtb_commands, print_dtb_usage};
use crate::ffi::{BootImage, FileFormat, cleanup, repack, split_image_dtb, unpack}; use crate::ffi::{BootImage, FileFormat, cleanup, repack, split_image_dtb, unpack};
@@ -414,14 +414,14 @@ fn boot_main(cmds: CmdArgs) -> LoggedResult<i32> {
cleanup(); cleanup();
} }
Action::Decompress(Decompress { mut file, mut out }) => { Action::Decompress(Decompress { mut file, mut out }) => {
decompress(&mut file, out.as_mut())?; decompress_cmd(&mut file, out.as_mut())?;
} }
Action::Compress(Compress { Action::Compress(Compress {
ref mut file, ref mut file,
ref format, ref format,
ref mut out, ref mut out,
}) => { }) => {
compress( compress_cmd(
FileFormat::from_str(format).unwrap_or(FileFormat::UNKNOWN), FileFormat::from_str(format).unwrap_or(FileFormat::UNKNOWN),
file, file,
out.as_mut(), out.as_mut(),

View File

@@ -1,6 +1,6 @@
use crate::ffi::{FileFormat, check_fmt}; use crate::ffi::{FileFormat, check_fmt};
use base::libc::{O_RDONLY, O_TRUNC, O_WRONLY}; use base::libc::{O_RDONLY, O_TRUNC, O_WRONLY};
use base::{Chunker, LoggedResult, Utf8CStr, WriteExt, error, log_err}; use base::{Chunker, FileOrStd, LoggedResult, Utf8CStr, Utf8CString, WriteExt, error, log_err};
use bytemuck::bytes_of_mut; use bytemuck::bytes_of_mut;
use bzip2::{Compression as BzCompression, write::BzDecoder, write::BzEncoder}; use bzip2::{Compression as BzCompression, write::BzDecoder, write::BzEncoder};
use flate2::{Compression as GzCompression, write::GzEncoder, write::MultiGzDecoder}; use flate2::{Compression as GzCompression, write::GzEncoder, write::MultiGzDecoder};
@@ -9,12 +9,13 @@ use lz4::{
EncoderBuilder as LZ4FrameEncoderBuilder, block::CompressionMode, liblz4::BlockChecksum, EncoderBuilder as LZ4FrameEncoderBuilder, block::CompressionMode, liblz4::BlockChecksum,
}; };
use std::cell::Cell; use std::cell::Cell;
use std::fmt::Write as FmtWrite;
use std::fs::File; use std::fs::File;
use std::io::{BufWriter, Read, Write, stdin, stdout}; use std::io::{BufWriter, Read, Write};
use std::mem::ManuallyDrop; use std::mem::ManuallyDrop;
use std::num::NonZeroU64; use std::num::NonZeroU64;
use std::ops::DerefMut; use std::ops::DerefMut;
use std::os::fd::{AsFd, AsRawFd, FromRawFd, RawFd}; use std::os::fd::{FromRawFd, RawFd};
use xz2::{ use xz2::{
stream::{Check as LzmaCheck, Filters as LzmaFilters, LzmaOptions, Stream as LzmaStream}, stream::{Check as LzmaCheck, Filters as LzmaFilters, LzmaOptions, Stream as LzmaStream},
write::{XzDecoder, XzEncoder}, write::{XzDecoder, XzEncoder},
@@ -383,29 +384,6 @@ pub fn get_decoder<'a, W: Write + 'a>(format: FileFormat, w: W) -> Box<dyn Write
// C++ FFI // C++ FFI
pub fn compress_fd(format: FileFormat, in_fd: RawFd, out_fd: RawFd) {
let mut in_file = unsafe { ManuallyDrop::new(File::from_raw_fd(in_fd)) };
let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) };
let mut encoder = get_encoder(format, out_file.deref_mut());
let _: LoggedResult<()> = try {
std::io::copy(in_file.deref_mut(), encoder.as_mut())?;
encoder.finish()?;
};
}
pub fn decompress_bytes_fd(format: FileFormat, in_bytes: &[u8], in_fd: RawFd, out_fd: RawFd) {
let mut in_file = unsafe { ManuallyDrop::new(File::from_raw_fd(in_fd)) };
let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) };
let mut decoder = get_decoder(format, out_file.deref_mut());
let _: LoggedResult<()> = try {
decoder.write_all(in_bytes)?;
std::io::copy(in_file.deref_mut(), decoder.as_mut())?;
decoder.finish()?;
};
}
pub fn compress_bytes(format: FileFormat, in_bytes: &[u8], out_fd: RawFd) { pub fn compress_bytes(format: FileFormat, in_bytes: &[u8], out_fd: RawFd) {
let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) }; let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) };
@@ -426,22 +404,31 @@ pub fn decompress_bytes(format: FileFormat, in_bytes: &[u8], out_fd: RawFd) {
}; };
} }
pub(crate) fn decompress(infile: &mut String, outfile: Option<&mut String>) -> LoggedResult<()> { // Command-line entry points
pub(crate) fn decompress_cmd(
infile: &mut String,
outfile: Option<&mut String>,
) -> LoggedResult<()> {
let infile = Utf8CStr::from_string(infile);
let outfile = outfile.map(Utf8CStr::from_string);
let in_std = infile == "-"; let in_std = infile == "-";
let mut rm_in = false; let mut rm_in = false;
let mut buf = [0u8; 4096]; let mut buf = [0u8; 4096];
let raw_in = if in_std {
super let mut stdin = stdin(); let input = if in_std {
let _ = stdin.read(&mut buf)?; FileOrStd::StdIn
stdin.as_fd()
} else { } else {
super let mut infile = Utf8CStr::from_string(infile).open(O_RDONLY)?; FileOrStd::File(infile.open(O_RDONLY)?)
let _ = infile.read(&mut buf)?;
infile.as_fd()
}; };
let format = check_fmt(&buf); // First read some bytes for format detection
let len = input.as_file().read(&mut buf)?;
let buf = &buf[..len];
let format = check_fmt(buf);
eprintln!("Detected format: {format}"); eprintln!("Detected format: {format}");
@@ -449,44 +436,46 @@ pub(crate) fn decompress(infile: &mut String, outfile: Option<&mut String>) -> L
return log_err!("Input file is not a supported type!"); return log_err!("Input file is not a supported type!");
} }
let raw_out = if let Some(outfile) = outfile { // If user did not provide outfile, infile has to be either
// <path>.[ext], or "-". Outfile will be either <path> or "-".
// If the input does not have proper format, abort.
let output = if let Some(outfile) = outfile {
if outfile == "-" { if outfile == "-" {
super let stdout = stdout(); FileOrStd::StdOut
stdout.as_fd()
} else { } else {
super let outfile = Utf8CStr::from_string(outfile).create(O_WRONLY | O_TRUNC, 0o644)?; FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?)
outfile.as_fd()
} }
} else if in_std { } else if in_std {
super let stdout = stdout(); FileOrStd::StdOut
stdout.as_fd()
} else { } else {
// strip the extension // Strip out extension and remove input
let outfile = if let Some((outfile, ext)) = infile.rsplit_once('.')
&& ext == format.ext()
{
Utf8CString::from(outfile)
} else {
return log_err!("Input file is not a supported type!");
};
rm_in = true; rm_in = true;
let mut outfile = if let Some((outfile, ext)) = infile.rsplit_once('.') {
if ext != format.ext() {
log_err!("Input file is not a supported type!")?;
}
outfile.to_owned()
} else {
infile.clone()
};
eprintln!("Decompressing to [{outfile}]"); eprintln!("Decompressing to [{outfile}]");
FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?)
super let outfile = Utf8CStr::from_string(&mut outfile).create(O_WRONLY | O_TRUNC, 0o644)?;
outfile.as_fd()
}; };
decompress_bytes_fd(format, &buf, raw_in.as_raw_fd(), raw_out.as_raw_fd()); let mut decoder = get_decoder(format, output.as_file());
decoder.write_all(buf)?;
std::io::copy(&mut input.as_file(), decoder.as_mut())?;
decoder.finish()?;
if rm_in { if rm_in {
Utf8CStr::from_string(infile).remove()?; infile.remove()?;
} }
Ok(()) Ok(())
} }
pub(crate) fn compress( pub(crate) fn compress_cmd(
method: FileFormat, method: FileFormat,
infile: &mut String, infile: &mut String,
outfile: Option<&mut String>, outfile: Option<&mut String>,
@@ -495,40 +484,43 @@ pub(crate) fn compress(
error!("Unsupported compression format"); error!("Unsupported compression format");
} }
let infile = Utf8CStr::from_string(infile);
let outfile = outfile.map(Utf8CStr::from_string);
let in_std = infile == "-"; let in_std = infile == "-";
let mut rm_in = false; let mut rm_in = false;
let raw_in = if in_std { let input = if in_std {
super let stdin = stdin(); FileOrStd::StdIn
stdin.as_fd()
} else { } else {
super let infile = Utf8CStr::from_string(infile).open(O_RDONLY)?; FileOrStd::File(infile.open(O_RDONLY)?)
infile.as_fd()
}; };
let raw_out = if let Some(outfile) = outfile { let output = if let Some(outfile) = outfile {
if outfile == "-" { if outfile == "-" {
super let stdout = stdout(); FileOrStd::StdOut
stdout.as_fd()
} else { } else {
super let outfile = Utf8CStr::from_string(outfile).create(O_WRONLY | O_TRUNC, 0o644)?; FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?)
outfile.as_fd()
} }
} else if in_std { } else if in_std {
super let stdout = stdout(); FileOrStd::StdOut
stdout.as_fd()
} else { } else {
let mut outfile = format!("{infile}.{}", method.ext()); let mut outfile = Utf8CString::default();
outfile.write_str(infile).ok();
outfile.write_char('.').ok();
outfile.write_str(method.ext()).ok();
eprintln!("Compressing to [{outfile}]"); eprintln!("Compressing to [{outfile}]");
rm_in = true; rm_in = true;
super let outfile = Utf8CStr::from_string(&mut outfile).create(O_WRONLY | O_TRUNC, 0o644)?; let outfile = outfile.create(O_WRONLY | O_TRUNC, 0o644)?;
outfile.as_fd() FileOrStd::File(outfile)
}; };
compress_fd(method, raw_in.as_raw_fd(), raw_out.as_raw_fd()); let mut encoder = get_encoder(method, output.as_file());
std::io::copy(&mut input.as_file(), encoder.as_mut())?;
encoder.finish()?;
if rm_in { if rm_in {
Utf8CStr::from_string(infile).remove()?; infile.remove()?;
} }
Ok(()) Ok(())
} }

View File

@@ -49,12 +49,12 @@ impl FileFormat {
impl FileFormat { impl FileFormat {
pub fn ext(&self) -> &'static str { pub fn ext(&self) -> &'static str {
match *self { match *self {
Self::GZIP | Self::ZOPFLI => ".gz", Self::GZIP | Self::ZOPFLI => "gz",
Self::LZOP => ".lzo", Self::LZOP => "lzo",
Self::XZ => ".xz", Self::XZ => "xz",
Self::LZMA => ".lzma", Self::LZMA => "lzma",
Self::BZIP2 => ".bz2", Self::BZIP2 => "bz2",
Self::LZ4 | Self::LZ4_LEGACY | Self::LZ4_LG => ".lz4", Self::LZ4 | Self::LZ4_LEGACY | Self::LZ4_LG => "lz4",
_ => "", _ => "",
} }
} }

View File

@@ -2,7 +2,6 @@
#![feature(btree_extract_if)] #![feature(btree_extract_if)]
#![feature(iter_intersperse)] #![feature(iter_intersperse)]
#![feature(try_blocks)] #![feature(try_blocks)]
#![feature(super_let)]
pub use base; pub use base;
use compress::{compress_bytes, decompress_bytes}; use compress::{compress_bytes, decompress_bytes};