aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ast.h3
-rw-r--r--src/compile/expressions.c3
-rw-r--r--src/compile/whens.c12
-rw-r--r--test/values.tm8
4 files changed, 20 insertions, 6 deletions
diff --git a/src/ast.h b/src/ast.h
index 7fa9092a..5307fc2c 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -23,6 +23,9 @@
#define LiteralCode(code, ...) \
new (ast_t, .tag = InlineCCode, \
.__data.InlineCCode = {.chunks = new (ast_list_t, .ast = FakeAST(TextLiteral, code)), __VA_ARGS__})
+#define WrapLiteralCode(ast, code, ...) \
+ new (ast_t, .tag = InlineCCode, .file = (ast)->file, .start = (ast)->start, .end = (ast)->end, \
+ .__data.InlineCCode = {.chunks = new (ast_list_t, .ast = WrapAST(ast, TextLiteral, code)), __VA_ARGS__})
#define Match(x, _tag) \
((x)->tag == _tag ? &(x)->__data._tag \
: (errx(1, __FILE__ ":%d This was supposed to be a " #_tag "\n", __LINE__), &(x)->__data._tag))
diff --git a/src/compile/expressions.c b/src/compile/expressions.c
index a69a7d56..19e0672e 100644
--- a/src/compile/expressions.c
+++ b/src/compile/expressions.c
@@ -30,8 +30,7 @@ Text_t compile_maybe_incref(env_t *env, ast_t *ast, type_t *t) {
} else {
Text_t code = Texts("({ ", compile_declaration(t, Text("_tmp")), " = ", compile_to_type(env, ast, t), "; ",
"((", compile_type(t), "){");
- ast_t *tmp = WrapAST(ast, InlineCCode,
- .chunks = new (ast_list_t, .ast = WrapAST(ast, TextLiteral, Text("_tmp"))), .type = t);
+ ast_t *tmp = WrapLiteralCode(ast, Text("_tmp"), .type = t);
for (arg_t *field = Match(t, StructType)->fields; field; field = field->next) {
Text_t val = compile_maybe_incref(env, WrapAST(ast, FieldAccess, .fielded = tmp, .field = field->name),
get_arg_type(env, field));
diff --git a/src/compile/whens.c b/src/compile/whens.c
index 4f6a2a40..122c581c 100644
--- a/src/compile/whens.c
+++ b/src/compile/whens.c
@@ -83,8 +83,10 @@ Text_t compile_when_statement(env_t *env, ast_t *ast) {
const char *var_name = Match(args->value, Var)->name;
if (!streq(var_name, "_")) {
Text_t var = Texts("_$", var_name);
- code = Texts(code, compile_declaration(tag_type, var), " = _when_subject.",
- valid_c_name(clause_tag_name), ";\n");
+ ast_t *member =
+ WrapLiteralCode(ast, Texts("_when_subject.", valid_c_name(clause_tag_name)), .type = tag_type);
+ code = Texts(code, compile_declaration(tag_type, var), " = ",
+ compile_maybe_incref(env, member, tag_type), ";\n");
scope = fresh_scope(scope);
set_binding(scope, Match(args->value, Var)->name, tag_type, EMPTY_TEXT);
}
@@ -101,8 +103,10 @@ Text_t compile_when_statement(env_t *env, ast_t *ast) {
const char *var_name = Match(arg->value, Var)->name;
if (!streq(var_name, "_")) {
Text_t var = Texts("_$", var_name);
- code = Texts(code, compile_declaration(field->type, var), " = _when_subject.",
- valid_c_name(clause_tag_name), ".", valid_c_name(field->name), ";\n");
+ ast_t *member =
+ WrapLiteralCode(ast, Texts("_when_subject.", valid_c_name(clause_tag_name)), .type = tag_type);
+ code = Texts(code, compile_declaration(field->type, var), " = ",
+ compile_maybe_incref(env, member, tag_type), ".", valid_c_name(field->name), ";\n");
set_binding(scope, Match(arg->value, Var)->name, field->type, var);
}
field = field->next;
diff --git a/test/values.tm b/test/values.tm
index 86f34a89..9f86c012 100644
--- a/test/values.tm
+++ b/test/values.tm
@@ -3,6 +3,8 @@ struct Inner(xs:[Int32])
struct Outer(inner:Inner)
+enum HoldsList(HasList(xs:[Int32]))
+
func sneaky(outer:Outer)
(&outer.inner.xs)[1] = 99
@@ -45,3 +47,9 @@ func main()
assert foo.inner.xs == [99, 20, 30]
assert copy.inner.xs == [10, 20, 30]
+ do
+ x := HoldsList.HasList([10, 20, 30])
+ when x is HasList(list)
+ (&list)[1] = 99
+
+ assert x == HoldsList.HasList([10, 20, 30])