1use crate::ast::*;
2use crate::lexer::Token;
3use chumsky::prelude::*;
4
5pub type ParseError = Simple<Token>;
6
7#[derive(Clone)]
8enum PostfixOp {
9 Member(String),
10 Index(Expression),
11 Call(Vec<Expression>),
12}
13
14#[derive(Clone)]
15enum TargetOp {
16 Member(String),
17 Index(Expression),
18}
19
20fn fold_postfix(lhs: Expression, op: PostfixOp) -> Expression {
21 match op {
22 PostfixOp::Member(name) => Expression::Member(Box::new(lhs), name),
23 PostfixOp::Index(idx) => Expression::Index(Box::new(lhs), Box::new(idx)),
24 PostfixOp::Call(args) => Expression::Call(Box::new(lhs), args),
25 }
26}
27
28fn fold_unary(op: UnaryOp, rhs: Expression) -> Expression {
29 Expression::Unary(op, Box::new(rhs))
30}
31
32fn fold_binary(left: Expression, (op, right): (BinaryOp, Expression)) -> Expression {
33 Expression::Binary(op, Box::new(left), Box::new(right))
34}
35
36fn fold_target(lhs: Expression, op: TargetOp) -> Expression {
37 match op {
38 TargetOp::Member(name) => Expression::Member(Box::new(lhs), name),
39 TargetOp::Index(idx) => Expression::Index(Box::new(lhs), Box::new(idx)),
40 }
41}
42
43fn fold_field_init((name, value): (String, Expression)) -> (String, Expression) {
44 (name, value)
45}
46
47fn fold_struct_init((name, fields): (String, Vec<(String, Expression)>)) -> Expression {
48 Expression::StructInit(name, fields)
49}
50
51pub fn parse_program(tokens: Vec<Token>) -> Result<Program, Vec<ParseError>> {
52 program_parser().parse(tokens)
53}
54
55pub fn parse_from_source(source: &str) -> Result<Program, Vec<ParseError>> {
56 use crate::lexer::PyraLexer;
57
58 let lexer = PyraLexer::new(source);
59 let tokens: Vec<Token> = lexer.collect();
60
61 let tokens: Vec<Token> = tokens.into_iter().filter(|t| !matches!(t, Token::Comment)).collect();
62
63 parse_program(tokens)
64}
65
66fn program_parser() -> impl Parser<Token, Program, Error = ParseError> {
67 nl()
68 .ignore_then(
69 choice((
70 function_parser().map(Item::Function),
71 struct_parser().map(Item::Struct),
72 event_parser().map(Item::Event),
73 const_item_parser().map(Item::Const),
74 ))
75 .then_ignore(nl()),
76 )
77 .repeated()
78 .map(|items| Program {
79 items,
80 span: Span { start: 0, end: 0 },
81 })
82 .then_ignore(end())
83}
84
85fn function_parser() -> impl Parser<Token, Function, Error = ParseError> {
86 just(Token::Def)
87 .ignore_then(identifier())
88 .then_ignore(just(Token::LParen))
89 .then(parameter_list())
90 .then_ignore(just(Token::RParen))
91 .then(return_type().or_not())
92 .then_ignore(just(Token::Colon))
93 .then(suite_parser(statement_parser()))
94 .map(|(((name, params), return_type), body)| Function {
95 name,
96 params,
97 return_type,
98 body,
99 span: Span { start: 0, end: 0 },
100 })
101}
102
103fn nl() -> impl Parser<Token, (), Error = ParseError> {
104 just(Token::Newline).repeated().ignored()
105}
106
107fn nl1() -> impl Parser<Token, (), Error = ParseError> {
108 just(Token::Newline).repeated().at_least(1).ignored()
109}
110
111fn parameter_list() -> impl Parser<Token, Vec<Parameter>, Error = ParseError> {
112 parameter_parser()
113 .separated_by(just(Token::Comma))
114 .allow_trailing()
115}
116
117fn parameter_parser() -> impl Parser<Token, Parameter, Error = ParseError> {
118 identifier()
119 .then_ignore(just(Token::Colon))
120 .then(type_parser())
121 .map(|(name, type_)| Parameter {
122 name,
123 type_,
124 span: Span { start: 0, end: 0 },
125 })
126}
127
128fn return_type() -> impl Parser<Token, Type, Error = ParseError> {
129 just(Token::Arrow).ignore_then(type_parser())
130}
131
132fn type_parser() -> impl Parser<Token, Type, Error = ParseError> {
133 choice((
134 just(Token::Uint8).to(Type::Uint8),
135 just(Token::Uint256).to(Type::Uint256),
136 just(Token::Int256).to(Type::Int256),
137 just(Token::Bool).to(Type::Bool),
138 just(Token::Address).to(Type::Address),
139 just(Token::Bytes).to(Type::Bytes),
140 just(Token::String).to(Type::String),
141 identifier().map(Type::Custom),
142 ))
143}
144
145fn generic_params_parser() -> impl Parser<Token, (), Error = ParseError> {
146 let param = identifier()
147 .then(just(Token::Colon).ignore_then(type_parser()).or_not())
148 .ignored();
149
150 just(Token::Less)
151 .ignore_then(param.separated_by(just(Token::Comma)).allow_trailing())
152 .then_ignore(just(Token::Greater))
153 .ignored()
154}
155
156fn struct_parser() -> impl Parser<Token, StructDef, Error = ParseError> {
157 let sep = choice((just(Token::Comma).ignore_then(nl()).ignored(), nl1()));
158 just(Token::Struct)
159 .ignore_then(identifier())
160 .then_ignore(generic_params_parser().or_not())
161 .then_ignore(nl())
162 .then_ignore(just(Token::LBrace))
163 .then_ignore(nl())
164 .then_ignore(just(Token::Indent).or_not())
165 .then_ignore(nl())
166 .then(struct_field().separated_by(sep).allow_leading().allow_trailing())
167 .then_ignore(nl())
168 .then_ignore(just(Token::Dedent).or_not())
169 .then_ignore(nl())
170 .then_ignore(just(Token::RBrace))
171 .map(|(name, fields)| StructDef {
172 name,
173 fields,
174 span: Span { start: 0, end: 0 },
175 })
176}
177
178fn struct_field() -> impl Parser<Token, StructField, Error = ParseError> {
179 identifier()
180 .then_ignore(just(Token::Colon))
181 .then(type_parser())
182 .map(|(name, type_)| StructField {
183 name,
184 type_,
185 span: Span { start: 0, end: 0 },
186 })
187}
188
189fn const_item_parser() -> impl Parser<Token, ConstDecl, Error = ParseError> {
190 choice((just(Token::Const), just(Token::Let)))
191 .ignore_then(identifier())
192 .then(just(Token::Colon).ignore_then(type_parser()).or_not())
193 .then_ignore(just(Token::Assign))
194 .then(expression_parser())
195 .map(|((name, type_), value)| ConstDecl {
196 name,
197 type_: type_.unwrap_or(Type::Uint256),
198 value,
199 span: Span { start: 0, end: 0 },
200 })
201}
202
203fn event_parser() -> impl Parser<Token, EventDef, Error = ParseError> {
204 just(Token::Event)
205 .ignore_then(identifier())
206 .then_ignore(just(Token::LParen))
207 .then(parameter_list())
208 .then_ignore(just(Token::RParen))
209 .map(|(name, fields)| EventDef {
210 name,
211 fields,
212 span: Span { start: 0, end: 0 },
213 })
214}
215
216fn expression_parser() -> impl Parser<Token, Expression, Error = ParseError> {
217 recursive(|expr| {
218 let field_init = identifier()
219 .then_ignore(just(Token::Colon))
220 .then(expr.clone())
221 .map(fold_field_init as fn((String, Expression)) -> (String, Expression));
222
223 let sep = choice((just(Token::Comma).ignore_then(nl()).ignored(), nl1()));
224
225 let struct_init = identifier()
226 .then(
227 just(Token::LBrace)
228 .ignore_then(nl())
229 .ignore_then(just(Token::Indent).or_not())
230 .ignore_then(nl())
231 .ignore_then(field_init.separated_by(sep).allow_leading().allow_trailing())
232 .then_ignore(nl())
233 .then_ignore(just(Token::Dedent).or_not())
234 .then_ignore(nl())
235 .then_ignore(just(Token::RBrace)),
236 )
237 .map(fold_struct_init as fn((String, Vec<(String, Expression)>)) -> Expression);
238
239 let atom = choice((
240 select! { Token::Number(n) => Expression::Number(n) },
241 select! { Token::HexNumber(n) => Expression::HexNumber(n) },
242 select! { Token::StringLiteral(s) => Expression::String(s) },
243 select! { Token::BytesLiteral(b) => Expression::Bytes(b) },
244 just(Token::True).to(Expression::Bool(true)),
245 just(Token::False).to(Expression::Bool(false)),
246 struct_init,
247 identifier().map(Expression::Identifier),
248 expr.clone().delimited_by(just(Token::LParen), just(Token::RParen)),
249 ));
250
251 let postfix_ops = choice((
252 just(Token::Dot)
253 .ignore_then(identifier())
254 .map(PostfixOp::Member),
255 just(Token::LBracket)
256 .ignore_then(expr.clone())
257 .then_ignore(just(Token::RBracket))
258 .map(PostfixOp::Index),
259 just(Token::LParen)
260 .ignore_then(expr.clone().separated_by(just(Token::Comma)).allow_trailing())
261 .then_ignore(just(Token::RParen))
262 .map(PostfixOp::Call),
263 ))
264 .repeated();
265
266 let postfix = atom
267 .then(postfix_ops)
268 .foldl(fold_postfix as fn(Expression, PostfixOp) -> Expression)
269 .boxed();
270
271 let unary = choice((
272 just(Token::Not).to(UnaryOp::Not),
273 just(Token::Minus).to(UnaryOp::Minus),
274 ))
275 .repeated()
276 .then(postfix)
277 .foldr(fold_unary as fn(UnaryOp, Expression) -> Expression)
278 .boxed();
279
280 let product = unary
281 .clone()
282 .then(
283 choice((
284 just(Token::Multiply).to(BinaryOp::Mul),
285 just(Token::Divide).to(BinaryOp::Div),
286 just(Token::Modulo).to(BinaryOp::Mod),
287 ))
288 .then(unary.clone())
289 .repeated(),
290 )
291 .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
292 .boxed();
293
294 let sum = product
295 .clone()
296 .then(
297 choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub)))
298 .then(product)
299 .repeated(),
300 )
301 .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
302 .boxed();
303
304 let cmp = sum
305 .clone()
306 .then(
307 choice((
308 just(Token::Equal).to(BinaryOp::Equal),
309 just(Token::NotEqual).to(BinaryOp::NotEqual),
310 just(Token::LessEqual).to(BinaryOp::LessEqual),
311 just(Token::GreaterEqual).to(BinaryOp::GreaterEqual),
312 just(Token::Less).to(BinaryOp::Less),
313 just(Token::Greater).to(BinaryOp::Greater),
314 ))
315 .then(sum)
316 .repeated(),
317 )
318 .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
319 .boxed();
320
321 let and_expr = cmp
322 .clone()
323 .then(just(Token::And).to(BinaryOp::And).then(cmp).repeated())
324 .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
325 .boxed();
326
327 and_expr
328 .clone()
329 .then(just(Token::Or).to(BinaryOp::Or).then(and_expr).repeated())
330 .foldl(fold_binary as fn(Expression, (BinaryOp, Expression)) -> Expression)
331 })
332}
333
334fn return_statement() -> impl Parser<Token, Statement, Error = ParseError> {
335 just(Token::Return)
336 .ignore_then(expression_parser().or_not())
337 .map(Statement::Return)
338}
339
340fn require_statement() -> impl Parser<Token, Statement, Error = ParseError> {
341 just(Token::Require)
342 .ignore_then(expression_parser())
343 .map(Statement::Require)
344}
345
346fn identifier() -> impl Parser<Token, String, Error = ParseError> {
347 select! { Token::Identifier(name) => name }
348}
349
350fn let_statement() -> impl Parser<Token, Statement, Error = ParseError> {
351 just(Token::Let)
352 .ignore_then(just(Token::Mut).or_not())
353 .then(identifier())
354 .then(just(Token::Colon).ignore_then(type_parser()).or_not())
355 .then(just(Token::Assign).ignore_then(expression_parser()).or_not())
356 .map(|(((mutable, name), type_), value)| {
357 Statement::Let(LetStatement {
358 name,
359 type_,
360 value,
361 mutable: mutable.is_some(),
362 span: Span { start: 0, end: 0 },
363 })
364 })
365}
366
367fn assign_statement() -> impl Parser<Token, Statement, Error = ParseError> {
368 let target = assignment_target_parser();
369
370 let op = choice((
371 just(Token::Assign).to(None),
372 just(Token::PlusAssign).to(Some(BinaryOp::Add)),
373 just(Token::MinusAssign).to(Some(BinaryOp::Sub)),
374 just(Token::MultiplyAssign).to(Some(BinaryOp::Mul)),
375 just(Token::DivideAssign).to(Some(BinaryOp::Div)),
376 ));
377
378 target
379 .then(op)
380 .then(expression_parser())
381 .map(|((target, op), rhs)| {
382 let value = match op {
383 None => rhs,
384 Some(bin_op) => Expression::Binary(bin_op, Box::new(target.clone()), Box::new(rhs)),
385 };
386
387 Statement::Assign(AssignStatement {
388 target,
389 value,
390 span: Span { start: 0, end: 0 },
391 })
392 })
393}
394
395fn assignment_target_parser() -> impl Parser<Token, Expression, Error = ParseError> {
396 let base = identifier().map(Expression::Identifier).boxed();
397 let ops = choice((
398 just(Token::Dot)
399 .ignore_then(identifier())
400 .map(TargetOp::Member),
401 just(Token::LBracket)
402 .ignore_then(expression_parser())
403 .then_ignore(just(Token::RBracket))
404 .map(TargetOp::Index),
405 ))
406 .repeated();
407
408 base.then(ops).foldl(fold_target as fn(Expression, TargetOp) -> Expression)
409}
410
411fn emit_statement() -> impl Parser<Token, Statement, Error = ParseError> {
412 just(Token::Emit)
413 .ignore_then(identifier())
414 .then_ignore(just(Token::LParen))
415 .then(expression_parser().separated_by(just(Token::Comma)).allow_trailing())
416 .then_ignore(just(Token::RParen))
417 .map(|(name, args)| {
418 Statement::Emit(EmitStatement {
419 name,
420 args,
421 span: Span { start: 0, end: 0 },
422 })
423 })
424}
425
426fn statement_parser() -> BoxedParser<'static, Token, Statement, ParseError> {
427 recursive(|stmt| {
428 let suite = suite_parser(stmt.clone().boxed());
429
430 let if_stmt = just(Token::If)
431 .ignore_then(expression_parser())
432 .then_ignore(just(Token::Colon))
433 .then(suite.clone())
434 .then(
435 nl1()
436 .ignore_then(
437 just(Token::Elif)
438 .ignore_then(expression_parser())
439 .then_ignore(just(Token::Colon))
440 .then(suite.clone()),
441 )
442 .repeated(),
443 )
444 .then(
445 nl1()
446 .ignore_then(just(Token::Else).ignore_then(just(Token::Colon)).ignore_then(suite.clone()))
447 .or_not(),
448 )
449 .map(|(((cond, then_branch), elifs), else_branch)| {
450 let mut else_acc = else_branch;
451 for (elif_cond, elif_body) in elifs.into_iter().rev() {
452 let nested = IfStatement {
453 condition: elif_cond,
454 then_branch: elif_body,
455 else_branch: else_acc,
456 span: Span { start: 0, end: 0 },
457 };
458
459 else_acc = Some(Block {
460 statements: vec![Statement::If(nested)],
461 span: Span { start: 0, end: 0 },
462 });
463 }
464
465 Statement::If(IfStatement {
466 condition: cond,
467 then_branch,
468 else_branch: else_acc,
469 span: Span { start: 0, end: 0 },
470 })
471 });
472
473 let for_stmt = just(Token::For)
474 .ignore_then(identifier())
475 .then_ignore(just(Token::In))
476 .then(expression_parser())
477 .then_ignore(just(Token::Colon))
478 .then(suite.clone())
479 .map(|((var, iterable), body)| {
480 Statement::For(ForStatement {
481 var,
482 iterable,
483 body,
484 span: Span { start: 0, end: 0 },
485 })
486 });
487
488 let while_stmt = just(Token::While)
489 .ignore_then(expression_parser())
490 .then_ignore(just(Token::Colon))
491 .then(suite)
492 .map(|(condition, body)| {
493 Statement::While(WhileStatement {
494 condition,
495 body,
496 span: Span { start: 0, end: 0 },
497 })
498 });
499
500 choice((
501 if_stmt,
502 for_stmt,
503 while_stmt,
504 emit_statement(),
505 require_statement(),
506 let_statement(),
507 return_statement(),
508 assign_statement(),
509 ))
510 .boxed()
511 })
512 .boxed()
513}
514
515fn suite_parser<S>(stmt: S) -> BoxedParser<'static, Token, Block, ParseError>
516where
517 S: Parser<Token, Statement, Error = ParseError> + Clone + 'static,
518{
519 let single = stmt.clone().map(|st| Block {
520 statements: vec![st],
521 span: Span { start: 0, end: 0 },
522 });
523
524 let indented = nl1()
525 .ignore_then(just(Token::Indent))
526 .ignore_then(nl())
527 .ignore_then(stmt.separated_by(nl1()).allow_leading().allow_trailing())
528 .then_ignore(nl())
529 .then_ignore(just(Token::Dedent))
530 .map(|statements| Block {
531 statements,
532 span: Span { start: 0, end: 0 },
533 });
534
535 choice((indented, single)).boxed()
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_simple_function() {
544 let source = "def transfer(to: address, amount: uint256) -> bool: return true";
545
546 let result = parse_from_source(source);
547 assert!(result.is_ok(), "Parser should handle simple function");
548
549 let program = result.unwrap();
550 assert_eq!(program.items.len(), 1);
551
552 if let Item::Function(func) = &program.items[0] {
553 assert_eq!(func.name, "transfer");
554 assert_eq!(func.params.len(), 2);
555 assert_eq!(func.params[0].name, "to");
556 assert!(matches!(func.params[0].type_, Type::Address));
557 }
558 }
559
560 #[test]
561 fn test_expression_parsing() {
562 let source = "def test() -> uint256: return 42";
563
564 let result = parse_from_source(source);
565 assert!(result.is_ok(), "Should parse simple return statement");
566 }
567
568 #[test]
569 fn parses_multiline_block_with_require() {
570 let source = "def t() -> bool:\n require true\n return true\n";
571 let program = parse_from_source(source).unwrap();
572 assert_eq!(program.items.len(), 1);
573 let Item::Function(f) = &program.items[0] else { panic!() };
574 assert_eq!(f.body.statements.len(), 2);
575 assert!(matches!(f.body.statements[0], Statement::Require(_)));
576 assert!(matches!(f.body.statements[1], Statement::Return(_)));
577 }
578
579 #[test]
580 fn parses_if_elif_else() {
581 let source = "def t() -> uint256:\n if true: return 1\n elif false: return 2\n else: return 3\n";
582 let program = parse_from_source(source).unwrap();
583 let Item::Function(f) = &program.items[0] else { panic!() };
584 assert_eq!(f.body.statements.len(), 1);
585 assert!(matches!(f.body.statements[0], Statement::If(_)));
586 }
587
588 #[test]
589 fn parses_augmented_assignment() {
590 let source = "def t() -> uint256:\n let mut x = 1\n x += 2\n return x\n";
591 let program = parse_from_source(source).unwrap();
592 let Item::Function(f) = &program.items[0] else { panic!() };
593 assert_eq!(f.body.statements.len(), 3);
594 assert!(matches!(f.body.statements[1], Statement::Assign(_)));
595 }
596
597 #[test]
598 fn parses_const_item() {
599 let source = "const total_supply: uint256 = 100\n\ndef t() -> uint256: return total_supply\n";
600 let program = parse_from_source(source).unwrap();
601 assert_eq!(program.items.len(), 2);
602 assert!(matches!(program.items[0], Item::Const(_)));
603 }
604
605 #[test]
606 fn parses_for_loop() {
607 let source = "def t():\n for i in items:\n let x = i\n";
608 let program = parse_from_source(source).unwrap();
609 let Item::Function(f) = &program.items[0] else { panic!() };
610 assert_eq!(f.body.statements.len(), 1);
611 assert!(matches!(f.body.statements[0], Statement::For(_)));
612 }
613
614 #[test]
615 fn parses_while_loop() {
616 let source = "def t():\n while true:\n let x = 1\n";
617 let program = parse_from_source(source).unwrap();
618 let Item::Function(f) = &program.items[0] else { panic!() };
619 assert_eq!(f.body.statements.len(), 1);
620 assert!(matches!(f.body.statements[0], Statement::While(_)));
621 }
622
623 #[test]
624 fn parses_event_declaration() {
625 let source = "event Transfer(from: address, to: address, amount: uint256)\n\ndef t() -> bool: return true\n";
626 let program = parse_from_source(source).unwrap();
627 assert_eq!(program.items.len(), 2);
628 assert!(matches!(program.items[0], Item::Event(_)));
629 }
630
631 #[test]
632 fn parses_emit_statement() {
633 let source = "def t():\n emit Transfer(a, b, c)\n";
634 let program = parse_from_source(source).unwrap();
635 let Item::Function(f) = &program.items[0] else { panic!() };
636 assert_eq!(f.body.statements.len(), 1);
637 assert!(matches!(f.body.statements[0], Statement::Emit(_)));
638 }
639}