//! CSRF (cross-site request forgery) protection.
//!
//! CSRF is a common phishing attack, fooling your
//! users into making changes to their data managed by your application.
//! CSRF protection ensures that forms submitted via POST
//! to the web app are coming from the form generated by the same website.
//!
//! ### Usage
//! CSRF protection is enabled by default. To make it work, include a Rwf-generated token
//! in all forms submitted via POST:
//!
//! ```html
//! <form method="post">
//!     <%= csrf_token() %>
//! </form>
//! ```
//!
//! If used via AJAX, include the CSRF token in the `X-CSRF-Token` header.
//! You can obtain the token by calling the `csrf_token_raw` template function:
//!
//! ```html
//! <script>
//!     window.csrf_token = "<%= csrf_token_raw() %>";
//! </script>
//! ```
//!
//! ### Configuration
//! Toggle `csrf_protection` in the configuration to enable/disable CSRF protection application-wide, e.g.:
//!
//! ```toml
//! [general]
//! csrf_protection = false
//! ```
use super::prelude::*;
use crate::{crypto::csrf_token_validate, http::Method};

/// CSRF HTTP header name.
pub static CSRF_HEADER: &str = "X-CSRF-Token";
/// CSRF HTTP form input name.
pub static CSRF_INPUT: &str = "rwf_csrf_token";

/// CSRF protection middleware.
pub struct Csrf;

impl Csrf {
    /// Create CSRF protection middleware.
    pub fn new() -> Self {
        Self {}
    }
}

#[async_trait]
impl Middleware for Csrf {
    async fn handle_request(&self, request: Request) -> Result<Outcome, Error> {
        if request.skip_csrf() {
            return Ok(Outcome::Forward(request));
        }

        if ![Method::Put, Method::Post, Method::Patch].contains(request.method()) {
            return Ok(Outcome::Forward(request));
        }

        let header = request.header(CSRF_HEADER);
        let session_id = request.session_id().to_string();

        if let Some(header) = header {
            if csrf_token_validate(header, &session_id) {
                return Ok(Outcome::Forward(request));
            }
        }

        match request.form_data() {
            Ok(form_data) => {
                if let Some(token) = form_data.get::<String>(CSRF_INPUT) {
                    if csrf_token_validate(&token, &session_id) {
                        return Ok(Outcome::Forward(request));
                    }
                }
            }

            Err(_) => (),
        }

        Ok(Outcome::Stop(request, Response::csrf_error()))
    }
}
