From ab55eee556ecbe6a8bd0a4f4cd92e38b021f6841 Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Mon, 21 Apr 2025 16:50:40 -0400 Subject: Add `assert` --- src/ast.c | 1 + src/ast.h | 5 +++- src/compile.c | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- src/parse.c | 24 ++++++++++++++++--- src/typecheck.c | 10 +++++++- 5 files changed, 106 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/ast.c b/src/ast.c index 3380d927..bcf0a601 100644 --- a/src/ast.c +++ b/src/ast.c @@ -189,6 +189,7 @@ CORD ast_to_xml(ast_t *ast) T(Optional, "%r", ast_to_xml(data.value)) T(NonOptional, "%r", ast_to_xml(data.value)) T(DocTest, "%r%r", optional_tagged("expression", data.expr), optional_tagged("expected", data.expected)) + T(Assert, "%r%r", ast_to_xml(data.expr), optional_tagged("message", data.message)) T(Use, "%r%r", optional_tagged("var", data.var), xml_escape(data.path)) T(InlineCCode, "%r", ast_list_to_xml(data.chunks)) T(Deserialize, "%r%r", type_ast_to_xml(data.type), ast_to_xml(data.value)) diff --git a/src/ast.h b/src/ast.h index 2e1a035a..c628e38c 100644 --- a/src/ast.h +++ b/src/ast.h @@ -147,7 +147,7 @@ typedef enum { Extern, StructDef, EnumDef, LangDef, Index, FieldAccess, Optional, NonOptional, - DocTest, + DocTest, Assert, Use, InlineCCode, Deserialize, @@ -321,6 +321,9 @@ struct ast_s { ast_t *expr, *expected; bool skip_source:1; } DocTest; + struct { + ast_t *expr, *message; + } Assert; struct { ast_t *var; const char *path; diff --git a/src/compile.c b/src/compile.c index 02b23919..6b320b41 100644 --- a/src/compile.c +++ b/src/compile.c @@ -447,6 +447,11 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr); break; } + case Assert: { + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr); + add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message); + break; + } case Deserialize: { add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value); break; @@ -1236,6 +1241,71 @@ static CORD _compile_statement(env_t *env, ast_t *ast) (int64_t)(test->expr->end - test->expr->file->text)); } } + case Assert: { + ast_t *expr = Match(ast, Assert)->expr; + ast_t *message = Match(ast, Assert)->message; + const char *failure = NULL; + switch (expr->tag) { + case And: { + auto and_ = Match(expr, And); + return CORD_all( + compile_statement(env, WrapAST(ast, Assert, .expr=and_->lhs, .message=message)), + compile_statement(env, WrapAST(ast, Assert, .expr=and_->rhs, .message=message))); + } + case Equals: failure = "!="; goto assert_comparison; + case NotEquals: failure = "=="; goto assert_comparison; + case LessThan: failure = ">="; goto assert_comparison; + case LessThanOrEquals: failure = ">"; goto assert_comparison; + case GreaterThan: failure = "<="; goto assert_comparison; + case GreaterThanOrEquals: failure = "<"; goto assert_comparison; { + assert_comparison: + binary_operands_t cmp = BINARY_OPERANDS(expr); + type_t *lhs_t = get_type(env, cmp.lhs); + type_t *rhs_t = get_type(env, cmp.rhs); + type_t *operand_t; + if (cmp.lhs->tag == Int && is_numeric_type(rhs_t)) { + operand_t = rhs_t; + } else if (cmp.rhs->tag == Int && is_numeric_type(lhs_t)) { + operand_t = lhs_t; + } else if (can_compile_to_type(env, cmp.rhs, lhs_t)) { + operand_t = lhs_t; + } else if (can_compile_to_type(env, cmp.lhs, rhs_t)) { + operand_t = rhs_t; + } else { + code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + + ast_t *lhs_var = FakeAST(InlineCCode, .chunks=new(ast_list_t, .ast=FakeAST(TextLiteral, "_lhs")), .type=operand_t); + ast_t *rhs_var = FakeAST(InlineCCode, .chunks=new(ast_list_t, .ast=FakeAST(TextLiteral, "_rhs")), .type=operand_t); + ast_t *var_comparison = new(ast_t, .file=expr->file, .start=expr->start, .end=expr->end, .tag=expr->tag, + .__data.Equals={.lhs=lhs_var, .rhs=rhs_var}); + return CORD_all("{ // assertion\n", + compile_declaration(operand_t, "_lhs"), " = ", compile_to_type(env, cmp.lhs, operand_t), ";\n", + compile_declaration(operand_t, "_rhs"), " = ", compile_to_type(env, cmp.rhs, operand_t), ";\n", + "if (!(", compile_condition(env, var_comparison), "))\n", + CORD_asprintf("fail_source(%r, %ld, %ld, %r, \" (\", %r, \" %s \", %r, \")\");\n", + CORD_quoted(ast->file->filename), + (long)(expr->start - expr->file->text), + (long)(expr->end - expr->file->text), + message ? CORD_all("Text$as_c_string(", compile_to_type(env, message, Type(TextType)), ")") + : "\"This assertion failed!\"", + expr_as_text("_lhs", operand_t, "no"), + failure, + expr_as_text("_rhs", operand_t, "no")), + "}\n"); + + } + default: { + return CORD_all("if (!(", compile_condition(env, expr), "))\n", + CORD_asprintf("fail_source(%r, %ld, %ld, %r);\n", + CORD_quoted(ast->file->filename), + (long)(expr->start - expr->file->text), + (long)(expr->end - expr->file->text), + message ? CORD_all("Text$as_c_string(", compile_to_type(env, message, Type(TextType)), ")") + : "\"This assertion failed!\"")); + } + } + } case Declare: { DeclareMatch(decl, ast, Declare); const char *name = Match(decl->var, Var)->name; @@ -3867,7 +3937,7 @@ CORD compile(env_t *env, ast_t *ast) case Extern: code_err(ast, "Externs are not supported as expressions"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); case Declare: case Assign: case UPDATE_CASES: case For: case While: case Repeat: case StructDef: case LangDef: case Extend: - case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: + case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: case Assert: code_err(ast, "This is not a valid expression"); default: case Unknown: code_err(ast, "Unknown AST: ", ast_to_xml_str(ast)); } diff --git a/src/parse.c b/src/parse.c index 4eae21e5..b5ca090a 100644 --- a/src/parse.c +++ b/src/parse.c @@ -62,8 +62,8 @@ int op_tightness[] = { }; static const char *keywords[] = { - "C_code", "_max_", "_min_", "and", "break", "continue", "defer", "deserialize", "do", "else", "enum", - "extend", "extern", "for", "func", "if", "in", "lang", "mod", "mod1", "no", "none", + "C_code", "_max_", "_min_", "and", "assert", "break", "continue", "defer", "deserialize", "do", "else", + "enum", "extend", "extern", "for", "func", "if", "in", "lang", "mod", "mod1", "no", "none", "not", "or", "pass", "return", "skip", "skip", "stop", "struct", "then", "unless", "use", "when", "while", "xor", "yes", }; @@ -107,6 +107,7 @@ static PARSER(parse_declaration); static PARSER(parse_defer); static PARSER(parse_do); static PARSER(parse_doctest); +static PARSER(parse_assert); static PARSER(parse_enum_def); static PARSER(parse_expr); static PARSER(parse_extended_expr); @@ -1775,7 +1776,8 @@ PARSER(parse_assignment) { PARSER(parse_statement) { ast_t *stmt = NULL; if ((stmt=parse_declaration(ctx, pos)) - || (stmt=parse_doctest(ctx, pos))) + || (stmt=parse_doctest(ctx, pos)) + || (stmt=parse_assert(ctx, pos))) return stmt; if (!(false @@ -2311,6 +2313,22 @@ PARSER(parse_doctest) { return NewAST(ctx->file, start, pos, DocTest, .expr=expr, .expected=expected); } +PARSER(parse_assert) { + const char *start = pos; + if (!match_word(&pos, "assert")) return NULL; + spaces(&pos); + ast_t *expr = expect(ctx, start, &pos, parse_extended_expr, "I couldn't parse the expression for this assert"); + spaces(&pos); + ast_t *message = NULL; + if (match(&pos, ",")) { + whitespace(&pos); + message = expect(ctx, start, &pos, parse_extended_expr, "I couldn't parse the error message for this assert"); + } else { + pos = expr->end; + } + return NewAST(ctx->file, start, pos, Assert, .expr=expr, .message=message); +} + PARSER(parse_use) { const char *start = pos; diff --git a/src/typecheck.c b/src/typecheck.c index 7f7cb438..7f901010 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -223,6 +223,10 @@ void prebind_statement(env_t *env, ast_t *statement) prebind_statement(env, Match(statement, DocTest)->expr); break; } + case Assert: { + prebind_statement(env, Match(statement, Assert)->expr); + break; + } case StructDef: { DeclareMatch(def, statement, StructDef); if (get_binding(env, def->name)) @@ -298,6 +302,10 @@ void bind_statement(env_t *env, ast_t *statement) bind_statement(env, Match(statement, DocTest)->expr); break; } + case Assert: { + bind_statement(env, Match(statement, Assert)->expr); + break; + } case Declare: { DeclareMatch(decl, statement, Declare); const char *name = Match(decl->var, Var)->name; @@ -998,7 +1006,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Extern: { return parse_type_ast(env, Match(ast, Extern)->type); } - case Declare: case Assign: case UPDATE_CASES: case DocTest: { + case Declare: case Assign: case UPDATE_CASES: case DocTest: case Assert: { return Type(VoidType); } case Use: { -- cgit v1.2.3