Fix up keyword args and default args

This commit is contained in:
Bruce Hill 2024-02-22 22:15:09 -05:00
parent 54b8b7af12
commit 9e2645ade7
3 changed files with 60 additions and 12 deletions

8
ast.c
View File

@ -56,9 +56,9 @@ CORD arg_list_to_cord(arg_ast_t *args) {
if (args->name)
c = CORD_cat(c, args->name);
if (args->type)
CORD_sprintf(&c, "%r:%s", c, type_ast_to_cord(args->type));
CORD_sprintf(&c, "%r:%r", c, type_ast_to_cord(args->type));
if (args->default_val)
CORD_sprintf(&c, "%r=%s", c, ast_to_cord(args->default_val));
CORD_sprintf(&c, "%r=%r", c, ast_to_cord(args->default_val));
if (args->next) c = CORD_cat(c, ", ");
}
return CORD_cat(c, ")");
@ -97,7 +97,7 @@ CORD ast_to_cord(ast_t *ast)
T(Var, "(\x1b[36;1m%s\x1b[m)", data.name)
T(Int, "(\x1b[35m%ld\x1b[m, bits=\x1b[35m%ld\x1b[m)", data.i, data.bits)
T(Num, "(\x1b[35m%ld\x1b[m, bits=\x1b[35m%ld\x1b[m)", data.n, data.bits)
T(StringLiteral, "\x1b[35m\"%r\"\x1b[m", data.cord)
T(StringLiteral, "%r", Str__quoted(data.cord, true))
T(StringJoin, "(%r)", ast_list_to_cord(data.children))
T(Declare, "(var=%s, value=%r)", ast_to_cord(data.var), ast_to_cord(data.value))
T(Assign, "(targets=%r, values=%r)", ast_list_to_cord(data.targets), ast_list_to_cord(data.values))
@ -142,7 +142,7 @@ CORD ast_to_cord(ast_t *ast)
T(LinkerDirective, "(%s)", Str__quoted(data.directive, true))
#undef T
}
return NULL;
return "???";
}
CORD type_ast_to_cord(type_ast_t *t)

View File

@ -524,7 +524,8 @@ CORD compile(env_t *env, ast_t *ast)
CORD name = compile(env, fndef->name);
CORD_appendf(&env->code->staticdefs, "static %r %r_(", fndef->ret_type ? compile_type_ast(fndef->ret_type) : "void", name);
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
CORD_appendf(&env->code->staticdefs, "%r %s", compile_type_ast(arg->type), arg->name);
type_t *arg_type = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->default_val);
CORD_appendf(&env->code->staticdefs, "%r %s", compile_type(arg_type), arg->name);
if (arg->next) env->code->staticdefs = CORD_cat(env->code->staticdefs, ", ");
}
env->code->staticdefs = CORD_cat(env->code->staticdefs, ");\n");
@ -535,13 +536,14 @@ CORD compile(env_t *env, ast_t *ast)
env_t *body_scope = fresh_scope(env);
body_scope->locals->fallback = env->globals;
for (arg_ast_t *arg = fndef->args; arg; arg = arg->next) {
CORD arg_type = compile_type_ast(arg->type);
CORD_appendf(&env->code->funcs, "%r %s", arg_type, arg->name);
type_t *arg_type = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->default_val);
CORD arg_typecode = compile_type(arg_type);
CORD_appendf(&env->code->funcs, "%r %s", arg_typecode, arg->name);
if (arg->next) env->code->funcs = CORD_cat(env->code->funcs, ", ");
CORD_appendf(&kwargs, "%r %s; ", arg_type, arg->name);
CORD_appendf(&kwargs, "%r %s; ", arg_typecode, arg->name);
CORD_appendf(&passed_args, "$args.%s", arg->name);
if (arg->next) passed_args = CORD_cat(passed_args, ", ");
set_binding(body_scope, arg->name, new(binding_t, .type=parse_type_ast(env, arg->type)));
set_binding(body_scope, arg->name, new(binding_t, .type=arg_type));
}
CORD_appendf(&kwargs, "} $args = {__VA_ARGS__}; %r_(%r); })\n", name, passed_args);
CORD_appendf(&env->code->staticdefs, "%r", kwargs);
@ -554,11 +556,57 @@ CORD compile(env_t *env, ast_t *ast)
}
case FunctionCall: {
auto call = Match(ast, FunctionCall);
type_t *fn_t = get_type(env, call->fn);
if (fn_t->tag != FunctionType)
code_err(call->fn, "This is not a function, it's a %T", fn_t);
CORD code = CORD_cat_char(compile(env, call->fn), '(');
// Pass 1: assign keyword args
// Pass 2: assign positional args
// Pass 3: compile and typecheck each arg
table_t arg_bindings = {};
for (ast_list_t *arg = call->args; arg; arg = arg->next) {
code = CORD_cat(code, compile(env, arg->ast));
if (arg->next) code = CORD_cat(code, ", ");
if (arg->ast->tag == KeywordArg)
Table_str_set(&arg_bindings, Match(arg->ast, KeywordArg)->name, Match(arg->ast, KeywordArg)->arg);
}
for (ast_list_t *call_arg = call->args; call_arg; call_arg = call_arg->next) {
if (call_arg->ast->tag == KeywordArg)
continue;
const char *name = NULL;
for (arg_t *fn_arg = Match(fn_t, FunctionType)->args; fn_arg; fn_arg = fn_arg->next) {
if (!Table_str_get(&arg_bindings, fn_arg->name)) {
name = fn_arg->name;
break;
}
}
if (name)
Table_str_set(&arg_bindings, name, call_arg->ast);
else
code_err(call_arg->ast, "This is too many arguments to the function: %T", fn_t);
}
for (arg_t *fn_arg = Match(fn_t, FunctionType)->args; fn_arg; fn_arg = fn_arg->next) {
ast_t *arg = Table_str_get(&arg_bindings, fn_arg->name);
if (arg) {
Table_str_remove(&arg_bindings, fn_arg->name);
} else {
arg = fn_arg->default_val;
}
if (!arg)
code_err(ast, "The required argument '%s' is not provided", fn_arg->name);
code = CORD_cat(code, compile(env, arg));
if (fn_arg->next) code = CORD_cat(code, ", ");
}
struct {
const char *name;
ast_t *ast;
} *invalid = Table_str_entry(&arg_bindings, 1);
if (invalid)
code_err(invalid->ast, "There is no argument named %s for %T", invalid->name, fn_t);
return CORD_cat_char(code, ')');
}
// Lambda,

View File

@ -161,7 +161,7 @@ type_t *get_function_def_type(env_t *env, ast_t *ast)
env_t *scope = fresh_scope(env);
for (arg_ast_t *arg = fn->args; arg; arg = arg->next) {
type_t *t = arg->type ? parse_type_ast(env, arg->type) : get_type(env, arg->default_val);
args = new(arg_t, .name=arg->name, .type=t, .next=args);
args = new(arg_t, .name=arg->name, .type=t, .default_val=arg->default_val, .next=args);
set_binding(scope, arg->name, new(binding_t, .type=t));
}
REVERSE_LIST(args);