diff --git a/src/nfa.rs b/src/nfa.rs index 17c1282..e15a1be 100644 --- a/src/nfa.rs +++ b/src/nfa.rs @@ -315,7 +315,7 @@ impl NfaBuilder { } fn group(mut self, start: StateId, end: StateId) -> Self { - self.capture_groups.push(CaptureGroup { start, end }); + self.capture_groups.insert(0, CaptureGroup { start, end }); self } @@ -358,19 +358,19 @@ mod tests { #[test] fn test_concatenation() { - let nfa = NfaBuilder::default() + let expected = NfaBuilder::default() .transition(0, TransitionKind::Character('h'), 1) .transition(1, TransitionKind::Epsilon, 2) .transition(2, TransitionKind::Character('i'), 3) .build(); - let expected = to_nfa("hi"); + let nfa = to_nfa("hi"); - assert_eq!(nfa, expected); + assert_eq!(expected, nfa); } #[test] fn test_alternation() { - let nfa = NfaBuilder::default() + let expected = NfaBuilder::default() .transition(0, TransitionKind::Epsilon, 1) .transition(0, TransitionKind::Epsilon, 3) .transition(1, TransitionKind::Character('a'), 2) @@ -378,50 +378,50 @@ mod tests { .transition(3, TransitionKind::Character('b'), 4) .transition(4, TransitionKind::Epsilon, 5) .build(); - let expected = to_nfa("a|b"); + let nfa = to_nfa("a|b"); - assert_eq!(nfa, expected); + assert_eq!(expected, nfa); } #[test] fn test_range_excat() { - let nfa = NfaBuilder::default() + let expected = NfaBuilder::default() .transition(0, TransitionKind::Character('e'), 1) .transition(1, TransitionKind::Epsilon, 2) .transition(2, TransitionKind::Character('e'), 3) .transition(3, TransitionKind::Epsilon, 4) .transition(4, TransitionKind::Character('e'), 5) .build(); - let expected = to_nfa("e{3}"); + let nfa = to_nfa("e{3}"); - assert_eq!(nfa, expected); + assert_eq!(expected, nfa); } #[test] fn test_range_between() { - let nfa = NfaBuilder::default() + let expected = NfaBuilder::default() .transition(0, TransitionKind::Character('e'), 1) .transition(1, TransitionKind::Epsilon, 2) .transition(2, TransitionKind::Character('e'), 3) .transition(2, TransitionKind::Epsilon, 3) .build(); - let expected = to_nfa("e{1,2}"); + let nfa = to_nfa("e{1,2}"); - assert_eq!(nfa, expected); + assert_eq!(expected, nfa); } #[test] fn test_range_minimum() { - let nfa = NfaBuilder::default() + let expected = NfaBuilder::default() .transition(0, TransitionKind::Character('e'), 1) .transition(1, TransitionKind::Epsilon, 2) .transition(2, TransitionKind::Character('e'), 3) .transition(3, TransitionKind::Epsilon, 1) .transition(3, TransitionKind::Epsilon, 4) .build(); - let expected = to_nfa("e{2,}"); + let nfa = to_nfa("e{2,}"); - assert_eq!(nfa, expected); + assert_eq!(expected, nfa); } #[test] @@ -454,4 +454,17 @@ mod tests { assert_eq!(eclosure, expected); } + + #[test] + fn test_capture_group_order() { + let nfa = to_nfa("a(b(c)(d))(e)"); + let expected = vec![ + CaptureGroup { start: 2, end: 7 }, + CaptureGroup { start: 4, end: 5 }, + CaptureGroup { start: 6, end: 7 }, + CaptureGroup { start: 8, end: 9 }, + ]; + + assert_eq!(nfa.capture_groups, expected); + } } diff --git a/src/regex.rs b/src/regex.rs index 4ec9c05..bee7d72 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -144,23 +144,14 @@ impl<'a> Regex { continue; } - captures.insert(0, (Some(start), end)); // full match - let captures = captures .into_iter() - .flat_map(|(index, (start, end))| { - start - .zip(end) - .map(|(start, end)| (index, Match::new(start, end, &input[start..end]))) - }) + .chain([(0, (Some(start), end))]) + .flat_map(|(index, (start, end))| Some(index).zip(self.new_mach(input, start, end))) .collect(); let named_captures = named_captures .into_iter() - .flat_map(|(name, (start, end))| { - start - .zip(end) - .map(|(start, end)| (name, Match::new(start, end, &input[start..end]))) - }) + .flat_map(|(name, (start, end))| Some(name).zip(self.new_mach(input, start, end))) .collect(); let capture = Capture { @@ -186,6 +177,17 @@ impl<'a> Regex { states.iter().any(|s| self.nfa.is_accepting(*s)) } + fn new_mach( + &self, + input: &'a str, + start: Option, + end: Option, + ) -> Option> { + start + .zip(end) + .map(|(start, end)| Match::new(start, end, &input[start..end])) + } + fn update_captures( &self, captures: &mut HashMap, Option)>, @@ -195,48 +197,66 @@ impl<'a> Regex { ) { for state in states { if let Some(groups) = self.start_capture.get(state) { - self.update_capture_starts(captures, named_captures, groups, position); + for group in groups { + self.update_capture_start(captures, named_captures, group, position); + } } if let Some(groups) = self.end_capture.get(state) { - self.update_capture_ends(captures, named_captures, groups, position); + for group in groups { + self.update_capture_end(captures, named_captures, group, position); + } } } } - fn update_capture_starts( + fn update_capture_start( &self, captures: &mut HashMap, Option)>, named_captures: &mut HashMap, Option)>, - groups: &[CaptureKind], + group: &CaptureKind, position: usize, ) { - for group in groups { - match group { - CaptureKind::Indexed(index) => { - captures.entry(*index + 1).or_default().0 = Some(position) - } - CaptureKind::Named(name) => { - named_captures.entry(name.to_owned()).or_default().0 = Some(position) - } + match group { + CaptureKind::Indexed(index) => { + captures + .entry(*index + 1) + .and_modify(|(start, end)| { + if end.is_none() { + *start = Some(position) + } + }) + .or_insert((Some(position), None)); + } + CaptureKind::Named(name) => { + named_captures + .entry(name.to_owned()) + .and_modify(|(start, end)| { + if end.is_none() { + *start = Some(position) + } + }) + .or_insert((Some(position), None)); } } } - fn update_capture_ends( + fn update_capture_end( &self, captures: &mut HashMap, Option)>, named_captures: &mut HashMap, Option)>, - groups: &[CaptureKind], + group: &CaptureKind, position: usize, ) { - for group in groups { - match group { - CaptureKind::Indexed(index) => { - captures.entry(*index + 1).or_default().1 = Some(position) - } - CaptureKind::Named(name) => { - named_captures.entry(name.to_owned()).or_default().1 = Some(position) - } + match group { + CaptureKind::Indexed(index) => { + captures + .entry(*index + 1) + .and_modify(|(_, end)| *end = Some(position)); + } + CaptureKind::Named(name) => { + named_captures + .entry(name.to_owned()) + .and_modify(|(_, end)| *end = Some(position)); } } } @@ -334,7 +354,12 @@ impl<'a> Match<'a> { #[cfg(test)] mod test { - use crate::regex::{Match, Regex}; + use std::collections::HashMap; + + use crate::{ + regex::{Match, Regex}, + Capture, + }; #[test] fn test_simple_match() { @@ -432,6 +457,24 @@ mod test { assert_eq!(matches.get_name("minute"), Some(&Match::new(3, 5, "30"))); } + // #[test] + // fn test_repeated_group() { + // let regex = Regex::new(r#"(hi)+(ah)+"#).unwrap(); + // let matches = regex.captures("hihiah").unwrap(); + // let expected = Capture { + // captures: vec![ + // (0, Match::new(0, 6, "hihiah")), + // (1, Match::new(2, 4, "hi")), + // (2, Match::new(4, 6, "ah")), + // ] + // .into_iter() + // .collect(), + // named_captures: HashMap::new(), + // }; + // + // assert_eq!(matches, expected); + // } + #[test] fn test_find() { let regex = Regex::new(r#"wh(at|o|y)"#).unwrap(); diff --git a/src/wasm.rs b/src/wasm.rs index e47bd9a..88858a4 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -74,7 +74,7 @@ impl RegexGroup { } #[wasm_bindgen] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RegexCapture { groups: Vec, } @@ -125,32 +125,31 @@ fn get_char_index(input: &str) -> HashMap { .collect() } -// #[cfg(test)] -// mod tests { -// use super::{RegexEngine, RegexGroup}; -// -// #[test] -// fn test_unicode_range() { -// let regex = RegexEngine::new(r#"ここ"#); -// let matches = regex.find_all("ここでここで"); -// -// assert_eq!( -// matches, -// vec![ -// RegexGroup { start: 0, end: 2 }, -// RegexGroup { start: 3, end: 5 }, -// ] -// ); -// -// let regex = RegexEngine::new(r#"日本語"#); -// let matches = regex.find_all("これは日本語のテストです。日本語"); -// -// assert_eq!( -// matches, -// vec![ -// RegexGroup { start: 3, end: 6 }, -// RegexGroup { start: 13, end: 16 }, -// ] -// ); -// } -// } +#[cfg(test)] +mod tests { + use super::{RegexCapture, RegexEngine, RegexGroup}; + + #[test] + fn test_unicode_range() { + let regex = RegexEngine::new(r#"ここ"#); + let matches = regex.captures_all("ここでここで"); + let expected = vec![ + RegexCapture { + groups: vec![RegexGroup { + name: "0".to_string(), + start: 0, + end: 2, + }], + }, + RegexCapture { + groups: vec![RegexGroup { + name: "0".to_string(), + start: 3, + end: 5, + }], + }, + ]; + + assert_eq!(matches, expected); + } +} diff --git a/web/src/App.tsx b/web/src/App.tsx index 919cdc7..2d6d904 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -4,7 +4,7 @@ import Navbar from "./components/Navbar"; import ExpressionsPopup from "./components/ExpressionsPopup"; import { RiQuestionFill } from "react-icons/ri"; import { RegexEngine, RegexCapture } from "regex-potata"; -import { graphFromRegex } from "./utils/viz"; +import { graphFromRegex } from "./utils/graphiz"; import TestInput from "./components/TestInput"; import Footer from "./components/Footer"; import RegexInput from "./components/RegexInput"; diff --git a/web/src/components/ExpressionsPopup.tsx b/web/src/components/ExpressionsPopup.tsx index 74b63a6..c2d57aa 100644 --- a/web/src/components/ExpressionsPopup.tsx +++ b/web/src/components/ExpressionsPopup.tsx @@ -3,11 +3,11 @@ import { useRef } from "react"; import Snippet from "./Snippet"; const expressions = [ - { desc: "Basic regex", pat: ["foo", "(bar)", "foo|bar", "fo."] }, + { desc: "Basic regex", pat: ["foo", "(bar)", "foo|bar"] }, { desc: "Quantifiers", pat: ["+", "*", "?", "{x}", "{x,y}", "{x,}"] }, { desc: "Character class", - pat: ["a-z]", "[^x]", "\\d", "\\D", "\\w", "\\W", "\\s", "\\S"], + pat: [".", "[a-z]", "[^x]", "\\d", "\\D", "\\w", "\\W", "\\s", "\\S"], }, { desc: "Capture groups", pat: ["(foo)", "(:?bar)", "(?foo)"] }, ]; diff --git a/web/src/utils/extensions.ts b/web/src/utils/extensions.ts index 6d9425d..e1d89cd 100644 --- a/web/src/utils/extensions.ts +++ b/web/src/utils/extensions.ts @@ -20,7 +20,7 @@ function getMatchHighlight(captures: RegexCapture[]) { const decorationBuilder = new RangeSetBuilder(); for (let i = 0; i < captures.length; i++) { - const groups = captures[i].groups(); + const groups = captures[i].groups().sort((a, b) => a.start - b.start); decorationBuilder.add( groups[0].start, @@ -37,7 +37,7 @@ function getMatchHighlight(captures: RegexCapture[]) { groups[j].end, decoration( `color: #0f172a; background-color: ${ - palette[i + (j % palette.length)] + palette[(i + j) % palette.length] }` ) ); diff --git a/web/src/utils/viz.ts b/web/src/utils/graphiz.ts similarity index 100% rename from web/src/utils/viz.ts rename to web/src/utils/graphiz.ts