Skip to main content

pyra_compiler/
parser.rs

1use crate::ast::*;
2use crate::lexer::Token;
3use chumsky::prelude::*;
4
5pub type ParseError = Simple<Token>;
6
7#[derive(Clone)]
8enum PostfixOp {
9    Member(String),
10    Index(Expression),
11    Call(Vec<Expression>),
12}
13
14#[derive(Clone)]
15enum TargetOp {
16    Member(String),
17    Index(Expression),
18}
19
20fn fold_postfix(lhs: Expression, op: PostfixOp) -> Expression {
21    match op {
22        PostfixOp::Member(name) => Expression::Member(Box::new(lhs), name),
23        PostfixOp::Index(idx) => Expression::Index(Box::new(lhs), Box::new(idx)),
24        PostfixOp::Call(args) => Expression::Call(Box::new(lhs), args),
25    }
26}
27
28fn fold_unary(op: UnaryOp, rhs: Expression) -> Expression {
29    Expression::Unary(op, Box::new(rhs))
30}
31
32fn fold_binary(left: Expression, (op, right): (BinaryOp, Expression)) -> Expression {
33    Expression::Binary(op, Box::new(left), Box::new(right))
34}
35
36fn fold_target(lhs: Expression, op: TargetOp) -> Expression {
37    match op {
38        TargetOp::Member(name) => Expression::Member(Box::new(lhs), name),
39        TargetOp::Index(idx) => Expression::Index(Box::new(lhs), Box::new(idx)),
40    }
41}
42
43fn fold_field_init((name, value): (String, Expression)) -> (String, Expression) {
44    (name, value)
45}
46
47fn fold_struct_init((name, fields): (String, Vec<(String, Expression)>)) -> Expression {
48    Expression::StructInit(name, fields)
49}
50
51pub fn parse_program(tokens: Vec<Token>) -> Result<Program, Vec<ParseError>> {
52    program_parser().parse(tokens)
53}
54
55pub fn parse_from_source(source: &str) -> Result<Program, Vec<ParseError>> {
56    use crate::lexer::PyraLexer;
57
58    let lexer = PyraLexer::new(source);
59    let tokens: Vec<Token> = lexer.collect();
60
61    let tokens: Vec<Token> = tokens.into_iter().filter(|t| !matches!(t, Token::Comment)).collect();
62
63    parse_program(tokens)
64}
65
66fn program_parser() -> impl Parser<Token, Program, Error = ParseError> {
67    nl()
68        .ignore_then(
69            choice((
70                function_parser().map(Item::Function),
71                struct_parser().map(Item::Struct),
72                event_parser().map(Item::Event),
73                const_item_parser().map(Item::Const),
74            ))
75            .then_ignore(nl()),
76        )
77        .repeated()
78        .map(|items| Program {
79            items,
80            span: Span { start: 0, end: 0 },
81        })
82        .then_ignore(end())
83}
84
85fn function_parser() -> impl Parser<Token, Function, Error = ParseError> {
86    just(Token::Def)
87        .ignore_then(identifier())
88        .then_ignore(just(Token::LParen))
89        .then(parameter_list())
90        .then_ignore(just(Token::RParen))
91        .then(return_type().or_not())
92        .then_ignore(just(Token::Colon))
93        .then(suite_parser(statement_parser()))
94        .map(|(((name, params), return_type), body)| Function {
95            name,
96            params,
97            return_type,
98            body,
99            span: Span { start: 0, end: 0 },
100        })
101}
102
103fn nl() -> impl Parser<Token, (), Error = ParseError> {
104    just(Token::Newline).repeated().ignored()
105}
106
107fn nl1() -> impl Parser<Token, (), Error = ParseError> {
108    just(Token::Newline).repeated().at_least(1).ignored()
109}
110
111fn parameter_list() -> impl Parser<Token, Vec<Parameter>, Error = ParseError> {
112    parameter_parser()
113        .separated_by(just(Token::Comma))
114        .allow_trailing()
115}
116
117fn parameter_parser() -> impl Parser<Token, Parameter, Error = ParseError> {
118    identifier()
119        .then_ignore(just(Token::Colon))
120        .then(type_parser())
121        .map(|(name, type_)| Parameter {
122            name,
123            type_,
124            span: Span { start: 0, end: 0 },
125        })
126}
127
128fn return_type() -> impl Parser<Token, Type, Error = ParseError> {
129    just(Token::Arrow).ignore_then(type_parser())
130}
131
132fn type_parser() -> impl Parser<Token, Type, Error = ParseError> {
133    choice((
134        just(Token::Uint8).to(Type::Uint8),
135        just(Token::Uint256).to(Type::Uint256),
136        just(Token::Int256).to(Type::Int256),
137        just(Token::Bool).to(Type::Bool),
138        just(Token::Address).to(Type::Address),
139        just(Token::Bytes).to(Type::Bytes),
140        just(Token::String).to(Type::String),
141        identifier().map(Type::Custom),
142    ))
143}
144
145fn generic_params_parser() -> impl Parser<Token, (), Error = ParseError> {
146    let param = identifier()
147        .then(just(Token::Colon).ignore_then(type_parser()).or_not())
148        .ignored();
149
150    just(Token::Less)
151        .ignore_then(param.separated_by(just(Token::Comma)).allow_trailing())
152        .then_ignore(just(Token::Greater))
153        .ignored()
154}
155
156fn struct_parser() -> impl Parser<Token, StructDef, Error = ParseError> {
157    let sep = choice((just(Token::Comma).ignore_then(nl()).ignored(), nl1()));
158    just(Token::Struct)
159        .ignore_then(identifier())
160        .then_ignore(generic_params_parser().or_not())
161        .then_ignore(nl())
162        .then_ignore(just(Token::LBrace))
163        .then_ignore(nl())
164        .then_ignore(just(Token::Indent).or_not())
165        .then_ignore(nl())
166        .then(struct_field().separated_by(sep).allow_leading().allow_trailing())
167        .then_ignore(nl())
168        .then_ignore(just(Token::Dedent).or_not())
169        .then_ignore(nl())
170        .then_ignore(just(Token::RBrace))
171        .map(|(name, fields)| StructDef {
172            name,
173            fields,
174            span: Span { start: 0, end: 0 },
175        })
176}
177
178fn struct_field() -> impl Parser<Token, StructField, Error = ParseError> {
179    identifier()
180        .then_ignore(just(Token::Colon))
181        .then(type_parser())
182        .map(|(name, type_)| StructField {
183            name,
184            type_,
185            span: Span { start: 0, end: 0 },
186        })
187}
188
189fn const_item_parser() -> impl Parser<Token, ConstDecl, Error = ParseError> {
190    choice((just(Token::Const), just(Token::Let)))
191        .ignore_then(identifier())
192        .then(just(Token::Colon).ignore_then(type_parser()).or_not())
193        .then_ignore(just(Token::Assign))
194        .then(expression_parser())
195        .map(|((name, type_), value)| ConstDecl {
196            name,
197            type_: type_.unwrap_or(Type::Uint256),
198            value,
199            span: Span { start: 0, end: 0 },
200        })
201}
202
203fn event_parser() -> impl Parser<Token, EventDef, Error = ParseError> {
204    just(Token::Event)
205        .ignore_then(identifier())
206        .then_ignore(just(Token::LParen))
207        .then(parameter_list())
208        .then_ignore(just(Token::RParen))
209        .map(|(name, fields)| EventDef {
210            name,
211            fields,
212            span: Span { start: 0, end: 0 },
213        })
214}
215
216fn expression_parser() -> impl Parser<Token, Expression, Error = ParseError> {
217    recursive(|expr| {
218        let field_init = identifier()
219            .then_ignore(just(Token::Colon))
220            .then(expr.clone())
221            .map(fold_field_init as fn((String, Expression)) -> (String, Expression));
222
223        let sep = choice((just(Token::Comma).ignore_then(nl()).ignored(), nl1()));
224
225        let struct_init = identifier()
226            .then(
227                just(Token::LBrace)
228                    .ignore_then(nl())
229                    .ignore_then(just(Token::Indent).or_not())
230                    .ignore_then(nl())
231                    .ignore_then(field_init.separated_by(sep).allow_leading().allow_trailing())
232                    .then_ignore(nl())
233                    .then_ignore(just(Token::Dedent).or_not())
234                    .then_ignore(nl())
235                    .then_ignore(just(Token::RBrace)),
236            )
237            .map(fold_struct_init as fn((String, Vec<(String, Expression)>)) -> Expression);
238
239        let atom = choice((
240            select! { Token::Number(n) => Expression::Number(n) },
241            select! { Token::HexNumber(n) => Expression::HexNumber(n) },
242            select! { Token::StringLiteral(s) => Expression::String(s) },
243            select! { Token::BytesLiteral(b) => Expression::Bytes(b) },
244            just(Token::True).to(Expression::Bool(true)),
245            just(Token::False).to(Expression::Bool(false)),
246            struct_init,
247            identifier().map(Expression::Identifier),
248            expr.clone().delimited_by(just(Token::LParen), just(Token::RParen)),
249        ));
250
251        let postfix_ops = choice((
252            just(Token::Dot)
253                .ignore_then(identifier())
254                .map(PostfixOp::Member),
255            just(Token::LBracket)
256                .ignore_then(expr.clone())
257                .then_ignore(just(Token::RBracket))
258                .map(PostfixOp::Index),
259            just(Token::LParen)
260                .ignore_then(expr.clone().separated_by(just(Token::Comma)).allow_trailing())
261                .then_ignore(just(Token::RParen))
262                .map(PostfixOp::Call),
263        ))
264        .repeated();
265
266        let postfix = atom
267            .then(postfix_ops)
268            .foldl(fold_postfix as fn(Expression, PostfixOp) -> Expression)
269            .boxed();
270
271        let unary = choice((
272            just(Token::Not).to(UnaryOp::Not),
273            just(Token::Minus).to(UnaryOp::Minus),
274        ))
275        .repeated()
276        .then(postfix)
277        .foldr(fold_unary as fn(UnaryOp, Expression) -> Expression)
278        .boxed();
279
280        let product = unary
281            .clone()
282            .then(
283                choice((
284                    just(Token::Multiply).to(BinaryOp::Mul),
285                    just(Token::Divide).to(BinaryOp::Div),
286                    just(Token::Modulo).to(BinaryOp::Mod),
287                ))
288                .then(unary.clone())
289                .repeated(),
290            )
291            .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
292            .boxed();
293
294        let sum = product
295            .clone()
296            .then(
297                choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub)))
298                    .then(product)
299                    .repeated(),
300            )
301            .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
302            .boxed();
303
304        let cmp = sum
305            .clone()
306            .then(
307                choice((
308                    just(Token::Equal).to(BinaryOp::Equal),
309                    just(Token::NotEqual).to(BinaryOp::NotEqual),
310                    just(Token::LessEqual).to(BinaryOp::LessEqual),
311                    just(Token::GreaterEqual).to(BinaryOp::GreaterEqual),
312                    just(Token::Less).to(BinaryOp::Less),
313                    just(Token::Greater).to(BinaryOp::Greater),
314                ))
315                .then(sum)
316                .repeated(),
317            )
318            .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
319            .boxed();
320
321        let and_expr = cmp
322            .clone()
323            .then(just(Token::And).to(BinaryOp::And).then(cmp).repeated())
324            .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
325            .boxed();
326
327        and_expr
328            .clone()
329            .then(just(Token::Or).to(BinaryOp::Or).then(and_expr).repeated())
330            .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
331    })
332}
333
334fn return_statement() -> impl Parser<Token, Statement, Error = ParseError> {
335    just(Token::Return)
336        .ignore_then(expression_parser().or_not())
337        .map(Statement::Return)
338}
339
340fn require_statement() -> impl Parser<Token, Statement, Error = ParseError> {
341    just(Token::Require)
342        .ignore_then(expression_parser())
343        .map(Statement::Require)
344}
345
346fn identifier() -> impl Parser<Token, String, Error = ParseError> {
347    select! { Token::Identifier(name) => name }
348}
349
350fn let_statement() -> impl Parser<Token, Statement, Error = ParseError> {
351    just(Token::Let)
352        .ignore_then(just(Token::Mut).or_not())
353        .then(identifier())
354        .then(just(Token::Colon).ignore_then(type_parser()).or_not())
355        .then(just(Token::Assign).ignore_then(expression_parser()).or_not())
356        .map(|(((mutable, name), type_), value)| {
357            Statement::Let(LetStatement {
358                name,
359                type_,
360                value,
361                mutable: mutable.is_some(),
362                span: Span { start: 0, end: 0 },
363            })
364        })
365}
366
367fn assign_statement() -> impl Parser<Token, Statement, Error = ParseError> {
368    let target = assignment_target_parser();
369
370    let op = choice((
371        just(Token::Assign).to(None),
372        just(Token::PlusAssign).to(Some(BinaryOp::Add)),
373        just(Token::MinusAssign).to(Some(BinaryOp::Sub)),
374        just(Token::MultiplyAssign).to(Some(BinaryOp::Mul)),
375        just(Token::DivideAssign).to(Some(BinaryOp::Div)),
376    ));
377
378    target
379        .then(op)
380        .then(expression_parser())
381        .map(|((target, op), rhs)| {
382            let value = match op {
383                None => rhs,
384                Some(bin_op) => Expression::Binary(bin_op, Box::new(target.clone()), Box::new(rhs)),
385            };
386
387            Statement::Assign(AssignStatement {
388                target,
389                value,
390                span: Span { start: 0, end: 0 },
391            })
392        })
393}
394
395fn assignment_target_parser() -> impl Parser<Token, Expression, Error = ParseError> {
396    let base = identifier().map(Expression::Identifier).boxed();
397    let ops = choice((
398        just(Token::Dot)
399            .ignore_then(identifier())
400            .map(TargetOp::Member),
401        just(Token::LBracket)
402            .ignore_then(expression_parser())
403            .then_ignore(just(Token::RBracket))
404            .map(TargetOp::Index),
405    ))
406    .repeated();
407
408    base.then(ops).foldl(fold_target as fn(Expression, TargetOp) -> Expression)
409}
410
411fn emit_statement() -> impl Parser<Token, Statement, Error = ParseError> {
412    just(Token::Emit)
413        .ignore_then(identifier())
414        .then_ignore(just(Token::LParen))
415        .then(expression_parser().separated_by(just(Token::Comma)).allow_trailing())
416        .then_ignore(just(Token::RParen))
417        .map(|(name, args)| {
418            Statement::Emit(EmitStatement {
419                name,
420                args,
421                span: Span { start: 0, end: 0 },
422            })
423        })
424}
425
426fn statement_parser() -> BoxedParser<'static, Token, Statement, ParseError> {
427    recursive(|stmt| {
428        let suite = suite_parser(stmt.clone().boxed());
429
430        let if_stmt = just(Token::If)
431            .ignore_then(expression_parser())
432            .then_ignore(just(Token::Colon))
433            .then(suite.clone())
434            .then(
435                nl1()
436                    .ignore_then(
437                        just(Token::Elif)
438                            .ignore_then(expression_parser())
439                            .then_ignore(just(Token::Colon))
440                            .then(suite.clone()),
441                    )
442                    .repeated(),
443            )
444            .then(
445                nl1()
446                    .ignore_then(just(Token::Else).ignore_then(just(Token::Colon)).ignore_then(suite.clone()))
447                    .or_not(),
448            )
449            .map(|(((cond, then_branch), elifs), else_branch)| {
450                let mut else_acc = else_branch;
451                for (elif_cond, elif_body) in elifs.into_iter().rev() {
452                    let nested = IfStatement {
453                        condition: elif_cond,
454                        then_branch: elif_body,
455                        else_branch: else_acc,
456                        span: Span { start: 0, end: 0 },
457                    };
458
459                    else_acc = Some(Block {
460                        statements: vec![Statement::If(nested)],
461                        span: Span { start: 0, end: 0 },
462                    });
463                }
464
465                Statement::If(IfStatement {
466                    condition: cond,
467                    then_branch,
468                    else_branch: else_acc,
469                    span: Span { start: 0, end: 0 },
470                })
471            });
472
473        let for_stmt = just(Token::For)
474            .ignore_then(identifier())
475            .then_ignore(just(Token::In))
476            .then(expression_parser())
477            .then_ignore(just(Token::Colon))
478            .then(suite.clone())
479            .map(|((var, iterable), body)| {
480                Statement::For(ForStatement {
481                    var,
482                    iterable,
483                    body,
484                    span: Span { start: 0, end: 0 },
485                })
486            });
487
488        let while_stmt = just(Token::While)
489            .ignore_then(expression_parser())
490            .then_ignore(just(Token::Colon))
491            .then(suite)
492            .map(|(condition, body)| {
493                Statement::While(WhileStatement {
494                    condition,
495                    body,
496                    span: Span { start: 0, end: 0 },
497                })
498            });
499
500        choice((
501            if_stmt,
502            for_stmt,
503            while_stmt,
504            emit_statement(),
505            require_statement(),
506            let_statement(),
507            return_statement(),
508            assign_statement(),
509        ))
510        .boxed()
511    })
512    .boxed()
513}
514
515fn suite_parser<S>(stmt: S) -> BoxedParser<'static, Token, Block, ParseError>
516where
517    S: Parser<Token, Statement, Error = ParseError> + Clone + 'static,
518{
519    let single = stmt.clone().map(|st| Block {
520        statements: vec![st],
521        span: Span { start: 0, end: 0 },
522    });
523
524    let indented = nl1()
525        .ignore_then(just(Token::Indent))
526        .ignore_then(nl())
527        .ignore_then(stmt.separated_by(nl1()).allow_leading().allow_trailing())
528        .then_ignore(nl())
529        .then_ignore(just(Token::Dedent))
530        .map(|statements| Block {
531            statements,
532            span: Span { start: 0, end: 0 },
533        });
534
535    choice((indented, single)).boxed()
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_simple_function() {
544        let source = "def transfer(to: address, amount: uint256) -> bool: return true";
545
546        let result = parse_from_source(source);
547        assert!(result.is_ok(), "Parser should handle simple function");
548
549        let program = result.unwrap();
550        assert_eq!(program.items.len(), 1);
551
552        if let Item::Function(func) = &program.items[0] {
553            assert_eq!(func.name, "transfer");
554            assert_eq!(func.params.len(), 2);
555            assert_eq!(func.params[0].name, "to");
556            assert!(matches!(func.params[0].type_, Type::Address));
557        }
558    }
559
560    #[test]
561    fn test_expression_parsing() {
562        let source = "def test() -> uint256: return 42";
563
564        let result = parse_from_source(source);
565        assert!(result.is_ok(), "Should parse simple return statement");
566    }
567
568    #[test]
569    fn parses_multiline_block_with_require() {
570        let source = "def t() -> bool:\n    require true\n    return true\n";
571        let program = parse_from_source(source).unwrap();
572        assert_eq!(program.items.len(), 1);
573        let Item::Function(f) = &program.items[0] else { panic!() };
574        assert_eq!(f.body.statements.len(), 2);
575        assert!(matches!(f.body.statements[0], Statement::Require(_)));
576        assert!(matches!(f.body.statements[1], Statement::Return(_)));
577    }
578
579    #[test]
580    fn parses_if_elif_else() {
581        let source = "def t() -> uint256:\n    if true: return 1\n    elif false: return 2\n    else: return 3\n";
582        let program = parse_from_source(source).unwrap();
583        let Item::Function(f) = &program.items[0] else { panic!() };
584        assert_eq!(f.body.statements.len(), 1);
585        assert!(matches!(f.body.statements[0], Statement::If(_)));
586    }
587
588    #[test]
589    fn parses_augmented_assignment() {
590        let source = "def t() -> uint256:\n    let mut x = 1\n    x += 2\n    return x\n";
591        let program = parse_from_source(source).unwrap();
592        let Item::Function(f) = &program.items[0] else { panic!() };
593        assert_eq!(f.body.statements.len(), 3);
594        assert!(matches!(f.body.statements[1], Statement::Assign(_)));
595    }
596
597    #[test]
598    fn parses_const_item() {
599        let source = "const total_supply: uint256 = 100\n\ndef t() -> uint256: return total_supply\n";
600        let program = parse_from_source(source).unwrap();
601        assert_eq!(program.items.len(), 2);
602        assert!(matches!(program.items[0], Item::Const(_)));
603    }
604
605    #[test]
606    fn parses_for_loop() {
607        let source = "def t():\n    for i in items:\n        let x = i\n";
608        let program = parse_from_source(source).unwrap();
609        let Item::Function(f) = &program.items[0] else { panic!() };
610        assert_eq!(f.body.statements.len(), 1);
611        assert!(matches!(f.body.statements[0], Statement::For(_)));
612    }
613
614    #[test]
615    fn parses_while_loop() {
616        let source = "def t():\n    while true:\n        let x = 1\n";
617        let program = parse_from_source(source).unwrap();
618        let Item::Function(f) = &program.items[0] else { panic!() };
619        assert_eq!(f.body.statements.len(), 1);
620        assert!(matches!(f.body.statements[0], Statement::While(_)));
621    }
622
623    #[test]
624    fn parses_event_declaration() {
625        let source = "event Transfer(from: address, to: address, amount: uint256)\n\ndef t() -> bool: return true\n";
626        let program = parse_from_source(source).unwrap();
627        assert_eq!(program.items.len(), 2);
628        assert!(matches!(program.items[0], Item::Event(_)));
629    }
630
631    #[test]
632    fn parses_emit_statement() {
633        let source = "def t():\n    emit Transfer(a, b, c)\n";
634        let program = parse_from_source(source).unwrap();
635        let Item::Function(f) = &program.items[0] else { panic!() };
636        assert_eq!(f.body.statements.len(), 1);
637        assert!(matches!(f.body.statements[0], Statement::Emit(_)));
638    }
639}