aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2025-04-21 16:50:40 -0400
committerBruce Hill <bruce@bruce-hill.com>2025-04-21 16:50:40 -0400
commitab55eee556ecbe6a8bd0a4f4cd92e38b021f6841 (patch)
tree47c339c48c9aaeffb931588c1b6241d35690c5a2 /src
parentf2eab0d205d1a60e9ce7a8e2420196e12d7eed10 (diff)
Add `assert`
Diffstat (limited to 'src')
-rw-r--r--src/ast.c1
-rw-r--r--src/ast.h5
-rw-r--r--src/compile.c72
-rw-r--r--src/parse.c24
-rw-r--r--src/typecheck.c10
5 files changed, 106 insertions, 6 deletions
diff --git a/src/ast.c b/src/ast.c
index 3380d927..bcf0a601 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -189,6 +189,7 @@ CORD ast_to_xml(ast_t *ast)
T(Optional, "<Optional>%r</Optional>", ast_to_xml(data.value))
T(NonOptional, "<NonOptional>%r</NonOptional>", ast_to_xml(data.value))
T(DocTest, "<DocTest>%r%r</DocTest>", optional_tagged("expression", data.expr), optional_tagged("expected", data.expected))
+ T(Assert, "<Assert>%r%r</Assert>", ast_to_xml(data.expr), optional_tagged("message", data.message))
T(Use, "<Use>%r%r</Use>", optional_tagged("var", data.var), xml_escape(data.path))
T(InlineCCode, "<InlineCode>%r</InlineCode>", ast_list_to_xml(data.chunks))
T(Deserialize, "<Deserialize><type>%r</type>%r</Deserialize>", type_ast_to_xml(data.type), ast_to_xml(data.value))
diff --git a/src/ast.h b/src/ast.h
index 2e1a035a..c628e38c 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -147,7 +147,7 @@ typedef enum {
Extern,
StructDef, EnumDef, LangDef,
Index, FieldAccess, Optional, NonOptional,
- DocTest,
+ DocTest, Assert,
Use,
InlineCCode,
Deserialize,
@@ -322,6 +322,9 @@ struct ast_s {
bool skip_source:1;
} DocTest;
struct {
+ ast_t *expr, *message;
+ } Assert;
+ struct {
ast_t *var;
const char *path;
enum { USE_LOCAL, USE_MODULE, USE_SHARED_OBJECT, USE_HEADER, USE_C_CODE, USE_ASM } what;
diff --git a/src/compile.c b/src/compile.c
index 02b23919..6b320b41 100644
--- a/src/compile.c
+++ b/src/compile.c
@@ -447,6 +447,11 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t
add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, DocTest)->expr);
break;
}
+ case Assert: {
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->expr);
+ add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Assert)->message);
+ break;
+ }
case Deserialize: {
add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, Deserialize)->value);
break;
@@ -1236,6 +1241,71 @@ static CORD _compile_statement(env_t *env, ast_t *ast)
(int64_t)(test->expr->end - test->expr->file->text));
}
}
+ case Assert: {
+ ast_t *expr = Match(ast, Assert)->expr;
+ ast_t *message = Match(ast, Assert)->message;
+ const char *failure = NULL;
+ switch (expr->tag) {
+ case And: {
+ auto and_ = Match(expr, And);
+ return CORD_all(
+ compile_statement(env, WrapAST(ast, Assert, .expr=and_->lhs, .message=message)),
+ compile_statement(env, WrapAST(ast, Assert, .expr=and_->rhs, .message=message)));
+ }
+ case Equals: failure = "!="; goto assert_comparison;
+ case NotEquals: failure = "=="; goto assert_comparison;
+ case LessThan: failure = ">="; goto assert_comparison;
+ case LessThanOrEquals: failure = ">"; goto assert_comparison;
+ case GreaterThan: failure = "<="; goto assert_comparison;
+ case GreaterThanOrEquals: failure = "<"; goto assert_comparison; {
+ assert_comparison:
+ binary_operands_t cmp = BINARY_OPERANDS(expr);
+ type_t *lhs_t = get_type(env, cmp.lhs);
+ type_t *rhs_t = get_type(env, cmp.rhs);
+ type_t *operand_t;
+ if (cmp.lhs->tag == Int && is_numeric_type(rhs_t)) {
+ operand_t = rhs_t;
+ } else if (cmp.rhs->tag == Int && is_numeric_type(lhs_t)) {
+ operand_t = lhs_t;
+ } else if (can_compile_to_type(env, cmp.rhs, lhs_t)) {
+ operand_t = lhs_t;
+ } else if (can_compile_to_type(env, cmp.lhs, rhs_t)) {
+ operand_t = rhs_t;
+ } else {
+ code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t));
+ }
+
+ ast_t *lhs_var = FakeAST(InlineCCode, .chunks=new(ast_list_t, .ast=FakeAST(TextLiteral, "_lhs")), .type=operand_t);
+ ast_t *rhs_var = FakeAST(InlineCCode, .chunks=new(ast_list_t, .ast=FakeAST(TextLiteral, "_rhs")), .type=operand_t);
+ ast_t *var_comparison = new(ast_t, .file=expr->file, .start=expr->start, .end=expr->end, .tag=expr->tag,
+ .__data.Equals={.lhs=lhs_var, .rhs=rhs_var});
+ return CORD_all("{ // assertion\n",
+ compile_declaration(operand_t, "_lhs"), " = ", compile_to_type(env, cmp.lhs, operand_t), ";\n",
+ compile_declaration(operand_t, "_rhs"), " = ", compile_to_type(env, cmp.rhs, operand_t), ";\n",
+ "if (!(", compile_condition(env, var_comparison), "))\n",
+ CORD_asprintf("fail_source(%r, %ld, %ld, %r, \" (\", %r, \" %s \", %r, \")\");\n",
+ CORD_quoted(ast->file->filename),
+ (long)(expr->start - expr->file->text),
+ (long)(expr->end - expr->file->text),
+ message ? CORD_all("Text$as_c_string(", compile_to_type(env, message, Type(TextType)), ")")
+ : "\"This assertion failed!\"",
+ expr_as_text("_lhs", operand_t, "no"),
+ failure,
+ expr_as_text("_rhs", operand_t, "no")),
+ "}\n");
+
+ }
+ default: {
+ return CORD_all("if (!(", compile_condition(env, expr), "))\n",
+ CORD_asprintf("fail_source(%r, %ld, %ld, %r);\n",
+ CORD_quoted(ast->file->filename),
+ (long)(expr->start - expr->file->text),
+ (long)(expr->end - expr->file->text),
+ message ? CORD_all("Text$as_c_string(", compile_to_type(env, message, Type(TextType)), ")")
+ : "\"This assertion failed!\""));
+ }
+ }
+ }
case Declare: {
DeclareMatch(decl, ast, Declare);
const char *name = Match(decl->var, Var)->name;
@@ -3867,7 +3937,7 @@ CORD compile(env_t *env, ast_t *ast)
case Extern: code_err(ast, "Externs are not supported as expressions");
case TableEntry: code_err(ast, "Table entries should not be compiled directly");
case Declare: case Assign: case UPDATE_CASES: case For: case While: case Repeat: case StructDef: case LangDef: case Extend:
- case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest:
+ case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: case Assert:
code_err(ast, "This is not a valid expression");
default: case Unknown: code_err(ast, "Unknown AST: ", ast_to_xml_str(ast));
}
diff --git a/src/parse.c b/src/parse.c
index 4eae21e5..b5ca090a 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -62,8 +62,8 @@ int op_tightness[] = {
};
static const char *keywords[] = {
- "C_code", "_max_", "_min_", "and", "break", "continue", "defer", "deserialize", "do", "else", "enum",
- "extend", "extern", "for", "func", "if", "in", "lang", "mod", "mod1", "no", "none",
+ "C_code", "_max_", "_min_", "and", "assert", "break", "continue", "defer", "deserialize", "do", "else",
+ "enum", "extend", "extern", "for", "func", "if", "in", "lang", "mod", "mod1", "no", "none",
"not", "or", "pass", "return", "skip", "skip", "stop", "struct", "then", "unless", "use", "when",
"while", "xor", "yes",
};
@@ -107,6 +107,7 @@ static PARSER(parse_declaration);
static PARSER(parse_defer);
static PARSER(parse_do);
static PARSER(parse_doctest);
+static PARSER(parse_assert);
static PARSER(parse_enum_def);
static PARSER(parse_expr);
static PARSER(parse_extended_expr);
@@ -1775,7 +1776,8 @@ PARSER(parse_assignment) {
PARSER(parse_statement) {
ast_t *stmt = NULL;
if ((stmt=parse_declaration(ctx, pos))
- || (stmt=parse_doctest(ctx, pos)))
+ || (stmt=parse_doctest(ctx, pos))
+ || (stmt=parse_assert(ctx, pos)))
return stmt;
if (!(false
@@ -2311,6 +2313,22 @@ PARSER(parse_doctest) {
return NewAST(ctx->file, start, pos, DocTest, .expr=expr, .expected=expected);
}
+PARSER(parse_assert) {
+ const char *start = pos;
+ if (!match_word(&pos, "assert")) return NULL;
+ spaces(&pos);
+ ast_t *expr = expect(ctx, start, &pos, parse_extended_expr, "I couldn't parse the expression for this assert");
+ spaces(&pos);
+ ast_t *message = NULL;
+ if (match(&pos, ",")) {
+ whitespace(&pos);
+ message = expect(ctx, start, &pos, parse_extended_expr, "I couldn't parse the error message for this assert");
+ } else {
+ pos = expr->end;
+ }
+ return NewAST(ctx->file, start, pos, Assert, .expr=expr, .message=message);
+}
+
PARSER(parse_use) {
const char *start = pos;
diff --git a/src/typecheck.c b/src/typecheck.c
index 7f7cb438..7f901010 100644
--- a/src/typecheck.c
+++ b/src/typecheck.c
@@ -223,6 +223,10 @@ void prebind_statement(env_t *env, ast_t *statement)
prebind_statement(env, Match(statement, DocTest)->expr);
break;
}
+ case Assert: {
+ prebind_statement(env, Match(statement, Assert)->expr);
+ break;
+ }
case StructDef: {
DeclareMatch(def, statement, StructDef);
if (get_binding(env, def->name))
@@ -298,6 +302,10 @@ void bind_statement(env_t *env, ast_t *statement)
bind_statement(env, Match(statement, DocTest)->expr);
break;
}
+ case Assert: {
+ bind_statement(env, Match(statement, Assert)->expr);
+ break;
+ }
case Declare: {
DeclareMatch(decl, statement, Declare);
const char *name = Match(decl->var, Var)->name;
@@ -998,7 +1006,7 @@ type_t *get_type(env_t *env, ast_t *ast)
case Extern: {
return parse_type_ast(env, Match(ast, Extern)->type);
}
- case Declare: case Assign: case UPDATE_CASES: case DocTest: {
+ case Declare: case Assign: case UPDATE_CASES: case DocTest: case Assert: {
return Type(VoidType);
}
case Use: {