diff --git a/src/ast.rs b/src/ast.rs index 3927eae..1a4190f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,4 +1,5 @@ mod arg; +mod case; mod const_expr; mod costume; mod enum_; @@ -24,6 +25,7 @@ mod value; mod var; pub use arg::*; +pub use case::*; pub use const_expr::*; pub use costume::*; pub use enum_::*; diff --git a/src/ast/case.rs b/src/ast/case.rs new file mode 100644 index 0000000..0a879b2 --- /dev/null +++ b/src/ast/case.rs @@ -0,0 +1,17 @@ +use logos::Span; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::ast::{ + Expr, + Stmt, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Case { + pub value: Box, + pub body: Vec, + pub span: Span, +} diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 61dd504..2ea12f0 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -130,3 +130,17 @@ impl From for Expr { } } } + +impl PartialEq for Expr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + // Simple expressions that can be compared + (Expr::Value { value: v1, .. }, Expr::Value { value: v2, .. }) => v1 == v2, + (Expr::Name(n1), Expr::Name(n2)) => n1 == n2, + (Expr::Arg(n1), Expr::Arg(n2)) => n1 == n2, + + // Complex expressions always return false + _ => false, + } + } +} diff --git a/src/ast/name.rs b/src/ast/name.rs index 25a83a7..922efd1 100644 --- a/src/ast/name.rs +++ b/src/ast/name.rs @@ -66,3 +66,26 @@ impl Name { } } } + +impl PartialEq for Name { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Name::Name { name: name1, .. }, Name::Name { name: name2, .. }) => name1 == name2, + ( + Name::DotName { + lhs: lhs1, + rhs: rhs1, + .. + }, + Name::DotName { + lhs: lhs2, + rhs: rhs2, + .. + }, + ) => lhs1 == lhs2 && rhs1 == rhs2, + _ => false, + } + } +} + +impl Eq for Name {} diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs index a86a8b8..43df291 100644 --- a/src/ast/stmt.rs +++ b/src/ast/stmt.rs @@ -12,6 +12,7 @@ use super::{ Value, }; use crate::{ + ast::Case, blocks::{ BinOp, Block, @@ -19,7 +20,7 @@ use crate::{ misc::SmolStr, }; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum Stmt { Repeat { times: Box, @@ -92,6 +93,11 @@ pub enum Stmt { value: Box, visited: bool, }, + Switch { + value: Box, + cases: Vec, + span: Span, + }, } impl Stmt { @@ -114,6 +120,7 @@ impl Stmt { Stmt::ProcCall { span, .. } => span.clone(), Stmt::FuncCall { span, .. } => span.clone(), Stmt::Return { value, .. } => value.span(), + Stmt::Switch { span, .. } => span.clone(), } } diff --git a/src/ast/value.rs b/src/ast/value.rs index 39a8ded..44c48f8 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -192,3 +192,22 @@ pub enum ListIndex { Index(usize), All, } + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Value::Boolean(a), Value::Boolean(b)) => a == b, + (Value::Number(a), Value::Number(b)) => { + // Handle NaN case - NaN != NaN in IEEE 754 + if a.is_nan() && b.is_nan() { + true + } else { + a == b + } + } + (Value::String(a), Value::String(b)) => a == b, + // Different variants are not equal + _ => false, + } + } +} diff --git a/src/codegen/sb3.rs b/src/codegen/sb3.rs index f8edbc3..4344571 100644 --- a/src/codegen/sb3.rs +++ b/src/codegen/sb3.rs @@ -300,6 +300,7 @@ impl Stmt { Stmt::ProcCall { .. } => "procedures_call", Stmt::FuncCall { .. } => "procedures_call", Stmt::Return { .. } => "data_setvariableto", + Stmt::Switch { .. } => "", } } } @@ -1302,6 +1303,7 @@ where T: Write + Seek args, ), Stmt::Return { .. } => panic!(), + Stmt::Switch { .. } => unreachable!(), } } diff --git a/src/diagnostic/diagnostic_kind.rs b/src/diagnostic/diagnostic_kind.rs index 90b5438..855239f 100644 --- a/src/diagnostic/diagnostic_kind.rs +++ b/src/diagnostic/diagnostic_kind.rs @@ -85,6 +85,7 @@ pub enum DiagnosticKind { field_name: SmolStr, }, EmptyStruct(SmolStr), + DuplicateSwitchCasePattern, // Warnings FollowedByUnreachableCode, UnrecognizedKey(SmolStr), @@ -221,6 +222,9 @@ impl DiagnosticKind { format!("struct {struct_name} is missing field {field_name}") } DiagnosticKind::EmptyStruct(name) => format!("struct {name} is empty"), + DiagnosticKind::DuplicateSwitchCasePattern => { + "duplicate switch case pattern".to_string() + } } } @@ -377,6 +381,7 @@ impl From<&DiagnosticKind> for Level { | DiagnosticKind::MissingField { .. } | DiagnosticKind::StructDoesNotHaveField { .. } | DiagnosticKind::EmptyStruct(_) + | DiagnosticKind::DuplicateSwitchCasePattern | DiagnosticKind::InvalidCostumeName(_) | DiagnosticKind::InvalidBackdropName(_) => Level::Error, diff --git a/src/lexer/token.rs b/src/lexer/token.rs index 7304b8a..17bb7bc 100644 --- a/src/lexer/token.rs +++ b/src/lexer/token.rs @@ -210,6 +210,10 @@ pub enum Token { Enum, #[token("struct")] Struct, + #[token("switch")] + Switch, + #[token("case")] + Case, #[token("true")] True, #[token("false")] @@ -342,6 +346,8 @@ impl Display for Token { Token::As => write!(f, "as"), Token::Enum => write!(f, "enum"), Token::Struct => write!(f, "struct"), + Token::Switch => write!(f, "switch"), + Token::Case => write!(f, "case"), Token::True => write!(f, "true"), Token::False => write!(f, "false"), Token::List => write!(f, "list"), diff --git a/src/parser/grammar.lalrpop b/src/parser/grammar.lalrpop index 9f9a18a..a6a866b 100644 --- a/src/parser/grammar.lalrpop +++ b/src/parser/grammar.lalrpop @@ -348,6 +348,19 @@ Stmt: Stmt = { SET_ROTATION_STYLE_LEFT_RIGHT ";" => Stmt::Block { block: Block::SetRotationStyleLeftRight, span: l..r, args: vec![], kwargs: Default::default() }, SET_ROTATION_STYLE_ALL_AROUND ";" => Stmt::Block { block: Block::SetRotationStyleAllAround, span: l..r, args: vec![], kwargs: Default::default() }, SET_ROTATION_STYLE_DO_NOT_ROTATE ";" => Stmt::Block { block: Block::SetRotationStyleDoNotRotate, span: l..r, args: vec![], kwargs: Default::default() }, + SWITCH "{" "}" => Stmt::Switch { + value, + cases, + span: l..r + } +} + +Case: Case = { + CASE => Case { + value, + body, + span: l..r + } } Elif: Stmt = { @@ -708,5 +721,7 @@ extern { SET_ROTATION_STYLE_DO_NOT_ROTATE => Token::SetRotationStyleDoNotRotate, SET_LAYER_ORDER => Token::SetLayerOrder, VAR => Token::Var, + SWITCH => Token::Switch, + CASE => Token::Case, } } diff --git a/src/visitor.rs b/src/visitor.rs index 9d827e7..672897d 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -3,4 +3,5 @@ pub mod pass1; pub mod pass2; pub mod pass3; pub mod pass4; +mod switchcase; mod transformations; diff --git a/src/visitor/pass1.rs b/src/visitor/pass1.rs index 760901e..372bd71 100644 --- a/src/visitor/pass1.rs +++ b/src/visitor/pass1.rs @@ -197,6 +197,17 @@ fn visit_stmt(stmt: &mut Stmt, s: &mut S) -> Vec { } } } + Stmt::Switch { + value, + cases, + span: _, + } => { + visit_expr(value, &mut before, s); + for case in cases { + visit_expr(&mut case.value, &mut before, s); + visit_stmts(&mut case.body, s); + } + } } if let Some(replace) = replace { *stmt = replace; diff --git a/src/visitor/pass2.rs b/src/visitor/pass2.rs index 468f681..5c159ba 100644 --- a/src/visitor/pass2.rs +++ b/src/visitor/pass2.rs @@ -18,6 +18,7 @@ use crate::{ SpriteDiagnostics, }, misc::SmolStr, + visitor::switchcase::switchcase, }; #[derive(Copy, Clone)] @@ -193,6 +194,7 @@ fn visit_stmts(stmts: &mut Vec, s: S, d: D, top_level: bool) { visit_stmt_return(value) } } + Stmt::Switch { value, cases, span } => Some(vec![switchcase(value, cases, span, d)]), _ => None, }; if let Some(replace) = replace { @@ -309,6 +311,17 @@ fn visit_stmt(stmt: &mut Stmt, s: S, d: D) { } } Stmt::Return { value, .. } => visit_expr(value, s, d), + Stmt::Switch { + value, + cases, + span: _, + } => { + visit_expr(value, s, d); + for case in cases { + visit_expr(&mut case.value, s, d); + visit_stmts(&mut case.body, s, d, false); + } + } } } diff --git a/src/visitor/pass3.rs b/src/visitor/pass3.rs index a436fc1..1d7108e 100644 --- a/src/visitor/pass3.rs +++ b/src/visitor/pass3.rs @@ -158,6 +158,17 @@ fn visit_stmt(stmt: &Stmt, s: &mut S) { value: _, visited: _, } => {} + Stmt::Switch { + value, + cases, + span: _, + } => { + visit_expr(value, s); + for case in cases { + visit_expr(&case.value, s); + visit_stmts(&case.body, s); + } + } } } diff --git a/src/visitor/switchcase.rs b/src/visitor/switchcase.rs new file mode 100644 index 0000000..0f1fb73 --- /dev/null +++ b/src/visitor/switchcase.rs @@ -0,0 +1,131 @@ +use logos::Span; + +use crate::{ + ast::{ + Case, + Expr, + Stmt, + Value, + }, + blocks::{ + BinOp, + UnOp, + }, + codegen::sb3::D, + diagnostic::DiagnosticKind, +}; + +fn casearm(value: &Expr, cases: &[Case], span: &Span, index: usize) -> Stmt { + Stmt::Branch { + cond: Box::new(BinOp::Eq.to_expr(span.clone(), value.clone(), *cases[index].value.clone())), + if_body: cases[index].body.clone(), + else_body: if index < cases.len() - 1 { + vec![casearm(value, cases, span, index + 1)] + } else { + vec![] + }, + } +} + +fn get_number(expr: &Expr) -> f64 { + if let Expr::Value { + value: Value::Number(number), + .. + } = expr + { + return *number; + } + unreachable!() +} + +fn searcharm( + value: &Expr, + cases: &[Case], + span: &Span, + nums: &[(usize, f64)], + low: usize, + high: usize, +) -> Stmt { + let mid = low + (1 + high - low) / 2; + Stmt::Branch { + cond: Box::new(BinOp::Lt.to_expr( + span.clone(), + value.clone(), + Value::from(nums[mid].1).to_expr(span.clone()), + )), + if_body: if mid - low == 1 { + cases[nums[low].0].body.clone() + } else { + vec![searcharm(value, cases, span, nums, low, mid - 1)] + }, + else_body: if mid == high { + cases[nums[mid].0].body.clone() + } else { + vec![searcharm(value, cases, span, nums, mid, high)] + }, + } +} + +fn searchtree(value: &Expr, cases: &[Case], span: &Span) -> Stmt { + let mut nums = cases + .iter() + .enumerate() + .map(|(i, case)| (i, get_number(&case.value))) + .collect::>(); + nums.sort_by(|a, b| a.1.total_cmp(&b.1)); + let low_minus_1 = nums[0].1 - 1.0; + let high_plus_1 = nums[nums.len() - 1].1 + 1.0; + Stmt::Branch { + cond: Box::new(BinOp::And.to_expr( + span.clone(), + BinOp::And.to_expr( + span.clone(), + BinOp::Lt.to_expr( + span.clone(), + Value::from(low_minus_1).to_expr(span.clone()), + value.clone(), + ), + BinOp::Lt.to_expr( + span.clone(), + value.clone(), + Value::from(high_plus_1).to_expr(span.clone()), + ), + ), + BinOp::Eq.to_expr( + span.clone(), + value.clone(), + UnOp::Round.to_expr(span.clone(), value.clone()), + ), + )), + if_body: vec![searcharm(value, cases, span, &nums, 0, nums.len() - 1)], + else_body: vec![], + } +} + +pub fn switchcase(value: &Expr, cases: &[Case], span: &Span, d: D) -> Stmt { + // Check for duplicate case patterns + let mut seen_cases: Vec<&Box> = Vec::new(); + for case in cases { + for seen_case in &seen_cases { + if case.value.as_ref() == seen_case.as_ref() { + d.report(DiagnosticKind::DuplicateSwitchCasePattern, &case.span); + break; + } + } + seen_cases.push(&case.value); + } + + let all_integers = cases.iter().all(|case| { + matches!( + *case.value, + Expr::Value {value: Value::Number(n),.. + } if n.fract() == 0.0 + ) + }); + // <25 doesn't benefit from search tree + if cases.len() > 25 && all_integers { + searchtree(value, cases, span) + } else { + casearm(value, cases, span, 0) + } +}