aboutsummaryrefslogtreecommitdiff
path: root/typecheck.c
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck.c')
-rw-r--r--typecheck.c32
1 files changed, 32 insertions, 0 deletions
diff --git a/typecheck.c b/typecheck.c
index 2d84c622..da810258 100644
--- a/typecheck.c
+++ b/typecheck.c
@@ -126,6 +126,28 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast)
errx(1, "Unreachable");
}
+static PUREFUNC bool risks_zero_or_inf(ast_t *ast)
+{
+ switch (ast->tag) {
+ case Int: {
+ const char *str = Match(ast, Int)->str;
+ OptionalInt_t int_val = Int$from_str(str);
+ return (int_val.small == 0x1); // zero
+ }
+ case Num: {
+ return Match(ast, Num)->n == 0.0;
+ }
+ case BinaryOp: {
+ auto binop = Match(ast, BinaryOp);
+ if (binop->op == BINOP_MULT || binop->op == BINOP_DIVIDE || binop->op == BINOP_MIN || binop->op == BINOP_MAX)
+ return risks_zero_or_inf(binop->lhs) || risks_zero_or_inf(binop->rhs);
+ else
+ return true;
+ }
+ default: return true;
+ }
+}
+
PUREFUNC type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t)
{
(void)env;
@@ -1011,6 +1033,16 @@ type_t *get_type(env_t *env, ast_t *ast)
return result;
return Type(NumType, .bits=TYPE_NBITS64);
}
+ case BINOP_MULT: case BINOP_DIVIDE: {
+ type_t *math_type = get_math_type(env, ast, value_type(lhs_t), value_type(rhs_t));
+ if (value_type(lhs_t)->tag == NumType || value_type(rhs_t)->tag == NumType) {
+ if (risks_zero_or_inf(binop->lhs) && risks_zero_or_inf(binop->rhs))
+ return Type(OptionalType, math_type);
+ else
+ return math_type;
+ }
+ return math_type;
+ }
default: {
return get_math_type(env, ast, lhs_t, rhs_t);
}