use std::borrow::Cow;

use chumsky::{
    input::MappedInput,
    pratt::{infix, left, postfix, prefix},
    prelude::*,
};

use super::ast::{BinOp, Expr, Ident, Token, UnOp};

pub fn lexer<'src>()
-> impl Parser<'src, &'src str, Vec<Spanned<Token<'src>>>, extra::Err<Rich<'src, char>>> {
    recursive(|_token| {
        choice((
            text::ident().map(|s| match s {
                "true" => Token::Bool(true),
                "false" => Token::Bool(false),
                s => Token::Ident(Ident(Cow::from(s))),
            }),
            just('"')
                .ignore_then(none_of('"').repeated().to_slice())
                .then_ignore(just('"'))
                .map(|s| Token::String(Cow::from(s))),
            just("+").to(Token::Add),
            just("-").to(Token::Sub),
            just("*").to(Token::Mul),
            just("/").to(Token::Div),
            just("%").to(Token::Mod),
            just("==").to(Token::Equal),
            just("!=").to(Token::NotEqual),
            just(">").to(Token::GreaterThan),
            just("<").to(Token::LessThan),
            just(">=").to(Token::GreaterEqual),
            just("<=").to(Token::LessEqual),
            just("!").to(Token::Not),
            just("&&").to(Token::And),
            just("||").to(Token::Or),
            just("(").to(Token::ParenOpen),
            just(")").to(Token::ParenClose),
            just("[").to(Token::BracketOpen),
            just("]").to(Token::BracketClose),
            just(",").to(Token::Comma),
            just(".").to(Token::Period),
            text::int(10)
                .then(just('.').then(text::digits(10)).or_not())
                .to_slice()
                .from_str()
                .unwrapped()
                .map(Token::Decimal),
            text::int(10)
                .to_slice()
                .from_str()
                .unwrapped()
                .map(Token::Integer),
        ))
        .spanned()
        .padded()
    })
    .repeated()
    .collect()
}

#[cfg(test)]
mod test_lexer {
    use chumsky::prelude::*;

    use super::*;

    #[rstest::rstest]
    #[case(
        r#"note["link-data"] == "fanfiction""#,
        vec![
            (Token::Ident(Ident("note".into())).with_span(SimpleSpan::from(0..4))),
            (Token::BracketOpen.with_span(SimpleSpan::from(4..5))),
            (Token::String("link-data".into()).with_span(SimpleSpan::from(5..16))),
            (Token::BracketClose.with_span(SimpleSpan::from(16..17))),
            (Token::Equal.with_span(SimpleSpan::from(18..20))),
            (Token::String("fanfiction".into()).with_span(SimpleSpan::from(21..33))),
        ],
    )]
    fn test(#[case] input: &str, #[case] expected: Vec<Spanned<Token<'_>>>) {
        let tokens = lexer().parse(input);
        assert!(!tokens.has_errors());
        assert_eq!(tokens.into_output(), Some(expected));
    }

    #[test]
    fn setup_test() {
        let tokens = lexer().parse(r#"note["link-data"] == "fanfiction""#);
        assert!(!tokens.has_errors());
        assert_eq!(
            tokens.into_output(),
            Some(vec![
                (Token::Ident(Ident("note".into())).with_span(SimpleSpan::from(0..4))),
                (Token::BracketOpen.with_span(SimpleSpan::from(4..5))),
                (Token::String("link-data".into()).with_span(SimpleSpan::from(5..16))),
                (Token::BracketClose.with_span(SimpleSpan::from(16..17))),
                (Token::Equal.with_span(SimpleSpan::from(18..20))),
                (Token::String("fanfiction".into()).with_span(SimpleSpan::from(21..33))),
            ])
        );
    }
}

pub fn parser<'tokens, 'src: 'tokens>() -> impl Parser<
    'tokens,
    MappedInput<'tokens, Token<'src>, SimpleSpan, &'tokens [Spanned<Token<'src>>]>,
    Spanned<Expr<'src>>,
    extra::Err<Rich<'tokens, Token<'src>>>,
> {
    recursive(|expr| {
        let value = select! {
            Token::Bool(b) => Expr::Bool(b),
            Token::Integer(i) => Expr::Integer(i),
            Token::Decimal(f) => Expr::Decimal(f),
            Token::String(s) => Expr::String(s),
        }
        .labelled("value");

        let ident = select! { Token::Ident(x) => x }.labelled("identifier");
        let ident_expr = select! { Token::Ident(x) => Expr::Ident(x) }.labelled("identifier");

        let params = expr
            .clone()
            .separated_by(just(Token::Comma))
            .allow_trailing()
            .collect::<Vec<_>>()
            .delimited_by(just(Token::ParenOpen), just(Token::ParenClose))
            .spanned();

        let group = expr
            .clone()
            .delimited_by(just(Token::ParenOpen), just(Token::ParenClose))
            .map(|inner| Expr::Group(Box::new(inner)))
            .boxed();

        let indexed = expr
            .clone()
            .delimited_by(just(Token::BracketOpen), just(Token::BracketClose));

        choice((value, ident_expr, group)).spanned().pratt((
            postfix(
                16,
                just(Token::Period)
                    .ignore_then(ident.spanned())
                    .then(params.clone()),
                |name, (method, params), e| {
                    Expr::Method(Box::new(name), method, params).with_span(e.span())
                },
            ),
            postfix(
                15,
                just(Token::Period).ignore_then(ident.spanned()),
                |name, method, e| Expr::Member(Box::new(name), method).with_span(e.span()),
            ),
            postfix(14, params.clone(), |name, params, e| {
                Expr::Function(Box::new(name), params).with_span(e.span())
            }),
            postfix(13, indexed, |name, params, e| {
                Expr::Index(Box::new(name), Box::new(params)).with_span(e.span())
            }),
            //
            infix(
                left(12),
                select! { Token::Or = e => BinOp::Or.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(11),
                select! { Token::And = e => BinOp::And.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            //
            infix(
                left(7),
                select! { Token::Equal = e => BinOp::Equal.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(7),
                select! { Token::NotEqual = e => BinOp::NotEqual.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            //
            infix(
                left(6),
                select! { Token::GreaterThan = e => BinOp::GreaterThan.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(6),
                select! { Token::GreaterEqual = e => BinOp::GreaterEqual.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(6),
                select! { Token::LessThan = e => BinOp::LessThan.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(6),
                select! { Token::LessEqual = e => BinOp::LessEqual.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            //
            infix(
                left(4),
                select! { Token::Add = e => BinOp::Add.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(4),
                select! { Token::Sub = e => BinOp::Sub.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            //
            infix(
                left(3),
                select! { Token::Mul = e => BinOp::Mul.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(3),
                select! { Token::Div = e => BinOp::Div.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            infix(
                left(3),
                select! { Token::Mod = e => BinOp::Mod.with_span(e.span()) },
                |x, op, y, e| Expr::BinOp(op, Box::new(x), Box::new(y)).with_span(e.span()),
            ),
            //
            prefix(
                2,
                select! { Token::Sub = e => UnOp::Neg.with_span(e.span()) },
                |op, x, e| Expr::UnOp(op, Box::new(x)).with_span(e.span()),
            ),
            prefix(
                2,
                select! { Token::Not = e => UnOp::Not.with_span(e.span()) },
                |op, x, e| Expr::UnOp(op, Box::new(x)).with_span(e.span()),
            ),
        ))
    })
}

#[cfg(test)]
mod test_parser {
    use chumsky::prelude::*;

    use super::*;

    #[test]
    fn setup_test() {
        let tokens = lexer().parse(r#"note["link-data"] == "fanfiction""#);
        assert!(!tokens.has_errors());
        assert_eq!(
            tokens.into_output(),
            Some(vec![
                (Token::Ident(Ident("note".into())).with_span(SimpleSpan::from(0..4))),
                (Token::BracketOpen.with_span(SimpleSpan::from(4..5))),
                (Token::String("link-data".into()).with_span(SimpleSpan::from(5..16))),
                (Token::BracketClose.with_span(SimpleSpan::from(16..17))),
                (Token::Equal.with_span(SimpleSpan::from(18..20))),
                (Token::String("fanfiction".into()).with_span(SimpleSpan::from(21..33))),
            ])
        );
    }
}
