diff options
| author | Bruce Hill <bruce@bruce-hill.com> | 2025-08-17 14:24:18 -0400 |
|---|---|---|
| committer | Bruce Hill <bruce@bruce-hill.com> | 2025-08-17 14:24:18 -0400 |
| commit | 724c0fcf9e227babe540dea175ce20438a466b88 (patch) | |
| tree | 41cf79a2563853688e64ec5d2d698ab8c52a8f6a /src | |
| parent | cdc6037af740566d5329cef9d303f06e81682780 (diff) | |
Major improvements to type inference to support `JSON({"key"=yes,
"key2"=[1,2,{"ok"=JSON.Null}]})` and similar fancy type inference stuff.
Diffstat (limited to 'src')
| -rw-r--r-- | src/compile.c | 23 | ||||
| -rw-r--r-- | src/typecheck.c | 118 |
2 files changed, 105 insertions, 36 deletions
diff --git a/src/compile.c b/src/compile.c index eb250d74..2b714acf 100644 --- a/src/compile.c +++ b/src/compile.c @@ -2123,6 +2123,7 @@ Text_t compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bo Text_t compile_to_type(env_t *env, ast_t *ast, type_t *t) { + assert(!is_incomplete_type(t)); if (ast->tag == Int && is_numeric_type(non_optional(t))) { return compile_int_to_type(env, ast, t); } else if (ast->tag == Num && t->tag == NumType) { @@ -2180,10 +2181,20 @@ Text_t compile_to_type(env_t *env, ast_t *ast, type_t *t) if (t->tag == PointerType && Match(t, PointerType)->is_stack && actual->tag != PointerType) return Texts("stack(", compile_to_type(env, ast, Match(t, PointerType)->pointed), ")"); - Text_t code = compile(env, ast); - if (!promote(env, ast, &code, actual, t)) - code_err(ast, "I expected a ", type_to_str(t), " here, but this is a ", type_to_str(actual)); - return code; + if (!is_incomplete_type(actual)) { + Text_t code = compile(env, ast); + if (promote(env, ast, &code, actual, t)) + return code; + } + + arg_ast_t *constructor_args = new(arg_ast_t, .value=ast); + binding_t *constructor = get_constructor(env, t, constructor_args, true); + if (constructor) { + arg_t *arg_spec = Match(constructor->type, FunctionType)->args; + return Texts(constructor->code, "(", compile_arguments(env, ast, arg_spec, constructor_args), ")"); + } + + code_err(ast, "I expected a ", type_to_str(t), " here, but this is a ", type_to_str(actual)); } Text_t compile_typed_list(env_t *env, ast_t *ast, type_t *list_type) @@ -2771,7 +2782,7 @@ Text_t compile(env_t *env, ast_t *ast) binding_t *b = get_namespace_binding(env, value, "negated"); if (b && b->type->tag == FunctionType) { DeclareMatch(fn, b->type, FunctionType); - if (fn->args && can_promote(t, get_arg_type(env, fn->args))) + if (fn->args && can_compile_to_type(env, value, get_arg_type(env, fn->args))) return Texts(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=value)), ")"); } @@ -2796,7 +2807,7 @@ Text_t compile(env_t *env, ast_t *ast) binding_t *b = get_namespace_binding(env, value, "negative"); if (b && b->type->tag == FunctionType) { DeclareMatch(fn, b->type, FunctionType); - if (fn->args && can_promote(t, get_arg_type(env, fn->args))) + if (fn->args && can_compile_to_type(env, value, get_arg_type(env, fn->args))) return Texts(b->code, "(", compile_arguments(env, ast, fn->args, new(arg_ast_t, .value=value)), ")"); } diff --git a/src/typecheck.c b/src/typecheck.c index 0ed2328f..d151c511 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -439,7 +439,10 @@ void bind_statement(env_t *env, ast_t *statement) for (tag_t *tag = tags; tag; tag = tag->next) { if (Match(tag->type, StructType)->fields) { // Constructor: type_t *constructor_t = Type(FunctionType, .args=Match(tag->type, StructType)->fields, .ret=type); - set_binding(ns_env, tag->name, constructor_t, namespace_name(env, env->namespace, Texts(def->name, "$tagged$", tag->name))); + Text_t tagged_name = namespace_name(env, env->namespace, Texts(def->name, "$tagged$", tag->name)); + set_binding(ns_env, tag->name, constructor_t, tagged_name); + binding_t binding = {.type=constructor_t, .code=tagged_name}; + List$insert(&ns_env->namespace->constructors, &binding, I(1), sizeof(binding)); } else if (has_any_tags_with_fields) { // Empty singleton value: Text_t code = Texts("((", namespace_name(env, env->namespace, Texts(def->name, "$$type")), "){", namespace_name(env, env->namespace, Texts(def->name, "$tag$", tag->name)), "})"); @@ -733,9 +736,7 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *t2 = get_type(scope, item_ast); type_t *merged = item_type ? type_or_type(item_type, t2) : t2; if (!merged) - code_err(item->ast, - "This list item has type ", type_to_str(t2), - ", which is different from earlier list items which have type ", type_to_str(item_type)); + return Type(ListType, .item_type=NULL); item_type = merged; } @@ -760,9 +761,7 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *this_item_type = get_type(scope, item_ast); type_t *item_merged = type_or_type(item_type, this_item_type); if (!item_merged) - code_err(item_ast, - "This set item has type ", type_to_str(this_item_type), - ", which is different from earlier set items which have type ", type_to_str(item_type)); + return Type(SetType, .item_type=NULL); item_type = item_merged; } @@ -774,6 +773,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Table: { DeclareMatch(table, ast, Table); type_t *key_type = NULL, *value_type = NULL; + bool ambiguous_key_type = false, ambiguous_value_type = false; for (ast_list_t *entry = table->entries; entry; entry = entry->next) { ast_t *entry_ast = entry->ast; env_t *scope = env; @@ -790,19 +790,21 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; if (!key_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(key_t), - ", which is different from earlier table entries which have type ", type_to_str(key_type)); + ambiguous_key_type = true; key_type = key_merged; type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; if (!val_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(value_t), - ", which is different from earlier table entries which have type ", type_to_str(value_type)); + ambiguous_value_type = true; value_type = val_merged; } + if (ambiguous_key_type) + key_type = NULL; + + if (ambiguous_value_type) + value_type = NULL; + if ((key_type && has_stack_memory(key_type)) || (value_type && has_stack_memory(value_type))) code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame."); @@ -898,7 +900,11 @@ type_t *get_type(env_t *env, ast_t *ast) else if (t->tag == StructType || t->tag == IntType || t->tag == BigIntType || t->tag == NumType || t->tag == ByteType || t->tag == TextType || t->tag == CStringType) return t; // Constructor - code_err(call->fn, "This is not a type that has a constructor"); + arg_t *arg_types = NULL; + for (arg_ast_t *arg = call->args; arg; arg = arg->next) + arg_types = new(arg_t, .type=get_type(env, arg->value), .name=arg->name, .next=arg_types); + REVERSE_LIST(arg_types); + code_err(call->fn, "I couldn't find a type constructor for ", type_to_text(Type(FunctionType, .args=arg_types, .ret=t))); } if (fn_type_t->tag == ClosureType) fn_type_t = Match(fn_type_t, ClosureType)->fn; @@ -1131,9 +1137,9 @@ type_t *get_type(env_t *env, ast_t *ast) } else if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) && lhs_t->tag != NumType && rhs_t->tag != NumType) { - if (can_promote(rhs_t, lhs_t)) + if (can_compile_to_type(env, binop.rhs, lhs_t)) return lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (can_compile_to_type(env, binop.lhs, rhs_t)) return rhs_t; } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { return lhs_t; @@ -1169,9 +1175,9 @@ type_t *get_type(env_t *env, ast_t *ast) if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) && lhs_t->tag != NumType && rhs_t->tag != NumType) { - if (can_promote(rhs_t, lhs_t)) + if (can_compile_to_type(env, binop.rhs, lhs_t)) return lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (can_compile_to_type(env, binop.lhs, rhs_t)) return rhs_t; } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { return lhs_t; @@ -1207,9 +1213,9 @@ type_t *get_type(env_t *env, ast_t *ast) if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) && lhs_t->tag != NumType && rhs_t->tag != NumType) { - if (can_promote(rhs_t, lhs_t)) + if (can_compile_to_type(env, binop.rhs, lhs_t)) return lhs_t; - else if (can_promote(lhs_t, rhs_t)) + else if (can_compile_to_type(env, binop.lhs, rhs_t)) return rhs_t; } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { return lhs_t; @@ -1224,8 +1230,8 @@ type_t *get_type(env_t *env, ast_t *ast) if ((binop.lhs->tag == Int && is_numeric_type(rhs_t)) || (binop.rhs->tag == Int && is_numeric_type(lhs_t)) - || can_promote(rhs_t, lhs_t) - || can_promote(lhs_t, rhs_t)) + || can_compile_to_type(env, binop.rhs, lhs_t) + || can_compile_to_type(env, binop.lhs, rhs_t)) return ast->tag == Compare ? Type(IntType, .bits=TYPE_IBITS32) : Type(BoolType); code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); @@ -1299,7 +1305,7 @@ type_t *get_type(env_t *env, ast_t *ast) } } - type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + type_t *overall_t = (can_compile_to_type(env, binop.rhs, lhs_t) ? lhs_t : (can_compile_to_type(env, binop.lhs, rhs_t) ? rhs_t : NULL)); if (overall_t == NULL) code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); @@ -1316,7 +1322,7 @@ type_t *get_type(env_t *env, ast_t *ast) type_t *lhs_t = get_type(env, binop.lhs); type_t *rhs_t = get_type(env, binop.rhs); - type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + type_t *overall_t = (can_compile_to_type(env, binop.rhs, lhs_t) ? lhs_t : (can_compile_to_type(env, binop.lhs, rhs_t) ? rhs_t : NULL)); if (overall_t == NULL) code_err(ast, "I don't know how to do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); @@ -1579,8 +1585,64 @@ type_t *get_arg_type(env_t *env, arg_t *arg) return get_type(env, arg->default_val); } +static Table_t *get_arg_bindings_with_promotion(env_t *env, arg_t *spec_args, arg_ast_t *call_args) +{ + Table_t used_args = {}; + + // Populate keyword args: + for (arg_ast_t *call_arg = call_args; call_arg; call_arg = call_arg->next) { + if (!call_arg->name) continue; + + for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) { + if (!streq(call_arg->name, spec_arg->name)) continue; + type_t *spec_type = get_arg_type(env, spec_arg); + if (!can_compile_to_type(env, call_arg->value, spec_type)) + return NULL; + Table$str_set(&used_args, call_arg->name, call_arg); + goto next_call_arg; + } + return NULL; + next_call_arg:; + } + + arg_ast_t *unused_args = call_args; + for (arg_t *spec_arg = spec_args; spec_arg; spec_arg = spec_arg->next) { + arg_ast_t *keyworded = Table$str_get(used_args, spec_arg->name); + if (keyworded) continue; + + type_t *spec_type = get_arg_type(env, spec_arg); + for (; unused_args; unused_args = unused_args->next) { + if (unused_args->name) continue; // Already handled the keyword args + if (!can_compile_to_type(env, unused_args->value, spec_type)) + return NULL; // Positional arg trying to fill in + Table$str_set(&used_args, spec_arg->name, unused_args); + unused_args = unused_args->next; + goto found_it; + } + + if (spec_arg->default_val) + goto found_it; + + return NULL; + found_it: continue; + } + + while (unused_args && unused_args->name) + unused_args = unused_args->next; + + if (unused_args != NULL) + return NULL; + + Table_t *ret = new(Table_t); + *ret = used_args; + return ret; +} + Table_t *get_arg_bindings(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bool promotion_allowed) { + if (promotion_allowed) + return get_arg_bindings_with_promotion(env, spec_args, call_args); + Table_t used_args = {}; // Populate keyword args: @@ -1593,9 +1655,7 @@ Table_t *get_arg_bindings(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bo type_t *spec_type = get_arg_type(env, spec_arg); type_t *complete_call_type = is_incomplete_type(call_type) ? most_complete_type(call_type, spec_type) : call_type; if (!complete_call_type) return NULL; - if (!(type_eq(complete_call_type, spec_type) || (promotion_allowed && can_promote(complete_call_type, spec_type)) - || (promotion_allowed && call_arg->value->tag == Int && is_numeric_type(spec_type)) - || (promotion_allowed && call_arg->value->tag == Num && spec_type->tag == NumType))) + if (!type_eq(complete_call_type, spec_type)) return NULL; Table$str_set(&used_args, call_arg->name, call_arg); goto next_call_arg; @@ -1615,9 +1675,7 @@ Table_t *get_arg_bindings(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bo type_t *call_type = get_arg_ast_type(env, unused_args); type_t *complete_call_type = is_incomplete_type(call_type) ? most_complete_type(call_type, spec_type) : call_type; if (!complete_call_type) return NULL; - if (!(type_eq(complete_call_type, spec_type) || (promotion_allowed && can_promote(complete_call_type, spec_type)) - || (promotion_allowed && unused_args->value->tag == Int && is_numeric_type(spec_type)) - || (promotion_allowed && unused_args->value->tag == Num && spec_type->tag == NumType))) + if (!type_eq(complete_call_type, spec_type)) return NULL; // Positional arg trying to fill in Table$str_set(&used_args, spec_arg->name, unused_args); unused_args = unused_args->next; |
