diff --git a/native/src/base/files.rs b/native/src/base/files.rs index 7a98f167b..77431e4d0 100644 --- a/native/src/base/files.rs +++ b/native/src/base/files.rs @@ -116,6 +116,26 @@ impl WriteExt for T { } } +pub enum FileOrStd { + StdIn, + StdOut, + StdErr, + File(File), +} + +impl FileOrStd { + pub fn as_file(&self) -> &File { + let raw_fd_ref: &'static RawFd = match self { + FileOrStd::StdIn => &0, + FileOrStd::StdOut => &1, + FileOrStd::StdErr => &2, + FileOrStd::File(file) => return file, + }; + // SAFETY: File is guaranteed to have the same ABI as RawFd + unsafe { mem::transmute(raw_fd_ref) } + } +} + fn open_fd(path: &Utf8CStr, flags: i32, mode: mode_t) -> OsResult { unsafe { let fd = libc::open(path.as_ptr(), flags, mode as c_uint).as_os_result( diff --git a/native/src/boot/cli.rs b/native/src/boot/cli.rs index 71cecf33b..cab1a2612 100644 --- a/native/src/boot/cli.rs +++ b/native/src/boot/cli.rs @@ -1,4 +1,4 @@ -use crate::compress::{compress, decompress}; +use crate::compress::{compress_cmd, decompress_cmd}; use crate::cpio::{cpio_commands, print_cpio_usage}; use crate::dtb::{DtbAction, dtb_commands, print_dtb_usage}; use crate::ffi::{BootImage, FileFormat, cleanup, repack, split_image_dtb, unpack}; @@ -414,14 +414,14 @@ fn boot_main(cmds: CmdArgs) -> LoggedResult { cleanup(); } Action::Decompress(Decompress { mut file, mut out }) => { - decompress(&mut file, out.as_mut())?; + decompress_cmd(&mut file, out.as_mut())?; } Action::Compress(Compress { ref mut file, ref format, ref mut out, }) => { - compress( + compress_cmd( FileFormat::from_str(format).unwrap_or(FileFormat::UNKNOWN), file, out.as_mut(), diff --git a/native/src/boot/compress.rs b/native/src/boot/compress.rs index e0ad1b5fc..f91bf153a 100644 --- a/native/src/boot/compress.rs +++ b/native/src/boot/compress.rs @@ -1,6 +1,6 @@ use crate::ffi::{FileFormat, check_fmt}; use base::libc::{O_RDONLY, O_TRUNC, O_WRONLY}; -use base::{Chunker, LoggedResult, Utf8CStr, WriteExt, error, log_err}; +use base::{Chunker, FileOrStd, LoggedResult, Utf8CStr, Utf8CString, WriteExt, error, log_err}; use bytemuck::bytes_of_mut; use bzip2::{Compression as BzCompression, write::BzDecoder, write::BzEncoder}; use flate2::{Compression as GzCompression, write::GzEncoder, write::MultiGzDecoder}; @@ -9,12 +9,13 @@ use lz4::{ EncoderBuilder as LZ4FrameEncoderBuilder, block::CompressionMode, liblz4::BlockChecksum, }; use std::cell::Cell; +use std::fmt::Write as FmtWrite; use std::fs::File; -use std::io::{BufWriter, Read, Write, stdin, stdout}; +use std::io::{BufWriter, Read, Write}; use std::mem::ManuallyDrop; use std::num::NonZeroU64; use std::ops::DerefMut; -use std::os::fd::{AsFd, AsRawFd, FromRawFd, RawFd}; +use std::os::fd::{FromRawFd, RawFd}; use xz2::{ stream::{Check as LzmaCheck, Filters as LzmaFilters, LzmaOptions, Stream as LzmaStream}, write::{XzDecoder, XzEncoder}, @@ -383,29 +384,6 @@ pub fn get_decoder<'a, W: Write + 'a>(format: FileFormat, w: W) -> Box = try { - std::io::copy(in_file.deref_mut(), encoder.as_mut())?; - encoder.finish()?; - }; -} - -pub fn decompress_bytes_fd(format: FileFormat, in_bytes: &[u8], in_fd: RawFd, out_fd: RawFd) { - let mut in_file = unsafe { ManuallyDrop::new(File::from_raw_fd(in_fd)) }; - let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) }; - - let mut decoder = get_decoder(format, out_file.deref_mut()); - let _: LoggedResult<()> = try { - decoder.write_all(in_bytes)?; - std::io::copy(in_file.deref_mut(), decoder.as_mut())?; - decoder.finish()?; - }; -} - pub fn compress_bytes(format: FileFormat, in_bytes: &[u8], out_fd: RawFd) { let mut out_file = unsafe { ManuallyDrop::new(File::from_raw_fd(out_fd)) }; @@ -426,22 +404,31 @@ pub fn decompress_bytes(format: FileFormat, in_bytes: &[u8], out_fd: RawFd) { }; } -pub(crate) fn decompress(infile: &mut String, outfile: Option<&mut String>) -> LoggedResult<()> { +// Command-line entry points + +pub(crate) fn decompress_cmd( + infile: &mut String, + outfile: Option<&mut String>, +) -> LoggedResult<()> { + let infile = Utf8CStr::from_string(infile); + let outfile = outfile.map(Utf8CStr::from_string); + let in_std = infile == "-"; let mut rm_in = false; let mut buf = [0u8; 4096]; - let raw_in = if in_std { - super let mut stdin = stdin(); - let _ = stdin.read(&mut buf)?; - stdin.as_fd() + + let input = if in_std { + FileOrStd::StdIn } else { - super let mut infile = Utf8CStr::from_string(infile).open(O_RDONLY)?; - let _ = infile.read(&mut buf)?; - infile.as_fd() + FileOrStd::File(infile.open(O_RDONLY)?) }; - let format = check_fmt(&buf); + // First read some bytes for format detection + let len = input.as_file().read(&mut buf)?; + let buf = &buf[..len]; + + let format = check_fmt(buf); eprintln!("Detected format: {format}"); @@ -449,44 +436,46 @@ pub(crate) fn decompress(infile: &mut String, outfile: Option<&mut String>) -> L return log_err!("Input file is not a supported type!"); } - let raw_out = if let Some(outfile) = outfile { + // If user did not provide outfile, infile has to be either + // .[ext], or "-". Outfile will be either or "-". + // If the input does not have proper format, abort. + + let output = if let Some(outfile) = outfile { if outfile == "-" { - super let stdout = stdout(); - stdout.as_fd() + FileOrStd::StdOut } else { - super let outfile = Utf8CStr::from_string(outfile).create(O_WRONLY | O_TRUNC, 0o644)?; - outfile.as_fd() + FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?) } } else if in_std { - super let stdout = stdout(); - stdout.as_fd() + FileOrStd::StdOut } else { - // strip the extension - rm_in = true; - let mut outfile = if let Some((outfile, ext)) = infile.rsplit_once('.') { - if ext != format.ext() { - log_err!("Input file is not a supported type!")?; - } - outfile.to_owned() + // Strip out extension and remove input + let outfile = if let Some((outfile, ext)) = infile.rsplit_once('.') + && ext == format.ext() + { + Utf8CString::from(outfile) } else { - infile.clone() + return log_err!("Input file is not a supported type!"); }; - eprintln!("Decompressing to [{outfile}]"); - super let outfile = Utf8CStr::from_string(&mut outfile).create(O_WRONLY | O_TRUNC, 0o644)?; - outfile.as_fd() + rm_in = true; + eprintln!("Decompressing to [{outfile}]"); + FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?) }; - decompress_bytes_fd(format, &buf, raw_in.as_raw_fd(), raw_out.as_raw_fd()); + let mut decoder = get_decoder(format, output.as_file()); + decoder.write_all(buf)?; + std::io::copy(&mut input.as_file(), decoder.as_mut())?; + decoder.finish()?; if rm_in { - Utf8CStr::from_string(infile).remove()?; + infile.remove()?; } Ok(()) } -pub(crate) fn compress( +pub(crate) fn compress_cmd( method: FileFormat, infile: &mut String, outfile: Option<&mut String>, @@ -495,40 +484,43 @@ pub(crate) fn compress( error!("Unsupported compression format"); } + let infile = Utf8CStr::from_string(infile); + let outfile = outfile.map(Utf8CStr::from_string); + let in_std = infile == "-"; let mut rm_in = false; - let raw_in = if in_std { - super let stdin = stdin(); - stdin.as_fd() + let input = if in_std { + FileOrStd::StdIn } else { - super let infile = Utf8CStr::from_string(infile).open(O_RDONLY)?; - infile.as_fd() + FileOrStd::File(infile.open(O_RDONLY)?) }; - let raw_out = if let Some(outfile) = outfile { + let output = if let Some(outfile) = outfile { if outfile == "-" { - super let stdout = stdout(); - stdout.as_fd() + FileOrStd::StdOut } else { - super let outfile = Utf8CStr::from_string(outfile).create(O_WRONLY | O_TRUNC, 0o644)?; - outfile.as_fd() + FileOrStd::File(outfile.create(O_WRONLY | O_TRUNC, 0o644)?) } } else if in_std { - super let stdout = stdout(); - stdout.as_fd() + FileOrStd::StdOut } else { - let mut outfile = format!("{infile}.{}", method.ext()); + let mut outfile = Utf8CString::default(); + outfile.write_str(infile).ok(); + outfile.write_char('.').ok(); + outfile.write_str(method.ext()).ok(); eprintln!("Compressing to [{outfile}]"); rm_in = true; - super let outfile = Utf8CStr::from_string(&mut outfile).create(O_WRONLY | O_TRUNC, 0o644)?; - outfile.as_fd() + let outfile = outfile.create(O_WRONLY | O_TRUNC, 0o644)?; + FileOrStd::File(outfile) }; - compress_fd(method, raw_in.as_raw_fd(), raw_out.as_raw_fd()); + let mut encoder = get_encoder(method, output.as_file()); + std::io::copy(&mut input.as_file(), encoder.as_mut())?; + encoder.finish()?; if rm_in { - Utf8CStr::from_string(infile).remove()?; + infile.remove()?; } Ok(()) } diff --git a/native/src/boot/format.rs b/native/src/boot/format.rs index ea8eba03a..87bf9b7a3 100644 --- a/native/src/boot/format.rs +++ b/native/src/boot/format.rs @@ -49,12 +49,12 @@ impl FileFormat { impl FileFormat { pub fn ext(&self) -> &'static str { match *self { - Self::GZIP | Self::ZOPFLI => ".gz", - Self::LZOP => ".lzo", - Self::XZ => ".xz", - Self::LZMA => ".lzma", - Self::BZIP2 => ".bz2", - Self::LZ4 | Self::LZ4_LEGACY | Self::LZ4_LG => ".lz4", + Self::GZIP | Self::ZOPFLI => "gz", + Self::LZOP => "lzo", + Self::XZ => "xz", + Self::LZMA => "lzma", + Self::BZIP2 => "bz2", + Self::LZ4 | Self::LZ4_LEGACY | Self::LZ4_LG => "lz4", _ => "", } } diff --git a/native/src/boot/lib.rs b/native/src/boot/lib.rs index 6f1d72667..fb98f7d92 100644 --- a/native/src/boot/lib.rs +++ b/native/src/boot/lib.rs @@ -2,7 +2,6 @@ #![feature(btree_extract_if)] #![feature(iter_intersperse)] #![feature(try_blocks)] -#![feature(super_let)] pub use base; use compress::{compress_bytes, decompress_bytes};