diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2024-11-25 14:57:58 -0500 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2024-11-25 14:57:58 -0500 |
| commit | 3e23ea6a67e4bbfea4dc667beb6285a90dff2877 (patch) | |
| tree | 15d537f0da478e5281d164af9a4696fa37d5e5f1 | |
| parent | 369c83e8c517116de560643c2c9bdcedfb3850dc (diff) | |
Improve handling of update assignments
| -rw-r--r-- | compile.c | 26 | ||||
| -rw-r--r-- | types.c | 2 | ||||
| -rw-r--r-- | types.h | 1 |
3 files changed, 17 insertions, 12 deletions
@@ -684,16 +684,17 @@ CORD compile_statement(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, update->lhs); type_t *rhs_t = get_type(env, update->rhs); - if (!promote(env, update->rhs, &rhs, rhs_t, lhs_t)) { - if (update->rhs->tag == Int && (lhs_t->tag == IntType || lhs_t->tag == ByteType)) - rhs = compile_int_to_type(env, update->rhs, lhs_t); - else if (!(lhs_t->tag == ArrayType && promote(env, update->rhs, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type))) + if (update->rhs->tag == Int && is_numeric_type(non_optional(lhs_t))) { + rhs = compile_int_to_type(env, update->rhs, lhs_t); + } else if (!promote(env, update->rhs, &rhs, rhs_t, lhs_t)) { + if (!(lhs_t->tag == ArrayType && promote(env, update->rhs, &rhs, rhs_t, Match(lhs_t, ArrayType)->item_type))) code_err(ast, "I can't do operations between %T and %T", lhs_t, rhs_t); } + bool lhs_is_optional_num = (lhs_t->tag == OptionalType && Match(lhs_t, OptionalType)->type && Match(lhs_t, OptionalType)->type->tag == NumType); switch (update->op) { case BINOP_MULT: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do a multiply assignment with this operator between %T and %T", lhs_t, rhs_t); if (lhs_t->tag == NumType) { // 0*INF -> NaN, needs checking return CORD_asprintf("%r *= %r;\n" @@ -706,7 +707,7 @@ CORD compile_statement(env_t *env, ast_t *ast) } return CORD_all(lhs, " *= ", rhs, ";"); case BINOP_DIVIDE: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do a divide assignment with this operator between %T and %T", lhs_t, rhs_t); if (lhs_t->tag == NumType) { // 0/0 or INF/INF -> NaN, needs checking return CORD_asprintf("%r /= %r;\n" @@ -719,23 +720,23 @@ CORD compile_statement(env_t *env, ast_t *ast) } return CORD_all(lhs, " /= ", rhs, ";"); case BINOP_MOD: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do a mod assignment with this operator between %T and %T", lhs_t, rhs_t); return CORD_all(lhs, " = ", lhs, " % ", rhs); case BINOP_MOD1: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do a mod assignment with this operator between %T and %T", lhs_t, rhs_t); return CORD_all(lhs, " = (((", lhs, ") - 1) % ", rhs, ") + 1;"); case BINOP_PLUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do an addition assignment with this operator between %T and %T", lhs_t, rhs_t); return CORD_all(lhs, " += ", rhs, ";"); case BINOP_MINUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType) + if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) code_err(ast, "I can't do a subtraction assignment with this operator between %T and %T", lhs_t, rhs_t); return CORD_all(lhs, " -= ", rhs, ";"); case BINOP_POWER: { - if (lhs_t->tag != NumType) + if (lhs_t->tag != NumType && !lhs_is_optional_num) code_err(ast, "'^=' is only supported for Num types"); if (lhs_t->tag == NumType && Match(lhs_t, NumType)->bits == TYPE_NBITS32) return CORD_all(lhs, " = powf(", lhs, ", ", rhs, ");"); @@ -1612,6 +1613,9 @@ CORD compile_int_to_type(env_t *env, ast_t *ast, type_t *target) if (target->tag == BigIntType) return compile(env, ast); + if (target->tag == OptionalType && Match(target, OptionalType)->type) + return compile_int_to_type(env, ast, Match(target, OptionalType)->type); + const char *literal = Match(ast, Int)->str; OptionalInt_t int_val = Int$from_str(literal); if (int_val.small == 0) @@ -135,7 +135,7 @@ bool type_is_a(type_t *t, type_t *req) return false; } -static type_t *non_optional(type_t *t) +type_t *non_optional(type_t *t) { return t->tag == OptionalType ? Match(t, OptionalType)->type : t; } @@ -152,6 +152,7 @@ PUREFUNC bool is_numeric_type(type_t *t); PUREFUNC size_t type_size(type_t *t); PUREFUNC size_t type_align(type_t *t); PUREFUNC size_t unpadded_struct_size(type_t *t); +PUREFUNC type_t *non_optional(type_t *t); type_t *get_field_type(type_t *t, const char *field_name); PUREFUNC type_t *get_iterated_type(type_t *t); |
