use async_trait::async_trait;
use base64::engine::{
Engine,
general_purpose::URL_SAFE_NO_PAD as url_safe_base_64,
};
use core::mem::{MaybeUninit, transmute};
use core::pin::Pin;
use core::task::Poll;
use flate2::{Decompress, FlushDecompress};
use pin_project_lite::pin_project;
use std::path::{Path, PathBuf};
use tokio::fs::File;
use tokio::io::{AsyncRead, ReadBuf};
use crate::error::Error;
#[async_trait]
pub trait FileSystem {
type HashedFileIn: HashedFileIn;
async fn open_hashed_file(
&self,
path: impl Into<String> + Send,
) -> Result<Self::HashedFileIn, Error>;
async fn open_compressed_hashed_file(
&self,
path: impl Into<String> + Send,
) -> Result<CompressedHashedFileIn<Self::HashedFileIn>, Error> {
let file = self.open_hashed_file(path).await?;
Ok(CompressedHashedFileIn::new(file))
}
}
#[async_trait]
pub trait HashedFileIn: AsyncRead + Send + Unpin {
async fn verify(self) -> Result<(), Error>;
}
pin_project! {
pub struct CompressedHashedFileIn<R>
where
R: AsyncRead,
{
#[pin]
decoder: AsyncZlibDecoder<R>,
}
}
impl<R> CompressedHashedFileIn<R>
where
R: AsyncRead,
{
pub fn new(r: R) -> Self {
Self {
decoder: AsyncZlibDecoder::new(r)
}
}
}
impl<R> AsyncRead for CompressedHashedFileIn<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().decoder.poll_read(cx, buf)
}
}
#[async_trait]
impl<R> HashedFileIn for CompressedHashedFileIn<R>
where
R: HashedFileIn,
{
async fn verify(self) -> Result<(), Error> {
self.decoder.into_inner().verify().await
}
}
pub struct LocalFileSystem {
base_path: PathBuf,
}
impl LocalFileSystem {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
}
#[async_trait]
impl FileSystem for LocalFileSystem {
type HashedFileIn = LocalHashedFileIn;
async fn open_hashed_file(
&self,
path: impl Into<String> + Send,
) -> Result<Self::HashedFileIn, Error> {
LocalHashedFileIn::open(self.base_path.join(path.into())).await
}
}
pin_project! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct LocalHashedFileIn {
#[pin]
file: File,
hash: String,
digest: ring::digest::Context,
}
}
impl LocalHashedFileIn {
async fn open(path: PathBuf) -> Result<Self, Error> {
let hash = path.file_stem()
.ok_or(Error::InvalidArgs(format!(
"file name must be hash: {}",
path.display(),
)))?
.to_string_lossy() .to_string();
let file = File::open(&path).await?;
Ok(Self {
file,
hash,
digest: ring::digest::Context::new(&ring::digest::SHA256),
})
}
}
#[async_trait]
impl HashedFileIn for LocalHashedFileIn {
async fn verify(self) -> Result<(), Error> {
let digest = self.digest.finish();
let hash = url_safe_base_64.encode(digest);
if self.hash == hash {
Ok(())
} else {
Err(Error::VerificationFailure(format!(
"hash discrepancy: expected {} but got {}",
self.hash,
hash,
)))
}
}
}
impl AsyncRead for LocalHashedFileIn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
let last_len = buf.filled().len();
match this.file.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if buf.filled().len() != last_len {
let buf = &buf.filled()[last_len..];
this.digest.update(buf);
}
Poll::Ready(Ok(()))
},
Poll::Pending => Poll::Pending,
Poll::Ready(err) => Poll::Ready(err),
}
}
}
const INPUT_BUFFER_SIZE: usize = 1024;
pin_project! {
pub struct AsyncZlibDecoder<R> {
#[pin]
reader: R,
reader_finished: bool,
decoder: Decompress,
decoder_finished: bool,
input_buf: [MaybeUninit<u8>; INPUT_BUFFER_SIZE],
input_pos: usize,
}
}
impl<R> AsyncZlibDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
reader_finished: false,
decoder: Decompress::new(true),
decoder_finished: false,
input_buf: unsafe { MaybeUninit::uninit().assume_init() },
input_pos: 0,
}
}
pub fn into_inner(self) -> R {
assert!(self.decoder_finished);
self.reader
}
}
impl<R> AsyncRead for AsyncZlibDecoder<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
macro_rules! assume_advance {
($buf:ident, $amount:expr) => {
unsafe { $buf.assume_init($buf.filled().len() + $amount); }
buf.advance($amount);
};
}
let mut this = self.project();
let initial_len = buf.filled().len();
let mut input_buf = ReadBuf::uninit(this.input_buf);
unsafe { input_buf.assume_init(*this.input_pos); }
input_buf.set_filled(*this.input_pos);
let mut had_buf_error = false;
loop {
if *this.input_pos < input_buf.filled().len()
&& buf.remaining() > 0
{
let last_total_in = this.decoder.total_in();
let last_total_out = this.decoder.total_out();
let input = &input_buf.filled()[*this.input_pos..];
match this.decoder.decompress(
input,
unsafe { transmute(buf.unfilled_mut()) },
if *this.reader_finished {
FlushDecompress::Finish
} else {
FlushDecompress::None
},
) {
Ok(flate2::Status::Ok) => {
let num_written =
this.decoder.total_out() - last_total_out;
assume_advance!(buf, num_written as usize);
let num_read = this.decoder.total_in() - last_total_in;
*this.input_pos += num_read as usize;
if *this.input_pos == input_buf.filled().len() {
input_buf.clear();
*this.input_pos = 0;
}
had_buf_error = false;
},
Ok(flate2::Status::BufError) => {
if had_buf_error {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
Error::InvalidContext(format!(
"got persisted decoder buffer error",
)),
)));
}
had_buf_error = true;
},
Ok(flate2::Status::StreamEnd) => {
*this.decoder_finished = true;
let num_written =
this.decoder.total_out() - last_total_out;
assume_advance!(buf, num_written as usize);
let num_read = this.decoder.total_in() - last_total_in;
*this.input_pos += num_read as usize;
if *this.input_pos == input_buf.filled().len() {
input_buf.clear();
*this.input_pos = 0;
} else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
Error::InvalidData(format!(
"extra bytes after compressed block",
)),
)));
}
had_buf_error = false;
},
Err(err) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
err,
)));
},
};
}
if !*this.reader_finished && input_buf.remaining() > 0 {
let last_len = input_buf.filled().len();
match this.reader
.as_mut()
.poll_read(cx, &mut input_buf)
{
Poll::Ready(Ok(_)) => {
if input_buf.filled().len() == last_len {
*this.reader_finished = true;
} else if *this.decoder_finished {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
Error::InvalidData(format!(
"extra bytes after compressed block",
)),
)));
}
},
Poll::Pending => {
if buf.filled().len() > initial_len {
return Poll::Ready(Ok(()));
} else {
return Poll::Pending;
}
},
Poll::Ready(err) => return Poll::Ready(err),
}
}
if *this.decoder_finished && *this.reader_finished {
return Poll::Ready(Ok(()));
}
}
}
}