diff --git a/Lua/lbp.c b/Lua/lbp.c index 2b6d62f..5e94890 100644 --- a/Lua/lbp.c +++ b/Lua/lbp.c @@ -31,6 +31,16 @@ static pat_t *builtins; static void push_match(lua_State *L, match_t *m, const char *start); +lua_State *cur_state = NULL; + +static void match_error(pat_t *pat, const char *msg) +{ + (void)pat; + recycle_all_matches(); + lua_pushstring(cur_state, msg); + lua_error(cur_state); +} + static inline void raise_parse_error(lua_State *L, maybe_pat_t m) { size_t err_len = (size_t)(m.value.error.end - m.value.error.start); @@ -153,7 +163,8 @@ static int Lmatch(lua_State *L) match_t *m = NULL; int ret = 0; - if (next_match(&m, text+index-1, &text[textlen], pat, builtins, NULL, false)) { + cur_state = L; + if (next_match_safe(&m, text+index-1, &text[textlen], pat, builtins, NULL, false, match_error)) { push_match(L, m, text); stop_matching(&m); ret = 1; @@ -191,7 +202,8 @@ static int Lreplace(lua_State *L) int replacements = 0; const char *prev = text; pat_t *rep_pat = maybe_replacement.value.pat; - for (match_t *m = NULL; next_match(&m, text, &text[textlen], rep_pat, builtins, NULL, false); ) { + cur_state = L; + for (match_t *m = NULL; next_match_safe(&m, text, &text[textlen], rep_pat, builtins, NULL, false, match_error); ) { fwrite(prev, sizeof(char), (size_t)(m->start - prev), out); fprint_match(out, text, m, NULL); prev = m->end; diff --git a/Lua/test.lua b/Lua/test.lua index 731b044..d812009 100644 --- a/Lua/test.lua +++ b/Lua/test.lua @@ -54,3 +54,10 @@ print(pat:replace("{@0}", "...baz...")) for m in pat:matches("hello world") do print(m) end + + +local ok, err = pcall(function() + bp.match("nonexistent", "xxx") +end) +assert(not ok) +print("(Successfully caught pattern matching error)") diff --git a/match.c b/match.c index 86ae2bd..cb7918c 100644 --- a/match.c +++ b/match.c @@ -49,6 +49,13 @@ typedef struct match_ctx_s { static match_t *unused_matches = NULL; static match_t *in_use_matches = NULL; +static void default_error_handler(pat_t *pat, const char *msg) { + (void)pat; + errx(EXIT_FAILURE, "%s", msg); +} + +static bp_errhand_t error_handler = default_error_handler; + #define MATCHES(...) (match_t*[]){__VA_ARGS__, NULL} __attribute__((hot, nonnull(1,2,3))) @@ -56,6 +63,18 @@ static match_t *match(match_ctx_t *ctx, const char *str, pat_t *pat); __attribute__((returns_nonnull)) static match_t *new_match(pat_t *pat, const char *start, const char *end, match_t *children[]); +__attribute__((format(printf,2,3))) +static inline void match_error(pat_t *pat, const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + char buf[256]; + vsnprintf(buf, sizeof(buf)-1, fmt, args); + va_end(args); + if (error_handler) + error_handler(pat, buf); +} + static match_t *clone_match(match_t *m) { if (!m) return NULL; @@ -207,7 +226,8 @@ static pat_t *_lookup_def(pat_t *defs, const char *name, size_t namelen) return defs->args.def.meaning; defs = defs->args.def.next_def; } else { - errx(1, "Invalid pattern type in definitions"); + match_error(defs, "Invalid pattern type in definitions"); + return NULL; } } return NULL; @@ -646,8 +666,10 @@ static match_t *match(match_ctx_t *ctx, const char *str, pat_t *pat) return NULL; pat_t *ref = lookup_ctx(ctx, pat->args.ref.name, pat->args.ref.len); - if (ref == NULL) - errx(EXIT_FAILURE, "Unknown identifier: '%.*s'", (int)pat->args.ref.len, pat->args.ref.name); + if (ref == NULL) { + match_error(pat, "Unknown pattern: '%.*s'", (int)pat->args.ref.len, pat->args.ref.name); + return NULL; + } if (ref->type == BP_LEFTRECURSION) return match(ctx, str, ref); @@ -732,7 +754,7 @@ static match_t *match(match_ctx_t *ctx, const char *str, pat_t *pat) return new_match(pat, str, str, NULL); } default: { - errx(EXIT_FAILURE, "Unknown pattern type: %u", pat->type); + match_error(pat, "Unknown pattern type: %u", pat->type); return NULL; } } @@ -830,7 +852,10 @@ bool next_match(match_t **m, const char *start, const char *end, pat_t *pat, pat pos = start; } - if (!pat) return false; + if (!pat) { + error_handler = default_error_handler; + return false; + } match_ctx_t ctx = { .cache = &(cache_t){0}, @@ -844,6 +869,17 @@ bool next_match(match_t **m, const char *start, const char *end, pat_t *pat, pat return *m != NULL; } +// +// Wrapper for next_match() that sets an error handler +// +bool next_match_safe(match_t **m, const char *start, const char *end, pat_t *pat, pat_t *defs, pat_t *skip, bool ignorecase, bp_errhand_t errhand) +{ + error_handler = errhand; + bool ret = next_match(m, start, end, pat, defs, skip, ignorecase); + error_handler = default_error_handler; + return ret; +} + // // Helper function to track state while doing a depth-first search. // diff --git a/match.h b/match.h index 6c875c4..64ae84c 100644 --- a/match.h +++ b/match.h @@ -25,12 +25,15 @@ typedef struct match_s { struct match_s *_children[3]; } match_t; +typedef void (*bp_errhand_t)(pat_t *pat, const char *err_msg); + __attribute__((nonnull)) void recycle_match(match_t **at_m); size_t free_all_matches(void); size_t recycle_all_matches(void); bool next_match(match_t **m, const char *start, const char *end, pat_t *pat, pat_t *defs, pat_t *skip, bool ignorecase); #define stop_matching(m) next_match(m, NULL, NULL, NULL, NULL, NULL, 0) +bool next_match_safe(match_t **m, const char *start, const char *end, pat_t *pat, pat_t *defs, pat_t *skip, bool ignorecase, bp_errhand_t errhand); __attribute__((nonnull)) match_t *get_numbered_capture(match_t *m, int n); __attribute__((nonnull, pure))