From 0b8074154e2671691050bdb3bcb33245625a056c Mon Sep 17 00:00:00 2001 From: Bruce Hill Date: Fri, 4 Apr 2025 17:06:09 -0400 Subject: First working compile of refactor to add explicit typing to declarations and support untyped empty collections and `none`s --- src/ast.c | 119 +++-- src/ast.h | 55 ++- src/compile.c | 1271 ++++++++++++++++++++++++----------------------------- src/environment.c | 28 +- src/environment.h | 1 + src/parse.c | 180 +++----- src/repl.c | 120 ++--- src/typecheck.c | 539 ++++++++++++----------- src/typecheck.h | 1 + src/types.c | 109 +++++ src/types.h | 2 + 11 files changed, 1191 insertions(+), 1234 deletions(-) (limited to 'src') diff --git a/src/ast.c b/src/ast.c index 84d25db0..67b54f9a 100644 --- a/src/ast.c +++ b/src/ast.c @@ -10,23 +10,47 @@ #include "stdlib/text.h" #include "cordhelpers.h" -static const char *OP_NAMES[] = { - [BINOP_UNKNOWN]="unknown", - [BINOP_POWER]="^", [BINOP_MULT]="*", [BINOP_DIVIDE]="/", - [BINOP_MOD]="mod", [BINOP_MOD1]="mod1", [BINOP_PLUS]="+", [BINOP_MINUS]="minus", - [BINOP_CONCAT]="++", [BINOP_LSHIFT]="<<", [BINOP_ULSHIFT]="<<<", - [BINOP_RSHIFT]=">>", [BINOP_URSHIFT]=">>>", [BINOP_MIN]="min", - [BINOP_MAX]="max", [BINOP_EQ]="==", [BINOP_NE]="!=", [BINOP_LT]="<", - [BINOP_LE]="<=", [BINOP_GT]=">", [BINOP_GE]=">=", [BINOP_CMP]="<>", - [BINOP_AND]="and", [BINOP_OR]="or", [BINOP_XOR]="xor", +CONSTFUNC const char *binop_method_name(ast_e tag) { + switch (tag) { + case Power: case PowerUpdate: return "power"; + case Multiply: case MultiplyUpdate: return "times"; + case Divide: case DivideUpdate: return "divided_by"; + case Mod: case ModUpdate: return "modulo"; + case Mod1: case Mod1Update: return "modulo1"; + case Plus: case PlusUpdate: return "plus"; + case Minus: case MinusUpdate: return "minus"; + case Concat: case ConcatUpdate: return "concatenated_with"; + case LeftShift: case LeftShiftUpdate: return "left_shifted"; + case RightShift: case RightShiftUpdate: return "right_shifted"; + case UnsignedLeftShift: case UnsignedLeftShiftUpdate: return "unsigned_left_shifted"; + case UnsignedRightShift: case UnsignedRightShiftUpdate: return "unsigned_right_shifted"; + case And: case AndUpdate: return "bit_and"; + case Or: case OrUpdate: return "bit_or"; + case Xor: case XorUpdate: return "bit_xor"; + default: return NULL; + } }; -const char *binop_method_names[BINOP_XOR+1] = { - [BINOP_POWER]="power", [BINOP_MULT]="times", [BINOP_DIVIDE]="divided_by", - [BINOP_MOD]="modulo", [BINOP_MOD1]="modulo1", [BINOP_PLUS]="plus", [BINOP_MINUS]="minus", - [BINOP_CONCAT]="concatenated_with", [BINOP_LSHIFT]="left_shifted", [BINOP_RSHIFT]="right_shifted", - [BINOP_ULSHIFT]="unsigned_left_shifted", [BINOP_URSHIFT]="unsigned_right_shifted", - [BINOP_AND]="bit_and", [BINOP_OR]="bit_or", [BINOP_XOR]="bit_xor", +CONSTFUNC const char *binop_operator(ast_e tag) { + switch (tag) { + case Multiply: case MultiplyUpdate: return "*"; + case Divide: case DivideUpdate: return "/"; + case Mod: case ModUpdate: return "%"; + case Plus: case PlusUpdate: return "+"; + case Minus: case MinusUpdate: return "-"; + case LeftShift: case LeftShiftUpdate: return "<<"; + case RightShift: case RightShiftUpdate: return ">>"; + case And: case AndUpdate: return "&"; + case Or: case OrUpdate: return "|"; + case Xor: case XorUpdate: return "^"; + case Equals: return "=="; + case NotEquals: return "!="; + case LessThan: return "<"; + case LessThanOrEquals: return "<="; + case GreaterThan: return ">"; + case GreaterThanOrEquals: return ">="; + default: return NULL; + } }; static CORD ast_list_to_xml(ast_list_t *asts); @@ -100,7 +124,7 @@ CORD ast_to_xml(ast_t *ast) switch (ast->tag) { #define T(type, ...) case type: { auto data = ast->__data.type; (void)data; return CORD_asprintf(__VA_ARGS__); } T(Unknown, "") - T(None, "%r", type_ast_to_xml(data.type)) + T(None, "") T(Bool, "", data.b ? "yes" : "no") T(Var, "%s", data.name) T(Int, "%s", data.str) @@ -108,22 +132,24 @@ CORD ast_to_xml(ast_t *ast) T(TextLiteral, "%r", xml_escape(data.cord)) T(TextJoin, "%r", data.lang ? CORD_all(" lang=\"", data.lang, "\"") : CORD_EMPTY, ast_list_to_xml(data.children)) T(Path, "%s", data.path) - T(Declare, "%r", ast_to_xml(data.var), ast_to_xml(data.value)) + T(Declare, "%r%r", ast_to_xml(data.var), type_ast_to_xml(data.type), ast_to_xml(data.value)) T(Assign, "%r%r", ast_list_to_xml(data.targets), ast_list_to_xml(data.values)) - T(BinaryOp, "%r %r", xml_escape(OP_NAMES[data.op]), ast_to_xml(data.lhs), ast_to_xml(data.rhs)) - T(UpdateAssign, "%r %r", xml_escape(OP_NAMES[data.op]), ast_to_xml(data.lhs), ast_to_xml(data.rhs)) +#define BINOP(name) T(name, "<" #name ">%r %r", data.lhs, data.rhs) + BINOP(Power) BINOP(PowerUpdate) BINOP(Multiply) BINOP(MultiplyUpdate) BINOP(Divide) BINOP(DivideUpdate) BINOP(Mod) BINOP(ModUpdate) + BINOP(Mod1) BINOP(Mod1Update) BINOP(Plus) BINOP(PlusUpdate) BINOP(Minus) BINOP(MinusUpdate) BINOP(Concat) BINOP(ConcatUpdate) + BINOP(LeftShift) BINOP(LeftShiftUpdate) BINOP(RightShift) BINOP(RightShiftUpdate) BINOP(UnsignedLeftShift) BINOP(UnsignedLeftShiftUpdate) + BINOP(UnsignedRightShift) BINOP(UnsignedRightShiftUpdate) BINOP(And) BINOP(AndUpdate) BINOP(Or) BINOP(OrUpdate) + BINOP(Xor) BINOP(XorUpdate) +#undef BINOP T(Negative, "%r", ast_to_xml(data.value)) T(Not, "%r", ast_to_xml(data.value)) T(HeapAllocate, "%r", ast_to_xml(data.value)) T(StackReference, "%r", ast_to_xml(data.value)) T(Min, "%r%r%r", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) T(Max, "%r%r%r", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) - T(Array, "%r%r", optional_tagged_type("item-type", data.item_type), ast_list_to_xml(data.items)) - T(Set, "%r%r", - optional_tagged_type("item-type", data.item_type), - ast_list_to_xml(data.items)) - T(Table, "%r%r%r%r
", - optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type), + T(Array, "%r", ast_list_to_xml(data.items)) + T(Set, "%r", ast_list_to_xml(data.items)) + T(Table, "%r%r
", optional_tagged("default-value", data.default_value), ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback)) T(TableEntry, "%r%r", ast_to_xml(data.key), ast_to_xml(data.value)) @@ -145,7 +171,7 @@ CORD ast_to_xml(ast_t *ast) T(Repeat, "%r", optional_tagged("body", data.body)) T(If, "%r%r%r", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body)) T(When, "%r%r%r", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body)) - T(Reduction, "%r", xml_escape(OP_NAMES[data.op]), optional_tagged("key", data.key), + T(Reduction, "%r", xml_escape(binop_method_name(data.op)), optional_tagged("key", data.key), optional_tagged("iterable", data.iter)) T(Skip, "%r", data.target) T(Stop, "%r", data.target) @@ -313,4 +339,45 @@ void visit_topologically(ast_list_t *asts, Closure_t fn) } } +CONSTFUNC bool is_binary_operation(ast_t *ast) +{ + switch (ast->tag) { + case BINOP_CASES: return true; + default: return false; + } +} + +CONSTFUNC bool is_update_assignment(ast_t *ast) +{ + switch (ast->tag) { + case PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: + case PlusUpdate: case MinusUpdate: case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: + case RightShiftUpdate: case UnsignedRightShiftUpdate: case AndUpdate: case OrUpdate: case XorUpdate: + return true; + default: return false; + } +} + +CONSTFUNC ast_e binop_tag(ast_e tag) +{ + switch (tag) { + case PowerUpdate: return Power; + case MultiplyUpdate: return Multiply; + case DivideUpdate: return Divide; + case ModUpdate: return Mod; + case Mod1Update: return Mod1; + case PlusUpdate: return Plus; + case MinusUpdate: return Minus; + case ConcatUpdate: return Concat; + case LeftShiftUpdate: return LeftShift; + case UnsignedLeftShiftUpdate: return UnsignedLeftShift; + case RightShiftUpdate: return RightShift; + case UnsignedRightShiftUpdate: return UnsignedRightShift; + case AndUpdate: return And; + case OrUpdate: return Or; + case XorUpdate: return Xor; + default: return Unknown; + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/ast.h b/src/ast.h index 766b484b..4f370e11 100644 --- a/src/ast.h +++ b/src/ast.h @@ -20,6 +20,7 @@ #define WrapAST(ast, ast_tag, ...) (new(ast_t, .file=(ast)->file, .start=(ast)->start, .end=(ast)->end, .tag=ast_tag, .__data.ast_tag={__VA_ARGS__})) #define TextAST(ast, _str) WrapAST(ast, TextLiteral, .str=GC_strdup(_str)) #define Match(x, _tag) ((x)->tag == _tag ? &(x)->__data._tag : (errx(1, __FILE__ ":%d This was supposed to be a " # _tag "\n", __LINE__), &(x)->__data._tag)) +#define BINARY_OPERANDS(ast) ({ if (!is_binary_operation(ast)) errx(1, __FILE__ ":%d This is not a binary operation!", __LINE__); (ast)->__data.Plus; }) #define REVERSE_LIST(list) do { \ __typeof(list) _prev = NULL; \ @@ -37,6 +38,9 @@ struct binding_s; typedef struct type_ast_s type_ast_t; typedef struct ast_s ast_t; +typedef struct { + ast_t *lhs, *rhs; +} binary_operands_t; typedef struct ast_list_s { ast_t *ast; @@ -55,17 +59,6 @@ typedef struct when_clause_s { struct when_clause_s *next; } when_clause_t; -typedef enum { - BINOP_UNKNOWN, - BINOP_POWER=100, BINOP_MULT, BINOP_DIVIDE, BINOP_MOD, BINOP_MOD1, BINOP_PLUS, - BINOP_MINUS, BINOP_CONCAT, BINOP_LSHIFT, BINOP_ULSHIFT, BINOP_RSHIFT, BINOP_URSHIFT, BINOP_MIN, - BINOP_MAX, BINOP_EQ, BINOP_NE, BINOP_LT, BINOP_LE, BINOP_GT, BINOP_GE, - BINOP_CMP, - BINOP_AND, BINOP_OR, BINOP_XOR, -} binop_e; - -extern const char *binop_method_names[BINOP_XOR+1]; - typedef enum { UnknownTypeAST, VarTypeAST, @@ -117,6 +110,15 @@ struct type_ast_s { } __data; }; +#define BINOP_CASES Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case Concat: case LeftShift: case UnsignedLeftShift: \ + case RightShift: case UnsignedRightShift: case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: \ + case GreaterThanOrEquals: case Compare: case And: case Or: case Xor: \ + case PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: case PlusUpdate: case MinusUpdate: case ConcatUpdate: \ + case LeftShiftUpdate: case UnsignedLeftShiftUpdate +#define UPDATE_CASES PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: case PlusUpdate: case MinusUpdate: \ + case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: case RightShiftUpdate: case UnsignedRightShiftUpdate: \ + case AndUpdate: case OrUpdate: case XorUpdate + typedef enum { Unknown = 0, None, Bool, Var, @@ -124,7 +126,11 @@ typedef enum { TextLiteral, TextJoin, PrintStatement, Path, Declare, Assign, - BinaryOp, UpdateAssign, + Power, Multiply, Divide, Mod, Mod1, Plus, Minus, Concat, LeftShift, UnsignedLeftShift, + RightShift, UnsignedRightShift, Equals, NotEquals, LessThan, LessThanOrEquals, GreaterThan, + GreaterThanOrEquals, Compare, And, Or, Xor, + PowerUpdate, MultiplyUpdate, DivideUpdate, ModUpdate, Mod1Update, PlusUpdate, MinusUpdate, ConcatUpdate, LeftShiftUpdate, UnsignedLeftShiftUpdate, + RightShiftUpdate, UnsignedRightShiftUpdate, AndUpdate, OrUpdate, XorUpdate, Not, Negative, HeapAllocate, StackReference, Min, Max, Array, Set, Table, TableEntry, Comprehension, @@ -152,9 +158,7 @@ struct ast_s { const char *start, *end; union { struct {} Unknown; - struct { - type_ast_t *type; - } None; + struct {} None; struct { bool b; } Bool; @@ -182,16 +186,17 @@ struct ast_s { } PrintStatement; struct { ast_t *var; + type_ast_t *type; ast_t *value; } Declare; struct { ast_list_t *targets, *values; } Assign; - struct { - ast_t *lhs; - binop_e op; - ast_t *rhs; - } BinaryOp, UpdateAssign; + binary_operands_t Power, Multiply, Divide, Mod, Mod1, Plus, Minus, Concat, LeftShift, UnsignedLeftShift, + RightShift, UnsignedRightShift, Equals, NotEquals, LessThan, LessThanOrEquals, GreaterThan, + GreaterThanOrEquals, Compare, And, Or, Xor, + PowerUpdate, MultiplyUpdate, DivideUpdate, ModUpdate, Mod1Update, PlusUpdate, MinusUpdate, ConcatUpdate, LeftShiftUpdate, UnsignedLeftShiftUpdate, + RightShiftUpdate, UnsignedRightShiftUpdate, AndUpdate, OrUpdate, XorUpdate; struct { ast_t *value; } Not, Negative, HeapAllocate, StackReference; @@ -199,15 +204,12 @@ struct ast_s { ast_t *lhs, *rhs, *key; } Min, Max; struct { - type_ast_t *item_type; ast_list_t *items; } Array; struct { - type_ast_t *item_type; ast_list_t *items; } Set; struct { - type_ast_t *key_type, *value_type; ast_t *default_value; ast_t *fallback; ast_list_t *entries; @@ -272,7 +274,7 @@ struct ast_s { } When; struct { ast_t *iter, *key; - binop_e op; + ast_e op; } Reduction; struct { const char *target; @@ -345,5 +347,10 @@ const char *ast_source(ast_t *ast); CORD type_ast_to_xml(type_ast_t *ast); PUREFUNC bool is_idempotent(ast_t *ast); void visit_topologically(ast_list_t *ast, Closure_t fn); +CONSTFUNC bool is_update_assignment(ast_t *ast); +CONSTFUNC const char *binop_method_name(ast_e tag); +CONSTFUNC const char *binop_operator(ast_e tag); +CONSTFUNC ast_e binop_tag(ast_e tag); +CONSTFUNC bool is_binary_operation(ast_t *ast); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/compile.c b/src/compile.c index 95cf5c9a..d893c674 100644 --- a/src/compile.c +++ b/src/compile.c @@ -24,7 +24,6 @@ typedef ast_t* (*comprehension_body_t)(ast_t*, ast_t*); static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool needs_incref); -static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); static CORD compile_maybe_incref(env_t *env, ast_t *ast, type_t *t); @@ -33,9 +32,17 @@ static CORD compile_unsigned_type(type_t *t); static CORD promote_to_optional(type_t *t, CORD code); static CORD compile_none(type_t *t); static CORD compile_to_type(env_t *env, ast_t *ast, type_t *t); +static CORD compile_typed_array(env_t *env, ast_t *ast, type_t *array_type); +static CORD compile_typed_set(env_t *env, ast_t *ast, type_t *set_type); +static CORD compile_typed_table(env_t *env, ast_t *ast, type_t *table_type); +static CORD compile_typed_allocation(env_t *env, ast_t *ast, type_t *pointer_type); static CORD check_none(type_t *t, CORD value); static CORD optional_into_nonnone(type_t *t, CORD value); static CORD compile_string_literal(CORD literal); +static ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject); +static ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject); +static ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject); +static CORD compile_lvalue(env_t *env, ast_t *ast); CORD promote_to_optional(type_t *t, CORD code) { @@ -79,6 +86,11 @@ static bool promote(env_t *env, ast_t *ast, CORD *code, type_t *actual, type_t * return true; } + // Empty promotion: + type_t *more_complete = most_complete_type(actual, needed); + if (more_complete) + return true; + // Optional promotion: if (needed->tag == OptionalType && type_eq(actual, Match(needed, OptionalType)->type)) { *code = promote_to_optional(actual, *code); @@ -218,14 +230,10 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t add_closed_vars(closed_vars, enclosing_scope, env, value->ast); break; } - case BinaryOp: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->rhs); - break; - } - case UpdateAssign: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->rhs); + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs); + add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs); break; } case Not: case Negative: case HeapAllocate: case StackReference: { @@ -481,6 +489,226 @@ CORD compile_declaration(type_t *t, CORD name) } } +static CORD compile_update_assignment(env_t *env, ast_t *ast) +{ + if (!is_update_assignment(ast)) + code_err(ast, "This is not an update assignment"); + + binary_operands_t update = BINARY_OPERANDS(ast); + + type_t *lhs_t = get_type(env, update.lhs); + + bool needs_idemotency_fix = !is_idempotent(update.lhs); + CORD lhs = needs_idemotency_fix ? "(*lhs)" : compile_lvalue(env, update.lhs); + + CORD update_assignment = CORD_EMPTY; + switch (ast->tag) { + case PlusUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " += ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case MinusUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " -= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case MultiplyUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " *= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case DivideUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " /= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case LeftShiftUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " <<= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case RightShiftUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " >>= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case AndUpdate: { + if (lhs_t->tag == BoolType) + update_assignment = CORD_all("if (", lhs, ") ", lhs, " = ", compile_to_type(env, update.rhs, Type(BoolType)), ";"); + break; + } + case OrUpdate: { + if (lhs_t->tag == BoolType) + update_assignment = CORD_all("if (!", lhs, ") ", lhs, " = ", compile_to_type(env, update.rhs, Type(BoolType)), ";"); + break; + } + default: break; + } + + if (update_assignment == CORD_EMPTY) { + ast_t *binop = new(ast_t); + *binop = *ast; + binop->tag = binop_tag(binop->tag); + if (needs_idemotency_fix) + binop->__data.Plus.lhs = WrapAST(update.lhs, InlineCCode, .code="*lhs", .type=lhs_t); + update_assignment = CORD_all(lhs, " = ", compile_to_type(env, binop, lhs_t)); + } + + if (needs_idemotency_fix) + return CORD_all("{ ", compile_declaration(Type(PointerType, .pointed=lhs_t), "lhs"), " = &", compile_lvalue(env, update.lhs), "; ", + update_assignment, "; }"); + else + return update_assignment; +} + +static CORD compile_binary_op(env_t *env, ast_t *ast) +{ + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + type_t *overall_t = get_type(env, ast); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + auto fn = Match(b->type, FunctionType); + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + + if (ast->tag == Or && lhs_t->tag == OptionalType) { + if (is_incomplete_type(rhs_t)) { + type_t *complete = most_complete_type(rhs_t, Match(lhs_t, OptionalType)->type); + if (complete == NULL) + code_err(binop.rhs, "I don't know how to convert a ", type_to_str(rhs_t), " to a ", type_to_str(Match(lhs_t, OptionalType)->type)); + rhs_t = complete; + } + + if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + "if (", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop.rhs), " ", + optional_into_nonnone(lhs_t, "lhs"), "; })"); + } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + check_none(lhs_t, "lhs"), " ? ", compile(env, binop.rhs), " : lhs; })"); + } else if (rhs_t->tag != OptionalType && type_eq(Match(lhs_t, OptionalType)->type, rhs_t)) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + check_none(lhs_t, "lhs"), " ? ", compile(env, binop.rhs), " : ", + optional_into_nonnone(lhs_t, "lhs"), "; })"); + } else if (rhs_t->tag == BoolType) { + return CORD_all("((!", check_none(lhs_t, compile(env, binop.lhs)), ") || ", compile(env, binop.rhs), ")"); + } else { + code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + } + + CORD lhs = compile_to_type(env, binop.lhs, overall_t); + CORD rhs = compile_to_type(env, binop.rhs, overall_t); + + switch (ast->tag) { + case Power: { + if (overall_t->tag != NumType) + code_err(ast, "Exponentiation is only supported for Num types, not ", type_to_str(overall_t)); + if (overall_t->tag == NumType && Match(overall_t, NumType)->bits == TYPE_NBITS32) + return CORD_all("powf(", lhs, ", ", rhs, ")"); + else + return CORD_all("pow(", lhs, ", ", rhs, ")"); + } + case Multiply: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " * ", rhs, ")"); + } + case Divide: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " / ", rhs, ")"); + } + case Mod: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " % ", rhs, ")"); + } + case Mod1: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("((((", lhs, ")-1) % (", rhs, ")) + 1)"); + } + case Plus: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " + ", rhs, ")"); + } + case Minus: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " - ", rhs, ")"); + } + case LeftShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " << ", rhs, ")"); + } + case RightShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " >> ", rhs, ")"); + } + case UnsignedLeftShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", compile_type(overall_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " << ", rhs, ")"); + } + case UnsignedRightShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", compile_type(overall_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " >> ", rhs, ")"); + } + case And: { + if (overall_t->tag == BoolType) + return CORD_all("(", lhs, " && ", rhs, ")"); + else if (overall_t->tag == IntType || overall_t->tag == ByteType) + return CORD_all("(", lhs, " & ", rhs, ")"); + else + code_err(ast, "The 'and' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + case Compare: { + return CORD_all("generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(overall_t), ")"); + } + case Or: { + if (overall_t->tag == BoolType) { + return CORD_all("(", lhs, " || ", rhs, ")"); + } else if (overall_t->tag == IntType || overall_t->tag == ByteType) { + return CORD_all("(", lhs, " | ", rhs, ")"); + } else { + code_err(ast, "The 'or' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + } + case Xor: { + // TODO: support optional values in `xor` expressions + if (overall_t->tag == BoolType || overall_t->tag == IntType || overall_t->tag == ByteType) + return CORD_all("(", lhs, " ^ ", rhs, ")"); + else + code_err(ast, "The 'xor' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + case Concat: { + if (overall_t == PATH_TYPE) + return CORD_all("Path$concat(", lhs, ", ", rhs, ")"); + switch (overall_t->tag) { + case TextType: { + return CORD_all("Text$concat(", lhs, ", ", rhs, ")"); + } + case ArrayType: { + return CORD_all("Array$concat(", lhs, ", ", rhs, ", sizeof(", compile_type(Match(overall_t, ArrayType)->item_type), "))"); + } + default: + code_err(ast, "Concatenation isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + } + default: errx(1, "Not a valid binary operation: ", ast_to_xml_str(ast)); + } +} + PUREFUNC CORD compile_unsigned_type(type_t *t) { if (t->tag != IntType) @@ -570,7 +798,7 @@ CORD compile_type(type_t *t) } } -static CORD compile_lvalue(env_t *env, ast_t *ast) +CORD compile_lvalue(env_t *env, ast_t *ast) { if (!can_be_mutated(env, ast)) { if (ast->tag == Index) { @@ -756,7 +984,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) CORD code = CORD_EMPTY; for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - ast_t *comparison = WrapAST(clause->pattern, BinaryOp, .lhs=subject, .op=BINOP_EQ, .rhs=clause->pattern); + ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs=subject, .rhs=clause->pattern); (void)get_type(env, comparison); if (code != CORD_EMPTY) code = CORD_all(code, "else "); @@ -865,7 +1093,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) if (streq(varname, "_")) return compile_statement(env, WrapAST(ast, DocTest, .expr=decl->value, .expected=test->expected, .skip_source=test->skip_source)); CORD var = CORD_all("_$", Match(decl->var, Var)->name); - type_t *t = get_type(env, decl->value); + type_t *t = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (!t) code_err(decl->value, "I couldn't figure out the type of this value!"); CORD val_code = compile_maybe_incref(env, decl->value, t); if (t->tag == FunctionType) { @@ -916,21 +1144,20 @@ static CORD _compile_statement(env_t *env, ast_t *ast) test_code = CORD_all(test_code, "$1; })"); } - } else if (test->expr->tag == UpdateAssign) { - type_t *lhs_t = get_type(env, Match(test->expr, UpdateAssign)->lhs); - auto update = Match(test->expr, UpdateAssign); - - if (update->lhs->tag == Index) { - type_t *indexed = value_type(get_type(env, Match(update->lhs, Index)->indexed)); + } else if (is_update_assignment(test->expr)) { + binary_operands_t update = BINARY_OPERANDS(test->expr); + type_t *lhs_t = get_type(env, update.lhs); + if (update.lhs->tag == Index) { + type_t *indexed = value_type(get_type(env, Match(update.lhs, Index)->indexed)); if (indexed->tag == TableType && Match(indexed, TableType)->default_value == NULL) - code_err(update->lhs, "Update assignments are not currently supported for tables"); + code_err(update.lhs, "Update assignments are not currently supported for tables"); } - ast_t *update_var = WrapAST(ast, UpdateAssign, - .lhs=WrapAST(update->lhs, InlineCCode, .code="(*expr)", .type=lhs_t), - .op=update->op, .rhs=update->rhs); + ast_t *update_var = new(ast_t); + *update_var = *ast; + update_var->__data.PlusUpdate.lhs = WrapAST(update.lhs, InlineCCode, .code="(*expr)", .type=lhs_t); // UNSAFE test_code = CORD_all("({", - compile_declaration(Type(PointerType, lhs_t), "expr"), " = &(", compile_lvalue(env, update->lhs), "); ", + compile_declaration(Type(PointerType, lhs_t), "expr"), " = &(", compile_lvalue(env, update.lhs), "); ", compile_statement(env, update_var), "; *expr; })"); expr_t = lhs_t; } else if (expr_t->tag == VoidType || expr_t->tag == AbortType || expr_t->tag == ReturnType) { @@ -939,14 +1166,10 @@ static CORD _compile_statement(env_t *env, ast_t *ast) test_code = compile(env, test->expr); } if (test->expected) { - type_t *expected_type = get_type(env, test->expected); - if (!type_eq(expr_t, expected_type)) - code_err(ast, "The type on the top of this test (", type_to_str(expr_t), - ") is different from the type on the bottom (", type_to_str(expected_type), ")"); return CORD_asprintf( "%rtest(%r, %r, %r, %ld, %ld);", setup, test_code, - compile(env, test->expected), + compile_to_type(env, test->expected, expr_t), compile_type_info(expr_t), (int64_t)(test->expr->start - test->expr->file->text), (int64_t)(test->expr->end - test->expr->file->text)); @@ -965,7 +1188,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) if (streq(name, "_")) { // Explicit discard return CORD_all("(void)", compile(env, decl->value), ";"); } else { - type_t *t = get_type(env, decl->value); + type_t *t = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType) code_err(ast, "You can't declare a variable with a ", type_to_str(t), " value"); @@ -1011,155 +1234,44 @@ static CORD _compile_statement(env_t *env, ast_t *ast) } return CORD_cat(code, "\n}"); } - case UpdateAssign: { - auto update = Match(ast, UpdateAssign); - - if (update->lhs->tag == Index) { - type_t *indexed = value_type(get_type(env, Match(update->lhs, Index)->indexed)); - if (indexed->tag == TableType && Match(indexed, TableType)->default_value == NULL) - code_err(update->lhs, "Update assignments are not currently supported for tables"); - } - - if (!is_idempotent(update->lhs)) { - type_t *lhs_t = get_type(env, update->lhs); - return CORD_all("{ ", compile_declaration(Type(PointerType, lhs_t), "update_lhs"), " = &", - compile_lvalue(env, update->lhs), ";\n", - "*update_lhs = ", compile(env, WrapAST(ast, BinaryOp, - .lhs=WrapAST(update->lhs, InlineCCode, .code="(*update_lhs)", .type=lhs_t), - .op=update->op, .rhs=update->rhs)), "; }"); - } - + case PlusUpdate: { + auto update = Match(ast, PlusUpdate); type_t *lhs_t = get_type(env, update->lhs); - CORD lhs = compile_lvalue(env, update->lhs); - - if (update->lhs->tag == Index && value_type(get_type(env, Match(update->lhs, Index)->indexed))->tag == TableType) { - ast_t *lhs_placeholder = WrapAST(update->lhs, InlineCCode, .code="(*lhs)", .type=lhs_t); - CORD method_call = compile_math_method(env, update->op, lhs_placeholder, update->rhs, lhs_t); - if (method_call) - return CORD_all("{ ", compile_declaration(Type(PointerType, .pointed=lhs_t), "lhs"), " = &", lhs, "; *lhs = ", method_call, "; }"); - } else { - CORD method_call = compile_math_method(env, update->op, update->lhs, update->rhs, lhs_t); - if (method_call) - return CORD_all(lhs, " = ", method_call, ";"); - } - - CORD rhs = compile(env, update->rhs); - - type_t *rhs_t = get_type(env, update->rhs); - if (update->rhs->tag == Int && is_numeric_type(non_optional(lhs_t))) { - rhs = compile_int_to_type(env, update->rhs, lhs_t); - } else if (!promote(env, update->rhs, &rhs, rhs_t, lhs_t)) { - code_err(ast, "I can't do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - - bool lhs_is_optional_num = (lhs_t->tag == OptionalType && Match(lhs_t, OptionalType)->type && Match(lhs_t, OptionalType)->type->tag == NumType); - switch (update->op) { - case BINOP_MULT: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a multiply assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - if (lhs_t->tag == NumType) { // 0*INF -> NaN, needs checking - return CORD_asprintf("%r *= %r;\n" - "if (isnan(%r))\n" - "fail_source(%r, %ld, %ld, \"This update assignment created a NaN value (probably multiplying zero with infinity), but the type is not optional!\");\n", - lhs, rhs, lhs, - CORD_quoted(ast->file->filename), - (long)(ast->start - ast->file->text), - (long)(ast->end - ast->file->text)); - } - return CORD_all(lhs, " *= ", rhs, ";"); - case BINOP_DIVIDE: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a divide assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - if (lhs_t->tag == NumType) { // 0/0 or INF/INF -> NaN, needs checking - return CORD_asprintf("%r /= %r;\n" - "if (isnan(%r))\n" - "fail_source(%r, %ld, %ld, \"This update assignment created a NaN value (probably 0/0 or INF/INF), but the type is not optional!\");\n", - lhs, rhs, lhs, - CORD_quoted(ast->file->filename), - (long)(ast->start - ast->file->text), - (long)(ast->end - ast->file->text)); - } - return CORD_all(lhs, " /= ", rhs, ";"); - case BINOP_MOD: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a mod assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " = ", lhs, " % ", rhs); - case BINOP_MOD1: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a mod assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " = (((", lhs, ") - 1) % ", rhs, ") + 1;"); - case BINOP_PLUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do an addition assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " += ", rhs, ";"); - case BINOP_MINUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a subtraction assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " -= ", rhs, ";"); - case BINOP_POWER: { - if (lhs_t->tag != NumType && !lhs_is_optional_num) - code_err(ast, "'^=' is only supported for Num types"); - if (lhs_t->tag == NumType && Match(lhs_t, NumType)->bits == TYPE_NBITS32) - return CORD_all(lhs, " = powf(", lhs, ", ", rhs, ");"); - else - return CORD_all(lhs, " = pow(", lhs, ", ", rhs, ");"); - } - case BINOP_LSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " <<= ", rhs, ";"); - case BINOP_RSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " >>= ", rhs, ";"); - case BINOP_ULSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("{ ", compile_unsigned_type(lhs_t), " *dest = (void*)&(", lhs, "); *dest <<= ", rhs, "; }"); - case BINOP_URSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("{ ", compile_unsigned_type(lhs_t), " *dest = (void*)&(", lhs, "); *dest >>= ", rhs, "; }"); - case BINOP_AND: { - if (lhs_t->tag == BoolType) - return CORD_all("if (", lhs, ") ", lhs, " = ", rhs, ";"); - else if (lhs_t->tag == IntType || lhs_t->tag == ByteType) - return CORD_all(lhs, " &= ", rhs, ";"); - else if (lhs_t->tag == OptionalType) - return CORD_all("if (!(", check_none(lhs_t, lhs), ")) ", lhs, " = ", promote_to_optional(rhs_t, rhs), ";"); - else - code_err(ast, "'or=' is not implemented for ", type_to_str(lhs_t), " types"); - } - case BINOP_OR: { - if (lhs_t->tag == BoolType) - return CORD_all("if (!(", lhs, ")) ", lhs, " = ", rhs, ";"); - else if (lhs_t->tag == IntType || lhs_t->tag == ByteType) - return CORD_all(lhs, " |= ", rhs, ";"); - else if (lhs_t->tag == OptionalType) - return CORD_all("if (", check_none(lhs_t, lhs), ") ", lhs, " = ", promote_to_optional(rhs_t, rhs), ";"); - else - code_err(ast, "'or=' is not implemented for ", type_to_str(lhs_t), " types"); - } - case BINOP_XOR: - if (lhs_t->tag != IntType && lhs_t->tag != BoolType && lhs_t->tag != ByteType) - code_err(ast, "I can't do an xor assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " ^= ", rhs, ";"); - case BINOP_CONCAT: { - if (lhs_t->tag == TextType) { - return CORD_all(lhs, " = Texts(", lhs, ", ", rhs, ");"); - } else if (lhs_t->tag == ArrayType) { - CORD padded_item_size = CORD_all("sizeof(", compile_type(Match(lhs_t, ArrayType)->item_type), ")"); - // arr ++= [...] - if (update->lhs->tag == Var) - return CORD_all("Array$insert_all(&", lhs, ", ", rhs, ", I(0), ", padded_item_size, ");"); - else - return CORD_all(lhs, " = Array$concat(", lhs, ", ", rhs, ", ", padded_item_size, ");"); - } else { - code_err(ast, "'++=' is not implemented for ", type_to_str(lhs_t), " types"); - } - } - default: code_err(ast, "Update assignments are not implemented for this operation"); - } + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " += ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case MinusUpdate: { + auto update = Match(ast, MinusUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " -= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case MultiplyUpdate: { + auto update = Match(ast, MultiplyUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " *= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case DivideUpdate: { + auto update = Match(ast, DivideUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " /= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case ModUpdate: { + auto update = Match(ast, ModUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " %= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case PowerUpdate: case Mod1Update: case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: + case RightShiftUpdate: case UnsignedRightShiftUpdate: case AndUpdate: case OrUpdate: case XorUpdate: { + return compile_update_assignment(env, ast); } case StructDef: case EnumDef: case LangDef: case Extend: case FunctionDef: case ConvertDef: { return CORD_EMPTY; @@ -1839,22 +1951,20 @@ CORD compile_to_type(env_t *env, ast_t *ast, type_t *t) case TYPE_NBITS32: return CORD_asprintf("N32(%.10g)", n); default: code_err(ast, "This is not a valid number bit width"); } - } else if (ast->tag == None && Match(ast, None)->type == NULL) { + } else if (ast->tag == None) { return compile_none(t); - } else if (t->tag == ArrayType && ast->tag == Array && !Match(ast, Array)->item_type && !Match(ast, Array)->items) { - return compile(env, ast); + } else if (t->tag == PointerType && (ast->tag == HeapAllocate || ast->tag == StackReference)) { + return compile_typed_allocation(env, ast, t); + } else if (t->tag == ArrayType && ast->tag == Array) { + return compile_typed_array(env, ast, t); } else if (t->tag == TableType && ast->tag == Table) { - auto table = Match(ast, Table); - if (!table->key_type && !table->value_type && !table->default_value && !table->fallback && !table->entries) - return compile(env, ast); + return compile_typed_table(env, ast, t); } else if (t->tag == SetType && ast->tag == Set) { - auto set = Match(ast, Set); - if (!set->item_type && !set->items) - return compile(env, ast); + return compile_typed_set(env, ast, t); } else if (t->tag == SetType && ast->tag == Table) { auto table = Match(ast, Table); - if (!table->key_type && !table->value_type && !table->default_value && !table->fallback && !table->entries) - return compile(env, ast); + if (!table->default_value && !table->fallback && !table->entries) + return compile_to_type(env, WrapAST(ast, Set), t); } type_t *actual = get_type(env, ast); @@ -1869,6 +1979,193 @@ CORD compile_to_type(env_t *env, ast_t *ast, type_t *t) return code; } +CORD compile_typed_array(env_t *env, ast_t *ast, type_t *array_type) +{ + auto array = Match(ast, Array); + if (!array->items) + return "(Array_t){.length=0}"; + + type_t *item_type = Match(array_type, ArrayType)->item_type; + + int64_t n = 0; + for (ast_list_t *item = array->items; item; item = item->next) { + ++n; + if (item->ast->tag == Comprehension) + goto array_comprehension; + } + + { + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; + CORD code = CORD_all("TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n)); + for (ast_list_t *item = array->items; item; item = item->next) { + code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); + } + return CORD_cat(code, ")"); + } + + array_comprehension: + { + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); + static int64_t comp_num = 1; + const char *comprehension_name = String("arr$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=array_type, .is_stack=true)); + Closure_t comp_action = {.fn=add_to_array_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + CORD code = CORD_all("({ Array_t ", comprehension_name, " = {};"); + // set_binding(scope, comprehension_name, array_type, comprehension_name); + for (ast_list_t *item = array->items; item; item = item->next) { + if (item->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_array_comprehension(item->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_set(env_t *env, ast_t *ast, type_t *set_type) +{ + auto set = Match(ast, Set); + if (!set->items) + return "((Table_t){})"; + + type_t *item_type = Match(set_type, SetType)->item_type; + + size_t n = 0; + for (ast_list_t *item = set->items; item; item = item->next) { + ++n; + if (item->ast->tag == Comprehension) + goto set_comprehension; + } + + { // No comprehension: + CORD code = CORD_all("Set(", + compile_type(item_type), ", ", + compile_type_info(item_type)); + CORD_appendf(&code, ", %zu", n); + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; + for (ast_list_t *item = set->items; item; item = item->next) { + code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); + } + return CORD_cat(code, ")"); + } + + set_comprehension: + { + static int64_t comp_num = 1; + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); + const char *comprehension_name = String("set$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=set_type, .is_stack=true)); + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {};"); + Closure_t comp_action = {.fn=add_to_set_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + for (ast_list_t *item = set->items; item; item = item->next) { + if (item->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_set_comprehension(item->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_table(env_t *env, ast_t *ast, type_t *table_type) +{ + auto table = Match(ast, Table); + if (!table->entries) { + CORD code = "((Table_t){"; + if (table->fallback) + code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback),")"); + return CORD_cat(code, "})"); + } + + type_t *key_t = Match(table_type, TableType)->key_type; + type_t *value_t = Match(table_type, TableType)->value_type; + + if (value_t->tag == OptionalType) + code_err(ast, "Tables whose values are optional (", type_to_str(value_t), ") are not currently supported."); + + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) + goto table_comprehension; + } + + { // No comprehension: + env_t *key_scope = key_t->tag == EnumType ? with_enum_scope(env, key_t) : env; + env_t *value_scope = value_t->tag == EnumType ? with_enum_scope(env, value_t) : env; + CORD code = CORD_all("Table(", + compile_type(key_t), ", ", + compile_type(value_t), ", ", + compile_type_info(key_t), ", ", + compile_type_info(value_t)); + if (table->fallback) + code = CORD_all(code, ", /*fallback:*/ heap(", compile(env, table->fallback), ")"); + else + code = CORD_all(code, ", /*fallback:*/ NULL"); + + size_t n = 0; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) + ++n; + CORD_appendf(&code, ", %zu", n); + + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + auto e = Match(entry->ast, TableEntry); + code = CORD_all(code, ",\n\t{", compile_to_type(key_scope, e->key, key_t), ", ", + compile_to_type(value_scope, e->value, value_t), "}"); + } + return CORD_cat(code, ")"); + } + + table_comprehension: + { + static int64_t comp_num = 1; + env_t *scope = fresh_scope(env); + const char *comprehension_name = String("table$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=table_type, .is_stack=true)); + + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {"); + if (table->fallback) + code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback), "), "); + + code = CORD_cat(code, "};"); + + Closure_t comp_action = {.fn=add_to_table_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_table_comprehension(entry->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_allocation(env_t *env, ast_t *ast, type_t *pointer_type) +{ + // TODO: for constructors, do new(T, ...) instead of heap((T){...}) + type_t *pointed = Match(pointer_type, PointerType)->pointed; + switch (ast->tag) { + case HeapAllocate: { + return CORD_asprintf("heap(%r)", compile_to_type(env, Match(ast, HeapAllocate)->value, pointed)); + } + case StackReference: { + ast_t *subject = Match(ast, StackReference)->value; + if (can_be_mutated(env, subject) && type_eq(pointed, get_type(env, subject))) + return CORD_all("(&", compile_lvalue(env, subject), ")"); + else + return CORD_all("stack(", compile_to_type(env, subject, pointed), ")"); + } + default: code_err(ast, "Not an allocation!"); + } +} + CORD compile_int_to_type(env_t *env, ast_t *ast, type_t *target) { if (ast->tag != Int) { @@ -2027,98 +2324,6 @@ CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t return code; } -CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) -{ - // Math methods are things like plus(), minus(), etc. If we don't find a - // matching method, return CORD_EMPTY. - const char *method_name = binop_method_names[op]; - if (!method_name) - return CORD_EMPTY; - - type_t *lhs_t = get_type(env, lhs); - type_t *rhs_t = get_type(env, rhs); -#define binding_works(b, lhs_t, rhs_t, ret_t) \ - (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ - (type_eq(fn->ret, ret_t) \ - && (fn->args && type_eq(fn->args->type, lhs_t)) \ - && (fn->args->next && can_promote(rhs_t, fn->args->next->type)) \ - && (!required_type || type_eq(required_type, fn->ret))); })) - arg_ast_t *args = new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs)); - switch (op) { - case BINOP_MULT: { - if (type_eq(lhs_t, rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } else if (lhs_t->tag == NumType || lhs_t->tag == IntType || lhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, rhs, "scaled_by"); - if (binding_works(b, rhs_t, lhs_t, rhs_t)) { - REVERSE_LIST(args); - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - } else if (rhs_t->tag == NumType || rhs_t->tag == IntType|| rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, "scaled_by"); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_OR: case BINOP_CONCAT: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$with(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_AND: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$overlap(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_MINUS: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$without(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_PLUS: case BINOP_XOR: { - fallthrough: - if (type_eq(lhs_t, rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); - } - break; - } - case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { - if (is_numeric_type(rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: { - if (rhs_t->tag == IntType || rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_POWER: { - if (rhs_t->tag == NumType || rhs_t->tag == IntType || rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - default: break; - } - return CORD_EMPTY; -} - CORD compile_string_literal(CORD literal) { CORD code = "\""; @@ -2209,19 +2414,19 @@ CORD compile_none(type_t *t) } } -static ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject) +ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject) { auto e = Match(entry, TableEntry); return WrapAST(entry, MethodCall, .name="set", .self=subject, .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); } -static ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject) +ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject) { return WrapAST(item, MethodCall, .name="insert", .self=subject, .args=new(arg_ast_t, .value=item)); } -static ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject) +ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject) { return WrapAST(item, MethodCall, .name="add", .self=subject, .args=new(arg_ast_t, .value=item)); } @@ -2230,10 +2435,7 @@ CORD compile(env_t *env, ast_t *ast) { switch (ast->tag) { case None: { - if (!Match(ast, None)->type) - code_err(ast, "This 'none' needs to specify what type it is using `none:Type` syntax"); - type_t *t = parse_type_ast(env, Match(ast, None)->type); - return compile_none(t); + code_err(ast, "This 'none' needs to specify what type it is using `none:Type` syntax"); } case Bool: return Match(ast, Bool)->b ? "yes" : "no"; case Var: { @@ -2303,14 +2505,8 @@ CORD compile(env_t *env, ast_t *ast) code_err(ast, "I don't know how to get the negative value of type ", type_to_str(t)); } - // TODO: for constructors, do new(T, ...) instead of heap((T){...}) - case HeapAllocate: return CORD_asprintf("heap(%r)", compile(env, Match(ast, HeapAllocate)->value)); - case StackReference: { - ast_t *subject = Match(ast, StackReference)->value; - if (can_be_mutated(env, subject)) - return CORD_all("(&", compile_lvalue(env, subject), ")"); - else - return CORD_all("stack(", compile(env, subject), ")"); + case HeapAllocate: case StackReference: { + return compile_typed_allocation(env, ast, get_type(env, ast)); } case Optional: { ast_t *value = Match(ast, Optional)->value; @@ -2329,264 +2525,67 @@ CORD compile(env_t *env, ast_t *ast) (long)(value->end - value->file->text)), optional_into_nonnone(t, "opt"), "; })"); } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - CORD method_call = compile_math_method(env, binop->op, binop->lhs, binop->rhs, NULL); - if (method_call != CORD_EMPTY) - return method_call; - - type_t *lhs_t = get_type(env, binop->lhs); - type_t *rhs_t = get_type(env, binop->rhs); - - if (binop->op == BINOP_OR && lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - "if (", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop->rhs), " ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? ", compile(env, binop->rhs), " : lhs; })"); - } else if (rhs_t->tag != OptionalType && type_eq(Match(lhs_t, OptionalType)->type, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? ", compile(env, binop->rhs), " : ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == BoolType) { - return CORD_all("((!", check_none(lhs_t, compile(env, binop->lhs)), ") || ", compile(env, binop->rhs), ")"); - } else { - code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - } else if (binop->op == BINOP_AND && lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - "if (!", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop->rhs), " ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? lhs : ", compile(env, binop->rhs), "; })"); - } else if (rhs_t->tag == BoolType) { - return CORD_all("((!", check_none(lhs_t, compile(env, binop->lhs)), ") && ", compile(env, binop->rhs), ")"); - } else { - code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - } - - type_t *non_optional_lhs = lhs_t; - if (lhs_t->tag == OptionalType) non_optional_lhs = Match(lhs_t, OptionalType)->type; - type_t *non_optional_rhs = rhs_t; - if (rhs_t->tag == OptionalType) non_optional_rhs = Match(rhs_t, OptionalType)->type; - - if (!non_optional_lhs && !non_optional_rhs) - code_err(ast, "Both of these values do not specify a type"); - else if (!non_optional_lhs) - non_optional_lhs = non_optional_rhs; - else if (!non_optional_rhs) - non_optional_rhs = non_optional_lhs; - - bool lhs_is_optional_num = (lhs_t->tag == OptionalType && non_optional_lhs->tag == NumType); - if (lhs_is_optional_num) - lhs_t = Match(lhs_t, OptionalType)->type; - bool rhs_is_optional_num = (rhs_t->tag == OptionalType && non_optional_rhs->tag == NumType); - if (rhs_is_optional_num) - rhs_t = Match(rhs_t, OptionalType)->type; + case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case Concat: + case LeftShift: case UnsignedLeftShift: case RightShift: case UnsignedRightShift: case And: case Or: case Xor: { + return compile_binary_op(env, ast); + } + case Equals: case NotEquals: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + type_t *operand_t; CORD lhs, rhs; - if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) { - lhs = compile_int_to_type(env, binop->lhs, rhs_t); - lhs_t = rhs_t; - rhs = compile(env, binop->rhs); - } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) { - lhs = compile(env, binop->lhs); - rhs = compile_int_to_type(env, binop->rhs, lhs_t); - rhs_t = lhs_t; + if (can_compile_to_type(env, binop.rhs, lhs_t)) { + lhs = compile(env, binop.lhs); + rhs = compile_to_type(env, binop.rhs, lhs_t); + operand_t = lhs_t; + } else if (can_compile_to_type(env, binop.lhs, rhs_t)) { + rhs = compile(env, binop.rhs); + lhs = compile_to_type(env, binop.lhs, rhs_t); + operand_t = rhs_t; } else { - lhs = compile(env, binop->lhs); - rhs = compile(env, binop->rhs); + code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + + switch (operand_t->tag) { + case BigIntType: + return CORD_all(ast->tag == Equals ? CORD_EMPTY : "!", "Int$equal_value(", lhs, ", ", rhs, ")"); + case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: + return CORD_all("(", lhs, ast->tag == Equals ? " == " : " != ", rhs, ")"); + default: + return CORD_asprintf(ast->tag == Equals ? CORD_EMPTY : "!", + "generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); } + } + case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: { + binary_operands_t cmp = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, cmp.lhs); + type_t *rhs_t = get_type(env, cmp.rhs); type_t *operand_t; - if (promote(env, binop->rhs, &rhs, rhs_t, lhs_t)) + CORD lhs, rhs; + if (can_compile_to_type(env, cmp.rhs, lhs_t)) { + lhs = compile(env, cmp.lhs); + rhs = compile_to_type(env, cmp.rhs, lhs_t); operand_t = lhs_t; - else if (promote(env, binop->lhs, &lhs, lhs_t, rhs_t)) + } else if (can_compile_to_type(env, cmp.lhs, rhs_t)) { + rhs = compile(env, cmp.rhs); + lhs = compile_to_type(env, cmp.lhs, rhs_t); operand_t = rhs_t; - else - code_err(ast, "I can't do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - - switch (binop->op) { - case BINOP_POWER: { - if (operand_t->tag != NumType) - code_err(ast, "Exponentiation is only supported for Num types, not ", type_to_str(operand_t)); - if (operand_t->tag == NumType && Match(operand_t, NumType)->bits == TYPE_NBITS32) - return CORD_all("powf(", lhs, ", ", rhs, ")"); - else - return CORD_all("pow(", lhs, ", ", rhs, ")"); - } - case BINOP_MULT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " * ", rhs, ")"); - } - case BINOP_DIVIDE: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " / ", rhs, ")"); - } - case BINOP_MOD: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " % ", rhs, ")"); - } - case BINOP_MOD1: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("((((", lhs, ")-1) % (", rhs, ")) + 1)"); - } - case BINOP_PLUS: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " + ", rhs, ")"); - } - case BINOP_MINUS: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " - ", rhs, ")"); - } - case BINOP_LSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " << ", rhs, ")"); - } - case BINOP_RSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " >> ", rhs, ")"); - } - case BINOP_ULSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", compile_type(operand_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " << ", rhs, ")"); - } - case BINOP_URSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", compile_type(operand_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " >> ", rhs, ")"); - } - case BINOP_EQ: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("Int$equal_value(", lhs, ", ", rhs, ")"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " == ", rhs, ")"); - default: - return CORD_asprintf("generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_NE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("!Int$equal_value(", lhs, ", ", rhs, ")"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("!generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " != ", rhs, ")"); - default: - return CORD_asprintf("!generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_LT: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") < 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) < 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " < ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) < 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_LE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") <= 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) <= 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " <= ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) <= 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_GT: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") > 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) > 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " > ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) > 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_GE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") >= 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) >= 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " >= ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) >= 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_AND: { - if (operand_t->tag == BoolType) - return CORD_all("(", lhs, " && ", rhs, ")"); - else if (operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " & ", rhs, ")"); - else - code_err(ast, "The 'and' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_CMP: { - if (lhs_is_optional_num || rhs_is_optional_num) - operand_t = Type(OptionalType, operand_t); - return CORD_all("generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(operand_t), ")"); - } - case BINOP_OR: { - if (operand_t->tag == BoolType) - return CORD_all("(", lhs, " || ", rhs, ")"); - else if (operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " | ", rhs, ")"); - else - code_err(ast, "The 'or' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_XOR: { - // TODO: support optional values in `xor` expressions - if (operand_t->tag == BoolType || operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " ^ ", rhs, ")"); - else - code_err(ast, "The 'xor' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_CONCAT: { - if (operand_t == PATH_TYPE) - return CORD_all("Path$concat(", lhs, ", ", rhs, ")"); - switch (operand_t->tag) { - case TextType: { - return CORD_all("Text$concat(", lhs, ", ", rhs, ")"); - } - case ArrayType: { - return CORD_all("Array$concat(", lhs, ", ", rhs, ", sizeof(", compile_type(Match(operand_t, ArrayType)->item_type), "))"); - } - default: - code_err(ast, "Concatenation isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } + } else { + code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } - default: break; + + const char *op = binop_operator(ast->tag); + switch (operand_t->tag) { + case BigIntType: + return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") ", op, " 0)"); + case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: + return CORD_all("(", lhs, " ", op, " ", rhs, ")"); + default: + return CORD_all("(generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(Type(OptionalType, operand_t)), ") ", op, " 0)"); } - code_err(ast, "unimplemented binop"); } case TextLiteral: { CORD literal = Match(ast, TextLiteral)->cord; @@ -2715,44 +2714,7 @@ CORD compile(env_t *env, ast_t *ast) return "(Array_t){.length=0}"; type_t *array_type = get_type(env, ast); - type_t *item_type = Match(array_type, ArrayType)->item_type; - - int64_t n = 0; - for (ast_list_t *item = array->items; item; item = item->next) { - ++n; - if (item->ast->tag == Comprehension) - goto array_comprehension; - } - - { - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; - CORD code = CORD_all("TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n)); - for (ast_list_t *item = array->items; item; item = item->next) { - code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); - } - return CORD_cat(code, ")"); - } - - array_comprehension: - { - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); - static int64_t comp_num = 1; - const char *comprehension_name = String("arr$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=array_type, .is_stack=true)); - Closure_t comp_action = {.fn=add_to_array_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - CORD code = CORD_all("({ Array_t ", comprehension_name, " = {};"); - // set_binding(scope, comprehension_name, array_type, comprehension_name); - for (ast_list_t *item = array->items; item; item = item->next) { - if (item->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_array_comprehension(item->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } + return compile_typed_array(env, ast, array_type); } case Table: { auto table = Match(ast, Table); @@ -2764,69 +2726,7 @@ CORD compile(env_t *env, ast_t *ast) } type_t *table_type = get_type(env, ast); - type_t *key_t = Match(table_type, TableType)->key_type; - type_t *value_t = Match(table_type, TableType)->value_type; - - if (value_t->tag == OptionalType) - code_err(ast, "Tables whose values are optional (", type_to_str(value_t), ") are not currently supported."); - - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - if (entry->ast->tag == Comprehension) - goto table_comprehension; - } - - { // No comprehension: - env_t *key_scope = key_t->tag == EnumType ? with_enum_scope(env, key_t) : env; - env_t *value_scope = value_t->tag == EnumType ? with_enum_scope(env, value_t) : env; - CORD code = CORD_all("Table(", - compile_type(key_t), ", ", - compile_type(value_t), ", ", - compile_type_info(key_t), ", ", - compile_type_info(value_t)); - if (table->fallback) - code = CORD_all(code, ", /*fallback:*/ heap(", compile(env, table->fallback), ")"); - else - code = CORD_all(code, ", /*fallback:*/ NULL"); - - size_t n = 0; - for (ast_list_t *entry = table->entries; entry; entry = entry->next) - ++n; - CORD_appendf(&code, ", %zu", n); - - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - auto e = Match(entry->ast, TableEntry); - code = CORD_all(code, ",\n\t{", compile_to_type(key_scope, e->key, key_t), ", ", - compile_to_type(value_scope, e->value, value_t), "}"); - } - return CORD_cat(code, ")"); - } - - table_comprehension: - { - static int64_t comp_num = 1; - env_t *scope = fresh_scope(env); - const char *comprehension_name = String("table$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=table_type, .is_stack=true)); - - CORD code = CORD_all("({ Table_t ", comprehension_name, " = {"); - if (table->fallback) - code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback), "), "); - - code = CORD_cat(code, "};"); - - Closure_t comp_action = {.fn=add_to_table_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - if (entry->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_table_comprehension(entry->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } - + return compile_typed_table(env, ast, table_type); } case Set: { auto set = Match(ast, Set); @@ -2834,47 +2734,7 @@ CORD compile(env_t *env, ast_t *ast) return "((Table_t){})"; type_t *set_type = get_type(env, ast); - type_t *item_type = Match(set_type, SetType)->item_type; - - size_t n = 0; - for (ast_list_t *item = set->items; item; item = item->next) { - ++n; - if (item->ast->tag == Comprehension) - goto set_comprehension; - } - - { // No comprehension: - CORD code = CORD_all("Set(", - compile_type(item_type), ", ", - compile_type_info(item_type)); - CORD_appendf(&code, ", %zu", n); - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; - for (ast_list_t *item = set->items; item; item = item->next) { - code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); - } - return CORD_cat(code, ")"); - } - - set_comprehension: - { - static int64_t comp_num = 1; - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); - const char *comprehension_name = String("set$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=set_type, .is_stack=true)); - CORD code = CORD_all("({ Table_t ", comprehension_name, " = {};"); - Closure_t comp_action = {.fn=add_to_set_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - for (ast_list_t *item = set->items; item; item = item->next) { - if (item->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_set_comprehension(item->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } - + return compile_typed_set(env, ast, set_type); } case Comprehension: { ast_t *base = Match(ast, Comprehension)->expr; @@ -3040,8 +2900,7 @@ CORD compile(env_t *env, ast_t *ast) self = compile_to_pointer_depth(env, call->self, 0, false); arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE, .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)), - .default_val=FakeAST(None, .type=new(type_ast_t, .tag=ArrayTypeAST, - .__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))), + .default_val=FakeAST(None), .next=new(arg_t, .name="random", .type=random_num_type, .default_val=none_rng))); return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")"); @@ -3473,7 +3332,7 @@ CORD compile(env_t *env, ast_t *ast) } case Reduction: { auto reduction = Match(ast, Reduction); - binop_e op = reduction->op; + ast_e op = reduction->op; type_t *iter_t = get_type(env, reduction->iter); type_t *item_t = get_iterated_type(iter_t); @@ -3484,7 +3343,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); env_t *body_scope = for_scope(env, loop); - if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) { + if (op == Equals || op == NotEquals || op == LessThan || op == LessThanOrEquals || op == GreaterThan || op == GreaterThanOrEquals) { // Chained comparisons like ==, <, etc. CORD code = CORD_all( "({ // Reduction:\n", @@ -3492,7 +3351,8 @@ CORD compile(env_t *env, ast_t *ast) "OptionalBool_t result = NONE_BOOL;\n" ); - ast_t *comparison = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .rhs=item); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .__data.Plus.rhs=item); body->__data.InlineCCode.code = CORD_all( "if (result == NONE_BOOL) {\n" " prev = ", compile(body_scope, item), ";\n" @@ -3507,9 +3367,9 @@ CORD compile(env_t *env, ast_t *ast) "}\n"); code = CORD_all(code, compile_statement(env, loop), "\nresult;})"); return code; - } else if (op == BINOP_MIN || op == BINOP_MAX) { + } else if (op == Min || op == Max) { // Min/max: - const char *superlative = op == BINOP_MIN ? "min" : "max"; + const char *superlative = op == Min ? "min" : "max"; CORD code = CORD_all( "({ // Reduction:\n", compile_declaration(item_t, superlative), ";\n" @@ -3517,17 +3377,18 @@ CORD compile(env_t *env, ast_t *ast) ); CORD item_code = compile(body_scope, item); - binop_e cmp_op = op == BINOP_MIN ? BINOP_LT : BINOP_GT; + ast_e cmp_op = op == Min ? LessThan : GreaterThan; if (reduction->key) { env_t *key_scope = fresh_scope(env); set_binding(key_scope, "$", item_t, item_code); type_t *key_type = get_type(key_scope, reduction->key); - const char *superlative_key = op == BINOP_MIN ? "min_key" : "max_key"; + const char *superlative_key = op == Min ? "min_key" : "max_key"; code = CORD_all(code, compile_declaration(key_type, superlative_key), ";\n"); - ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, - .lhs=FakeAST(InlineCCode, .code="key", .type=key_type), - .rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type)); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=cmp_op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="key", .type=key_type), + .__data.Plus.rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type)); + body->__data.InlineCCode.code = CORD_all( compile_declaration(key_type, "key"), " = ", compile(key_scope, reduction->key), ";\n", "if (!has_value || ", compile(body_scope, comparison), ") {\n" @@ -3536,7 +3397,9 @@ CORD compile(env_t *env, ast_t *ast) " has_value = yes;\n" "}\n"); } else { - ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, .lhs=item, .rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t)); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=cmp_op, .__data.Plus.lhs=item, + .__data.Plus.rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t)); body->__data.InlineCCode.code = CORD_all( "if (!has_value || ", compile(body_scope, comparison), ") {\n" " ", superlative, " = ", compile(body_scope, item), ";\n" @@ -3558,22 +3421,24 @@ CORD compile(env_t *env, ast_t *ast) // For the special case of (or)/(and), we need to early out if we can: CORD early_out = CORD_EMPTY; - if (op == BINOP_CMP) { + if (op == Compare) { if (item_t->tag != IntType || Match(item_t, IntType)->bits != TYPE_IBITS32) code_err(ast, "<> reductions are only supported for Int32 values"); - } else if (op == BINOP_AND) { + } else if (op == And) { if (item_t->tag == BoolType) early_out = "if (!reduction) break;"; else if (item_t->tag == OptionalType) early_out = CORD_all("if (", check_none(item_t, "reduction"), ") break;"); - } else if (op == BINOP_OR) { + } else if (op == Or) { if (item_t->tag == BoolType) early_out = "if (reduction) break;"; else if (item_t->tag == OptionalType) early_out = CORD_all("if (!", check_none(item_t, "reduction"), ") break;"); } - ast_t *combination = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), .rhs=item); + ast_t *combination = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), + .__data.Plus.rhs=item); body->__data.InlineCCode.code = CORD_all( "if (!has_value) {\n" " reduction = ", compile(body_scope, item), ";\n" @@ -3764,7 +3629,7 @@ CORD compile(env_t *env, ast_t *ast) case Defer: code_err(ast, "Compiling 'defer' as expression!"); case Extern: code_err(ast, "Externs are not supported as expressions"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); - case Declare: case Assign: case UpdateAssign: case For: case While: case Repeat: case StructDef: case LangDef: case Extend: + case Declare: case Assign: case UPDATE_CASES: case For: case While: case Repeat: case StructDef: case LangDef: case Extend: case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: case PrintStatement: code_err(ast, "This is not a valid expression"); default: case Unknown: code_err(ast, "Unknown AST"); diff --git a/src/environment.c b/src/environment.c index 18818a9c..a6a450e8 100644 --- a/src/environment.c +++ b/src/environment.c @@ -122,7 +122,7 @@ env_t *global_env(void) {"right_shifted", "Int$right_shifted", "func(x,y:Int -> Int)"}, {"sqrt", "Int$sqrt", "func(x:Int -> Int?)"}, {"times", "Int$times", "func(x,y:Int -> Int)"}, - {"to", "Int$to", "func(first:Int,last:Int,step=none:Int -> func(->Int?))"}, + {"to", "Int$to", "func(first:Int,last:Int,step:Int?=none -> func(->Int?))"}, )}, {"Int64", Type(IntType, .bits=TYPE_IBITS64), "Int64_t", "Int64$info", TypedArray(ns_entry_t, {"abs", "labs", "func(i:Int64 -> Int64)"}, @@ -139,7 +139,7 @@ env_t *global_env(void) {"modulo1", "Int64$modulo1", "func(x,y:Int64 -> Int64)"}, {"octal", "Int64$octal", "func(i:Int64, digits=0, prefix=yes -> Text)"}, {"onward", "Int64$onward", "func(first:Int64,step=Int64(1) -> func(->Int64?))"}, - {"to", "Int64$to", "func(first:Int64,last:Int64,step=none:Int64 -> func(->Int64?))"}, + {"to", "Int64$to", "func(first:Int64,last:Int64,step:Int64?=none -> func(->Int64?))"}, {"unsigned_left_shifted", "Int64$unsigned_left_shifted", "func(x:Int64,y:Int64 -> Int64)"}, {"unsigned_right_shifted", "Int64$unsigned_right_shifted", "func(x:Int64,y:Int64 -> Int64)"}, {"wrapping_minus", "Int64$wrapping_minus", "func(x:Int64,y:Int64 -> Int64)"}, @@ -160,7 +160,7 @@ env_t *global_env(void) {"modulo1", "Int32$modulo1", "func(x,y:Int32 -> Int32)"}, {"octal", "Int32$octal", "func(i:Int32, digits=0, prefix=yes -> Text)"}, {"onward", "Int32$onward", "func(first:Int32,step=Int32(1) -> func(->Int32?))"}, - {"to", "Int32$to", "func(first:Int32,last:Int32,step=none:Int32 -> func(->Int32?))"}, + {"to", "Int32$to", "func(first:Int32,last:Int32,step:Int32?=none -> func(->Int32?))"}, {"unsigned_left_shifted", "Int32$unsigned_left_shifted", "func(x:Int32,y:Int32 -> Int32)"}, {"unsigned_right_shifted", "Int32$unsigned_right_shifted", "func(x:Int32,y:Int32 -> Int32)"}, {"wrapping_minus", "Int32$wrapping_minus", "func(x:Int32,y:Int32 -> Int32)"}, @@ -181,7 +181,7 @@ env_t *global_env(void) {"modulo1", "Int16$modulo1", "func(x,y:Int16 -> Int16)"}, {"octal", "Int16$octal", "func(i:Int16, digits=0, prefix=yes -> Text)"}, {"onward", "Int16$onward", "func(first:Int16,step=Int16(1) -> func(->Int16?))"}, - {"to", "Int16$to", "func(first:Int16,last:Int16,step=none:Int16 -> func(->Int16?))"}, + {"to", "Int16$to", "func(first:Int16,last:Int16,step:Int16?=none -> func(->Int16?))"}, {"unsigned_left_shifted", "Int16$unsigned_left_shifted", "func(x:Int16,y:Int16 -> Int16)"}, {"unsigned_right_shifted", "Int16$unsigned_right_shifted", "func(x:Int16,y:Int16 -> Int16)"}, {"wrapping_minus", "Int16$wrapping_minus", "func(x:Int16,y:Int16 -> Int16)"}, @@ -202,7 +202,7 @@ env_t *global_env(void) {"modulo1", "Int8$modulo1", "func(x,y:Int8 -> Int8)"}, {"octal", "Int8$octal", "func(i:Int8, digits=0, prefix=yes -> Text)"}, {"onward", "Int8$onward", "func(first:Int8,step=Int8(1) -> func(->Int8?))"}, - {"to", "Int8$to", "func(first:Int8,last:Int8,step=none:Int8 -> func(->Int8?))"}, + {"to", "Int8$to", "func(first:Int8,last:Int8,step:Int8?=none -> func(->Int8?))"}, {"unsigned_left_shifted", "Int8$unsigned_left_shifted", "func(x:Int8,y:Int8 -> Int8)"}, {"unsigned_right_shifted", "Int8$unsigned_right_shifted", "func(x:Int8,y:Int8 -> Int8)"}, {"wrapping_minus", "Int8$wrapping_minus", "func(x:Int8,y:Int8 -> Int8)"}, @@ -310,11 +310,11 @@ env_t *global_env(void) {"owner", "Path$owner", "func(path:Path, follow_symlinks=yes -> Text?)"}, {"parent", "Path$parent", "func(path:Path -> Path)"}, {"read", "Path$read", "func(path:Path -> Text?)"}, - {"read_bytes", "Path$read_bytes", "func(path:Path, limit=none:Int -> [Byte]?)"}, + {"read_bytes", "Path$read_bytes", "func(path:Path, limit:Int?=none -> [Byte]?)"}, {"relative_to", "Path$relative_to", "func(path:Path, relative_to:Path -> Path)"}, {"remove", "Path$remove", "func(path:Path, ignore_missing=no)"}, {"resolved", "Path$resolved", "func(path:Path, relative_to=(./) -> Path)"}, - {"set_owner", "Path$set_owner", "func(path:Path, owner=none:Text, group=none:Text, follow_symlinks=yes)"}, + {"set_owner", "Path$set_owner", "func(path:Path, owner:Text?=none, group:Text?=none, follow_symlinks=yes)"}, {"subdirectories", "Path$children", "func(path:Path, include_hidden=no -> [Path])"}, {"unique_directory", "Path$unique_directory", "func(path:Path -> Path)"}, {"write", "Path$write", "func(path:Path, text:Text, permissions=Int32(0o644))"}, @@ -508,7 +508,7 @@ env_t *global_env(void) {"say", "say", "func(text:Text, newline=yes)"}, {"print", "say", "func(text:Text, newline=yes)"}, {"ask", "ask", "func(prompt:Text, bold=yes, force_tty=yes -> Text?)"}, - {"exit", "tomo_exit", "func(message=none:Text, code=Int32(1) -> Abort)"}, + {"exit", "tomo_exit", "func(message:Text?=none, code=Int32(1) -> Abort)"}, {"fail", "fail_text", "func(message:Text -> Abort)"}, {"sleep", "sleep_num", "func(seconds:Num)"}, }; @@ -749,6 +749,18 @@ PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args) return NULL; } +PUREFUNC binding_t *get_metamethod_binding(env_t *env, ast_e tag, ast_t *lhs, ast_t *rhs, type_t *ret) +{ + const char *method_name = binop_method_name(tag); + if (!method_name) return NULL; + binding_t *b = get_namespace_binding(env, lhs, method_name); + if (!b || b->type->tag != FunctionType) return NULL; + auto fn = Match(b->type, FunctionType); + if (!type_eq(fn->ret, ret)) return NULL; + arg_ast_t *args = new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs)); + return is_valid_call(env, fn->args, args, true) ? b : NULL; +} + void set_binding(env_t *env, const char *name, type_t *type, CORD code) { assert(name); diff --git a/src/environment.h b/src/environment.h index fce6bc91..cbaae09b 100644 --- a/src/environment.h +++ b/src/environment.h @@ -85,6 +85,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name); }) binding_t *get_binding(env_t *env, const char *name); binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args); +PUREFUNC binding_t *get_metamethod_binding(env_t *env, ast_e tag, ast_t *lhs, ast_t *rhs, type_t *ret); void set_binding(env_t *env, const char *name, type_t *type, CORD code); binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name); #define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__) diff --git a/src/parse.c b/src/parse.c index 63f5deb0..0aa26002 100644 --- a/src/parse.c +++ b/src/parse.c @@ -48,16 +48,16 @@ typedef struct { #define PARSER(name) ast_t *name(parse_ctx_t *ctx, const char *pos) int op_tightness[] = { - [BINOP_POWER]=9, - [BINOP_MULT]=8, [BINOP_DIVIDE]=8, [BINOP_MOD]=8, [BINOP_MOD1]=8, - [BINOP_PLUS]=7, [BINOP_MINUS]=7, - [BINOP_CONCAT]=6, - [BINOP_LSHIFT]=5, [BINOP_RSHIFT]=5, - [BINOP_MIN]=4, [BINOP_MAX]=4, - [BINOP_EQ]=3, [BINOP_NE]=3, - [BINOP_LT]=2, [BINOP_LE]=2, [BINOP_GT]=2, [BINOP_GE]=2, - [BINOP_CMP]=2, - [BINOP_AND]=1, [BINOP_OR]=1, [BINOP_XOR]=1, + [Power]=9, + [Multiply]=8, [Divide]=8, [Mod]=8, [Mod1]=8, + [Plus]=7, [Minus]=7, + [Concat]=6, + [LeftShift]=5, [RightShift]=5, [UnsignedLeftShift]=5, [UnsignedRightShift]=5, + [Min]=4, [Max]=4, + [Equals]=3, [NotEquals]=3, + [LessThan]=2, [LessThanOrEquals]=2, [GreaterThan]=2, [GreaterThanOrEquals]=2, + [Compare]=2, + [And]=1, [Or]=1, [Xor]=1, }; static const char *keywords[] = { @@ -79,7 +79,7 @@ static INLINE const char* get_word(const char **pos); static INLINE const char* get_id(const char **pos); static INLINE bool comment(const char **pos); static INLINE bool indent(parse_ctx_t *ctx, const char **pos); -static INLINE binop_e match_binary_operator(const char **pos); +static INLINE ast_e match_binary_operator(const char **pos); static ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr); static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs); static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn); @@ -685,15 +685,6 @@ PARSER(parse_array) { whitespace(&pos); ast_list_t *items = NULL; - type_ast_t *item_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a type for this array"); - whitespace(&pos); - match(&pos, ","); - whitespace(&pos); - } - for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; @@ -711,7 +702,7 @@ PARSER(parse_array) { expect_closing(ctx, &pos, "]", "I wasn't able to parse the rest of this array"); REVERSE_LIST(items); - return NewAST(ctx->file, start, pos, Array, .item_type=item_type, .items=items); + return NewAST(ctx->file, start, pos, Array, .items=items); } PARSER(parse_table) { @@ -722,20 +713,6 @@ PARSER(parse_table) { whitespace(&pos); ast_list_t *entries = NULL; - type_ast_t *key_type = NULL, *value_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - key_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this table"); - whitespace(&pos); - if (match(&pos, "=")) { - value_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse the value type for this table"); - } else { - return NULL; - } - whitespace(&pos); - match(&pos, ","); - } - for (;;) { const char *entry_start = pos; ast_t *key = optional(ctx, &pos, parse_extended_expr); @@ -787,8 +764,7 @@ PARSER(parse_table) { whitespace(&pos); expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this table"); - return NewAST(ctx->file, start, pos, Table, .key_type=key_type, .value_type=value_type, - .default_value=default_value, .entries=entries, .fallback=fallback); + return NewAST(ctx->file, start, pos, Table, .default_value=default_value, .entries=entries, .fallback=fallback); } PARSER(parse_set) { @@ -801,18 +777,6 @@ PARSER(parse_set) { whitespace(&pos); ast_list_t *items = NULL; - type_ast_t *item_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this set"); - whitespace(&pos); - if (match(&pos, ",")) - return NULL; - whitespace(&pos); - match(&pos, ","); - whitespace(&pos); - } - for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; @@ -834,7 +798,7 @@ PARSER(parse_set) { whitespace(&pos); expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set"); - return NewAST(ctx->file, start, pos, Set, .item_type=item_type, .items=items); + return NewAST(ctx->file, start, pos, Set, .items=items); } ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs) { @@ -874,11 +838,11 @@ PARSER(parse_reduction) { if (!match(&pos, "(")) return NULL; whitespace(&pos); - binop_e op = match_binary_operator(&pos); - if (op == BINOP_UNKNOWN) return NULL; + ast_e op = match_binary_operator(&pos); + if (op == Unknown) return NULL; ast_t *key = NULL; - if (op == BINOP_MIN || op == BINOP_MAX) { + if (op == Min || op == Max) { key = NewAST(ctx->file, pos, pos, Var, .name="$"); for (bool progress = true; progress; ) { ast_t *new_term; @@ -1425,16 +1389,7 @@ PARSER(parse_none) { const char *start = pos; if (!match_word(&pos, "none")) return NULL; - - const char *none_end = pos; - spaces(&pos); - if (!match(&pos, ":")) - return NewAST(ctx->file, start, none_end, None, .type=NULL); - - spaces(&pos); - type_ast_t *type = parse_type(ctx, pos); - if (!type) return NULL; - return NewAST(ctx->file, start, type->end, None, .type=type); + return NewAST(ctx->file, start, pos, None); } PARSER(parse_deserialize) { @@ -1602,53 +1557,53 @@ ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn) { return NewAST(ctx->file, start, pos, FunctionCall, .fn=fn, .args=args); } -binop_e match_binary_operator(const char **pos) +ast_e match_binary_operator(const char **pos) { switch (**pos) { case '+': { *pos += 1; - return match(pos, "+") ? BINOP_CONCAT : BINOP_PLUS; + return match(pos, "+") ? Concat : Plus; } case '-': { *pos += 1; if ((*pos)[0] != ' ' && (*pos)[-2] == ' ') // looks like `fn -5` - return BINOP_UNKNOWN; - return BINOP_MINUS; + return Unknown; + return Minus; } - case '*': *pos += 1; return BINOP_MULT; - case '/': *pos += 1; return BINOP_DIVIDE; - case '^': *pos += 1; return BINOP_POWER; + case '*': *pos += 1; return Multiply; + case '/': *pos += 1; return Divide; + case '^': *pos += 1; return Power; case '<': { *pos += 1; - if (match(pos, "=")) return BINOP_LE; // "<=" - else if (match(pos, ">")) return BINOP_CMP; // "<>" + if (match(pos, "=")) return LessThanOrEquals; // "<=" + else if (match(pos, ">")) return Compare; // "<>" else if (match(pos, "<")) { if (match(pos, "<")) - return BINOP_ULSHIFT; // "<<<" - return BINOP_LSHIFT; // "<<" - } else return BINOP_LT; + return UnsignedLeftShift; // "<<<" + return LeftShift; // "<<" + } else return LessThan; } case '>': { *pos += 1; - if (match(pos, "=")) return BINOP_GE; // ">=" + if (match(pos, "=")) return GreaterThanOrEquals; // ">=" if (match(pos, ">")) { if (match(pos, ">")) - return BINOP_URSHIFT; // ">>>" - return BINOP_RSHIFT; // ">>" + return UnsignedRightShift; // ">>>" + return RightShift; // ">>" } - return BINOP_GT; + return GreaterThan; } default: { - if (match(pos, "!=")) return BINOP_NE; - else if (match(pos, "==") && **pos != '=') return BINOP_EQ; - else if (match_word(pos, "and")) return BINOP_AND; - else if (match_word(pos, "or")) return BINOP_OR; - else if (match_word(pos, "xor")) return BINOP_XOR; - else if (match_word(pos, "mod1")) return BINOP_MOD1; - else if (match_word(pos, "mod")) return BINOP_MOD; - else if (match_word(pos, "_min_")) return BINOP_MIN; - else if (match_word(pos, "_max_")) return BINOP_MAX; - else return BINOP_UNKNOWN; + if (match(pos, "!=")) return NotEquals; + else if (match(pos, "==") && **pos != '=') return Equals; + else if (match_word(pos, "and")) return And; + else if (match_word(pos, "or")) return Or; + else if (match_word(pos, "xor")) return Xor; + else if (match_word(pos, "mod1")) return Mod1; + else if (match_word(pos, "mod")) return Mod; + else if (match_word(pos, "_min_")) return Min; + else if (match_word(pos, "_max_")) return Max; + else return Unknown; } } } @@ -1660,9 +1615,9 @@ static ast_t *parse_infix_expr(parse_ctx_t *ctx, const char *pos, int min_tightn int64_t starting_line = get_line_number(ctx->file, pos); int64_t starting_indent = get_indent(ctx, pos); spaces(&pos); - for (binop_e op; (op=match_binary_operator(&pos)) != BINOP_UNKNOWN && op_tightness[op] >= min_tightness; spaces(&pos)) { + for (ast_e op; (op=match_binary_operator(&pos)) != Unknown && op_tightness[op] >= min_tightness; spaces(&pos)) { ast_t *key = NULL; - if (op == BINOP_MIN || op == BINOP_MAX) { + if (op == Min || op == Max) { key = NewAST(ctx->file, pos, pos, Var, .name="$"); for (bool progress = true; progress; ) { ast_t *new_term; @@ -1688,12 +1643,12 @@ static ast_t *parse_infix_expr(parse_ctx_t *ctx, const char *pos, int min_tightn if (!rhs) break; pos = rhs->end; - if (op == BINOP_MIN) { + if (op == Min) { return NewAST(ctx->file, lhs->start, rhs->end, Min, .lhs=lhs, .rhs=rhs, .key=key); - } else if (op == BINOP_MAX) { + } else if (op == Max) { return NewAST(ctx->file, lhs->start, rhs->end, Max, .lhs=lhs, .rhs=rhs, .key=key); } else { - lhs = NewAST(ctx->file, lhs->start, rhs->end, BinaryOp, .lhs=lhs, .op=op, .rhs=rhs); + lhs = new(ast_t, .file=ctx->file, .start=lhs->start, .end=rhs->end, .tag=op, .__data.Plus.lhs=lhs, .__data.Plus.rhs=rhs); } } return lhs; @@ -1709,8 +1664,11 @@ PARSER(parse_declaration) { if (!var) return NULL; pos = var->end; spaces(&pos); - if (!match(&pos, ":=")) return NULL; + if (!match(&pos, ":")) return NULL; + spaces(&pos); + type_ast_t *type = optional(ctx, &pos, parse_type); spaces(&pos); + if (!match(&pos, "=")) return NULL; ast_t *val = optional(ctx, &pos, parse_extended_expr); if (!val) { if (optional(ctx, &pos, parse_use)) @@ -1718,7 +1676,7 @@ PARSER(parse_declaration) { else parser_err(ctx, pos, eol(pos), "This is not a valid expression"); } - return NewAST(ctx->file, start, pos, Declare, .var=var, .value=val); + return NewAST(ctx->file, start, pos, Declare, .var=var, .type=type, .value=val); } PARSER(parse_update) { @@ -1726,23 +1684,23 @@ PARSER(parse_update) { ast_t *lhs = optional(ctx, &pos, parse_expr); if (!lhs) return NULL; spaces(&pos); - binop_e op; - if (match(&pos, "+=")) op = BINOP_PLUS; - else if (match(&pos, "++=")) op = BINOP_CONCAT; - else if (match(&pos, "-=")) op = BINOP_MINUS; - else if (match(&pos, "*=")) op = BINOP_MULT; - else if (match(&pos, "/=")) op = BINOP_DIVIDE; - else if (match(&pos, "^=")) op = BINOP_POWER; - else if (match(&pos, "<<=")) op = BINOP_LSHIFT; - else if (match(&pos, "<<<=")) op = BINOP_ULSHIFT; - else if (match(&pos, ">>=")) op = BINOP_RSHIFT; - else if (match(&pos, ">>>=")) op = BINOP_URSHIFT; - else if (match(&pos, "and=")) op = BINOP_AND; - else if (match(&pos, "or=")) op = BINOP_OR; - else if (match(&pos, "xor=")) op = BINOP_XOR; + ast_e op; + if (match(&pos, "+=")) op = Plus; + else if (match(&pos, "++=")) op = Concat; + else if (match(&pos, "-=")) op = Minus; + else if (match(&pos, "*=")) op = Multiply; + else if (match(&pos, "/=")) op = Divide; + else if (match(&pos, "^=")) op = Power; + else if (match(&pos, "<<=")) op = LeftShift; + else if (match(&pos, "<<<=")) op = UnsignedLeftShift; + else if (match(&pos, ">>=")) op = RightShift; + else if (match(&pos, ">>>=")) op = UnsignedRightShift; + else if (match(&pos, "and=")) op = And; + else if (match(&pos, "or=")) op = Or; + else if (match(&pos, "xor=")) op = Xor; else return NULL; ast_t *rhs = expect(ctx, start, &pos, parse_extended_expr, "I expected an expression here"); - return NewAST(ctx->file, start, pos, UpdateAssign, .lhs=lhs, .rhs=rhs, .op=op); + return new(ast_t, .file=ctx->file, .start=start, .end=pos, .tag=op, .__data.PlusUpdate.lhs=lhs, .__data.PlusUpdate.rhs=rhs); } PARSER(parse_assignment) { diff --git a/src/repl.c b/src/repl.c index 2f5c60f6..463e7ff8 100644 --- a/src/repl.c +++ b/src/repl.c @@ -186,31 +186,31 @@ static Int_t ast_to_int(env_t *env, ast_t *ast) } } -static double ast_to_num(env_t *env, ast_t *ast) -{ - type_t *t = get_type(env, ast); - switch (t->tag) { - case BigIntType: case IntType: { - number_t num; - eval(env, ast, &num); - if (t->tag == BigIntType) - return Num$from_int(num.integer, false); - switch (Match(t, IntType)->bits) { - case TYPE_IBITS64: return Num$from_int64(num.i64, false); - case TYPE_IBITS32: return Num$from_int32(num.i32); - case TYPE_IBITS16: return Num$from_int16(num.i16); - case TYPE_IBITS8: return Num$from_int8(num.i8); - default: print_err("Invalid int bits"); - } - } - case NumType: { - number_t num; - eval(env, ast, &num); - return Match(t, NumType)->bits == TYPE_NBITS32 ? (double)num.n32 : (double)num.n64; - } - default: print_err("Cannot convert to number"); - } -} +// static double ast_to_num(env_t *env, ast_t *ast) +// { +// type_t *t = get_type(env, ast); +// switch (t->tag) { +// case BigIntType: case IntType: { +// number_t num; +// eval(env, ast, &num); +// if (t->tag == BigIntType) +// return Num$from_int(num.integer, false); +// switch (Match(t, IntType)->bits) { +// case TYPE_IBITS64: return Num$from_int64(num.i64, false); +// case TYPE_IBITS32: return Num$from_int32(num.i32); +// case TYPE_IBITS16: return Num$from_int16(num.i16); +// case TYPE_IBITS8: return Num$from_int8(num.i8); +// default: print_err("Invalid int bits"); +// } +// } +// case NumType: { +// number_t num; +// eval(env, ast, &num); +// return Match(t, NumType)->bits == TYPE_NBITS32 ? (double)num.n32 : (double)num.n64; +// } +// default: print_err("Cannot convert to number"); +// } +// } static Text_t obj_to_text(type_t *t, const void *obj, bool use_color) { @@ -386,76 +386,6 @@ void eval(env_t *env, ast_t *ast, void *dest) if (dest) *(CORD*)dest = ret; break; } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - if (t->tag == IntType || t->tag == BigIntType) { -#define CASE_OP(OP_NAME, method_name) case BINOP_##OP_NAME: {\ - Int_t lhs = ast_to_int(env, binop->lhs); \ - Int_t rhs = ast_to_int(env, binop->rhs); \ - Int_t result = Int$ ## method_name (lhs, rhs); \ - if (t->tag == BigIntType) {\ - *(Int_t*)dest = result; \ - return; \ - } \ - switch (Match(t, IntType)->bits) { \ - case 64: *(int64_t*)dest = Int64$from_int(result, false); return; \ - case 32: *(int32_t*)dest = Int32$from_int(result, false); return; \ - case 16: *(int16_t*)dest = Int16$from_int(result, false); return; \ - case 8: *(int8_t*)dest = Int8$from_int(result, false); return; \ - default: print_err("Invalid int bits"); \ - } \ - break; \ - } - switch (binop->op) { - CASE_OP(MULT, times) CASE_OP(DIVIDE, divided_by) CASE_OP(PLUS, plus) CASE_OP(MINUS, minus) - CASE_OP(RSHIFT, right_shifted) CASE_OP(LSHIFT, left_shifted) - CASE_OP(MOD, modulo) CASE_OP(MOD1, modulo1) - CASE_OP(AND, bit_and) CASE_OP(OR, bit_or) CASE_OP(XOR, bit_xor) - default: break; - } -#undef CASE_OP - } else if (t->tag == NumType) { -#define CASE_OP(OP_NAME, C_OP) case BINOP_##OP_NAME: {\ - double lhs = ast_to_num(env, binop->lhs); \ - double rhs = ast_to_num(env, binop->rhs); \ - if (Match(t, NumType)->bits == 64) \ - *(double*)dest = (double)(lhs C_OP rhs); \ - else \ - *(float*)dest = (float)(lhs C_OP rhs); \ - return; \ - } - switch (binop->op) { - CASE_OP(MULT, *) CASE_OP(DIVIDE, /) CASE_OP(PLUS, +) CASE_OP(MINUS, -) - default: break; - } -#undef CASE_OP - } - switch (binop->op) { - case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: { - type_t *t_lhs = get_type(env, binop->lhs); - if (!type_eq(t_lhs, get_type(env, binop->rhs))) - print_err("Comparisons between different types aren't supported"); - const TypeInfo_t *info = type_to_type_info(t_lhs); - size_t value_size = type_size(t_lhs); - char lhs[value_size], rhs[value_size]; - eval(env, binop->lhs, lhs); - eval(env, binop->rhs, rhs); - int cmp = generic_compare(lhs, rhs, info); - switch (binop->op) { - case BINOP_EQ: *(bool*)dest = (cmp == 0); break; - case BINOP_NE: *(bool*)dest = (cmp != 0); break; - case BINOP_GT: *(bool*)dest = (cmp > 0); break; - case BINOP_GE: *(bool*)dest = (cmp >= 0); break; - case BINOP_LT: *(bool*)dest = (cmp < 0); break; - case BINOP_LE: *(bool*)dest = (cmp <= 0); break; - default: break; - } - break; - } - default: print_err(1, "Binary op not implemented for ", type_to_str(t), ": ", ast_to_xml_str(ast)); - } - break; - } case Index: { auto index = Match(ast, Index); type_t *indexed_t = get_type(env, index->indexed); diff --git a/src/typecheck.c b/src/typecheck.c index 8a2ee32b..cd6ff1c2 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -135,27 +135,27 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) errx(1, "Unreachable"); } -static PUREFUNC bool risks_zero_or_inf(ast_t *ast) -{ - switch (ast->tag) { - case Int: { - const char *str = Match(ast, Int)->str; - OptionalInt_t int_val = Int$from_str(str); - return (int_val.small == 0x1); // zero - } - case Num: { - return Match(ast, Num)->n == 0.0; - } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - if (binop->op == BINOP_MULT || binop->op == BINOP_DIVIDE || binop->op == BINOP_MIN || binop->op == BINOP_MAX) - return risks_zero_or_inf(binop->lhs) || risks_zero_or_inf(binop->rhs); - else - return true; - } - default: return true; - } -} +// static PUREFUNC bool risks_zero_or_inf(ast_t *ast) +// { +// switch (ast->tag) { +// case Int: { +// const char *str = Match(ast, Int)->str; +// OptionalInt_t int_val = Int$from_str(str); +// return (int_val.small == 0x1); // zero +// } +// case Num: { +// return Match(ast, Num)->n == 0.0; +// } +// case BINOP_CASES: { +// binary_operands_t binop = BINARY_OPERANDS(ast); +// if (ast->tag == Multiply || ast->tag == Divide || ast->tag == Min || ast->tag == Max) +// return risks_zero_or_inf(binop.lhs) || risks_zero_or_inf(binop.rhs); +// else +// return true; +// } +// default: return true; +// } +// } PUREFUNC type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t) { @@ -312,7 +312,7 @@ void bind_statement(env_t *env, ast_t *statement) if (get_binding(env, name)) code_err(decl->var, "A ", type_to_str(get_binding(env, name)->type), " called ", quoted(name), " has already been defined"); bind_statement(env, decl->value); - type_t *type = get_type(env, decl->value); + type_t *type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (!type) code_err(decl->value, "I couldn't figure out the type of this value"); if (type->tag == FunctionType) @@ -617,12 +617,7 @@ type_t *get_type(env_t *env, ast_t *ast) #endif switch (ast->tag) { case None: { - if (!Match(ast, None)->type) - return Type(OptionalType, .type=NULL); - type_t *t = parse_type_ast(env, Match(ast, None)->type); - if (t->tag == OptionalType) - code_err(ast, "Nested optional types are not supported. This should be: `none:", type_to_str(Match(t, OptionalType)->type), "`"); - return Type(OptionalType, .type=t); + return Type(OptionalType, .type=NULL); } case Bool: { return Type(BoolType); @@ -714,115 +709,91 @@ type_t *get_type(env_t *env, ast_t *ast) case Array: { auto array = Match(ast, Array); type_t *item_type = NULL; - if (array->item_type) { - item_type = parse_type_ast(env, array->item_type); - } else if (array->items) { - for (ast_list_t *item = array->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - type_t *t2 = get_type(scope, item_ast); - type_t *merged = item_type ? type_or_type(item_type, t2) : t2; - if (!merged) - code_err(item->ast, - "This array item has type ", type_to_str(t2), - ", which is different from earlier array items which have type ", type_to_str(item_type)); - item_type = merged; + for (ast_list_t *item = array->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } - } else { - code_err(ast, "I can't figure out what type this array has because it has no members or explicit type"); + type_t *t2 = get_type(scope, item_ast); + type_t *merged = item_type ? type_or_type(item_type, t2) : t2; + if (!merged) + code_err(item->ast, + "This array item has type ", type_to_str(t2), + ", which is different from earlier array items which have type ", type_to_str(item_type)); + item_type = merged; } - if (has_stack_memory(item_type)) - code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this array!"); + if (item_type && has_stack_memory(item_type)) + code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); return Type(ArrayType, .item_type=item_type); } case Set: { auto set = Match(ast, Set); type_t *item_type = NULL; - if (set->item_type) { - item_type = parse_type_ast(env, set->item_type); - } else { - for (ast_list_t *item = set->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - - type_t *this_item_type = get_type(scope, item_ast); - type_t *item_merged = type_or_type(item_type, this_item_type); - if (!item_merged) - code_err(item_ast, - "This set item has type ", type_to_str(this_item_type), - ", which is different from earlier set items which have type ", type_to_str(item_type)); - item_type = item_merged; + for (ast_list_t *item = set->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } - } - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this set!"); + type_t *this_item_type = get_type(scope, item_ast); + type_t *item_merged = type_or_type(item_type, this_item_type); + if (!item_merged) + code_err(item_ast, + "This set item has type ", type_to_str(this_item_type), + ", which is different from earlier set items which have type ", type_to_str(item_type)); + item_type = item_merged; + } - if (has_stack_memory(item_type)) + if (item_type && has_stack_memory(item_type)) code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame."); + return Type(SetType, .item_type=item_type); } case Table: { auto table = Match(ast, Table); type_t *key_type = NULL, *value_type = NULL; - if (table->key_type && table->value_type) { - key_type = parse_type_ast(env, table->key_type); - value_type = parse_type_ast(env, table->value_type); - } else if (table->key_type && table->default_value) { - key_type = parse_type_ast(env, table->key_type); - value_type = get_type(env, table->default_value); - } else { - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - ast_t *entry_ast = entry->ast; - env_t *scope = env; - while (entry_ast->tag == Comprehension) { - auto comp = Match(entry_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - entry_ast = comp->expr; - } - - auto e = Match(entry_ast, TableEntry); - type_t *key_t = get_type(scope, e->key); - type_t *value_t = get_type(scope, e->value); - - type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; - if (!key_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(key_t), - ", which is different from earlier table entries which have type ", type_to_str(key_type)); - key_type = key_merged; - - type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; - if (!val_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(value_t), - ", which is different from earlier table entries which have type ", type_to_str(value_type)); - value_type = val_merged; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + ast_t *entry_ast = entry->ast; + env_t *scope = env; + while (entry_ast->tag == Comprehension) { + auto comp = Match(entry_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + entry_ast = comp->expr; } - } - if (!key_type || !value_type) - code_err(ast, "I couldn't figure out the key and value types for this table!"); + auto e = Match(entry_ast, TableEntry); + type_t *key_t = get_type(scope, e->key); + type_t *value_t = get_type(scope, e->value); + + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; + if (!key_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(key_t), + ", which is different from earlier table entries which have type ", type_to_str(key_type)); + key_type = key_merged; - if (has_stack_memory(key_type) || has_stack_memory(value_type)) + type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; + if (!val_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(value_t), + ", which is different from earlier table entries which have type ", type_to_str(value_type)); + value_type = val_merged; + } + + if ((key_type && has_stack_memory(key_type)) || (value_type && has_stack_memory(value_type))) code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame."); + return Type(TableType, .key_type=key_type, .value_type=value_type, .default_value=table->default_value, .env=env); } case TableEntry: { @@ -998,7 +969,7 @@ type_t *get_type(env_t *env, ast_t *ast) // Early out if the type is knowable without any context from the block: switch (last->ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: return Type(VoidType); default: break; } @@ -1022,7 +993,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Extern: { return parse_type_ast(env, Match(ast, Extern)->type); } - case Declare: case Assign: case DocTest: { + case Declare: case Assign: case UPDATE_CASES: case DocTest: { return Type(VoidType); } case Use: { @@ -1078,169 +1049,160 @@ type_t *get_type(env_t *env, ast_t *ast) } code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not ", type_to_str(t)); } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - type_t *lhs_t = get_type(env, binop->lhs), - *rhs_t = get_type(env, binop->rhs); - - if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) { - lhs_t = rhs_t; - } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) { + case Or: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - rhs_t = lhs_t; + // `opt? or (x == y)` / `(x == y) or opt?` is a boolean conditional: + if ((lhs_t->tag == OptionalType && rhs_t->tag == BoolType) + || (lhs_t->tag == BoolType && rhs_t->tag == OptionalType)) { + return Type(BoolType); } -#define binding_works(name, self, lhs_t, rhs_t, ret_t) \ - ({ binding_t *b = get_namespace_binding(env, self, name); \ - (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ - (type_eq(fn->ret, ret_t) \ - && (fn->args && type_eq(fn->args->type, lhs_t)) \ - && (fn->args->next && can_promote(rhs_t, fn->args->next->type))); })); }) - // Check for a binop method like plus() etc: - switch (binop->op) { - case BINOP_MULT: { - if (is_numeric_type(lhs_t) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t)) - return rhs_t; - else if (is_numeric_type(rhs_t) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + if (lhs_t->tag == OptionalType) { + if (rhs_t->tag == OptionalType) { + type_t *result = most_complete_type(lhs_t, rhs_t); + if (result == NULL) + code_err(ast, "I could not determine the type of ", type_to_str(lhs_t), " `or` ", type_to_str(rhs_t)); + return result; + } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { + return Match(lhs_t, OptionalType)->type; + } + type_t *non_opt = Match(lhs_t, OptionalType)->type; + non_opt = most_complete_type(non_opt, rhs_t); + if (non_opt != NULL) + return non_opt; + } else if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - break; + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } - case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: case BINOP_CONCAT: { - if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; + code_err(ast, "I couldn't figure out how to do `or` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case And: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + // `and` between optionals/bools is a boolean expression like `if opt? and opt?:` or `if x > 0 and opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); } - case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) + + // Bitwise AND: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - break; - } - case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: { + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { return lhs_t; } - case BINOP_POWER: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - default: break; - } -#undef binding_works + code_err(ast, "I couldn't figure out how to do `and` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Xor: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - switch (binop->op) { - case BINOP_AND: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { - return lhs_t; - } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return lhs_t; - } else if (rhs_t->tag == OptionalType) { - if (can_promote(lhs_t, rhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `and` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); + // `xor` between optionals/bools is a boolean expression like `if opt? xor opt?:` or `if x > 0 xor opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); } - case BINOP_OR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { + + // Bitwise XOR: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) return lhs_t; - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } else if (lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) - return Match(lhs_t, OptionalType)->type; - if (can_promote(rhs_t, lhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if (rhs_t->tag == PointerType) { - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } - } else if (rhs_t->tag == OptionalType) { - return type_or_type(lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `or` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } - case BINOP_XOR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } + code_err(ast, "I couldn't figure out how to do `xor` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Compare: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - code_err(ast, "I can't figure out the type of this `xor` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); - } - case BINOP_CONCAT: { - if (!type_eq(lhs_t, rhs_t)) - code_err(ast, "The type on the left side of this concatenation doesn't match the right side: ", type_to_str(lhs_t), - " vs. ", type_to_str(rhs_t)); - if (lhs_t->tag == ArrayType || lhs_t->tag == TextType || lhs_t->tag == SetType) - return lhs_t; + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) + return Type(IntType, .bits=TYPE_IBITS32); - code_err(ast, "Only array/set/text value types support concatenation, not ", type_to_str(lhs_t)); - } - case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: { - if (!can_promote(lhs_t, rhs_t) && !can_promote(rhs_t, lhs_t)) - code_err(ast, "I can't compare these two different types: ", type_to_str(lhs_t), " vs ", type_to_str(rhs_t)); + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) return Type(BoolType); + + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case LeftShift: + case UnsignedLeftShift: case RightShift: case UnsignedRightShift: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + if (ast->tag == LeftShift || ast->tag == UnsignedLeftShift || ast->tag == RightShift || ast->tag == UnsignedRightShift) { + if (!is_int_type(rhs_t)) + code_err(binop.rhs, "I only know how to do bit shifting by integer amounts, not ", type_to_str(rhs_t)); } - case BINOP_CMP: - return Type(IntType, .bits=TYPE_IBITS32); - case BINOP_POWER: { - type_t *result = get_math_type(env, ast, lhs_t, rhs_t); - if (result->tag == NumType) - return result; - return Type(NumType, .bits=TYPE_NBITS64); - } - case BINOP_MULT: case BINOP_DIVIDE: { - type_t *math_type = get_math_type(env, ast, value_type(lhs_t), value_type(rhs_t)); - if (value_type(lhs_t)->tag == NumType || value_type(rhs_t)->tag == NumType) { - if (risks_zero_or_inf(binop->lhs) && risks_zero_or_inf(binop->rhs)) - return Type(OptionalType, math_type); - else - return math_type; - } - return math_type; - } - default: { - return get_math_type(env, ast, lhs_t, rhs_t); - } + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (ast->tag == Multiply || ast->tag == Divide) { + binding_t *b = is_numeric_type(lhs_t) ? get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t) + : get_metamethod_binding(env, ast->tag, binop.rhs, binop.lhs, rhs_t); + if (b) return overall_t; + } else { + if (overall_t == NULL) + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; } + if (is_numeric_type(lhs_t) && is_numeric_type(rhs_t)) + return overall_t; + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Concat: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (overall_t == NULL) + code_err(ast, "I don't know how to do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; + + if (overall_t->tag == ArrayType || overall_t->tag == SetType || overall_t->tag == TextType) + return overall_t; + + code_err(ast, "I don't know how to do concatenation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } case Reduction: { auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); - if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT - || reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE) + if (reduction->op == Equals || reduction->op == NotEquals || reduction->op == LessThan + || reduction->op == LessThanOrEquals || reduction->op == GreaterThan || reduction->op == GreaterThanOrEquals) return Type(OptionalType, .type=Type(BoolType)); type_t *iterated = get_iterated_type(iter_t); @@ -1249,9 +1211,6 @@ type_t *get_type(env_t *env, ast_t *ast) return iterated->tag == OptionalType ? iterated : Type(OptionalType, .type=iterated); } - case UpdateAssign: - return Type(VoidType); - case Min: case Max: { // Unsafe! These types *should* have the same fields and this saves a lot of duplicate code: ast_t *lhs = ast->__data.Min.lhs, *rhs = ast->__data.Min.rhs; @@ -1310,8 +1269,9 @@ type_t *get_type(env_t *env, ast_t *ast) env_t *truthy_scope = env; env_t *falsey_scope = env; if (if_->condition->tag == Declare) { - type_t *condition_type = get_type(env, Match(if_->condition, Declare)->value); - const char *varname = Match(Match(if_->condition, Declare)->var, Var)->name; + auto decl = Match(if_->condition, Declare); + type_t *condition_type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); + const char *varname = Match(decl->var, Var)->name; if (streq(varname, "_")) code_err(if_->condition, "To use `if var := ...:`, you must choose a real variable name, not `_`"); @@ -1456,7 +1416,7 @@ type_t *get_type(env_t *env, ast_t *ast) PUREFUNC bool is_discardable(env_t *env, ast_t *ast) { switch (ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Use: case Extend: return true; default: break; @@ -1610,13 +1570,13 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } case Not: return is_constant(env, Match(ast, Not)->value); case Negative: return is_constant(env, Match(ast, Negative)->value); - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - switch (binop->op) { - case BINOP_UNKNOWN: case BINOP_POWER: case BINOP_CONCAT: case BINOP_MIN: case BINOP_MAX: case BINOP_CMP: + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + switch (ast->tag) { + case Power: case Concat: case Min: case Max: case Compare: return false; default: - return is_constant(env, binop->lhs) && is_constant(env, binop->rhs); + return is_constant(env, binop.lhs) && is_constant(env, binop.rhs); } } case Use: return true; @@ -1626,4 +1586,49 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } } +PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed) +{ + if (needed->tag == OptionalType && ast->tag == None) { + return true; + } + + needed = non_optional(needed); + if (needed->tag == ArrayType && ast->tag == Array) { + type_t *item_type = Match(needed, ArrayType)->item_type; + for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == SetType && ast->tag == Set) { + type_t *item_type = Match(needed, SetType)->item_type; + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == TableType && ast->tag == Table) { + type_t *key_type = Match(needed, TableType)->key_type; + type_t *value_type = Match(needed, TableType)->value_type; + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) { + if (entry->ast->tag != TableEntry) + continue; // TODO: fix this + auto e = Match(entry->ast, TableEntry); + if (!can_compile_to_type(env, e->key, key_type) || !can_compile_to_type(env, e->value, value_type)) + return false; + } + return true; + } else if (needed->tag == PointerType) { + auto ptr = Match(needed, PointerType); + if (ast->tag == HeapAllocate) + return !ptr->is_stack && can_compile_to_type(env, Match(ast, HeapAllocate)->value, ptr->pointed); + else if (ast->tag == StackReference) + return ptr->is_stack && can_compile_to_type(env, Match(ast, StackReference)->value, ptr->pointed); + else + return can_promote(needed, get_type(env, ast)); + } else { + return can_promote(needed, get_type(env, ast)); + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/typecheck.h b/src/typecheck.h index cc5cb18c..4342acc2 100644 --- a/src/typecheck.h +++ b/src/typecheck.h @@ -29,5 +29,6 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name); PUREFUNC bool is_constant(env_t *env, ast_t *ast); Table_t *get_arg_bindings(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bool promotion_allowed); bool is_valid_call(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bool promotion_allowed); +PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/types.c b/src/types.c index d8ff377b..0b9bc72b 100644 --- a/src/types.c +++ b/src/types.c @@ -360,6 +360,16 @@ PUREFUNC bool can_promote(type_t *actual, type_t *needed) return true; } + // Empty literals: + if (actual->tag == ArrayType && needed->tag == ArrayType && Match(actual, ArrayType)->item_type == NULL) + return true; // [] -> [T] + if (actual->tag == SetType && needed->tag == SetType && Match(actual, SetType)->item_type == NULL) + return true; // {/} -> {T} + if (actual->tag == TableType && needed->tag == SetType && Match(actual, TableType)->key_type == NULL && Match(actual, TableType)->value_type == NULL) + return true; // {} -> {T} + if (actual->tag == TableType && needed->tag == TableType && Match(actual, TableType)->key_type == NULL && Match(actual, TableType)->value_type == NULL) + return true; // {} -> {K=V} + // Cross-promotion between tables with default values and without if (needed->tag == TableType && actual->tag == TableType) { auto actual_table = Match(actual, TableType); @@ -708,4 +718,103 @@ PUREFUNC type_t *get_iterated_type(type_t *t) } } +CONSTFUNC bool is_incomplete_type(type_t *t) +{ + if (t == NULL) return true; + switch (t->tag) { + case ReturnType: return is_incomplete_type(Match(t, ReturnType)->ret); + case OptionalType: return is_incomplete_type(Match(t, OptionalType)->type); + case ArrayType: return is_incomplete_type(Match(t, ArrayType)->item_type); + case SetType: return is_incomplete_type(Match(t, SetType)->item_type); + case TableType: { + auto table = Match(t, TableType); + return is_incomplete_type(table->key_type) || is_incomplete_type(table->value_type); + } + case FunctionType: { + auto fn = Match(t, FunctionType); + for (arg_t *arg = fn->args; arg; arg = arg->next) { + if (arg->type == NULL || is_incomplete_type(arg->type)) + return true; + } + return fn->ret ? is_incomplete_type(fn->ret) : false; + } + case ClosureType: return is_incomplete_type(Match(t, ClosureType)->fn); + case PointerType: return is_incomplete_type(Match(t, PointerType)->pointed); + default: return false; + } +} + +CONSTFUNC type_t *most_complete_type(type_t *t1, type_t *t2) +{ + if (!t1) return t2; + if (!t2) return t1; + + if (is_incomplete_type(t1) && is_incomplete_type(t2)) + return NULL; + else if (!is_incomplete_type(t1) && !is_incomplete_type(t2) && type_eq(t1, t2)) + return t1; + + if (t1->tag != t2->tag) + return NULL; + + switch (t1->tag) { + case ReturnType: { + type_t *ret = most_complete_type(Match(t1, ReturnType)->ret, Match(t1, ReturnType)->ret); + return ret ? Type(ReturnType, ret) : NULL; + } + case OptionalType: { + type_t *opt = most_complete_type(Match(t1, OptionalType)->type, Match(t2, OptionalType)->type); + return opt ? Type(OptionalType, opt) : NULL; + } + case ArrayType: { + type_t *item = most_complete_type(Match(t1, ArrayType)->item_type, Match(t2, ArrayType)->item_type); + return item ? Type(ArrayType, item) : NULL; + } + case SetType: { + type_t *item = most_complete_type(Match(t1, SetType)->item_type, Match(t2, SetType)->item_type); + return item ? Type(SetType, item) : NULL; + } + case TableType: { + auto table1 = Match(t1, TableType); + auto table2 = Match(t2, TableType); + type_t *key = most_complete_type(table1->key_type, table2->key_type); + type_t *value = most_complete_type(table1->value_type, table2->value_type); + return (key && value) ? Type(TableType, key, value) : NULL; + } + case FunctionType: { + auto fn1 = Match(t1, FunctionType); + auto fn2 = Match(t2, FunctionType); + arg_t *args = NULL; + for (arg_t *arg1 = fn1->args, *arg2 = fn2->args; arg1 || arg2; arg1 = arg1->next, arg2 = arg2->next) { + if (!arg1 || !arg2) + return NULL; + + type_t *arg_type = most_complete_type(arg1->type, arg2->type); + if (!arg_type) return NULL; + args = new(arg_t, .type=arg_type, .next=args); + } + REVERSE_LIST(args); + type_t *ret = most_complete_type(fn1->ret, fn2->ret); + return ret ? Type(FunctionType, .args=args, .ret=ret) : NULL; + } + case ClosureType: { + type_t *fn = most_complete_type(Match(t1, ClosureType)->fn, Match(t1, ClosureType)->fn); + return fn ? Type(ClosureType, fn) : NULL; + } + case PointerType: { + auto ptr1 = Match(t1, PointerType); + auto ptr2 = Match(t2, PointerType); + if (ptr1->is_stack != ptr2->is_stack) + return NULL; + type_t *pointed = most_complete_type(ptr1->pointed, ptr2->pointed); + return pointed ? Type(PointerType, .is_stack=ptr1->is_stack, .pointed=pointed) : NULL; + } + default: { + if (is_incomplete_type(t1) || is_incomplete_type(t2)) + return NULL; + return type_eq(t1, t2) ? t1 : NULL; + } + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/types.h b/src/types.h index a5b2ad04..53488584 100644 --- a/src/types.h +++ b/src/types.h @@ -147,6 +147,8 @@ PUREFUNC const char *enum_single_value_tag(type_t *enum_type, type_t *t); PUREFUNC bool is_int_type(type_t *t); PUREFUNC bool is_numeric_type(type_t *t); PUREFUNC bool is_packed_data(type_t *t); +CONSTFUNC bool is_incomplete_type(type_t *t); +CONSTFUNC type_t *most_complete_type(type_t *t1, type_t *t2); PUREFUNC size_t type_size(type_t *t); PUREFUNC size_t type_align(type_t *t); PUREFUNC size_t unpadded_struct_size(type_t *t); -- cgit v1.2.3