use super::MapError;
use crate::fs::VFile;
use crate::imgact::orbis::{Elf, ProgramFlags, ProgramType};
use crate::process::VProc;
use crate::vm::{MappingFlags, MemoryUpdateError, Protections};
use gmtx::{Gutex, GutexGroup, GutexWriteGuard};
use std::alloc::Layout;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use std::sync::Arc;
use thiserror::Error;

/// A memory of the loaded module.
pub struct Memory {
    // This will be remove from here soon once we started working on multi-processes.
    proc: Arc<VProc>,
    ptr: *mut u8,
    len: usize,
    segments: Vec<MemorySegment>,
    base: usize,
    text: usize,
    data: usize,
    relro: Option<usize>,
    obcode: usize,
    obcode_sealed: Gutex<usize>,
    obdata: usize,
    obdata_sealed: Gutex<usize>,
    destructors: Gutex<Vec<Box<dyn FnOnce()>>>,
}

impl Memory {
    pub(super) fn new<N: Into<String>>(
        proc: &Arc<VProc>,
        image: &Elf<VFile>,
        base: usize,
        name: N,
    ) -> Result<Self, MapError> {
        // It seems like the PS4 expected to have only one for each text, data and relo program.
        let mut segments: Vec<MemorySegment> = Vec::with_capacity(3 + 2);
        let mut text: Option<usize> = None;
        let mut relro: Option<usize> = None;
        let mut data: Option<usize> = None;

        for (i, prog) in image.programs().iter().enumerate() {
            // Skip if memory size is zero.
            if prog.memory_size() == 0 {
                continue;
            }

            // Check type.
            match prog.ty() {
                ProgramType::PT_LOAD => {
                    if prog.flags().contains(ProgramFlags::EXECUTE) {
                        if text.is_some() {
                            return Err(MapError::MultipleExecProgram);
                        }
                        text = Some(segments.len());
                    } else if data.is_some() {
                        return Err(MapError::MultipleDataProgram);
                    } else {
                        data = Some(segments.len());
                    }
                }
                ProgramType::PT_SCE_RELRO => {
                    if relro.replace(segments.len()).is_some() {
                        return Err(MapError::MultipleRelroProgram);
                    }
                }
                _ => continue,
            }

            // Get offset and length.
            let start = base + prog.addr();
            let len = prog.aligned_size();

            if start & 0x3fff != 0 {
                return Err(MapError::InvalidProgramAlignment(i));
            }

            // Get protection.
            let flags = prog.flags();
            let mut prot = Protections::empty();

            if flags.contains(ProgramFlags::EXECUTE) {
                prot |= Protections::CPU_EXEC;
            }

            if flags.contains(ProgramFlags::READ) {
                prot |= Protections::CPU_READ;
            }

            if flags.contains(ProgramFlags::WRITE) {
                prot |= Protections::CPU_WRITE;
            }

            // Construct the segment info.
            segments.push(MemorySegment {
                start,
                len,
                program: Some(i),
                prot,
            });
        }

        let text = text.unwrap_or_else(|| todo!("(S)ELF with no executable program"));
        let data = data.unwrap_or_else(|| todo!("(S)ELF with no data program"));

        // Make sure no any segment is overlapped.
        let mut len = base;

        segments.sort_unstable_by_key(|s| s.start);

        for s in &segments {
            if s.start < len {
                // We need to check the PS4 kernel to see how it is handled this case.
                todo!("(S)ELF with overlapped programs");
            }

            len = s.start + s.len;
        }

        // Create workspace for our code.
        let obcode = segments.len();
        let segment = MemorySegment {
            start: len,
            len: (1024 * 1024) * 4,
            program: None,
            prot: Protections::CPU_READ | Protections::CPU_EXEC,
        };

        len += segment.len;
        segments.push(segment);

        // Create workspace for our data. We cannot mix this with the code because the executable-space
        // protection on some system don't allow execution on writable page.
        let obdata = segments.len();
        let segment = MemorySegment {
            start: len,
            len: 1024 * 1024,
            program: None,
            prot: Protections::CPU_READ | Protections::CPU_WRITE,
        };

        len += segment.len;
        segments.push(segment);

        // TODO: Use separate name for our code and data.
        let mut pages = match proc.vm_space().mmap(
            0,
            len,
            Protections::empty(),
            name,
            MappingFlags::MAP_ANON | MappingFlags::MAP_PRIVATE,
            -1,
            0,
        ) {
            Ok(v) => v,
            Err(e) => return Err(MapError::MemoryAllocationFailed(len, e)),
        };

        // Apply memory protection.
        for seg in &segments {
            let addr = unsafe { pages.as_mut_ptr().add(seg.start) };
            let len = seg.len;
            let prot = seg.prot;

            if let Err(e) = proc.vm_space().mprotect(addr, len, prot) {
                return Err(MapError::ProtectMemoryFailed(addr as _, len, prot, e));
            }
        }

        let gg = GutexGroup::new();

        Ok(Self {
            proc: proc.clone(),
            ptr: pages.into_raw(),
            len,
            segments,
            base,
            text,
            data,
            relro,
            obcode,
            obcode_sealed: gg.spawn(0),
            obdata,
            obdata_sealed: gg.spawn(0),
            destructors: gg.spawn(Vec::new()),
        })
    }

    pub fn addr(&self) -> usize {
        self.ptr as _
    }

    pub fn len(&self) -> usize {
        self.len
    }

    pub fn segments(&self) -> &[MemorySegment] {
        self.segments.as_ref()
    }

    pub fn base(&self) -> usize {
        self.base
    }

    pub fn text_segment(&self) -> &MemorySegment {
        &self.segments[self.text]
    }

    pub fn data_segment(&self) -> &MemorySegment {
        &self.segments[self.data]
    }

    pub fn relro_segment(&self) -> Option<&MemorySegment> {
        self.relro.as_ref().map(|i| &self.segments[*i])
    }

    /// # Safety
    /// Some part of the returned slice may not readable.
    pub unsafe fn as_bytes(&self) -> &[u8] {
        std::slice::from_raw_parts(self.ptr, self.len)
    }

    /// Beware of deadlock because this method will hold on the mutex until
    /// [`CodeWorkspace::seal()`] is called.
    ///
    /// # Safety
    /// No other threads may execute the memory in the segment until the returned [`CodeWorkspace`]
    /// has been dropped.
    pub unsafe fn code_workspace(&self) -> Result<CodeWorkspace<'_>, CodeWorkspaceError> {
        let sealed = self.obcode_sealed.write();
        let seg = match self.unprotect_segment(self.obcode) {
            Ok(v) => v,
            Err(e) => {
                return Err(CodeWorkspaceError::UnprotectSegmentFailed(self.obcode, e));
            }
        };

        Ok(CodeWorkspace {
            ptr: unsafe { seg.ptr.add(*sealed) },
            len: seg.len - *sealed,
            seg,
            sealed,
        })
    }

    pub fn push_data<T: 'static>(&self, value: T) -> Option<*mut T> {
        let mut sealed = self.obdata_sealed.write();
        let seg = &self.segments[self.obdata];
        let ptr = unsafe { self.ptr.add(seg.start + *sealed) };
        let available = seg.len - *sealed;

        // Check if the remaining space is enough.
        let layout = Layout::new::<T>();
        let offset = match (ptr as usize) % layout.align() {
            0 => 0,
            v => layout.align() - v,
        };

        if offset + layout.size() > available {
            return None;
        }

        // Move value to the workspace.
        let ptr = unsafe { ptr.add(offset) } as *mut T;

        unsafe { std::ptr::write(ptr, value) };

        self.destructors
            .write()
            .push(Box::new(move || unsafe { std::ptr::drop_in_place(ptr) }));

        // Seal the memory.
        *sealed += offset + layout.size();

        Some(ptr)
    }

    /// # Safety
    /// No other threads may access the memory in the segment until the returned
    /// [`UnprotectedSegment`] has been dropped.
    ///
    /// # Panics
    /// `seg` is not a valid segment.
    pub unsafe fn unprotect_segment(
        &self,
        seg: usize,
    ) -> Result<UnprotectedSegment<'_>, UnprotectSegmentError> {
        let seg = &self.segments[seg];
        let ptr = self.ptr.add(seg.start);
        let len = seg.len;
        let prot = Protections::CPU_READ | Protections::CPU_WRITE;

        if let Err(e) = self.proc.vm_space().mprotect(ptr, len, prot) {
            return Err(UnprotectSegmentError::MprotectFailed(
                ptr as _, len, prot, e,
            ));
        }

        Ok(UnprotectedSegment {
            proc: &self.proc,
            ptr,
            len,
            prot: seg.prot,
            phantom: PhantomData,
        })
    }

    /// # Safety
    /// No other threads may access the memory until the returned [`UnprotectedMemory`] has been
    /// dropped.
    pub unsafe fn unprotect(&self) -> Result<UnprotectedMemory<'_>, UnprotectError> {
        // Get the end offset of non-custom segments.
        let mut end = 0;

        for s in &self.segments {
            // Check if segment is a custom segment.
            if s.program().is_none() {
                break;
            }

            // Update end offset.
            end = s.end();
        }

        // Unprotect the memory.
        let prot = Protections::CPU_READ | Protections::CPU_WRITE;

        if let Err(e) = self.proc.vm_space().mprotect(self.ptr, end, prot) {
            return Err(UnprotectError::MprotectFailed(self.ptr as _, end, prot, e));
        }

        Ok(UnprotectedMemory {
            proc: &self.proc,
            ptr: self.ptr,
            len: end,
            segments: &self.segments,
        })
    }
}

impl Drop for Memory {
    fn drop(&mut self) {
        // Run destructors.
        let destructors = self.destructors.get_mut();

        for d in destructors.drain(..).rev() {
            d();
        }

        // Unmap the memory.
        self.proc.vm_space().munmap(self.ptr, self.len).unwrap();
    }
}

impl Debug for Memory {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Memory")
            .field("ptr", &self.ptr)
            .field("len", &self.len)
            .field("segments", &self.segments)
            .field("base", &self.base)
            .field("text", &self.text)
            .field("data", &self.data)
            .field("obcode", &self.obcode)
            .field("obcode_sealed", &self.obcode_sealed)
            .field("obdata", &self.obdata)
            .field("obdata_sealed", &self.obdata_sealed)
            .field("destructors", &self.destructors.read().len())
            .finish()
    }
}

unsafe impl Send for Memory {}
unsafe impl Sync for Memory {}

/// A segment in the [`Memory`].
#[derive(Debug)]
pub struct MemorySegment {
    start: usize,
    len: usize,
    program: Option<usize>,
    prot: Protections,
}

impl MemorySegment {
    /// Gets the **offset (not address)** within the module memory.
    ///
    /// This offset already included module base.
    pub fn start(&self) -> usize {
        self.start
    }

    pub fn len(&self) -> usize {
        self.len
    }

    pub fn end(&self) -> usize {
        self.start + self.len
    }

    /// Gets the corresponding index of (S)ELF program.
    pub fn program(&self) -> Option<usize> {
        self.program
    }

    pub fn prot(&self) -> Protections {
        self.prot
    }
}

/// A memory segment in an unprotected form.
pub struct UnprotectedSegment<'a> {
    proc: &'a VProc,
    ptr: *mut u8,
    len: usize,
    prot: Protections,
    phantom: PhantomData<&'a [u8]>,
}

impl<'a> AsMut<[u8]> for UnprotectedSegment<'a> {
    fn as_mut(&mut self) -> &mut [u8] {
        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
    }
}

impl<'a> Drop for UnprotectedSegment<'a> {
    fn drop(&mut self) {
        self.proc
            .vm_space()
            .mprotect(self.ptr, self.len, self.prot)
            .unwrap();
    }
}

/// The unprotected form of [`Memory`], not including our custom segments.
pub struct UnprotectedMemory<'a> {
    proc: &'a VProc,
    ptr: *mut u8,
    len: usize,
    segments: &'a [MemorySegment],
}

impl<'a> Drop for UnprotectedMemory<'a> {
    fn drop(&mut self) {
        for s in self.segments {
            if s.program().is_none() {
                break;
            }

            let addr = unsafe { self.ptr.add(s.start()) };

            self.proc
                .vm_space()
                .mprotect(addr, s.len(), s.prot())
                .unwrap();
        }
    }
}

impl<'a> AsRef<[u8]> for UnprotectedMemory<'a> {
    fn as_ref(&self) -> &[u8] {
        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
    }
}

impl<'a> AsMut<[u8]> for UnprotectedMemory<'a> {
    fn as_mut(&mut self) -> &mut [u8] {
        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
    }
}

/// An exclusive access to the unsealed code workspace.
pub struct CodeWorkspace<'a> {
    ptr: *mut u8,
    len: usize,
    seg: UnprotectedSegment<'a>,
    sealed: GutexWriteGuard<'a, usize>,
}

impl<'a> CodeWorkspace<'a> {
    pub fn addr(&self) -> usize {
        self.ptr as _
    }

    pub fn seal(mut self, len: usize) {
        if len > self.len {
            panic!("The amount to seal is larger than available space.");
        }

        *self.sealed += len;

        drop(self.seg);
    }
}

impl<'a> AsMut<[u8]> for CodeWorkspace<'a> {
    fn as_mut(&mut self) -> &mut [u8] {
        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
    }
}

/// Represents an error when [`Memory::code_workspace()`] is failed.
#[derive(Debug, Error)]
pub enum CodeWorkspaceError {
    #[error("cannot unprotect segment {0}")]
    UnprotectSegmentFailed(usize, #[source] UnprotectSegmentError),
}

/// Represents an error when [`Memory::unprotect_segment()`] is failed.
#[derive(Debug, Error)]
pub enum UnprotectSegmentError {
    #[error("cannot protect {1:#018x} bytes starting at {0:#x} with {2}")]
    MprotectFailed(usize, usize, Protections, #[source] MemoryUpdateError),
}

/// Represents an error when [`Memory::unprotect()`] is failed.
#[derive(Debug, Error)]
pub enum UnprotectError {
    #[error("cannot protect {1:#018x} bytes starting at {0:#x} with {2}")]
    MprotectFailed(usize, usize, Protections, #[source] MemoryUpdateError),
}
