diff --git a/docs/arrays.md b/docs/arrays.md index e3da6e0..437ab76 100644 --- a/docs/arrays.md +++ b/docs/arrays.md @@ -18,7 +18,7 @@ you want to have an empty array, you must specify what type goes inside the arra like this: ```tomo -empty := [:Int] +empty : [Int] = [] ``` For type annotations, an array that holds items with type `T` is written as `[T]`. diff --git a/docs/functions.md b/docs/functions.md index edd1261..9f95277 100644 --- a/docs/functions.md +++ b/docs/functions.md @@ -82,7 +82,7 @@ func _add(x, y:Int -> Int): return x + y struct add_args(x,y:Int) -add_cache := @{:add_args,Int} +add_cache : @{add_args=Int} = @{} func add(x, y:Int -> Int): args := add_args(x, y) diff --git a/docs/integers.md b/docs/integers.md index d4e2fc3..6e45f1b 100644 --- a/docs/integers.md +++ b/docs/integers.md @@ -361,7 +361,7 @@ An iterator function that counts onward from the starting integer. **Example:** ```tomo -nums := &[:Int] +nums : &[Int] = &[] for i in 5:onward(): nums:insert(i) stop if i == 10 diff --git a/docs/operators.md b/docs/operators.md index a4b68fa..a304cf3 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -75,7 +75,7 @@ first option is to not account for it, in which case you'll get a runtime error if you use a reducer on something that has no values: ```tomo ->> nums := [:Int] +>> nums : [Int] = [] >> (+: nums)! Error: this collection was empty! diff --git a/docs/reductions.md b/docs/reductions.md index 58cfa31..959d970 100644 --- a/docs/reductions.md +++ b/docs/reductions.md @@ -18,7 +18,7 @@ a runtime check and error if there's a null value, or you can use `or` to provide a fallback value: ```tomo -nums := [:Int] +nums : [Int] = [] sum := (+: nums) >> sum diff --git a/docs/sets.md b/docs/sets.md index 778740a..740a37f 100644 --- a/docs/sets.md +++ b/docs/sets.md @@ -21,7 +21,7 @@ nums := {10, 20, 30} Empty sets must specify the item type explicitly: ```tomo -empty := {:Int} +empty : {Int} = {} ``` For type annotations, a set that holds items with type `T` is written as `{T}`. diff --git a/docs/tables.md b/docs/tables.md index f0045ef..83c80b2 100644 --- a/docs/tables.md +++ b/docs/tables.md @@ -17,7 +17,7 @@ table := {"A"=10, "B"=20} Empty tables must specify the key and value types explicitly: ```tomo -empty := {:Text=Int} +empty : {Text=Int} = {} ``` For type annotations, a table that maps keys with type `K` to values of type diff --git a/examples/colorful/colorful.tm b/examples/colorful/colorful.tm index 78a831c..57a2a3d 100644 --- a/examples/colorful/colorful.tm +++ b/examples/colorful/colorful.tm @@ -24,7 +24,7 @@ lang Colorful: say(c:for_terminal(), newline=newline) -func main(texts:[Text], files=[:Path], by_line=no): +func main(texts:[Text], files:[Path]=[], by_line=no): for i,text in texts: colorful := Colorful.from_text(text) colorful:print(newline=no) @@ -141,7 +141,7 @@ struct _TermState( ): func apply(old,new:_TermState -> Text): - sequences := &[:Text] + sequences : &[Text] = &[] _toggle2(sequences, old.bold, old.dim, new.bold, new.dim, "1", "2", "22") _toggle2(sequences, old.italic, old.fraktur, new.italic, new.fraktur, "3", "20", "23") _toggle(sequences, old.underline, new.underline, "4", "24") diff --git a/examples/commands/commands.tm b/examples/commands/commands.tm index cbf5439..ddd04d7 100644 --- a/examples/commands/commands.tm +++ b/examples/commands/commands.tm @@ -46,16 +46,16 @@ struct ProgramResult(stdout:[Byte], stderr:[Byte], exit_type:ExitType): else: return no -struct Command(command:Text, args=[:Text], env={:Text=Text}): - func from_path(path:Path, args=[:Text], env={:Text=Text} -> Command): +struct Command(command:Text, args:[Text]=[], env:{Text=Text}={}): + func from_path(path:Path, args:[Text]=[], env:{Text=Text}={} -> Command): return Command(Text(path), args, env) - func result(command:Command, input="", input_bytes=[:Byte] -> ProgramResult): + func result(command:Command, input="", input_bytes:[Byte]=[] -> ProgramResult): if input.length > 0: (&input_bytes):insert_all(input:bytes()) - stdout := [:Byte] - stderr := [:Byte] + stdout : [Byte] = [] + stderr : [Byte] = [] status := run_command(command.command, command.args, command.env, input_bytes, &stdout, &stderr) if inline C : Bool { WIFEXITED(_$status) }: @@ -80,7 +80,7 @@ struct Command(command:Text, args=[:Text], env={:Text=Text}): func get_output(command:Command, input="", trim_newline=yes -> Text?): return command:result(input=input):output_text(trim_newline=trim_newline) - func get_output_bytes(command:Command, input="", input_bytes=[:Byte] -> [Byte]?): + func get_output_bytes(command:Command, input="", input_bytes:[Byte]=[] -> [Byte]?): result := command:result(input=input, input_bytes=input_bytes) when result.exit_type is Exited(status): if status == 0: return result.stdout diff --git a/examples/game/game.tm b/examples/game/game.tm index b034f68..36ef14f 100644 --- a/examples/game/game.tm +++ b/examples/game/game.tm @@ -10,7 +10,7 @@ func main(map=(./map.txt)): world := @World( player=@Player(Vector2(0,0), Vector2(0,0)), goal=@Box(Vector2(0,0), Vector2(50,50), color=Color(0x10,0xa0,0x10)), - boxes=@[:@Box], + boxes=@[], ) world:load_map(map_contents) diff --git a/examples/game/world.tm b/examples/game/world.tm index 809f1f8..8b721c5 100644 --- a/examples/game/world.tm +++ b/examples/game/world.tm @@ -70,7 +70,7 @@ struct World(player:@Player, goal:@Box, boxes:@[@Box], dt_accum=Num32(0.0), won= func load_map(w:@World, map:Text): if map:has("[]"): map = map:translate({"[]"="#", "@ "="@", " "=" "}) - w.boxes = @[:@Box] + w.boxes = @[] box_size := Vector2(50., 50.) for y,line in map:lines(): for x,cell in line:split(): diff --git a/examples/http-server/connection-queue.tm b/examples/http-server/connection-queue.tm index a198f09..362dab7 100644 --- a/examples/http-server/connection-queue.tm +++ b/examples/http-server/connection-queue.tm @@ -3,7 +3,7 @@ use pthreads func _assert_success(name:Text, val:Int32; inline): fail("$name() failed!") if val < 0 -struct ConnectionQueue(_connections=@[:Int32], _mutex=pthread_mutex_t.new(), _cond=pthread_cond_t.new()): +struct ConnectionQueue(_connections:@[Int32]=@[], _mutex=pthread_mutex_t.new(), _cond=pthread_cond_t.new()): func enqueue(queue:ConnectionQueue, connection:Int32): queue._mutex:lock() queue._connections:insert(connection) diff --git a/examples/http-server/http-server.tm b/examples/http-server/http-server.tm index f7338b9..56ba368 100644 --- a/examples/http-server/http-server.tm +++ b/examples/http-server/http-server.tm @@ -17,7 +17,7 @@ use ./connection-queue.tm func serve(port:Int32, handler:func(request:HTTPRequest -> HTTPResponse), num_threads=16): connections := ConnectionQueue() - workers := &[:@pthread_t] + workers : &[@pthread_t] = &[] for i in num_threads: workers:insert(pthread_t.new(func(): repeat: @@ -82,7 +82,7 @@ struct HTTPRequest(method:Text, path:Text, version:Text, headers:[Text], body:Te body := rest[-1] return HTTPRequest(method, path, version, headers, body) -struct HTTPResponse(body:Text, status=200, content_type="text/plain", headers={:Text=Text}): +struct HTTPResponse(body:Text, status=200, content_type="text/plain", headers:{Text=Text}={}): func bytes(r:HTTPResponse -> [Byte]): body_bytes := r.body:bytes() extra_headers := (++: "$k: $v$(\r\n)" for k,v in r.headers) or "" @@ -114,7 +114,7 @@ enum RouteEntry(ServeFile(file:Path), Redirect(destination:Text)): return HTTPResponse("Found", 302, headers={"Location"=destination}) func load_routes(directory:Path -> {Text=RouteEntry}): - routes := &{:Text=RouteEntry} + routes : &{Text=RouteEntry} = &{} for file in (directory ++ (./*)):glob(): skip unless file:is_file() contents := file:read() or skip diff --git a/examples/http/http.tm b/examples/http/http.tm index 12203fa..ee1b8c7 100644 --- a/examples/http/http.tm +++ b/examples/http/http.tm @@ -7,8 +7,8 @@ struct HTTPResponse(code:Int, body:Text) enum _Method(GET, POST, PUT, PATCH, DELETE) -func _send(method:_Method, url:Text, data:Text?, headers=[:Text] -> HTTPResponse): - chunks := @[:Text] +func _send(method:_Method, url:Text, data:Text?, headers:[Text]=[] -> HTTPResponse): + chunks : @[Text] = @[] save_chunk := func(chunk:CString, size:Int64, n:Int64): chunks:insert(inline C:Text { Text$format("%.*s", _$size*_$n, _$chunk) @@ -81,7 +81,7 @@ func _send(method:_Method, url:Text, data:Text?, headers=[:Text] -> HTTPResponse } return HTTPResponse(Int(code), "":join(chunks)) -func get(url:Text, headers=[:Text] -> HTTPResponse): +func get(url:Text, headers:[Text]=[] -> HTTPResponse): return _send(GET, url, none, headers) func post(url:Text, data="", headers=["Content-Type: application/json", "Accept: application/json"] -> HTTPResponse): diff --git a/examples/ini/ini.tm b/examples/ini/ini.tm index 1e8e015..d67dd0d 100644 --- a/examples/ini/ini.tm +++ b/examples/ini/ini.tm @@ -11,8 +11,8 @@ _HELP := " func parse_ini(path:Path -> {Text={Text=Text}}): text := path:read() or exit("Could not read INI file: $\[31;1]$(path)$\[]") - sections := @{:Text=@{Text=Text}} - current_section := @{:Text=Text} + sections : @{Text=@{Text=Text}} = @{} + current_section : @{Text=Text} = @{} # Line wraps: text = text:replace_pattern($Pat/\{1 nl}{0+space}/, " ") @@ -22,7 +22,7 @@ func parse_ini(path:Path -> {Text={Text=Text}}): skip if line:starts_with(";") or line:starts_with("#") if line:matches_pattern($Pat/[?]/): section_name := line:replace($Pat/[?]/, "\1"):trim():lower() - current_section = @{:Text=Text} + current_section = @{} sections[section_name] = current_section else if line:matches_pattern($Pat/{..}={..}/): key := line:replace_pattern($Pat/{..}={..}/, "\1"):trim():lower() diff --git a/examples/learnxiny.tm b/examples/learnxiny.tm index a394e58..b585a2a 100644 --- a/examples/learnxiny.tm +++ b/examples/learnxiny.tm @@ -59,7 +59,7 @@ func main(): my_numbers := [10, 20, 30] # Empty arrays require specifying the type: - empty_array := [:Int] + empty_array : [Int] = [] >> empty_array.length = 0 @@ -123,7 +123,7 @@ func main(): = 0 # Empty tables require specifying the key and value types: - empty_table := {:Text=Int} + empty_table : {Text=Int} = {} # Tables can be iterated over either by key or key,value: for key in table: diff --git a/examples/log/log.tm b/examples/log/log.tm index 5e32a2b..f4b0b39 100644 --- a/examples/log/log.tm +++ b/examples/log/log.tm @@ -3,7 +3,7 @@ use timestamp_format := CString("%F %T") -logfiles := @{:Path} +logfiles : @{Path} = @{} func _timestamp(->Text): c_str := inline C:CString { diff --git a/examples/pthreads/pthreads.tm b/examples/pthreads/pthreads.tm index 975b981..fb79e82 100644 --- a/examples/pthreads/pthreads.tm +++ b/examples/pthreads/pthreads.tm @@ -65,7 +65,7 @@ struct pthread_t(; extern, opaque): func detatch(p:pthread_t): inline C { pthread_detach(_$p); } struct IntQueue(_queue:@[Int], _mutex:@pthread_mutex_t, _cond:@pthread_cond_t): - func new(initial=[:Int] -> IntQueue): + func new(initial:[Int]=[] -> IntQueue): return IntQueue(@initial, pthread_mutex_t.new(), pthread_cond_t.new()) func give(q:IntQueue, n:Int): diff --git a/examples/random/README.md b/examples/random/README.md index 697f3f7..6233c1b 100644 --- a/examples/random/README.md +++ b/examples/random/README.md @@ -110,7 +110,7 @@ A copy of the given RNG. **Example:** ```tomo ->> rng := RNG.new([:Byte]) +>> rng := RNG.new([]) >> copy := rng:copy() >> rng:bytes(10) diff --git a/examples/random/random.tm b/examples/random/random.tm index 1d6e560..0a0167a 100644 --- a/examples/random/random.tm +++ b/examples/random/random.tm @@ -4,7 +4,7 @@ use ./sysrandom.h use ./chacha.h struct chacha_ctx(j0,j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14,j15:Int32; extern, secret): - func from_seed(seed=[:Byte] -> chacha_ctx): + func from_seed(seed:[Byte]=[] -> chacha_ctx): return inline C : chacha_ctx { chacha_ctx ctx; uint8_t seed_bytes[KEYSZ + IVSZ] = {}; @@ -24,10 +24,10 @@ func _os_random_bytes(count:Int64 -> [Byte]): (Array_t){.length=_$count, .data=random_bytes, .stride=1, .atomic=1}; } -struct RandomNumberGenerator(_chacha:chacha_ctx, _random_bytes=[:Byte]; secret): +struct RandomNumberGenerator(_chacha:chacha_ctx, _random_bytes:[Byte]=[]; secret): func new(seed=none:[Byte], -> @RandomNumberGenerator): ctx := chacha_ctx.from_seed(seed or _os_random_bytes(40)) - return @RandomNumberGenerator(ctx, [:Byte]) + return @RandomNumberGenerator(ctx, []) func _rekey(rng:&RandomNumberGenerator): rng._random_bytes = inline C : [Byte] { diff --git a/examples/shell/shell.tm b/examples/shell/shell.tm index 9ca9e05..dd26428 100644 --- a/examples/shell/shell.tm +++ b/examples/shell/shell.tm @@ -24,7 +24,7 @@ lang Shell: func command(shell:Shell -> Command): return Command("sh", ["-c", shell.text]) - func result(shell:Shell, input="", input_bytes=[:Byte] -> ProgramResult): + func result(shell:Shell, input="", input_bytes:[Byte]=[] -> ProgramResult): return shell:command():result(input=input, input_bytes=input_bytes) func run(shell:Shell -> ExitType): @@ -33,7 +33,7 @@ lang Shell: func get_output(shell:Shell, input="", trim_newline=yes -> Text?): return shell:command():get_output(input=input, trim_newline=trim_newline) - func get_output_bytes(shell:Shell, input="", input_bytes=[:Byte] -> [Byte]?): + func get_output_bytes(shell:Shell, input="", input_bytes:[Byte]=[] -> [Byte]?): return shell:command():get_output_bytes(input=input, input_bytes=input_bytes) func by_line(shell:Shell -> func(->Text?)?): diff --git a/examples/tomo-install/tomo-install.tm b/examples/tomo-install/tomo-install.tm index e584fe6..c705af1 100644 --- a/examples/tomo-install/tomo-install.tm +++ b/examples/tomo-install/tomo-install.tm @@ -11,7 +11,7 @@ _HELP := " " func find_urls(path:Path -> [Text]): - urls := @[:Text] + urls : @[Text] = @[] if path:is_directory(): for f in path:children(): urls:insert_all(find_urls(f)) @@ -25,7 +25,7 @@ func main(paths:[Path]): if paths.length == 0: paths = [(./)] - urls := (++: find_urls(p) for p in paths) or [:Text] + urls := (++: find_urls(p) for p in paths) or [] github_token := (~/.config/tomo/github-token):read() diff --git a/examples/tomodeps/tomodeps.tm b/examples/tomodeps/tomodeps.tm index 96838a6..8da64eb 100644 --- a/examples/tomodeps/tomodeps.tm +++ b/examples/tomodeps/tomodeps.tm @@ -14,9 +14,9 @@ enum Dependency(File(path:Path), Module(name:Text)) func _get_file_dependencies(file:Path -> {Dependency}): if not file:is_file(): !! Could not read file: $file - return {:Dependency} + return {} - deps := @{:Dependency} + deps : @{Dependency} = @{} if lines := file:by_line(): for line in lines: if line:matches_pattern($Pat/use {..}.tm/): @@ -30,14 +30,14 @@ func _get_file_dependencies(file:Path -> {Dependency}): func _build_dependency_graph(dep:Dependency, dependencies:@{Dependency,{Dependency}}): return if dependencies:has(dep) - dependencies[dep] = {:Dependency} # Placeholder + dependencies[dep] = {} # Placeholder dep_deps := when dep is File(path): _get_file_dependencies(path) is Module(module): dir := (~/.local/share/tomo/installed/$module) - module_deps := @{:Dependency} - visited := @{:Path} + module_deps : @{Dependency} = @{} + visited : @{Path} = @{} unvisited := @{f:resolved() for f in dir:files() if f:extension() == ".tm"} while unvisited.length > 0: file := unvisited.items[-1] @@ -58,7 +58,7 @@ func _build_dependency_graph(dep:Dependency, dependencies:@{Dependency,{Dependen _build_dependency_graph(dep2, dependencies) func get_dependency_graph(dep:Dependency -> {Dependency,{Dependency}}): - graph := @{:Dependency,{Dependency}} + graph : @{Dependency={Dependency}} = @{} _build_dependency_graph(dep, graph) return graph @@ -82,16 +82,16 @@ func _draw_tree(dep:Dependency, dependencies:{Dependency,{Dependency}}, already_ child_prefix := prefix ++ (if is_last: " " else: "│ ") - children := dependencies[dep] or {:Dependency} + children := dependencies[dep] or {} for i,child in children.items: is_child_last := (i == children.length) _draw_tree(child, dependencies, already_printed, child_prefix, is_child_last) func draw_tree(dep:Dependency, dependencies:{Dependency,{Dependency}}): - printed := @{:Dependency} + printed : @{Dependency} = @{} say(_printable_name(dep)) printed:add(dep) - deps := dependencies[dep] or {:Dependency} + deps := dependencies[dep] or {} for i,child in deps.items: is_child_last := (i == deps.length) _draw_tree(child, dependencies, already_printed=printed, is_last=is_child_last) diff --git a/examples/wrap/wrap.tm b/examples/wrap/wrap.tm index c90713a..61ca582 100644 --- a/examples/wrap/wrap.tm +++ b/examples/wrap/wrap.tm @@ -33,7 +33,7 @@ func wrap(text:Text, width:Int, min_split=3, hyphen="-" -> Text): ... and I can't split it without splitting into chunks smaller than $min_split. ") - lines := @[:Text] + lines : @[Text] = @[] line := "" for word in text:split($/{whitespace}/): letters := word:split() @@ -93,7 +93,7 @@ func main(files:[Path], width=80, inplace=no, min_split=3, rewrap=yes, hyphen=UN (/dev/stdout) first := yes - wrapped_paragraphs := @[:Text] + wrapped_paragraphs : @[Text] = @[] for paragraph in text:split($/{2+ nl}/): wrapped_paragraphs:insert( wrap(paragraph, width=width, min_split=min_split, hyphen=hyphen) diff --git a/src/ast.c b/src/ast.c index 84d25db..67b54f9 100644 --- a/src/ast.c +++ b/src/ast.c @@ -10,23 +10,47 @@ #include "stdlib/text.h" #include "cordhelpers.h" -static const char *OP_NAMES[] = { - [BINOP_UNKNOWN]="unknown", - [BINOP_POWER]="^", [BINOP_MULT]="*", [BINOP_DIVIDE]="/", - [BINOP_MOD]="mod", [BINOP_MOD1]="mod1", [BINOP_PLUS]="+", [BINOP_MINUS]="minus", - [BINOP_CONCAT]="++", [BINOP_LSHIFT]="<<", [BINOP_ULSHIFT]="<<<", - [BINOP_RSHIFT]=">>", [BINOP_URSHIFT]=">>>", [BINOP_MIN]="min", - [BINOP_MAX]="max", [BINOP_EQ]="==", [BINOP_NE]="!=", [BINOP_LT]="<", - [BINOP_LE]="<=", [BINOP_GT]=">", [BINOP_GE]=">=", [BINOP_CMP]="<>", - [BINOP_AND]="and", [BINOP_OR]="or", [BINOP_XOR]="xor", +CONSTFUNC const char *binop_method_name(ast_e tag) { + switch (tag) { + case Power: case PowerUpdate: return "power"; + case Multiply: case MultiplyUpdate: return "times"; + case Divide: case DivideUpdate: return "divided_by"; + case Mod: case ModUpdate: return "modulo"; + case Mod1: case Mod1Update: return "modulo1"; + case Plus: case PlusUpdate: return "plus"; + case Minus: case MinusUpdate: return "minus"; + case Concat: case ConcatUpdate: return "concatenated_with"; + case LeftShift: case LeftShiftUpdate: return "left_shifted"; + case RightShift: case RightShiftUpdate: return "right_shifted"; + case UnsignedLeftShift: case UnsignedLeftShiftUpdate: return "unsigned_left_shifted"; + case UnsignedRightShift: case UnsignedRightShiftUpdate: return "unsigned_right_shifted"; + case And: case AndUpdate: return "bit_and"; + case Or: case OrUpdate: return "bit_or"; + case Xor: case XorUpdate: return "bit_xor"; + default: return NULL; + } }; -const char *binop_method_names[BINOP_XOR+1] = { - [BINOP_POWER]="power", [BINOP_MULT]="times", [BINOP_DIVIDE]="divided_by", - [BINOP_MOD]="modulo", [BINOP_MOD1]="modulo1", [BINOP_PLUS]="plus", [BINOP_MINUS]="minus", - [BINOP_CONCAT]="concatenated_with", [BINOP_LSHIFT]="left_shifted", [BINOP_RSHIFT]="right_shifted", - [BINOP_ULSHIFT]="unsigned_left_shifted", [BINOP_URSHIFT]="unsigned_right_shifted", - [BINOP_AND]="bit_and", [BINOP_OR]="bit_or", [BINOP_XOR]="bit_xor", +CONSTFUNC const char *binop_operator(ast_e tag) { + switch (tag) { + case Multiply: case MultiplyUpdate: return "*"; + case Divide: case DivideUpdate: return "/"; + case Mod: case ModUpdate: return "%"; + case Plus: case PlusUpdate: return "+"; + case Minus: case MinusUpdate: return "-"; + case LeftShift: case LeftShiftUpdate: return "<<"; + case RightShift: case RightShiftUpdate: return ">>"; + case And: case AndUpdate: return "&"; + case Or: case OrUpdate: return "|"; + case Xor: case XorUpdate: return "^"; + case Equals: return "=="; + case NotEquals: return "!="; + case LessThan: return "<"; + case LessThanOrEquals: return "<="; + case GreaterThan: return ">"; + case GreaterThanOrEquals: return ">="; + default: return NULL; + } }; static CORD ast_list_to_xml(ast_list_t *asts); @@ -100,7 +124,7 @@ CORD ast_to_xml(ast_t *ast) switch (ast->tag) { #define T(type, ...) case type: { auto data = ast->__data.type; (void)data; return CORD_asprintf(__VA_ARGS__); } T(Unknown, "") - T(None, "%r", type_ast_to_xml(data.type)) + T(None, "") T(Bool, "", data.b ? "yes" : "no") T(Var, "%s", data.name) T(Int, "%s", data.str) @@ -108,22 +132,24 @@ CORD ast_to_xml(ast_t *ast) T(TextLiteral, "%r", xml_escape(data.cord)) T(TextJoin, "%r", data.lang ? CORD_all(" lang=\"", data.lang, "\"") : CORD_EMPTY, ast_list_to_xml(data.children)) T(Path, "%s", data.path) - T(Declare, "%r", ast_to_xml(data.var), ast_to_xml(data.value)) + T(Declare, "%r%r", ast_to_xml(data.var), type_ast_to_xml(data.type), ast_to_xml(data.value)) T(Assign, "%r%r", ast_list_to_xml(data.targets), ast_list_to_xml(data.values)) - T(BinaryOp, "%r %r", xml_escape(OP_NAMES[data.op]), ast_to_xml(data.lhs), ast_to_xml(data.rhs)) - T(UpdateAssign, "%r %r", xml_escape(OP_NAMES[data.op]), ast_to_xml(data.lhs), ast_to_xml(data.rhs)) +#define BINOP(name) T(name, "<" #name ">%r %r", data.lhs, data.rhs) + BINOP(Power) BINOP(PowerUpdate) BINOP(Multiply) BINOP(MultiplyUpdate) BINOP(Divide) BINOP(DivideUpdate) BINOP(Mod) BINOP(ModUpdate) + BINOP(Mod1) BINOP(Mod1Update) BINOP(Plus) BINOP(PlusUpdate) BINOP(Minus) BINOP(MinusUpdate) BINOP(Concat) BINOP(ConcatUpdate) + BINOP(LeftShift) BINOP(LeftShiftUpdate) BINOP(RightShift) BINOP(RightShiftUpdate) BINOP(UnsignedLeftShift) BINOP(UnsignedLeftShiftUpdate) + BINOP(UnsignedRightShift) BINOP(UnsignedRightShiftUpdate) BINOP(And) BINOP(AndUpdate) BINOP(Or) BINOP(OrUpdate) + BINOP(Xor) BINOP(XorUpdate) +#undef BINOP T(Negative, "%r", ast_to_xml(data.value)) T(Not, "%r", ast_to_xml(data.value)) T(HeapAllocate, "%r", ast_to_xml(data.value)) T(StackReference, "%r", ast_to_xml(data.value)) T(Min, "%r%r%r", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) T(Max, "%r%r%r", ast_to_xml(data.lhs), ast_to_xml(data.rhs), optional_tagged("key", data.key)) - T(Array, "%r%r", optional_tagged_type("item-type", data.item_type), ast_list_to_xml(data.items)) - T(Set, "%r%r", - optional_tagged_type("item-type", data.item_type), - ast_list_to_xml(data.items)) - T(Table, "%r%r%r%r
", - optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type), + T(Array, "%r", ast_list_to_xml(data.items)) + T(Set, "%r", ast_list_to_xml(data.items)) + T(Table, "%r%r
", optional_tagged("default-value", data.default_value), ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback)) T(TableEntry, "%r%r", ast_to_xml(data.key), ast_to_xml(data.value)) @@ -145,7 +171,7 @@ CORD ast_to_xml(ast_t *ast) T(Repeat, "%r", optional_tagged("body", data.body)) T(If, "%r%r%r", optional_tagged("condition", data.condition), optional_tagged("body", data.body), optional_tagged("else", data.else_body)) T(When, "%r%r%r", ast_to_xml(data.subject), when_clauses_to_xml(data.clauses), optional_tagged("else", data.else_body)) - T(Reduction, "%r", xml_escape(OP_NAMES[data.op]), optional_tagged("key", data.key), + T(Reduction, "%r", xml_escape(binop_method_name(data.op)), optional_tagged("key", data.key), optional_tagged("iterable", data.iter)) T(Skip, "%r", data.target) T(Stop, "%r", data.target) @@ -313,4 +339,45 @@ void visit_topologically(ast_list_t *asts, Closure_t fn) } } +CONSTFUNC bool is_binary_operation(ast_t *ast) +{ + switch (ast->tag) { + case BINOP_CASES: return true; + default: return false; + } +} + +CONSTFUNC bool is_update_assignment(ast_t *ast) +{ + switch (ast->tag) { + case PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: + case PlusUpdate: case MinusUpdate: case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: + case RightShiftUpdate: case UnsignedRightShiftUpdate: case AndUpdate: case OrUpdate: case XorUpdate: + return true; + default: return false; + } +} + +CONSTFUNC ast_e binop_tag(ast_e tag) +{ + switch (tag) { + case PowerUpdate: return Power; + case MultiplyUpdate: return Multiply; + case DivideUpdate: return Divide; + case ModUpdate: return Mod; + case Mod1Update: return Mod1; + case PlusUpdate: return Plus; + case MinusUpdate: return Minus; + case ConcatUpdate: return Concat; + case LeftShiftUpdate: return LeftShift; + case UnsignedLeftShiftUpdate: return UnsignedLeftShift; + case RightShiftUpdate: return RightShift; + case UnsignedRightShiftUpdate: return UnsignedRightShift; + case AndUpdate: return And; + case OrUpdate: return Or; + case XorUpdate: return Xor; + default: return Unknown; + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/ast.h b/src/ast.h index 766b484..4f370e1 100644 --- a/src/ast.h +++ b/src/ast.h @@ -20,6 +20,7 @@ #define WrapAST(ast, ast_tag, ...) (new(ast_t, .file=(ast)->file, .start=(ast)->start, .end=(ast)->end, .tag=ast_tag, .__data.ast_tag={__VA_ARGS__})) #define TextAST(ast, _str) WrapAST(ast, TextLiteral, .str=GC_strdup(_str)) #define Match(x, _tag) ((x)->tag == _tag ? &(x)->__data._tag : (errx(1, __FILE__ ":%d This was supposed to be a " # _tag "\n", __LINE__), &(x)->__data._tag)) +#define BINARY_OPERANDS(ast) ({ if (!is_binary_operation(ast)) errx(1, __FILE__ ":%d This is not a binary operation!", __LINE__); (ast)->__data.Plus; }) #define REVERSE_LIST(list) do { \ __typeof(list) _prev = NULL; \ @@ -37,6 +38,9 @@ struct binding_s; typedef struct type_ast_s type_ast_t; typedef struct ast_s ast_t; +typedef struct { + ast_t *lhs, *rhs; +} binary_operands_t; typedef struct ast_list_s { ast_t *ast; @@ -55,17 +59,6 @@ typedef struct when_clause_s { struct when_clause_s *next; } when_clause_t; -typedef enum { - BINOP_UNKNOWN, - BINOP_POWER=100, BINOP_MULT, BINOP_DIVIDE, BINOP_MOD, BINOP_MOD1, BINOP_PLUS, - BINOP_MINUS, BINOP_CONCAT, BINOP_LSHIFT, BINOP_ULSHIFT, BINOP_RSHIFT, BINOP_URSHIFT, BINOP_MIN, - BINOP_MAX, BINOP_EQ, BINOP_NE, BINOP_LT, BINOP_LE, BINOP_GT, BINOP_GE, - BINOP_CMP, - BINOP_AND, BINOP_OR, BINOP_XOR, -} binop_e; - -extern const char *binop_method_names[BINOP_XOR+1]; - typedef enum { UnknownTypeAST, VarTypeAST, @@ -117,6 +110,15 @@ struct type_ast_s { } __data; }; +#define BINOP_CASES Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case Concat: case LeftShift: case UnsignedLeftShift: \ + case RightShift: case UnsignedRightShift: case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: \ + case GreaterThanOrEquals: case Compare: case And: case Or: case Xor: \ + case PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: case PlusUpdate: case MinusUpdate: case ConcatUpdate: \ + case LeftShiftUpdate: case UnsignedLeftShiftUpdate +#define UPDATE_CASES PowerUpdate: case MultiplyUpdate: case DivideUpdate: case ModUpdate: case Mod1Update: case PlusUpdate: case MinusUpdate: \ + case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: case RightShiftUpdate: case UnsignedRightShiftUpdate: \ + case AndUpdate: case OrUpdate: case XorUpdate + typedef enum { Unknown = 0, None, Bool, Var, @@ -124,7 +126,11 @@ typedef enum { TextLiteral, TextJoin, PrintStatement, Path, Declare, Assign, - BinaryOp, UpdateAssign, + Power, Multiply, Divide, Mod, Mod1, Plus, Minus, Concat, LeftShift, UnsignedLeftShift, + RightShift, UnsignedRightShift, Equals, NotEquals, LessThan, LessThanOrEquals, GreaterThan, + GreaterThanOrEquals, Compare, And, Or, Xor, + PowerUpdate, MultiplyUpdate, DivideUpdate, ModUpdate, Mod1Update, PlusUpdate, MinusUpdate, ConcatUpdate, LeftShiftUpdate, UnsignedLeftShiftUpdate, + RightShiftUpdate, UnsignedRightShiftUpdate, AndUpdate, OrUpdate, XorUpdate, Not, Negative, HeapAllocate, StackReference, Min, Max, Array, Set, Table, TableEntry, Comprehension, @@ -152,9 +158,7 @@ struct ast_s { const char *start, *end; union { struct {} Unknown; - struct { - type_ast_t *type; - } None; + struct {} None; struct { bool b; } Bool; @@ -182,16 +186,17 @@ struct ast_s { } PrintStatement; struct { ast_t *var; + type_ast_t *type; ast_t *value; } Declare; struct { ast_list_t *targets, *values; } Assign; - struct { - ast_t *lhs; - binop_e op; - ast_t *rhs; - } BinaryOp, UpdateAssign; + binary_operands_t Power, Multiply, Divide, Mod, Mod1, Plus, Minus, Concat, LeftShift, UnsignedLeftShift, + RightShift, UnsignedRightShift, Equals, NotEquals, LessThan, LessThanOrEquals, GreaterThan, + GreaterThanOrEquals, Compare, And, Or, Xor, + PowerUpdate, MultiplyUpdate, DivideUpdate, ModUpdate, Mod1Update, PlusUpdate, MinusUpdate, ConcatUpdate, LeftShiftUpdate, UnsignedLeftShiftUpdate, + RightShiftUpdate, UnsignedRightShiftUpdate, AndUpdate, OrUpdate, XorUpdate; struct { ast_t *value; } Not, Negative, HeapAllocate, StackReference; @@ -199,15 +204,12 @@ struct ast_s { ast_t *lhs, *rhs, *key; } Min, Max; struct { - type_ast_t *item_type; ast_list_t *items; } Array; struct { - type_ast_t *item_type; ast_list_t *items; } Set; struct { - type_ast_t *key_type, *value_type; ast_t *default_value; ast_t *fallback; ast_list_t *entries; @@ -272,7 +274,7 @@ struct ast_s { } When; struct { ast_t *iter, *key; - binop_e op; + ast_e op; } Reduction; struct { const char *target; @@ -345,5 +347,10 @@ const char *ast_source(ast_t *ast); CORD type_ast_to_xml(type_ast_t *ast); PUREFUNC bool is_idempotent(ast_t *ast); void visit_topologically(ast_list_t *ast, Closure_t fn); +CONSTFUNC bool is_update_assignment(ast_t *ast); +CONSTFUNC const char *binop_method_name(ast_e tag); +CONSTFUNC const char *binop_operator(ast_e tag); +CONSTFUNC ast_e binop_tag(ast_e tag); +CONSTFUNC bool is_binary_operation(ast_t *ast); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/compile.c b/src/compile.c index 95cf5c9..d893c67 100644 --- a/src/compile.c +++ b/src/compile.c @@ -24,7 +24,6 @@ typedef ast_t* (*comprehension_body_t)(ast_t*, ast_t*); static CORD compile_to_pointer_depth(env_t *env, ast_t *ast, int64_t target_depth, bool needs_incref); -static CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type); static CORD compile_string(env_t *env, ast_t *ast, CORD color); static CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t *call_args); static CORD compile_maybe_incref(env_t *env, ast_t *ast, type_t *t); @@ -33,9 +32,17 @@ static CORD compile_unsigned_type(type_t *t); static CORD promote_to_optional(type_t *t, CORD code); static CORD compile_none(type_t *t); static CORD compile_to_type(env_t *env, ast_t *ast, type_t *t); +static CORD compile_typed_array(env_t *env, ast_t *ast, type_t *array_type); +static CORD compile_typed_set(env_t *env, ast_t *ast, type_t *set_type); +static CORD compile_typed_table(env_t *env, ast_t *ast, type_t *table_type); +static CORD compile_typed_allocation(env_t *env, ast_t *ast, type_t *pointer_type); static CORD check_none(type_t *t, CORD value); static CORD optional_into_nonnone(type_t *t, CORD value); static CORD compile_string_literal(CORD literal); +static ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject); +static ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject); +static ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject); +static CORD compile_lvalue(env_t *env, ast_t *ast); CORD promote_to_optional(type_t *t, CORD code) { @@ -79,6 +86,11 @@ static bool promote(env_t *env, ast_t *ast, CORD *code, type_t *actual, type_t * return true; } + // Empty promotion: + type_t *more_complete = most_complete_type(actual, needed); + if (more_complete) + return true; + // Optional promotion: if (needed->tag == OptionalType && type_eq(actual, Match(needed, OptionalType)->type)) { *code = promote_to_optional(actual, *code); @@ -218,14 +230,10 @@ static void add_closed_vars(Table_t *closed_vars, env_t *enclosing_scope, env_t add_closed_vars(closed_vars, enclosing_scope, env, value->ast); break; } - case BinaryOp: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, BinaryOp)->rhs); - break; - } - case UpdateAssign: { - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->lhs); - add_closed_vars(closed_vars, enclosing_scope, env, Match(ast, UpdateAssign)->rhs); + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + add_closed_vars(closed_vars, enclosing_scope, env, binop.lhs); + add_closed_vars(closed_vars, enclosing_scope, env, binop.rhs); break; } case Not: case Negative: case HeapAllocate: case StackReference: { @@ -481,6 +489,226 @@ CORD compile_declaration(type_t *t, CORD name) } } +static CORD compile_update_assignment(env_t *env, ast_t *ast) +{ + if (!is_update_assignment(ast)) + code_err(ast, "This is not an update assignment"); + + binary_operands_t update = BINARY_OPERANDS(ast); + + type_t *lhs_t = get_type(env, update.lhs); + + bool needs_idemotency_fix = !is_idempotent(update.lhs); + CORD lhs = needs_idemotency_fix ? "(*lhs)" : compile_lvalue(env, update.lhs); + + CORD update_assignment = CORD_EMPTY; + switch (ast->tag) { + case PlusUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " += ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case MinusUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " -= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case MultiplyUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " *= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case DivideUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " /= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case LeftShiftUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " <<= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case RightShiftUpdate: { + if (lhs_t->tag == IntType || lhs_t->tag == ByteType) + update_assignment = CORD_all(lhs, " >>= ", compile_to_type(env, update.rhs, lhs_t), ";"); + break; + } + case AndUpdate: { + if (lhs_t->tag == BoolType) + update_assignment = CORD_all("if (", lhs, ") ", lhs, " = ", compile_to_type(env, update.rhs, Type(BoolType)), ";"); + break; + } + case OrUpdate: { + if (lhs_t->tag == BoolType) + update_assignment = CORD_all("if (!", lhs, ") ", lhs, " = ", compile_to_type(env, update.rhs, Type(BoolType)), ";"); + break; + } + default: break; + } + + if (update_assignment == CORD_EMPTY) { + ast_t *binop = new(ast_t); + *binop = *ast; + binop->tag = binop_tag(binop->tag); + if (needs_idemotency_fix) + binop->__data.Plus.lhs = WrapAST(update.lhs, InlineCCode, .code="*lhs", .type=lhs_t); + update_assignment = CORD_all(lhs, " = ", compile_to_type(env, binop, lhs_t)); + } + + if (needs_idemotency_fix) + return CORD_all("{ ", compile_declaration(Type(PointerType, .pointed=lhs_t), "lhs"), " = &", compile_lvalue(env, update.lhs), "; ", + update_assignment, "; }"); + else + return update_assignment; +} + +static CORD compile_binary_op(env_t *env, ast_t *ast) +{ + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + type_t *overall_t = get_type(env, ast); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) { + arg_ast_t *args = new(arg_ast_t, .value=binop.lhs, .next=new(arg_ast_t, .value=binop.rhs)); + auto fn = Match(b->type, FunctionType); + return CORD_all(b->code, "(", compile_arguments(env, ast, fn->args, args), ")"); + } + + if (ast->tag == Or && lhs_t->tag == OptionalType) { + if (is_incomplete_type(rhs_t)) { + type_t *complete = most_complete_type(rhs_t, Match(lhs_t, OptionalType)->type); + if (complete == NULL) + code_err(binop.rhs, "I don't know how to convert a ", type_to_str(rhs_t), " to a ", type_to_str(Match(lhs_t, OptionalType)->type)); + rhs_t = complete; + } + + if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + "if (", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop.rhs), " ", + optional_into_nonnone(lhs_t, "lhs"), "; })"); + } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + check_none(lhs_t, "lhs"), " ? ", compile(env, binop.rhs), " : lhs; })"); + } else if (rhs_t->tag != OptionalType && type_eq(Match(lhs_t, OptionalType)->type, rhs_t)) { + return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop.lhs), "; ", + check_none(lhs_t, "lhs"), " ? ", compile(env, binop.rhs), " : ", + optional_into_nonnone(lhs_t, "lhs"), "; })"); + } else if (rhs_t->tag == BoolType) { + return CORD_all("((!", check_none(lhs_t, compile(env, binop.lhs)), ") || ", compile(env, binop.rhs), ")"); + } else { + code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + } + + CORD lhs = compile_to_type(env, binop.lhs, overall_t); + CORD rhs = compile_to_type(env, binop.rhs, overall_t); + + switch (ast->tag) { + case Power: { + if (overall_t->tag != NumType) + code_err(ast, "Exponentiation is only supported for Num types, not ", type_to_str(overall_t)); + if (overall_t->tag == NumType && Match(overall_t, NumType)->bits == TYPE_NBITS32) + return CORD_all("powf(", lhs, ", ", rhs, ")"); + else + return CORD_all("pow(", lhs, ", ", rhs, ")"); + } + case Multiply: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " * ", rhs, ")"); + } + case Divide: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " / ", rhs, ")"); + } + case Mod: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " % ", rhs, ")"); + } + case Mod1: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("((((", lhs, ")-1) % (", rhs, ")) + 1)"); + } + case Plus: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " + ", rhs, ")"); + } + case Minus: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " - ", rhs, ")"); + } + case LeftShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " << ", rhs, ")"); + } + case RightShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", lhs, " >> ", rhs, ")"); + } + case UnsignedLeftShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", compile_type(overall_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " << ", rhs, ")"); + } + case UnsignedRightShift: { + if (overall_t->tag != IntType && overall_t->tag != NumType && overall_t->tag != ByteType) + code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + return CORD_all("(", compile_type(overall_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " >> ", rhs, ")"); + } + case And: { + if (overall_t->tag == BoolType) + return CORD_all("(", lhs, " && ", rhs, ")"); + else if (overall_t->tag == IntType || overall_t->tag == ByteType) + return CORD_all("(", lhs, " & ", rhs, ")"); + else + code_err(ast, "The 'and' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + case Compare: { + return CORD_all("generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(overall_t), ")"); + } + case Or: { + if (overall_t->tag == BoolType) { + return CORD_all("(", lhs, " || ", rhs, ")"); + } else if (overall_t->tag == IntType || overall_t->tag == ByteType) { + return CORD_all("(", lhs, " | ", rhs, ")"); + } else { + code_err(ast, "The 'or' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + } + case Xor: { + // TODO: support optional values in `xor` expressions + if (overall_t->tag == BoolType || overall_t->tag == IntType || overall_t->tag == ByteType) + return CORD_all("(", lhs, " ^ ", rhs, ")"); + else + code_err(ast, "The 'xor' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + case Concat: { + if (overall_t == PATH_TYPE) + return CORD_all("Path$concat(", lhs, ", ", rhs, ")"); + switch (overall_t->tag) { + case TextType: { + return CORD_all("Text$concat(", lhs, ", ", rhs, ")"); + } + case ArrayType: { + return CORD_all("Array$concat(", lhs, ", ", rhs, ", sizeof(", compile_type(Match(overall_t, ArrayType)->item_type), "))"); + } + default: + code_err(ast, "Concatenation isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); + } + } + default: errx(1, "Not a valid binary operation: ", ast_to_xml_str(ast)); + } +} + PUREFUNC CORD compile_unsigned_type(type_t *t) { if (t->tag != IntType) @@ -570,7 +798,7 @@ CORD compile_type(type_t *t) } } -static CORD compile_lvalue(env_t *env, ast_t *ast) +CORD compile_lvalue(env_t *env, ast_t *ast) { if (!can_be_mutated(env, ast)) { if (ast->tag == Index) { @@ -756,7 +984,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) CORD code = CORD_EMPTY; for (when_clause_t *clause = when->clauses; clause; clause = clause->next) { - ast_t *comparison = WrapAST(clause->pattern, BinaryOp, .lhs=subject, .op=BINOP_EQ, .rhs=clause->pattern); + ast_t *comparison = WrapAST(clause->pattern, Equals, .lhs=subject, .rhs=clause->pattern); (void)get_type(env, comparison); if (code != CORD_EMPTY) code = CORD_all(code, "else "); @@ -865,7 +1093,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) if (streq(varname, "_")) return compile_statement(env, WrapAST(ast, DocTest, .expr=decl->value, .expected=test->expected, .skip_source=test->skip_source)); CORD var = CORD_all("_$", Match(decl->var, Var)->name); - type_t *t = get_type(env, decl->value); + type_t *t = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (!t) code_err(decl->value, "I couldn't figure out the type of this value!"); CORD val_code = compile_maybe_incref(env, decl->value, t); if (t->tag == FunctionType) { @@ -916,21 +1144,20 @@ static CORD _compile_statement(env_t *env, ast_t *ast) test_code = CORD_all(test_code, "$1; })"); } - } else if (test->expr->tag == UpdateAssign) { - type_t *lhs_t = get_type(env, Match(test->expr, UpdateAssign)->lhs); - auto update = Match(test->expr, UpdateAssign); - - if (update->lhs->tag == Index) { - type_t *indexed = value_type(get_type(env, Match(update->lhs, Index)->indexed)); + } else if (is_update_assignment(test->expr)) { + binary_operands_t update = BINARY_OPERANDS(test->expr); + type_t *lhs_t = get_type(env, update.lhs); + if (update.lhs->tag == Index) { + type_t *indexed = value_type(get_type(env, Match(update.lhs, Index)->indexed)); if (indexed->tag == TableType && Match(indexed, TableType)->default_value == NULL) - code_err(update->lhs, "Update assignments are not currently supported for tables"); + code_err(update.lhs, "Update assignments are not currently supported for tables"); } - ast_t *update_var = WrapAST(ast, UpdateAssign, - .lhs=WrapAST(update->lhs, InlineCCode, .code="(*expr)", .type=lhs_t), - .op=update->op, .rhs=update->rhs); + ast_t *update_var = new(ast_t); + *update_var = *ast; + update_var->__data.PlusUpdate.lhs = WrapAST(update.lhs, InlineCCode, .code="(*expr)", .type=lhs_t); // UNSAFE test_code = CORD_all("({", - compile_declaration(Type(PointerType, lhs_t), "expr"), " = &(", compile_lvalue(env, update->lhs), "); ", + compile_declaration(Type(PointerType, lhs_t), "expr"), " = &(", compile_lvalue(env, update.lhs), "); ", compile_statement(env, update_var), "; *expr; })"); expr_t = lhs_t; } else if (expr_t->tag == VoidType || expr_t->tag == AbortType || expr_t->tag == ReturnType) { @@ -939,14 +1166,10 @@ static CORD _compile_statement(env_t *env, ast_t *ast) test_code = compile(env, test->expr); } if (test->expected) { - type_t *expected_type = get_type(env, test->expected); - if (!type_eq(expr_t, expected_type)) - code_err(ast, "The type on the top of this test (", type_to_str(expr_t), - ") is different from the type on the bottom (", type_to_str(expected_type), ")"); return CORD_asprintf( "%rtest(%r, %r, %r, %ld, %ld);", setup, test_code, - compile(env, test->expected), + compile_to_type(env, test->expected, expr_t), compile_type_info(expr_t), (int64_t)(test->expr->start - test->expr->file->text), (int64_t)(test->expr->end - test->expr->file->text)); @@ -965,7 +1188,7 @@ static CORD _compile_statement(env_t *env, ast_t *ast) if (streq(name, "_")) { // Explicit discard return CORD_all("(void)", compile(env, decl->value), ";"); } else { - type_t *t = get_type(env, decl->value); + type_t *t = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (t->tag == AbortType || t->tag == VoidType || t->tag == ReturnType) code_err(ast, "You can't declare a variable with a ", type_to_str(t), " value"); @@ -1011,155 +1234,44 @@ static CORD _compile_statement(env_t *env, ast_t *ast) } return CORD_cat(code, "\n}"); } - case UpdateAssign: { - auto update = Match(ast, UpdateAssign); - - if (update->lhs->tag == Index) { - type_t *indexed = value_type(get_type(env, Match(update->lhs, Index)->indexed)); - if (indexed->tag == TableType && Match(indexed, TableType)->default_value == NULL) - code_err(update->lhs, "Update assignments are not currently supported for tables"); - } - - if (!is_idempotent(update->lhs)) { - type_t *lhs_t = get_type(env, update->lhs); - return CORD_all("{ ", compile_declaration(Type(PointerType, lhs_t), "update_lhs"), " = &", - compile_lvalue(env, update->lhs), ";\n", - "*update_lhs = ", compile(env, WrapAST(ast, BinaryOp, - .lhs=WrapAST(update->lhs, InlineCCode, .code="(*update_lhs)", .type=lhs_t), - .op=update->op, .rhs=update->rhs)), "; }"); - } - + case PlusUpdate: { + auto update = Match(ast, PlusUpdate); type_t *lhs_t = get_type(env, update->lhs); - CORD lhs = compile_lvalue(env, update->lhs); - - if (update->lhs->tag == Index && value_type(get_type(env, Match(update->lhs, Index)->indexed))->tag == TableType) { - ast_t *lhs_placeholder = WrapAST(update->lhs, InlineCCode, .code="(*lhs)", .type=lhs_t); - CORD method_call = compile_math_method(env, update->op, lhs_placeholder, update->rhs, lhs_t); - if (method_call) - return CORD_all("{ ", compile_declaration(Type(PointerType, .pointed=lhs_t), "lhs"), " = &", lhs, "; *lhs = ", method_call, "; }"); - } else { - CORD method_call = compile_math_method(env, update->op, update->lhs, update->rhs, lhs_t); - if (method_call) - return CORD_all(lhs, " = ", method_call, ";"); - } - - CORD rhs = compile(env, update->rhs); - - type_t *rhs_t = get_type(env, update->rhs); - if (update->rhs->tag == Int && is_numeric_type(non_optional(lhs_t))) { - rhs = compile_int_to_type(env, update->rhs, lhs_t); - } else if (!promote(env, update->rhs, &rhs, rhs_t, lhs_t)) { - code_err(ast, "I can't do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - - bool lhs_is_optional_num = (lhs_t->tag == OptionalType && Match(lhs_t, OptionalType)->type && Match(lhs_t, OptionalType)->type->tag == NumType); - switch (update->op) { - case BINOP_MULT: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a multiply assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - if (lhs_t->tag == NumType) { // 0*INF -> NaN, needs checking - return CORD_asprintf("%r *= %r;\n" - "if (isnan(%r))\n" - "fail_source(%r, %ld, %ld, \"This update assignment created a NaN value (probably multiplying zero with infinity), but the type is not optional!\");\n", - lhs, rhs, lhs, - CORD_quoted(ast->file->filename), - (long)(ast->start - ast->file->text), - (long)(ast->end - ast->file->text)); - } - return CORD_all(lhs, " *= ", rhs, ";"); - case BINOP_DIVIDE: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a divide assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - if (lhs_t->tag == NumType) { // 0/0 or INF/INF -> NaN, needs checking - return CORD_asprintf("%r /= %r;\n" - "if (isnan(%r))\n" - "fail_source(%r, %ld, %ld, \"This update assignment created a NaN value (probably 0/0 or INF/INF), but the type is not optional!\");\n", - lhs, rhs, lhs, - CORD_quoted(ast->file->filename), - (long)(ast->start - ast->file->text), - (long)(ast->end - ast->file->text)); - } - return CORD_all(lhs, " /= ", rhs, ";"); - case BINOP_MOD: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a mod assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " = ", lhs, " % ", rhs); - case BINOP_MOD1: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a mod assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " = (((", lhs, ") - 1) % ", rhs, ") + 1;"); - case BINOP_PLUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do an addition assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " += ", rhs, ";"); - case BINOP_MINUS: - if (lhs_t->tag != IntType && lhs_t->tag != NumType && lhs_t->tag != ByteType && !lhs_is_optional_num) - code_err(ast, "I can't do a subtraction assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " -= ", rhs, ";"); - case BINOP_POWER: { - if (lhs_t->tag != NumType && !lhs_is_optional_num) - code_err(ast, "'^=' is only supported for Num types"); - if (lhs_t->tag == NumType && Match(lhs_t, NumType)->bits == TYPE_NBITS32) - return CORD_all(lhs, " = powf(", lhs, ", ", rhs, ");"); - else - return CORD_all(lhs, " = pow(", lhs, ", ", rhs, ");"); - } - case BINOP_LSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " <<= ", rhs, ";"); - case BINOP_RSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " >>= ", rhs, ";"); - case BINOP_ULSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("{ ", compile_unsigned_type(lhs_t), " *dest = (void*)&(", lhs, "); *dest <<= ", rhs, "; }"); - case BINOP_URSHIFT: - if (lhs_t->tag != IntType && lhs_t->tag != ByteType) - code_err(ast, "I can't do a shift assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("{ ", compile_unsigned_type(lhs_t), " *dest = (void*)&(", lhs, "); *dest >>= ", rhs, "; }"); - case BINOP_AND: { - if (lhs_t->tag == BoolType) - return CORD_all("if (", lhs, ") ", lhs, " = ", rhs, ";"); - else if (lhs_t->tag == IntType || lhs_t->tag == ByteType) - return CORD_all(lhs, " &= ", rhs, ";"); - else if (lhs_t->tag == OptionalType) - return CORD_all("if (!(", check_none(lhs_t, lhs), ")) ", lhs, " = ", promote_to_optional(rhs_t, rhs), ";"); - else - code_err(ast, "'or=' is not implemented for ", type_to_str(lhs_t), " types"); - } - case BINOP_OR: { - if (lhs_t->tag == BoolType) - return CORD_all("if (!(", lhs, ")) ", lhs, " = ", rhs, ";"); - else if (lhs_t->tag == IntType || lhs_t->tag == ByteType) - return CORD_all(lhs, " |= ", rhs, ";"); - else if (lhs_t->tag == OptionalType) - return CORD_all("if (", check_none(lhs_t, lhs), ") ", lhs, " = ", promote_to_optional(rhs_t, rhs), ";"); - else - code_err(ast, "'or=' is not implemented for ", type_to_str(lhs_t), " types"); - } - case BINOP_XOR: - if (lhs_t->tag != IntType && lhs_t->tag != BoolType && lhs_t->tag != ByteType) - code_err(ast, "I can't do an xor assignment with this operator between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all(lhs, " ^= ", rhs, ";"); - case BINOP_CONCAT: { - if (lhs_t->tag == TextType) { - return CORD_all(lhs, " = Texts(", lhs, ", ", rhs, ");"); - } else if (lhs_t->tag == ArrayType) { - CORD padded_item_size = CORD_all("sizeof(", compile_type(Match(lhs_t, ArrayType)->item_type), ")"); - // arr ++= [...] - if (update->lhs->tag == Var) - return CORD_all("Array$insert_all(&", lhs, ", ", rhs, ", I(0), ", padded_item_size, ");"); - else - return CORD_all(lhs, " = Array$concat(", lhs, ", ", rhs, ", ", padded_item_size, ");"); - } else { - code_err(ast, "'++=' is not implemented for ", type_to_str(lhs_t), " types"); - } - } - default: code_err(ast, "Update assignments are not implemented for this operation"); - } + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " += ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case MinusUpdate: { + auto update = Match(ast, MinusUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " -= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case MultiplyUpdate: { + auto update = Match(ast, MultiplyUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " *= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case DivideUpdate: { + auto update = Match(ast, DivideUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " /= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case ModUpdate: { + auto update = Match(ast, ModUpdate); + type_t *lhs_t = get_type(env, update->lhs); + if (is_idempotent(update->lhs) && (lhs_t->tag == IntType || lhs_t->tag == NumType || lhs_t->tag == ByteType)) + return CORD_all(compile_lvalue(env, update->lhs), " %= ", compile_to_type(env, update->rhs, lhs_t), ";"); + return compile_update_assignment(env, ast); + } + case PowerUpdate: case Mod1Update: case ConcatUpdate: case LeftShiftUpdate: case UnsignedLeftShiftUpdate: + case RightShiftUpdate: case UnsignedRightShiftUpdate: case AndUpdate: case OrUpdate: case XorUpdate: { + return compile_update_assignment(env, ast); } case StructDef: case EnumDef: case LangDef: case Extend: case FunctionDef: case ConvertDef: { return CORD_EMPTY; @@ -1839,22 +1951,20 @@ CORD compile_to_type(env_t *env, ast_t *ast, type_t *t) case TYPE_NBITS32: return CORD_asprintf("N32(%.10g)", n); default: code_err(ast, "This is not a valid number bit width"); } - } else if (ast->tag == None && Match(ast, None)->type == NULL) { + } else if (ast->tag == None) { return compile_none(t); - } else if (t->tag == ArrayType && ast->tag == Array && !Match(ast, Array)->item_type && !Match(ast, Array)->items) { - return compile(env, ast); + } else if (t->tag == PointerType && (ast->tag == HeapAllocate || ast->tag == StackReference)) { + return compile_typed_allocation(env, ast, t); + } else if (t->tag == ArrayType && ast->tag == Array) { + return compile_typed_array(env, ast, t); } else if (t->tag == TableType && ast->tag == Table) { - auto table = Match(ast, Table); - if (!table->key_type && !table->value_type && !table->default_value && !table->fallback && !table->entries) - return compile(env, ast); + return compile_typed_table(env, ast, t); } else if (t->tag == SetType && ast->tag == Set) { - auto set = Match(ast, Set); - if (!set->item_type && !set->items) - return compile(env, ast); + return compile_typed_set(env, ast, t); } else if (t->tag == SetType && ast->tag == Table) { auto table = Match(ast, Table); - if (!table->key_type && !table->value_type && !table->default_value && !table->fallback && !table->entries) - return compile(env, ast); + if (!table->default_value && !table->fallback && !table->entries) + return compile_to_type(env, WrapAST(ast, Set), t); } type_t *actual = get_type(env, ast); @@ -1869,6 +1979,193 @@ CORD compile_to_type(env_t *env, ast_t *ast, type_t *t) return code; } +CORD compile_typed_array(env_t *env, ast_t *ast, type_t *array_type) +{ + auto array = Match(ast, Array); + if (!array->items) + return "(Array_t){.length=0}"; + + type_t *item_type = Match(array_type, ArrayType)->item_type; + + int64_t n = 0; + for (ast_list_t *item = array->items; item; item = item->next) { + ++n; + if (item->ast->tag == Comprehension) + goto array_comprehension; + } + + { + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; + CORD code = CORD_all("TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n)); + for (ast_list_t *item = array->items; item; item = item->next) { + code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); + } + return CORD_cat(code, ")"); + } + + array_comprehension: + { + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); + static int64_t comp_num = 1; + const char *comprehension_name = String("arr$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=array_type, .is_stack=true)); + Closure_t comp_action = {.fn=add_to_array_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + CORD code = CORD_all("({ Array_t ", comprehension_name, " = {};"); + // set_binding(scope, comprehension_name, array_type, comprehension_name); + for (ast_list_t *item = array->items; item; item = item->next) { + if (item->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_array_comprehension(item->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_set(env_t *env, ast_t *ast, type_t *set_type) +{ + auto set = Match(ast, Set); + if (!set->items) + return "((Table_t){})"; + + type_t *item_type = Match(set_type, SetType)->item_type; + + size_t n = 0; + for (ast_list_t *item = set->items; item; item = item->next) { + ++n; + if (item->ast->tag == Comprehension) + goto set_comprehension; + } + + { // No comprehension: + CORD code = CORD_all("Set(", + compile_type(item_type), ", ", + compile_type_info(item_type)); + CORD_appendf(&code, ", %zu", n); + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; + for (ast_list_t *item = set->items; item; item = item->next) { + code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); + } + return CORD_cat(code, ")"); + } + + set_comprehension: + { + static int64_t comp_num = 1; + env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); + const char *comprehension_name = String("set$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=set_type, .is_stack=true)); + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {};"); + Closure_t comp_action = {.fn=add_to_set_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + for (ast_list_t *item = set->items; item; item = item->next) { + if (item->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, item->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_set_comprehension(item->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_table(env_t *env, ast_t *ast, type_t *table_type) +{ + auto table = Match(ast, Table); + if (!table->entries) { + CORD code = "((Table_t){"; + if (table->fallback) + code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback),")"); + return CORD_cat(code, "})"); + } + + type_t *key_t = Match(table_type, TableType)->key_type; + type_t *value_t = Match(table_type, TableType)->value_type; + + if (value_t->tag == OptionalType) + code_err(ast, "Tables whose values are optional (", type_to_str(value_t), ") are not currently supported."); + + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) + goto table_comprehension; + } + + { // No comprehension: + env_t *key_scope = key_t->tag == EnumType ? with_enum_scope(env, key_t) : env; + env_t *value_scope = value_t->tag == EnumType ? with_enum_scope(env, value_t) : env; + CORD code = CORD_all("Table(", + compile_type(key_t), ", ", + compile_type(value_t), ", ", + compile_type_info(key_t), ", ", + compile_type_info(value_t)); + if (table->fallback) + code = CORD_all(code, ", /*fallback:*/ heap(", compile(env, table->fallback), ")"); + else + code = CORD_all(code, ", /*fallback:*/ NULL"); + + size_t n = 0; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) + ++n; + CORD_appendf(&code, ", %zu", n); + + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + auto e = Match(entry->ast, TableEntry); + code = CORD_all(code, ",\n\t{", compile_to_type(key_scope, e->key, key_t), ", ", + compile_to_type(value_scope, e->value, value_t), "}"); + } + return CORD_cat(code, ")"); + } + + table_comprehension: + { + static int64_t comp_num = 1; + env_t *scope = fresh_scope(env); + const char *comprehension_name = String("table$", comp_num++); + ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), + .type=Type(PointerType, .pointed=table_type, .is_stack=true)); + + CORD code = CORD_all("({ Table_t ", comprehension_name, " = {"); + if (table->fallback) + code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback), "), "); + + code = CORD_cat(code, "};"); + + Closure_t comp_action = {.fn=add_to_table_comprehension, .userdata=comprehension_var}; + scope->comprehension_action = &comp_action; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + if (entry->ast->tag == Comprehension) + code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); + else + code = CORD_all(code, compile_statement(env, add_to_table_comprehension(entry->ast, comprehension_var))); + } + code = CORD_all(code, " ", comprehension_name, "; })"); + return code; + } +} + +CORD compile_typed_allocation(env_t *env, ast_t *ast, type_t *pointer_type) +{ + // TODO: for constructors, do new(T, ...) instead of heap((T){...}) + type_t *pointed = Match(pointer_type, PointerType)->pointed; + switch (ast->tag) { + case HeapAllocate: { + return CORD_asprintf("heap(%r)", compile_to_type(env, Match(ast, HeapAllocate)->value, pointed)); + } + case StackReference: { + ast_t *subject = Match(ast, StackReference)->value; + if (can_be_mutated(env, subject) && type_eq(pointed, get_type(env, subject))) + return CORD_all("(&", compile_lvalue(env, subject), ")"); + else + return CORD_all("stack(", compile_to_type(env, subject, pointed), ")"); + } + default: code_err(ast, "Not an allocation!"); + } +} + CORD compile_int_to_type(env_t *env, ast_t *ast, type_t *target) { if (ast->tag != Int) { @@ -2027,98 +2324,6 @@ CORD compile_arguments(env_t *env, ast_t *call_ast, arg_t *spec_args, arg_ast_t return code; } -CORD compile_math_method(env_t *env, binop_e op, ast_t *lhs, ast_t *rhs, type_t *required_type) -{ - // Math methods are things like plus(), minus(), etc. If we don't find a - // matching method, return CORD_EMPTY. - const char *method_name = binop_method_names[op]; - if (!method_name) - return CORD_EMPTY; - - type_t *lhs_t = get_type(env, lhs); - type_t *rhs_t = get_type(env, rhs); -#define binding_works(b, lhs_t, rhs_t, ret_t) \ - (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ - (type_eq(fn->ret, ret_t) \ - && (fn->args && type_eq(fn->args->type, lhs_t)) \ - && (fn->args->next && can_promote(rhs_t, fn->args->next->type)) \ - && (!required_type || type_eq(required_type, fn->ret))); })) - arg_ast_t *args = new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs)); - switch (op) { - case BINOP_MULT: { - if (type_eq(lhs_t, rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } else if (lhs_t->tag == NumType || lhs_t->tag == IntType || lhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, rhs, "scaled_by"); - if (binding_works(b, rhs_t, lhs_t, rhs_t)) { - REVERSE_LIST(args); - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - } else if (rhs_t->tag == NumType || rhs_t->tag == IntType|| rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, "scaled_by"); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_OR: case BINOP_CONCAT: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$with(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_AND: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$overlap(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_MINUS: { - if (lhs_t->tag == SetType) { - return CORD_all("Table$without(", compile(env, lhs), ", ", compile(env, rhs), ", ", compile_type_info(lhs_t), ")"); - } - goto fallthrough; - } - case BINOP_PLUS: case BINOP_XOR: { - fallthrough: - if (type_eq(lhs_t, rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile(env, lhs), ", ", compile(env, rhs), ")"); - } - break; - } - case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { - if (is_numeric_type(rhs_t)) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: { - if (rhs_t->tag == IntType || rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - case BINOP_POWER: { - if (rhs_t->tag == NumType || rhs_t->tag == IntType || rhs_t->tag == BigIntType) { - binding_t *b = get_namespace_binding(env, lhs, binop_method_names[op]); - if (binding_works(b, lhs_t, rhs_t, lhs_t)) - return CORD_all(b->code, "(", compile_arguments(env, lhs, Match(b->type, FunctionType)->args, args), ")"); - } - break; - } - default: break; - } - return CORD_EMPTY; -} - CORD compile_string_literal(CORD literal) { CORD code = "\""; @@ -2209,19 +2414,19 @@ CORD compile_none(type_t *t) } } -static ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject) +ast_t *add_to_table_comprehension(ast_t *entry, ast_t *subject) { auto e = Match(entry, TableEntry); return WrapAST(entry, MethodCall, .name="set", .self=subject, .args=new(arg_ast_t, .value=e->key, .next=new(arg_ast_t, .value=e->value))); } -static ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject) +ast_t *add_to_array_comprehension(ast_t *item, ast_t *subject) { return WrapAST(item, MethodCall, .name="insert", .self=subject, .args=new(arg_ast_t, .value=item)); } -static ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject) +ast_t *add_to_set_comprehension(ast_t *item, ast_t *subject) { return WrapAST(item, MethodCall, .name="add", .self=subject, .args=new(arg_ast_t, .value=item)); } @@ -2230,10 +2435,7 @@ CORD compile(env_t *env, ast_t *ast) { switch (ast->tag) { case None: { - if (!Match(ast, None)->type) - code_err(ast, "This 'none' needs to specify what type it is using `none:Type` syntax"); - type_t *t = parse_type_ast(env, Match(ast, None)->type); - return compile_none(t); + code_err(ast, "This 'none' needs to specify what type it is using `none:Type` syntax"); } case Bool: return Match(ast, Bool)->b ? "yes" : "no"; case Var: { @@ -2303,14 +2505,8 @@ CORD compile(env_t *env, ast_t *ast) code_err(ast, "I don't know how to get the negative value of type ", type_to_str(t)); } - // TODO: for constructors, do new(T, ...) instead of heap((T){...}) - case HeapAllocate: return CORD_asprintf("heap(%r)", compile(env, Match(ast, HeapAllocate)->value)); - case StackReference: { - ast_t *subject = Match(ast, StackReference)->value; - if (can_be_mutated(env, subject)) - return CORD_all("(&", compile_lvalue(env, subject), ")"); - else - return CORD_all("stack(", compile(env, subject), ")"); + case HeapAllocate: case StackReference: { + return compile_typed_allocation(env, ast, get_type(env, ast)); } case Optional: { ast_t *value = Match(ast, Optional)->value; @@ -2329,264 +2525,67 @@ CORD compile(env_t *env, ast_t *ast) (long)(value->end - value->file->text)), optional_into_nonnone(t, "opt"), "; })"); } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - CORD method_call = compile_math_method(env, binop->op, binop->lhs, binop->rhs, NULL); - if (method_call != CORD_EMPTY) - return method_call; - - type_t *lhs_t = get_type(env, binop->lhs); - type_t *rhs_t = get_type(env, binop->rhs); - - if (binop->op == BINOP_OR && lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - "if (", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop->rhs), " ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? ", compile(env, binop->rhs), " : lhs; })"); - } else if (rhs_t->tag != OptionalType && type_eq(Match(lhs_t, OptionalType)->type, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? ", compile(env, binop->rhs), " : ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == BoolType) { - return CORD_all("((!", check_none(lhs_t, compile(env, binop->lhs)), ") || ", compile(env, binop->rhs), ")"); - } else { - code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - } else if (binop->op == BINOP_AND && lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - "if (!", check_none(lhs_t, "lhs"), ") ", compile_statement(env, binop->rhs), " ", - optional_into_nonnone(lhs_t, "lhs"), "; })"); - } else if (rhs_t->tag == OptionalType && type_eq(lhs_t, rhs_t)) { - return CORD_all("({ ", compile_declaration(lhs_t, "lhs"), " = ", compile(env, binop->lhs), "; ", - check_none(lhs_t, "lhs"), " ? lhs : ", compile(env, binop->rhs), "; })"); - } else if (rhs_t->tag == BoolType) { - return CORD_all("((!", check_none(lhs_t, compile(env, binop->lhs)), ") && ", compile(env, binop->rhs), ")"); - } else { - code_err(ast, "I don't know how to do an 'or' operation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - } - } - - type_t *non_optional_lhs = lhs_t; - if (lhs_t->tag == OptionalType) non_optional_lhs = Match(lhs_t, OptionalType)->type; - type_t *non_optional_rhs = rhs_t; - if (rhs_t->tag == OptionalType) non_optional_rhs = Match(rhs_t, OptionalType)->type; - - if (!non_optional_lhs && !non_optional_rhs) - code_err(ast, "Both of these values do not specify a type"); - else if (!non_optional_lhs) - non_optional_lhs = non_optional_rhs; - else if (!non_optional_rhs) - non_optional_rhs = non_optional_lhs; - - bool lhs_is_optional_num = (lhs_t->tag == OptionalType && non_optional_lhs->tag == NumType); - if (lhs_is_optional_num) - lhs_t = Match(lhs_t, OptionalType)->type; - bool rhs_is_optional_num = (rhs_t->tag == OptionalType && non_optional_rhs->tag == NumType); - if (rhs_is_optional_num) - rhs_t = Match(rhs_t, OptionalType)->type; - - CORD lhs, rhs; - if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) { - lhs = compile_int_to_type(env, binop->lhs, rhs_t); - lhs_t = rhs_t; - rhs = compile(env, binop->rhs); - } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) { - lhs = compile(env, binop->lhs); - rhs = compile_int_to_type(env, binop->rhs, lhs_t); - rhs_t = lhs_t; - } else { - lhs = compile(env, binop->lhs); - rhs = compile(env, binop->rhs); - } + case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case Concat: + case LeftShift: case UnsignedLeftShift: case RightShift: case UnsignedRightShift: case And: case Or: case Xor: { + return compile_binary_op(env, ast); + } + case Equals: case NotEquals: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); type_t *operand_t; - if (promote(env, binop->rhs, &rhs, rhs_t, lhs_t)) + CORD lhs, rhs; + if (can_compile_to_type(env, binop.rhs, lhs_t)) { + lhs = compile(env, binop.lhs); + rhs = compile_to_type(env, binop.rhs, lhs_t); operand_t = lhs_t; - else if (promote(env, binop->lhs, &lhs, lhs_t, rhs_t)) + } else if (can_compile_to_type(env, binop.lhs, rhs_t)) { + rhs = compile(env, binop.rhs); + lhs = compile_to_type(env, binop.lhs, rhs_t); operand_t = rhs_t; - else - code_err(ast, "I can't do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } else { + code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } - switch (binop->op) { - case BINOP_POWER: { - if (operand_t->tag != NumType) - code_err(ast, "Exponentiation is only supported for Num types, not ", type_to_str(operand_t)); - if (operand_t->tag == NumType && Match(operand_t, NumType)->bits == TYPE_NBITS32) - return CORD_all("powf(", lhs, ", ", rhs, ")"); - else - return CORD_all("pow(", lhs, ", ", rhs, ")"); + switch (operand_t->tag) { + case BigIntType: + return CORD_all(ast->tag == Equals ? CORD_EMPTY : "!", "Int$equal_value(", lhs, ", ", rhs, ")"); + case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: + return CORD_all("(", lhs, ast->tag == Equals ? " == " : " != ", rhs, ")"); + default: + return CORD_asprintf(ast->tag == Equals ? CORD_EMPTY : "!", + "generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); } - case BINOP_MULT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " * ", rhs, ")"); + } + case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: { + binary_operands_t cmp = BINARY_OPERANDS(ast); + + type_t *lhs_t = get_type(env, cmp.lhs); + type_t *rhs_t = get_type(env, cmp.rhs); + type_t *operand_t; + CORD lhs, rhs; + if (can_compile_to_type(env, cmp.rhs, lhs_t)) { + lhs = compile(env, cmp.lhs); + rhs = compile_to_type(env, cmp.rhs, lhs_t); + operand_t = lhs_t; + } else if (can_compile_to_type(env, cmp.lhs, rhs_t)) { + rhs = compile(env, cmp.rhs); + lhs = compile_to_type(env, cmp.lhs, rhs_t); + operand_t = rhs_t; + } else { + code_err(ast, "I can't do comparisons between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } - case BINOP_DIVIDE: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " / ", rhs, ")"); + + const char *op = binop_operator(ast->tag); + switch (operand_t->tag) { + case BigIntType: + return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") ", op, " 0)"); + case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: + return CORD_all("(", lhs, " ", op, " ", rhs, ")"); + default: + return CORD_all("(generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(Type(OptionalType, operand_t)), ") ", op, " 0)"); } - case BINOP_MOD: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " % ", rhs, ")"); - } - case BINOP_MOD1: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("((((", lhs, ")-1) % (", rhs, ")) + 1)"); - } - case BINOP_PLUS: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " + ", rhs, ")"); - } - case BINOP_MINUS: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " - ", rhs, ")"); - } - case BINOP_LSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " << ", rhs, ")"); - } - case BINOP_RSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", lhs, " >> ", rhs, ")"); - } - case BINOP_ULSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", compile_type(operand_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " << ", rhs, ")"); - } - case BINOP_URSHIFT: { - if (operand_t->tag != IntType && operand_t->tag != NumType && operand_t->tag != ByteType) - code_err(ast, "Math operations are only supported for values of the same numeric type, not ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); - return CORD_all("(", compile_type(operand_t), ")((", compile_unsigned_type(lhs_t), ")", lhs, " >> ", rhs, ")"); - } - case BINOP_EQ: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("Int$equal_value(", lhs, ", ", rhs, ")"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " == ", rhs, ")"); - default: - return CORD_asprintf("generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_NE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("!Int$equal_value(", lhs, ", ", rhs, ")"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("!generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " != ", rhs, ")"); - default: - return CORD_asprintf("!generic_equal(stack(%r), stack(%r), %r)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_LT: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") < 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) < 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " < ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) < 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_LE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") <= 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) <= 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " <= ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) <= 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_GT: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") > 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) > 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " > ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) > 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_GE: { - switch (operand_t->tag) { - case BigIntType: - return CORD_all("(Int$compare_value(", lhs, ", ", rhs, ") >= 0)"); - case BoolType: case ByteType: case IntType: case NumType: case PointerType: case FunctionType: - if (lhs_is_optional_num || rhs_is_optional_num) - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) >= 0)", lhs, rhs, compile_type_info(Type(OptionalType, operand_t))); - return CORD_all("(", lhs, " >= ", rhs, ")"); - default: - return CORD_asprintf("(generic_compare(stack(%r), stack(%r), %r) >= 0)", lhs, rhs, compile_type_info(operand_t)); - } - } - case BINOP_AND: { - if (operand_t->tag == BoolType) - return CORD_all("(", lhs, " && ", rhs, ")"); - else if (operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " & ", rhs, ")"); - else - code_err(ast, "The 'and' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_CMP: { - if (lhs_is_optional_num || rhs_is_optional_num) - operand_t = Type(OptionalType, operand_t); - return CORD_all("generic_compare(stack(", lhs, "), stack(", rhs, "), ", compile_type_info(operand_t), ")"); - } - case BINOP_OR: { - if (operand_t->tag == BoolType) - return CORD_all("(", lhs, " || ", rhs, ")"); - else if (operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " | ", rhs, ")"); - else - code_err(ast, "The 'or' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_XOR: { - // TODO: support optional values in `xor` expressions - if (operand_t->tag == BoolType || operand_t->tag == IntType || operand_t->tag == ByteType) - return CORD_all("(", lhs, " ^ ", rhs, ")"); - else - code_err(ast, "The 'xor' operator isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - case BINOP_CONCAT: { - if (operand_t == PATH_TYPE) - return CORD_all("Path$concat(", lhs, ", ", rhs, ")"); - switch (operand_t->tag) { - case TextType: { - return CORD_all("Text$concat(", lhs, ", ", rhs, ")"); - } - case ArrayType: { - return CORD_all("Array$concat(", lhs, ", ", rhs, ", sizeof(", compile_type(Match(operand_t, ArrayType)->item_type), "))"); - } - default: - code_err(ast, "Concatenation isn't supported between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t), " values"); - } - } - default: break; - } - code_err(ast, "unimplemented binop"); } case TextLiteral: { CORD literal = Match(ast, TextLiteral)->cord; @@ -2715,44 +2714,7 @@ CORD compile(env_t *env, ast_t *ast) return "(Array_t){.length=0}"; type_t *array_type = get_type(env, ast); - type_t *item_type = Match(array_type, ArrayType)->item_type; - - int64_t n = 0; - for (ast_list_t *item = array->items; item; item = item->next) { - ++n; - if (item->ast->tag == Comprehension) - goto array_comprehension; - } - - { - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; - CORD code = CORD_all("TypedArrayN(", compile_type(item_type), CORD_asprintf(", %ld", n)); - for (ast_list_t *item = array->items; item; item = item->next) { - code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); - } - return CORD_cat(code, ")"); - } - - array_comprehension: - { - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); - static int64_t comp_num = 1; - const char *comprehension_name = String("arr$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=array_type, .is_stack=true)); - Closure_t comp_action = {.fn=add_to_array_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - CORD code = CORD_all("({ Array_t ", comprehension_name, " = {};"); - // set_binding(scope, comprehension_name, array_type, comprehension_name); - for (ast_list_t *item = array->items; item; item = item->next) { - if (item->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_array_comprehension(item->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } + return compile_typed_array(env, ast, array_type); } case Table: { auto table = Match(ast, Table); @@ -2764,69 +2726,7 @@ CORD compile(env_t *env, ast_t *ast) } type_t *table_type = get_type(env, ast); - type_t *key_t = Match(table_type, TableType)->key_type; - type_t *value_t = Match(table_type, TableType)->value_type; - - if (value_t->tag == OptionalType) - code_err(ast, "Tables whose values are optional (", type_to_str(value_t), ") are not currently supported."); - - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - if (entry->ast->tag == Comprehension) - goto table_comprehension; - } - - { // No comprehension: - env_t *key_scope = key_t->tag == EnumType ? with_enum_scope(env, key_t) : env; - env_t *value_scope = value_t->tag == EnumType ? with_enum_scope(env, value_t) : env; - CORD code = CORD_all("Table(", - compile_type(key_t), ", ", - compile_type(value_t), ", ", - compile_type_info(key_t), ", ", - compile_type_info(value_t)); - if (table->fallback) - code = CORD_all(code, ", /*fallback:*/ heap(", compile(env, table->fallback), ")"); - else - code = CORD_all(code, ", /*fallback:*/ NULL"); - - size_t n = 0; - for (ast_list_t *entry = table->entries; entry; entry = entry->next) - ++n; - CORD_appendf(&code, ", %zu", n); - - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - auto e = Match(entry->ast, TableEntry); - code = CORD_all(code, ",\n\t{", compile_to_type(key_scope, e->key, key_t), ", ", - compile_to_type(value_scope, e->value, value_t), "}"); - } - return CORD_cat(code, ")"); - } - - table_comprehension: - { - static int64_t comp_num = 1; - env_t *scope = fresh_scope(env); - const char *comprehension_name = String("table$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=table_type, .is_stack=true)); - - CORD code = CORD_all("({ Table_t ", comprehension_name, " = {"); - if (table->fallback) - code = CORD_all(code, ".fallback=heap(", compile(env, table->fallback), "), "); - - code = CORD_cat(code, "};"); - - Closure_t comp_action = {.fn=add_to_table_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - if (entry->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, entry->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_table_comprehension(entry->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } - + return compile_typed_table(env, ast, table_type); } case Set: { auto set = Match(ast, Set); @@ -2834,47 +2734,7 @@ CORD compile(env_t *env, ast_t *ast) return "((Table_t){})"; type_t *set_type = get_type(env, ast); - type_t *item_type = Match(set_type, SetType)->item_type; - - size_t n = 0; - for (ast_list_t *item = set->items; item; item = item->next) { - ++n; - if (item->ast->tag == Comprehension) - goto set_comprehension; - } - - { // No comprehension: - CORD code = CORD_all("Set(", - compile_type(item_type), ", ", - compile_type_info(item_type)); - CORD_appendf(&code, ", %zu", n); - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : env; - for (ast_list_t *item = set->items; item; item = item->next) { - code = CORD_all(code, ", ", compile_to_type(scope, item->ast, item_type)); - } - return CORD_cat(code, ")"); - } - - set_comprehension: - { - static int64_t comp_num = 1; - env_t *scope = item_type->tag == EnumType ? with_enum_scope(env, item_type) : fresh_scope(env); - const char *comprehension_name = String("set$", comp_num++); - ast_t *comprehension_var = FakeAST(InlineCCode, .code=CORD_all("&", comprehension_name), - .type=Type(PointerType, .pointed=set_type, .is_stack=true)); - CORD code = CORD_all("({ Table_t ", comprehension_name, " = {};"); - Closure_t comp_action = {.fn=add_to_set_comprehension, .userdata=comprehension_var}; - scope->comprehension_action = &comp_action; - for (ast_list_t *item = set->items; item; item = item->next) { - if (item->ast->tag == Comprehension) - code = CORD_all(code, "\n", compile_statement(scope, item->ast)); - else - code = CORD_all(code, compile_statement(env, add_to_set_comprehension(item->ast, comprehension_var))); - } - code = CORD_all(code, " ", comprehension_name, "; })"); - return code; - } - + return compile_typed_set(env, ast, set_type); } case Comprehension: { ast_t *base = Match(ast, Comprehension)->expr; @@ -3040,8 +2900,7 @@ CORD compile(env_t *env, ast_t *ast) self = compile_to_pointer_depth(env, call->self, 0, false); arg_t *arg_spec = new(arg_t, .name="count", .type=INT_TYPE, .next=new(arg_t, .name="weights", .type=Type(ArrayType, .item_type=Type(NumType)), - .default_val=FakeAST(None, .type=new(type_ast_t, .tag=ArrayTypeAST, - .__data.ArrayTypeAST.item=new(type_ast_t, .tag=VarTypeAST, .__data.VarTypeAST.name="Num"))), + .default_val=FakeAST(None), .next=new(arg_t, .name="random", .type=random_num_type, .default_val=none_rng))); return CORD_all("Array$sample(", self, ", ", compile_arguments(env, ast, arg_spec, call->args), ", ", padded_item_size, ")"); @@ -3473,7 +3332,7 @@ CORD compile(env_t *env, ast_t *ast) } case Reduction: { auto reduction = Match(ast, Reduction); - binop_e op = reduction->op; + ast_e op = reduction->op; type_t *iter_t = get_type(env, reduction->iter); type_t *item_t = get_iterated_type(iter_t); @@ -3484,7 +3343,7 @@ CORD compile(env_t *env, ast_t *ast) ast_t *body = FakeAST(InlineCCode, .code="{}"); // placeholder ast_t *loop = FakeAST(For, .vars=new(ast_list_t, .ast=item), .iter=reduction->iter, .body=body); env_t *body_scope = for_scope(env, loop); - if (op == BINOP_EQ || op == BINOP_NE || op == BINOP_LT || op == BINOP_LE || op == BINOP_GT || op == BINOP_GE) { + if (op == Equals || op == NotEquals || op == LessThan || op == LessThanOrEquals || op == GreaterThan || op == GreaterThanOrEquals) { // Chained comparisons like ==, <, etc. CORD code = CORD_all( "({ // Reduction:\n", @@ -3492,7 +3351,8 @@ CORD compile(env_t *env, ast_t *ast) "OptionalBool_t result = NONE_BOOL;\n" ); - ast_t *comparison = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .rhs=item); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="prev", .type=item_t), .__data.Plus.rhs=item); body->__data.InlineCCode.code = CORD_all( "if (result == NONE_BOOL) {\n" " prev = ", compile(body_scope, item), ";\n" @@ -3507,9 +3367,9 @@ CORD compile(env_t *env, ast_t *ast) "}\n"); code = CORD_all(code, compile_statement(env, loop), "\nresult;})"); return code; - } else if (op == BINOP_MIN || op == BINOP_MAX) { + } else if (op == Min || op == Max) { // Min/max: - const char *superlative = op == BINOP_MIN ? "min" : "max"; + const char *superlative = op == Min ? "min" : "max"; CORD code = CORD_all( "({ // Reduction:\n", compile_declaration(item_t, superlative), ";\n" @@ -3517,17 +3377,18 @@ CORD compile(env_t *env, ast_t *ast) ); CORD item_code = compile(body_scope, item); - binop_e cmp_op = op == BINOP_MIN ? BINOP_LT : BINOP_GT; + ast_e cmp_op = op == Min ? LessThan : GreaterThan; if (reduction->key) { env_t *key_scope = fresh_scope(env); set_binding(key_scope, "$", item_t, item_code); type_t *key_type = get_type(key_scope, reduction->key); - const char *superlative_key = op == BINOP_MIN ? "min_key" : "max_key"; + const char *superlative_key = op == Min ? "min_key" : "max_key"; code = CORD_all(code, compile_declaration(key_type, superlative_key), ";\n"); - ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, - .lhs=FakeAST(InlineCCode, .code="key", .type=key_type), - .rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type)); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=cmp_op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="key", .type=key_type), + .__data.Plus.rhs=FakeAST(InlineCCode, .code=superlative_key, .type=key_type)); + body->__data.InlineCCode.code = CORD_all( compile_declaration(key_type, "key"), " = ", compile(key_scope, reduction->key), ";\n", "if (!has_value || ", compile(body_scope, comparison), ") {\n" @@ -3536,7 +3397,9 @@ CORD compile(env_t *env, ast_t *ast) " has_value = yes;\n" "}\n"); } else { - ast_t *comparison = WrapAST(ast, BinaryOp, .op=cmp_op, .lhs=item, .rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t)); + ast_t *comparison = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=cmp_op, .__data.Plus.lhs=item, + .__data.Plus.rhs=FakeAST(InlineCCode, .code=superlative, .type=item_t)); body->__data.InlineCCode.code = CORD_all( "if (!has_value || ", compile(body_scope, comparison), ") {\n" " ", superlative, " = ", compile(body_scope, item), ";\n" @@ -3558,22 +3421,24 @@ CORD compile(env_t *env, ast_t *ast) // For the special case of (or)/(and), we need to early out if we can: CORD early_out = CORD_EMPTY; - if (op == BINOP_CMP) { + if (op == Compare) { if (item_t->tag != IntType || Match(item_t, IntType)->bits != TYPE_IBITS32) code_err(ast, "<> reductions are only supported for Int32 values"); - } else if (op == BINOP_AND) { + } else if (op == And) { if (item_t->tag == BoolType) early_out = "if (!reduction) break;"; else if (item_t->tag == OptionalType) early_out = CORD_all("if (", check_none(item_t, "reduction"), ") break;"); - } else if (op == BINOP_OR) { + } else if (op == Or) { if (item_t->tag == BoolType) early_out = "if (reduction) break;"; else if (item_t->tag == OptionalType) early_out = CORD_all("if (!", check_none(item_t, "reduction"), ") break;"); } - ast_t *combination = WrapAST(ast, BinaryOp, .op=op, .lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), .rhs=item); + ast_t *combination = new(ast_t, .file=ast->file, .start=ast->start, .end=ast->end, + .tag=op, .__data.Plus.lhs=FakeAST(InlineCCode, .code="reduction", .type=item_t), + .__data.Plus.rhs=item); body->__data.InlineCCode.code = CORD_all( "if (!has_value) {\n" " reduction = ", compile(body_scope, item), ";\n" @@ -3764,7 +3629,7 @@ CORD compile(env_t *env, ast_t *ast) case Defer: code_err(ast, "Compiling 'defer' as expression!"); case Extern: code_err(ast, "Externs are not supported as expressions"); case TableEntry: code_err(ast, "Table entries should not be compiled directly"); - case Declare: case Assign: case UpdateAssign: case For: case While: case Repeat: case StructDef: case LangDef: case Extend: + case Declare: case Assign: case UPDATE_CASES: case For: case While: case Repeat: case StructDef: case LangDef: case Extend: case EnumDef: case FunctionDef: case ConvertDef: case Skip: case Stop: case Pass: case Return: case DocTest: case PrintStatement: code_err(ast, "This is not a valid expression"); default: case Unknown: code_err(ast, "Unknown AST"); diff --git a/src/environment.c b/src/environment.c index 18818a9..a6a450e 100644 --- a/src/environment.c +++ b/src/environment.c @@ -122,7 +122,7 @@ env_t *global_env(void) {"right_shifted", "Int$right_shifted", "func(x,y:Int -> Int)"}, {"sqrt", "Int$sqrt", "func(x:Int -> Int?)"}, {"times", "Int$times", "func(x,y:Int -> Int)"}, - {"to", "Int$to", "func(first:Int,last:Int,step=none:Int -> func(->Int?))"}, + {"to", "Int$to", "func(first:Int,last:Int,step:Int?=none -> func(->Int?))"}, )}, {"Int64", Type(IntType, .bits=TYPE_IBITS64), "Int64_t", "Int64$info", TypedArray(ns_entry_t, {"abs", "labs", "func(i:Int64 -> Int64)"}, @@ -139,7 +139,7 @@ env_t *global_env(void) {"modulo1", "Int64$modulo1", "func(x,y:Int64 -> Int64)"}, {"octal", "Int64$octal", "func(i:Int64, digits=0, prefix=yes -> Text)"}, {"onward", "Int64$onward", "func(first:Int64,step=Int64(1) -> func(->Int64?))"}, - {"to", "Int64$to", "func(first:Int64,last:Int64,step=none:Int64 -> func(->Int64?))"}, + {"to", "Int64$to", "func(first:Int64,last:Int64,step:Int64?=none -> func(->Int64?))"}, {"unsigned_left_shifted", "Int64$unsigned_left_shifted", "func(x:Int64,y:Int64 -> Int64)"}, {"unsigned_right_shifted", "Int64$unsigned_right_shifted", "func(x:Int64,y:Int64 -> Int64)"}, {"wrapping_minus", "Int64$wrapping_minus", "func(x:Int64,y:Int64 -> Int64)"}, @@ -160,7 +160,7 @@ env_t *global_env(void) {"modulo1", "Int32$modulo1", "func(x,y:Int32 -> Int32)"}, {"octal", "Int32$octal", "func(i:Int32, digits=0, prefix=yes -> Text)"}, {"onward", "Int32$onward", "func(first:Int32,step=Int32(1) -> func(->Int32?))"}, - {"to", "Int32$to", "func(first:Int32,last:Int32,step=none:Int32 -> func(->Int32?))"}, + {"to", "Int32$to", "func(first:Int32,last:Int32,step:Int32?=none -> func(->Int32?))"}, {"unsigned_left_shifted", "Int32$unsigned_left_shifted", "func(x:Int32,y:Int32 -> Int32)"}, {"unsigned_right_shifted", "Int32$unsigned_right_shifted", "func(x:Int32,y:Int32 -> Int32)"}, {"wrapping_minus", "Int32$wrapping_minus", "func(x:Int32,y:Int32 -> Int32)"}, @@ -181,7 +181,7 @@ env_t *global_env(void) {"modulo1", "Int16$modulo1", "func(x,y:Int16 -> Int16)"}, {"octal", "Int16$octal", "func(i:Int16, digits=0, prefix=yes -> Text)"}, {"onward", "Int16$onward", "func(first:Int16,step=Int16(1) -> func(->Int16?))"}, - {"to", "Int16$to", "func(first:Int16,last:Int16,step=none:Int16 -> func(->Int16?))"}, + {"to", "Int16$to", "func(first:Int16,last:Int16,step:Int16?=none -> func(->Int16?))"}, {"unsigned_left_shifted", "Int16$unsigned_left_shifted", "func(x:Int16,y:Int16 -> Int16)"}, {"unsigned_right_shifted", "Int16$unsigned_right_shifted", "func(x:Int16,y:Int16 -> Int16)"}, {"wrapping_minus", "Int16$wrapping_minus", "func(x:Int16,y:Int16 -> Int16)"}, @@ -202,7 +202,7 @@ env_t *global_env(void) {"modulo1", "Int8$modulo1", "func(x,y:Int8 -> Int8)"}, {"octal", "Int8$octal", "func(i:Int8, digits=0, prefix=yes -> Text)"}, {"onward", "Int8$onward", "func(first:Int8,step=Int8(1) -> func(->Int8?))"}, - {"to", "Int8$to", "func(first:Int8,last:Int8,step=none:Int8 -> func(->Int8?))"}, + {"to", "Int8$to", "func(first:Int8,last:Int8,step:Int8?=none -> func(->Int8?))"}, {"unsigned_left_shifted", "Int8$unsigned_left_shifted", "func(x:Int8,y:Int8 -> Int8)"}, {"unsigned_right_shifted", "Int8$unsigned_right_shifted", "func(x:Int8,y:Int8 -> Int8)"}, {"wrapping_minus", "Int8$wrapping_minus", "func(x:Int8,y:Int8 -> Int8)"}, @@ -310,11 +310,11 @@ env_t *global_env(void) {"owner", "Path$owner", "func(path:Path, follow_symlinks=yes -> Text?)"}, {"parent", "Path$parent", "func(path:Path -> Path)"}, {"read", "Path$read", "func(path:Path -> Text?)"}, - {"read_bytes", "Path$read_bytes", "func(path:Path, limit=none:Int -> [Byte]?)"}, + {"read_bytes", "Path$read_bytes", "func(path:Path, limit:Int?=none -> [Byte]?)"}, {"relative_to", "Path$relative_to", "func(path:Path, relative_to:Path -> Path)"}, {"remove", "Path$remove", "func(path:Path, ignore_missing=no)"}, {"resolved", "Path$resolved", "func(path:Path, relative_to=(./) -> Path)"}, - {"set_owner", "Path$set_owner", "func(path:Path, owner=none:Text, group=none:Text, follow_symlinks=yes)"}, + {"set_owner", "Path$set_owner", "func(path:Path, owner:Text?=none, group:Text?=none, follow_symlinks=yes)"}, {"subdirectories", "Path$children", "func(path:Path, include_hidden=no -> [Path])"}, {"unique_directory", "Path$unique_directory", "func(path:Path -> Path)"}, {"write", "Path$write", "func(path:Path, text:Text, permissions=Int32(0o644))"}, @@ -508,7 +508,7 @@ env_t *global_env(void) {"say", "say", "func(text:Text, newline=yes)"}, {"print", "say", "func(text:Text, newline=yes)"}, {"ask", "ask", "func(prompt:Text, bold=yes, force_tty=yes -> Text?)"}, - {"exit", "tomo_exit", "func(message=none:Text, code=Int32(1) -> Abort)"}, + {"exit", "tomo_exit", "func(message:Text?=none, code=Int32(1) -> Abort)"}, {"fail", "fail_text", "func(message:Text -> Abort)"}, {"sleep", "sleep_num", "func(seconds:Num)"}, }; @@ -749,6 +749,18 @@ PUREFUNC binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args) return NULL; } +PUREFUNC binding_t *get_metamethod_binding(env_t *env, ast_e tag, ast_t *lhs, ast_t *rhs, type_t *ret) +{ + const char *method_name = binop_method_name(tag); + if (!method_name) return NULL; + binding_t *b = get_namespace_binding(env, lhs, method_name); + if (!b || b->type->tag != FunctionType) return NULL; + auto fn = Match(b->type, FunctionType); + if (!type_eq(fn->ret, ret)) return NULL; + arg_ast_t *args = new(arg_ast_t, .value=lhs, .next=new(arg_ast_t, .value=rhs)); + return is_valid_call(env, fn->args, args, true) ? b : NULL; +} + void set_binding(env_t *env, const char *name, type_t *type, CORD code) { assert(name); diff --git a/src/environment.h b/src/environment.h index fce6bc9..cbaae09 100644 --- a/src/environment.h +++ b/src/environment.h @@ -85,6 +85,7 @@ env_t *namespace_env(env_t *env, const char *namespace_name); }) binding_t *get_binding(env_t *env, const char *name); binding_t *get_constructor(env_t *env, type_t *t, arg_ast_t *args); +PUREFUNC binding_t *get_metamethod_binding(env_t *env, ast_e tag, ast_t *lhs, ast_t *rhs, type_t *ret); void set_binding(env_t *env, const char *name, type_t *type, CORD code); binding_t *get_namespace_binding(env_t *env, ast_t *self, const char *name); #define code_err(ast, ...) compiler_err((ast)->file, (ast)->start, (ast)->end, __VA_ARGS__) diff --git a/src/parse.c b/src/parse.c index 63f5deb..0aa2600 100644 --- a/src/parse.c +++ b/src/parse.c @@ -48,16 +48,16 @@ typedef struct { #define PARSER(name) ast_t *name(parse_ctx_t *ctx, const char *pos) int op_tightness[] = { - [BINOP_POWER]=9, - [BINOP_MULT]=8, [BINOP_DIVIDE]=8, [BINOP_MOD]=8, [BINOP_MOD1]=8, - [BINOP_PLUS]=7, [BINOP_MINUS]=7, - [BINOP_CONCAT]=6, - [BINOP_LSHIFT]=5, [BINOP_RSHIFT]=5, - [BINOP_MIN]=4, [BINOP_MAX]=4, - [BINOP_EQ]=3, [BINOP_NE]=3, - [BINOP_LT]=2, [BINOP_LE]=2, [BINOP_GT]=2, [BINOP_GE]=2, - [BINOP_CMP]=2, - [BINOP_AND]=1, [BINOP_OR]=1, [BINOP_XOR]=1, + [Power]=9, + [Multiply]=8, [Divide]=8, [Mod]=8, [Mod1]=8, + [Plus]=7, [Minus]=7, + [Concat]=6, + [LeftShift]=5, [RightShift]=5, [UnsignedLeftShift]=5, [UnsignedRightShift]=5, + [Min]=4, [Max]=4, + [Equals]=3, [NotEquals]=3, + [LessThan]=2, [LessThanOrEquals]=2, [GreaterThan]=2, [GreaterThanOrEquals]=2, + [Compare]=2, + [And]=1, [Or]=1, [Xor]=1, }; static const char *keywords[] = { @@ -79,7 +79,7 @@ static INLINE const char* get_word(const char **pos); static INLINE const char* get_id(const char **pos); static INLINE bool comment(const char **pos); static INLINE bool indent(parse_ctx_t *ctx, const char **pos); -static INLINE binop_e match_binary_operator(const char **pos); +static INLINE ast_e match_binary_operator(const char **pos); static ast_t *parse_comprehension_suffix(parse_ctx_t *ctx, ast_t *expr); static ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs); static ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn); @@ -685,15 +685,6 @@ PARSER(parse_array) { whitespace(&pos); ast_list_t *items = NULL; - type_ast_t *item_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a type for this array"); - whitespace(&pos); - match(&pos, ","); - whitespace(&pos); - } - for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; @@ -711,7 +702,7 @@ PARSER(parse_array) { expect_closing(ctx, &pos, "]", "I wasn't able to parse the rest of this array"); REVERSE_LIST(items); - return NewAST(ctx->file, start, pos, Array, .item_type=item_type, .items=items); + return NewAST(ctx->file, start, pos, Array, .items=items); } PARSER(parse_table) { @@ -722,20 +713,6 @@ PARSER(parse_table) { whitespace(&pos); ast_list_t *entries = NULL; - type_ast_t *key_type = NULL, *value_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - key_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this table"); - whitespace(&pos); - if (match(&pos, "=")) { - value_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse the value type for this table"); - } else { - return NULL; - } - whitespace(&pos); - match(&pos, ","); - } - for (;;) { const char *entry_start = pos; ast_t *key = optional(ctx, &pos, parse_extended_expr); @@ -787,8 +764,7 @@ PARSER(parse_table) { whitespace(&pos); expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this table"); - return NewAST(ctx->file, start, pos, Table, .key_type=key_type, .value_type=value_type, - .default_value=default_value, .entries=entries, .fallback=fallback); + return NewAST(ctx->file, start, pos, Table, .default_value=default_value, .entries=entries, .fallback=fallback); } PARSER(parse_set) { @@ -801,18 +777,6 @@ PARSER(parse_set) { whitespace(&pos); ast_list_t *items = NULL; - type_ast_t *item_type = NULL; - if (match(&pos, ":")) { - whitespace(&pos); - item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a key type for this set"); - whitespace(&pos); - if (match(&pos, ",")) - return NULL; - whitespace(&pos); - match(&pos, ","); - whitespace(&pos); - } - for (;;) { ast_t *item = optional(ctx, &pos, parse_extended_expr); if (!item) break; @@ -834,7 +798,7 @@ PARSER(parse_set) { whitespace(&pos); expect_closing(ctx, &pos, "}", "I wasn't able to parse the rest of this set"); - return NewAST(ctx->file, start, pos, Set, .item_type=item_type, .items=items); + return NewAST(ctx->file, start, pos, Set, .items=items); } ast_t *parse_field_suffix(parse_ctx_t *ctx, ast_t *lhs) { @@ -874,11 +838,11 @@ PARSER(parse_reduction) { if (!match(&pos, "(")) return NULL; whitespace(&pos); - binop_e op = match_binary_operator(&pos); - if (op == BINOP_UNKNOWN) return NULL; + ast_e op = match_binary_operator(&pos); + if (op == Unknown) return NULL; ast_t *key = NULL; - if (op == BINOP_MIN || op == BINOP_MAX) { + if (op == Min || op == Max) { key = NewAST(ctx->file, pos, pos, Var, .name="$"); for (bool progress = true; progress; ) { ast_t *new_term; @@ -1425,16 +1389,7 @@ PARSER(parse_none) { const char *start = pos; if (!match_word(&pos, "none")) return NULL; - - const char *none_end = pos; - spaces(&pos); - if (!match(&pos, ":")) - return NewAST(ctx->file, start, none_end, None, .type=NULL); - - spaces(&pos); - type_ast_t *type = parse_type(ctx, pos); - if (!type) return NULL; - return NewAST(ctx->file, start, type->end, None, .type=type); + return NewAST(ctx->file, start, pos, None); } PARSER(parse_deserialize) { @@ -1602,53 +1557,53 @@ ast_t *parse_fncall_suffix(parse_ctx_t *ctx, ast_t *fn) { return NewAST(ctx->file, start, pos, FunctionCall, .fn=fn, .args=args); } -binop_e match_binary_operator(const char **pos) +ast_e match_binary_operator(const char **pos) { switch (**pos) { case '+': { *pos += 1; - return match(pos, "+") ? BINOP_CONCAT : BINOP_PLUS; + return match(pos, "+") ? Concat : Plus; } case '-': { *pos += 1; if ((*pos)[0] != ' ' && (*pos)[-2] == ' ') // looks like `fn -5` - return BINOP_UNKNOWN; - return BINOP_MINUS; + return Unknown; + return Minus; } - case '*': *pos += 1; return BINOP_MULT; - case '/': *pos += 1; return BINOP_DIVIDE; - case '^': *pos += 1; return BINOP_POWER; + case '*': *pos += 1; return Multiply; + case '/': *pos += 1; return Divide; + case '^': *pos += 1; return Power; case '<': { *pos += 1; - if (match(pos, "=")) return BINOP_LE; // "<=" - else if (match(pos, ">")) return BINOP_CMP; // "<>" + if (match(pos, "=")) return LessThanOrEquals; // "<=" + else if (match(pos, ">")) return Compare; // "<>" else if (match(pos, "<")) { if (match(pos, "<")) - return BINOP_ULSHIFT; // "<<<" - return BINOP_LSHIFT; // "<<" - } else return BINOP_LT; + return UnsignedLeftShift; // "<<<" + return LeftShift; // "<<" + } else return LessThan; } case '>': { *pos += 1; - if (match(pos, "=")) return BINOP_GE; // ">=" + if (match(pos, "=")) return GreaterThanOrEquals; // ">=" if (match(pos, ">")) { if (match(pos, ">")) - return BINOP_URSHIFT; // ">>>" - return BINOP_RSHIFT; // ">>" + return UnsignedRightShift; // ">>>" + return RightShift; // ">>" } - return BINOP_GT; + return GreaterThan; } default: { - if (match(pos, "!=")) return BINOP_NE; - else if (match(pos, "==") && **pos != '=') return BINOP_EQ; - else if (match_word(pos, "and")) return BINOP_AND; - else if (match_word(pos, "or")) return BINOP_OR; - else if (match_word(pos, "xor")) return BINOP_XOR; - else if (match_word(pos, "mod1")) return BINOP_MOD1; - else if (match_word(pos, "mod")) return BINOP_MOD; - else if (match_word(pos, "_min_")) return BINOP_MIN; - else if (match_word(pos, "_max_")) return BINOP_MAX; - else return BINOP_UNKNOWN; + if (match(pos, "!=")) return NotEquals; + else if (match(pos, "==") && **pos != '=') return Equals; + else if (match_word(pos, "and")) return And; + else if (match_word(pos, "or")) return Or; + else if (match_word(pos, "xor")) return Xor; + else if (match_word(pos, "mod1")) return Mod1; + else if (match_word(pos, "mod")) return Mod; + else if (match_word(pos, "_min_")) return Min; + else if (match_word(pos, "_max_")) return Max; + else return Unknown; } } } @@ -1660,9 +1615,9 @@ static ast_t *parse_infix_expr(parse_ctx_t *ctx, const char *pos, int min_tightn int64_t starting_line = get_line_number(ctx->file, pos); int64_t starting_indent = get_indent(ctx, pos); spaces(&pos); - for (binop_e op; (op=match_binary_operator(&pos)) != BINOP_UNKNOWN && op_tightness[op] >= min_tightness; spaces(&pos)) { + for (ast_e op; (op=match_binary_operator(&pos)) != Unknown && op_tightness[op] >= min_tightness; spaces(&pos)) { ast_t *key = NULL; - if (op == BINOP_MIN || op == BINOP_MAX) { + if (op == Min || op == Max) { key = NewAST(ctx->file, pos, pos, Var, .name="$"); for (bool progress = true; progress; ) { ast_t *new_term; @@ -1688,12 +1643,12 @@ static ast_t *parse_infix_expr(parse_ctx_t *ctx, const char *pos, int min_tightn if (!rhs) break; pos = rhs->end; - if (op == BINOP_MIN) { + if (op == Min) { return NewAST(ctx->file, lhs->start, rhs->end, Min, .lhs=lhs, .rhs=rhs, .key=key); - } else if (op == BINOP_MAX) { + } else if (op == Max) { return NewAST(ctx->file, lhs->start, rhs->end, Max, .lhs=lhs, .rhs=rhs, .key=key); } else { - lhs = NewAST(ctx->file, lhs->start, rhs->end, BinaryOp, .lhs=lhs, .op=op, .rhs=rhs); + lhs = new(ast_t, .file=ctx->file, .start=lhs->start, .end=rhs->end, .tag=op, .__data.Plus.lhs=lhs, .__data.Plus.rhs=rhs); } } return lhs; @@ -1709,8 +1664,11 @@ PARSER(parse_declaration) { if (!var) return NULL; pos = var->end; spaces(&pos); - if (!match(&pos, ":=")) return NULL; + if (!match(&pos, ":")) return NULL; spaces(&pos); + type_ast_t *type = optional(ctx, &pos, parse_type); + spaces(&pos); + if (!match(&pos, "=")) return NULL; ast_t *val = optional(ctx, &pos, parse_extended_expr); if (!val) { if (optional(ctx, &pos, parse_use)) @@ -1718,7 +1676,7 @@ PARSER(parse_declaration) { else parser_err(ctx, pos, eol(pos), "This is not a valid expression"); } - return NewAST(ctx->file, start, pos, Declare, .var=var, .value=val); + return NewAST(ctx->file, start, pos, Declare, .var=var, .type=type, .value=val); } PARSER(parse_update) { @@ -1726,23 +1684,23 @@ PARSER(parse_update) { ast_t *lhs = optional(ctx, &pos, parse_expr); if (!lhs) return NULL; spaces(&pos); - binop_e op; - if (match(&pos, "+=")) op = BINOP_PLUS; - else if (match(&pos, "++=")) op = BINOP_CONCAT; - else if (match(&pos, "-=")) op = BINOP_MINUS; - else if (match(&pos, "*=")) op = BINOP_MULT; - else if (match(&pos, "/=")) op = BINOP_DIVIDE; - else if (match(&pos, "^=")) op = BINOP_POWER; - else if (match(&pos, "<<=")) op = BINOP_LSHIFT; - else if (match(&pos, "<<<=")) op = BINOP_ULSHIFT; - else if (match(&pos, ">>=")) op = BINOP_RSHIFT; - else if (match(&pos, ">>>=")) op = BINOP_URSHIFT; - else if (match(&pos, "and=")) op = BINOP_AND; - else if (match(&pos, "or=")) op = BINOP_OR; - else if (match(&pos, "xor=")) op = BINOP_XOR; + ast_e op; + if (match(&pos, "+=")) op = Plus; + else if (match(&pos, "++=")) op = Concat; + else if (match(&pos, "-=")) op = Minus; + else if (match(&pos, "*=")) op = Multiply; + else if (match(&pos, "/=")) op = Divide; + else if (match(&pos, "^=")) op = Power; + else if (match(&pos, "<<=")) op = LeftShift; + else if (match(&pos, "<<<=")) op = UnsignedLeftShift; + else if (match(&pos, ">>=")) op = RightShift; + else if (match(&pos, ">>>=")) op = UnsignedRightShift; + else if (match(&pos, "and=")) op = And; + else if (match(&pos, "or=")) op = Or; + else if (match(&pos, "xor=")) op = Xor; else return NULL; ast_t *rhs = expect(ctx, start, &pos, parse_extended_expr, "I expected an expression here"); - return NewAST(ctx->file, start, pos, UpdateAssign, .lhs=lhs, .rhs=rhs, .op=op); + return new(ast_t, .file=ctx->file, .start=start, .end=pos, .tag=op, .__data.PlusUpdate.lhs=lhs, .__data.PlusUpdate.rhs=rhs); } PARSER(parse_assignment) { diff --git a/src/repl.c b/src/repl.c index 2f5c60f..463e7ff 100644 --- a/src/repl.c +++ b/src/repl.c @@ -186,31 +186,31 @@ static Int_t ast_to_int(env_t *env, ast_t *ast) } } -static double ast_to_num(env_t *env, ast_t *ast) -{ - type_t *t = get_type(env, ast); - switch (t->tag) { - case BigIntType: case IntType: { - number_t num; - eval(env, ast, &num); - if (t->tag == BigIntType) - return Num$from_int(num.integer, false); - switch (Match(t, IntType)->bits) { - case TYPE_IBITS64: return Num$from_int64(num.i64, false); - case TYPE_IBITS32: return Num$from_int32(num.i32); - case TYPE_IBITS16: return Num$from_int16(num.i16); - case TYPE_IBITS8: return Num$from_int8(num.i8); - default: print_err("Invalid int bits"); - } - } - case NumType: { - number_t num; - eval(env, ast, &num); - return Match(t, NumType)->bits == TYPE_NBITS32 ? (double)num.n32 : (double)num.n64; - } - default: print_err("Cannot convert to number"); - } -} +// static double ast_to_num(env_t *env, ast_t *ast) +// { +// type_t *t = get_type(env, ast); +// switch (t->tag) { +// case BigIntType: case IntType: { +// number_t num; +// eval(env, ast, &num); +// if (t->tag == BigIntType) +// return Num$from_int(num.integer, false); +// switch (Match(t, IntType)->bits) { +// case TYPE_IBITS64: return Num$from_int64(num.i64, false); +// case TYPE_IBITS32: return Num$from_int32(num.i32); +// case TYPE_IBITS16: return Num$from_int16(num.i16); +// case TYPE_IBITS8: return Num$from_int8(num.i8); +// default: print_err("Invalid int bits"); +// } +// } +// case NumType: { +// number_t num; +// eval(env, ast, &num); +// return Match(t, NumType)->bits == TYPE_NBITS32 ? (double)num.n32 : (double)num.n64; +// } +// default: print_err("Cannot convert to number"); +// } +// } static Text_t obj_to_text(type_t *t, const void *obj, bool use_color) { @@ -386,76 +386,6 @@ void eval(env_t *env, ast_t *ast, void *dest) if (dest) *(CORD*)dest = ret; break; } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - if (t->tag == IntType || t->tag == BigIntType) { -#define CASE_OP(OP_NAME, method_name) case BINOP_##OP_NAME: {\ - Int_t lhs = ast_to_int(env, binop->lhs); \ - Int_t rhs = ast_to_int(env, binop->rhs); \ - Int_t result = Int$ ## method_name (lhs, rhs); \ - if (t->tag == BigIntType) {\ - *(Int_t*)dest = result; \ - return; \ - } \ - switch (Match(t, IntType)->bits) { \ - case 64: *(int64_t*)dest = Int64$from_int(result, false); return; \ - case 32: *(int32_t*)dest = Int32$from_int(result, false); return; \ - case 16: *(int16_t*)dest = Int16$from_int(result, false); return; \ - case 8: *(int8_t*)dest = Int8$from_int(result, false); return; \ - default: print_err("Invalid int bits"); \ - } \ - break; \ - } - switch (binop->op) { - CASE_OP(MULT, times) CASE_OP(DIVIDE, divided_by) CASE_OP(PLUS, plus) CASE_OP(MINUS, minus) - CASE_OP(RSHIFT, right_shifted) CASE_OP(LSHIFT, left_shifted) - CASE_OP(MOD, modulo) CASE_OP(MOD1, modulo1) - CASE_OP(AND, bit_and) CASE_OP(OR, bit_or) CASE_OP(XOR, bit_xor) - default: break; - } -#undef CASE_OP - } else if (t->tag == NumType) { -#define CASE_OP(OP_NAME, C_OP) case BINOP_##OP_NAME: {\ - double lhs = ast_to_num(env, binop->lhs); \ - double rhs = ast_to_num(env, binop->rhs); \ - if (Match(t, NumType)->bits == 64) \ - *(double*)dest = (double)(lhs C_OP rhs); \ - else \ - *(float*)dest = (float)(lhs C_OP rhs); \ - return; \ - } - switch (binop->op) { - CASE_OP(MULT, *) CASE_OP(DIVIDE, /) CASE_OP(PLUS, +) CASE_OP(MINUS, -) - default: break; - } -#undef CASE_OP - } - switch (binop->op) { - case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: { - type_t *t_lhs = get_type(env, binop->lhs); - if (!type_eq(t_lhs, get_type(env, binop->rhs))) - print_err("Comparisons between different types aren't supported"); - const TypeInfo_t *info = type_to_type_info(t_lhs); - size_t value_size = type_size(t_lhs); - char lhs[value_size], rhs[value_size]; - eval(env, binop->lhs, lhs); - eval(env, binop->rhs, rhs); - int cmp = generic_compare(lhs, rhs, info); - switch (binop->op) { - case BINOP_EQ: *(bool*)dest = (cmp == 0); break; - case BINOP_NE: *(bool*)dest = (cmp != 0); break; - case BINOP_GT: *(bool*)dest = (cmp > 0); break; - case BINOP_GE: *(bool*)dest = (cmp >= 0); break; - case BINOP_LT: *(bool*)dest = (cmp < 0); break; - case BINOP_LE: *(bool*)dest = (cmp <= 0); break; - default: break; - } - break; - } - default: print_err(1, "Binary op not implemented for ", type_to_str(t), ": ", ast_to_xml_str(ast)); - } - break; - } case Index: { auto index = Match(ast, Index); type_t *indexed_t = get_type(env, index->indexed); diff --git a/src/typecheck.c b/src/typecheck.c index 8a2ee32..cd6ff1c 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -135,27 +135,27 @@ type_t *parse_type_ast(env_t *env, type_ast_t *ast) errx(1, "Unreachable"); } -static PUREFUNC bool risks_zero_or_inf(ast_t *ast) -{ - switch (ast->tag) { - case Int: { - const char *str = Match(ast, Int)->str; - OptionalInt_t int_val = Int$from_str(str); - return (int_val.small == 0x1); // zero - } - case Num: { - return Match(ast, Num)->n == 0.0; - } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - if (binop->op == BINOP_MULT || binop->op == BINOP_DIVIDE || binop->op == BINOP_MIN || binop->op == BINOP_MAX) - return risks_zero_or_inf(binop->lhs) || risks_zero_or_inf(binop->rhs); - else - return true; - } - default: return true; - } -} +// static PUREFUNC bool risks_zero_or_inf(ast_t *ast) +// { +// switch (ast->tag) { +// case Int: { +// const char *str = Match(ast, Int)->str; +// OptionalInt_t int_val = Int$from_str(str); +// return (int_val.small == 0x1); // zero +// } +// case Num: { +// return Match(ast, Num)->n == 0.0; +// } +// case BINOP_CASES: { +// binary_operands_t binop = BINARY_OPERANDS(ast); +// if (ast->tag == Multiply || ast->tag == Divide || ast->tag == Min || ast->tag == Max) +// return risks_zero_or_inf(binop.lhs) || risks_zero_or_inf(binop.rhs); +// else +// return true; +// } +// default: return true; +// } +// } PUREFUNC type_t *get_math_type(env_t *env, ast_t *ast, type_t *lhs_t, type_t *rhs_t) { @@ -312,7 +312,7 @@ void bind_statement(env_t *env, ast_t *statement) if (get_binding(env, name)) code_err(decl->var, "A ", type_to_str(get_binding(env, name)->type), " called ", quoted(name), " has already been defined"); bind_statement(env, decl->value); - type_t *type = get_type(env, decl->value); + type_t *type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); if (!type) code_err(decl->value, "I couldn't figure out the type of this value"); if (type->tag == FunctionType) @@ -617,12 +617,7 @@ type_t *get_type(env_t *env, ast_t *ast) #endif switch (ast->tag) { case None: { - if (!Match(ast, None)->type) - return Type(OptionalType, .type=NULL); - type_t *t = parse_type_ast(env, Match(ast, None)->type); - if (t->tag == OptionalType) - code_err(ast, "Nested optional types are not supported. This should be: `none:", type_to_str(Match(t, OptionalType)->type), "`"); - return Type(OptionalType, .type=t); + return Type(OptionalType, .type=NULL); } case Bool: { return Type(BoolType); @@ -714,115 +709,91 @@ type_t *get_type(env_t *env, ast_t *ast) case Array: { auto array = Match(ast, Array); type_t *item_type = NULL; - if (array->item_type) { - item_type = parse_type_ast(env, array->item_type); - } else if (array->items) { - for (ast_list_t *item = array->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - type_t *t2 = get_type(scope, item_ast); - type_t *merged = item_type ? type_or_type(item_type, t2) : t2; - if (!merged) - code_err(item->ast, - "This array item has type ", type_to_str(t2), - ", which is different from earlier array items which have type ", type_to_str(item_type)); - item_type = merged; + for (ast_list_t *item = array->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } - } else { - code_err(ast, "I can't figure out what type this array has because it has no members or explicit type"); + type_t *t2 = get_type(scope, item_ast); + type_t *merged = item_type ? type_or_type(item_type, t2) : t2; + if (!merged) + code_err(item->ast, + "This array item has type ", type_to_str(t2), + ", which is different from earlier array items which have type ", type_to_str(item_type)); + item_type = merged; } - if (has_stack_memory(item_type)) - code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this array!"); + if (item_type && has_stack_memory(item_type)) + code_err(ast, "Arrays cannot hold stack references, because the array may outlive the stack frame the reference was created in."); return Type(ArrayType, .item_type=item_type); } case Set: { auto set = Match(ast, Set); type_t *item_type = NULL; - if (set->item_type) { - item_type = parse_type_ast(env, set->item_type); - } else { - for (ast_list_t *item = set->items; item; item = item->next) { - ast_t *item_ast = item->ast; - env_t *scope = env; - while (item_ast->tag == Comprehension) { - auto comp = Match(item_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - item_ast = comp->expr; - } - - type_t *this_item_type = get_type(scope, item_ast); - type_t *item_merged = type_or_type(item_type, this_item_type); - if (!item_merged) - code_err(item_ast, - "This set item has type ", type_to_str(this_item_type), - ", which is different from earlier set items which have type ", type_to_str(item_type)); - item_type = item_merged; + for (ast_list_t *item = set->items; item; item = item->next) { + ast_t *item_ast = item->ast; + env_t *scope = env; + while (item_ast->tag == Comprehension) { + auto comp = Match(item_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + item_ast = comp->expr; } + + type_t *this_item_type = get_type(scope, item_ast); + type_t *item_merged = type_or_type(item_type, this_item_type); + if (!item_merged) + code_err(item_ast, + "This set item has type ", type_to_str(this_item_type), + ", which is different from earlier set items which have type ", type_to_str(item_type)); + item_type = item_merged; } - if (!item_type) - code_err(ast, "I couldn't figure out the item type for this set!"); - - if (has_stack_memory(item_type)) + if (item_type && has_stack_memory(item_type)) code_err(ast, "Sets cannot hold stack references because the set may outlive the reference's stack frame."); + return Type(SetType, .item_type=item_type); } case Table: { auto table = Match(ast, Table); type_t *key_type = NULL, *value_type = NULL; - if (table->key_type && table->value_type) { - key_type = parse_type_ast(env, table->key_type); - value_type = parse_type_ast(env, table->value_type); - } else if (table->key_type && table->default_value) { - key_type = parse_type_ast(env, table->key_type); - value_type = get_type(env, table->default_value); - } else { - for (ast_list_t *entry = table->entries; entry; entry = entry->next) { - ast_t *entry_ast = entry->ast; - env_t *scope = env; - while (entry_ast->tag == Comprehension) { - auto comp = Match(entry_ast, Comprehension); - scope = for_scope( - scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); - entry_ast = comp->expr; - } - - auto e = Match(entry_ast, TableEntry); - type_t *key_t = get_type(scope, e->key); - type_t *value_t = get_type(scope, e->value); - - type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; - if (!key_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(key_t), - ", which is different from earlier table entries which have type ", type_to_str(key_type)); - key_type = key_merged; - - type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; - if (!val_merged) - code_err(entry->ast, - "This table entry has type ", type_to_str(value_t), - ", which is different from earlier table entries which have type ", type_to_str(value_type)); - value_type = val_merged; + for (ast_list_t *entry = table->entries; entry; entry = entry->next) { + ast_t *entry_ast = entry->ast; + env_t *scope = env; + while (entry_ast->tag == Comprehension) { + auto comp = Match(entry_ast, Comprehension); + scope = for_scope( + scope, FakeAST(For, .iter=comp->iter, .vars=comp->vars)); + entry_ast = comp->expr; } + + auto e = Match(entry_ast, TableEntry); + type_t *key_t = get_type(scope, e->key); + type_t *value_t = get_type(scope, e->value); + + type_t *key_merged = key_type ? type_or_type(key_type, key_t) : key_t; + if (!key_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(key_t), + ", which is different from earlier table entries which have type ", type_to_str(key_type)); + key_type = key_merged; + + type_t *val_merged = value_type ? type_or_type(value_type, value_t) : value_t; + if (!val_merged) + code_err(entry->ast, + "This table entry has type ", type_to_str(value_t), + ", which is different from earlier table entries which have type ", type_to_str(value_type)); + value_type = val_merged; } - if (!key_type || !value_type) - code_err(ast, "I couldn't figure out the key and value types for this table!"); - - if (has_stack_memory(key_type) || has_stack_memory(value_type)) + if ((key_type && has_stack_memory(key_type)) || (value_type && has_stack_memory(value_type))) code_err(ast, "Tables cannot hold stack references because the table may outlive the reference's stack frame."); + return Type(TableType, .key_type=key_type, .value_type=value_type, .default_value=table->default_value, .env=env); } case TableEntry: { @@ -998,7 +969,7 @@ type_t *get_type(env_t *env, ast_t *ast) // Early out if the type is knowable without any context from the block: switch (last->ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Extend: return Type(VoidType); default: break; } @@ -1022,7 +993,7 @@ type_t *get_type(env_t *env, ast_t *ast) case Extern: { return parse_type_ast(env, Match(ast, Extern)->type); } - case Declare: case Assign: case DocTest: { + case Declare: case Assign: case UPDATE_CASES: case DocTest: { return Type(VoidType); } case Use: { @@ -1078,169 +1049,160 @@ type_t *get_type(env_t *env, ast_t *ast) } code_err(ast, "I only know how to get 'not' of boolean, numeric, and optional pointer types, not ", type_to_str(t)); } - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - type_t *lhs_t = get_type(env, binop->lhs), - *rhs_t = get_type(env, binop->rhs); + case Or: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); - if (lhs_t->tag == BigIntType && rhs_t->tag != BigIntType && is_numeric_type(rhs_t) && binop->lhs->tag == Int) { - lhs_t = rhs_t; - } else if (rhs_t->tag == BigIntType && lhs_t->tag != BigIntType && is_numeric_type(lhs_t) && binop->rhs->tag == Int) { - - rhs_t = lhs_t; - } - -#define binding_works(name, self, lhs_t, rhs_t, ret_t) \ - ({ binding_t *b = get_namespace_binding(env, self, name); \ - (b && b->type->tag == FunctionType && ({ auto fn = Match(b->type, FunctionType); \ - (type_eq(fn->ret, ret_t) \ - && (fn->args && type_eq(fn->args->type, lhs_t)) \ - && (fn->args->next && can_promote(rhs_t, fn->args->next->type))); })); }) - // Check for a binop method like plus() etc: - switch (binop->op) { - case BINOP_MULT: { - if (is_numeric_type(lhs_t) && binding_works("scaled_by", binop->rhs, rhs_t, lhs_t, rhs_t)) - return rhs_t; - else if (is_numeric_type(rhs_t) && binding_works("scaled_by", binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - else if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - case BINOP_PLUS: case BINOP_MINUS: case BINOP_AND: case BINOP_OR: case BINOP_XOR: case BINOP_CONCAT: { - if (type_eq(lhs_t, rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - case BINOP_DIVIDE: case BINOP_MOD: case BINOP_MOD1: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - case BINOP_LSHIFT: case BINOP_RSHIFT: case BINOP_ULSHIFT: case BINOP_URSHIFT: { - return lhs_t; - } - case BINOP_POWER: { - if (is_numeric_type(rhs_t) && binding_works(binop_method_names[binop->op], binop->lhs, lhs_t, rhs_t, lhs_t)) - return lhs_t; - break; - } - default: break; - } -#undef binding_works - - switch (binop->op) { - case BINOP_AND: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { - return lhs_t; - } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return lhs_t; - } else if (rhs_t->tag == OptionalType) { - if (can_promote(lhs_t, rhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType && rhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(lhs_ptr->pointed, rhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `and` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); - } - case BINOP_OR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if (lhs_t->tag == BoolType && (rhs_t->tag == AbortType || rhs_t->tag == ReturnType)) { - return lhs_t; - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } else if (lhs_t->tag == OptionalType) { - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) - return Match(lhs_t, OptionalType)->type; - if (can_promote(rhs_t, lhs_t)) - return rhs_t; - } else if (lhs_t->tag == PointerType) { - auto lhs_ptr = Match(lhs_t, PointerType); - if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { - return Type(PointerType, .pointed=lhs_ptr->pointed); - } else if (rhs_t->tag == PointerType) { - auto rhs_ptr = Match(rhs_t, PointerType); - if (type_eq(rhs_ptr->pointed, lhs_ptr->pointed)) - return Type(PointerType, .pointed=lhs_ptr->pointed); - } - } else if (rhs_t->tag == OptionalType) { - return type_or_type(lhs_t, rhs_t); - } - code_err(ast, "I can't figure out the type of this `or` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); - } - case BINOP_XOR: { - if (lhs_t->tag == BoolType && rhs_t->tag == BoolType) { - return lhs_t; - } else if ((lhs_t->tag == BoolType && rhs_t->tag == OptionalType) || - (lhs_t->tag == OptionalType && rhs_t->tag == BoolType)) { - return Type(BoolType); - } else if ((is_int_type(lhs_t) && is_int_type(rhs_t)) - || (lhs_t->tag == ByteType && rhs_t->tag == ByteType)) { - return get_math_type(env, ast, lhs_t, rhs_t); - } - - code_err(ast, "I can't figure out the type of this `xor` expression between a ", type_to_str(lhs_t), " and a ", type_to_str(rhs_t)); - } - case BINOP_CONCAT: { - if (!type_eq(lhs_t, rhs_t)) - code_err(ast, "The type on the left side of this concatenation doesn't match the right side: ", type_to_str(lhs_t), - " vs. ", type_to_str(rhs_t)); - if (lhs_t->tag == ArrayType || lhs_t->tag == TextType || lhs_t->tag == SetType) - return lhs_t; - - code_err(ast, "Only array/set/text value types support concatenation, not ", type_to_str(lhs_t)); - } - case BINOP_EQ: case BINOP_NE: case BINOP_LT: case BINOP_LE: case BINOP_GT: case BINOP_GE: { - if (!can_promote(lhs_t, rhs_t) && !can_promote(rhs_t, lhs_t)) - code_err(ast, "I can't compare these two different types: ", type_to_str(lhs_t), " vs ", type_to_str(rhs_t)); + // `opt? or (x == y)` / `(x == y) or opt?` is a boolean conditional: + if ((lhs_t->tag == OptionalType && rhs_t->tag == BoolType) + || (lhs_t->tag == BoolType && rhs_t->tag == OptionalType)) { return Type(BoolType); } - case BINOP_CMP: - return Type(IntType, .bits=TYPE_IBITS32); - case BINOP_POWER: { - type_t *result = get_math_type(env, ast, lhs_t, rhs_t); - if (result->tag == NumType) + + if (lhs_t->tag == OptionalType) { + if (rhs_t->tag == OptionalType) { + type_t *result = most_complete_type(lhs_t, rhs_t); + if (result == NULL) + code_err(ast, "I could not determine the type of ", type_to_str(lhs_t), " `or` ", type_to_str(rhs_t)); return result; - return Type(NumType, .bits=TYPE_NBITS64); - } - case BINOP_MULT: case BINOP_DIVIDE: { - type_t *math_type = get_math_type(env, ast, value_type(lhs_t), value_type(rhs_t)); - if (value_type(lhs_t)->tag == NumType || value_type(rhs_t)->tag == NumType) { - if (risks_zero_or_inf(binop->lhs) && risks_zero_or_inf(binop->rhs)) - return Type(OptionalType, math_type); - else - return math_type; + } else if (rhs_t->tag == AbortType || rhs_t->tag == ReturnType) { + return Match(lhs_t, OptionalType)->type; } - return math_type; + type_t *non_opt = Match(lhs_t, OptionalType)->type; + non_opt = most_complete_type(non_opt, rhs_t); + if (non_opt != NULL) + return non_opt; + } else if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) + return lhs_t; + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } - default: { - return get_math_type(env, ast, lhs_t, rhs_t); + code_err(ast, "I couldn't figure out how to do `or` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case And: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + // `and` between optionals/bools is a boolean expression like `if opt? and opt?:` or `if x > 0 and opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); } + + // Bitwise AND: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) + return lhs_t; + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; } + code_err(ast, "I couldn't figure out how to do `and` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Xor: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + // `xor` between optionals/bools is a boolean expression like `if opt? xor opt?:` or `if x > 0 xor opt?:` + if ((lhs_t->tag == OptionalType || lhs_t->tag == BoolType) + && (rhs_t->tag == OptionalType || rhs_t->tag == BoolType)) { + return Type(BoolType); + } + + // Bitwise XOR: + if ((is_numeric_type(lhs_t) || lhs_t->tag == BoolType) + && (is_numeric_type(rhs_t) || rhs_t->tag == BoolType) + && lhs_t->tag != NumType && rhs_t->tag != NumType) { + if (can_promote(rhs_t, lhs_t)) + return lhs_t; + else if (can_promote(lhs_t, rhs_t)) + return rhs_t; + } else if (lhs_t->tag == SetType && rhs_t->tag == SetType && type_eq(lhs_t, rhs_t)) { + return lhs_t; + } + code_err(ast, "I couldn't figure out how to do `xor` between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Compare: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) + return Type(IntType, .bits=TYPE_IBITS32); + + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Equals: case NotEquals: case LessThan: case LessThanOrEquals: case GreaterThan: case GreaterThanOrEquals: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + if (can_promote(rhs_t, lhs_t) || can_promote(lhs_t, rhs_t)) + return Type(BoolType); + + code_err(ast, "I don't know how to compare ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Power: case Multiply: case Divide: case Mod: case Mod1: case Plus: case Minus: case LeftShift: + case UnsignedLeftShift: case RightShift: case UnsignedRightShift: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + if (ast->tag == LeftShift || ast->tag == UnsignedLeftShift || ast->tag == RightShift || ast->tag == UnsignedRightShift) { + if (!is_int_type(rhs_t)) + code_err(binop.rhs, "I only know how to do bit shifting by integer amounts, not ", type_to_str(rhs_t)); + } + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (ast->tag == Multiply || ast->tag == Divide) { + binding_t *b = is_numeric_type(lhs_t) ? get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, lhs_t) + : get_metamethod_binding(env, ast->tag, binop.rhs, binop.lhs, rhs_t); + if (b) return overall_t; + } else { + if (overall_t == NULL) + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; + } + if (is_numeric_type(lhs_t) && is_numeric_type(rhs_t)) + return overall_t; + code_err(ast, "I don't know how to do math operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + } + case Concat: { + binary_operands_t binop = BINARY_OPERANDS(ast); + type_t *lhs_t = get_type(env, binop.lhs); + type_t *rhs_t = get_type(env, binop.rhs); + + type_t *overall_t = (can_promote(rhs_t, lhs_t) ? lhs_t : (can_promote(lhs_t, rhs_t) ? rhs_t : NULL)); + if (overall_t == NULL) + code_err(ast, "I don't know how to do operations between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); + + binding_t *b = get_metamethod_binding(env, ast->tag, binop.lhs, binop.rhs, overall_t); + if (b) return overall_t; + + if (overall_t->tag == ArrayType || overall_t->tag == SetType || overall_t->tag == TextType) + return overall_t; + + code_err(ast, "I don't know how to do concatenation between ", type_to_str(lhs_t), " and ", type_to_str(rhs_t)); } case Reduction: { auto reduction = Match(ast, Reduction); type_t *iter_t = get_type(env, reduction->iter); - if (reduction->op == BINOP_EQ || reduction->op == BINOP_NE || reduction->op == BINOP_LT - || reduction->op == BINOP_LE || reduction->op == BINOP_GT || reduction->op == BINOP_GE) + if (reduction->op == Equals || reduction->op == NotEquals || reduction->op == LessThan + || reduction->op == LessThanOrEquals || reduction->op == GreaterThan || reduction->op == GreaterThanOrEquals) return Type(OptionalType, .type=Type(BoolType)); type_t *iterated = get_iterated_type(iter_t); @@ -1249,9 +1211,6 @@ type_t *get_type(env_t *env, ast_t *ast) return iterated->tag == OptionalType ? iterated : Type(OptionalType, .type=iterated); } - case UpdateAssign: - return Type(VoidType); - case Min: case Max: { // Unsafe! These types *should* have the same fields and this saves a lot of duplicate code: ast_t *lhs = ast->__data.Min.lhs, *rhs = ast->__data.Min.rhs; @@ -1310,8 +1269,9 @@ type_t *get_type(env_t *env, ast_t *ast) env_t *truthy_scope = env; env_t *falsey_scope = env; if (if_->condition->tag == Declare) { - type_t *condition_type = get_type(env, Match(if_->condition, Declare)->value); - const char *varname = Match(Match(if_->condition, Declare)->var, Var)->name; + auto decl = Match(if_->condition, Declare); + type_t *condition_type = decl->type ? parse_type_ast(env, decl->type) : get_type(env, decl->value); + const char *varname = Match(decl->var, Var)->name; if (streq(varname, "_")) code_err(if_->condition, "To use `if var := ...:`, you must choose a real variable name, not `_`"); @@ -1456,7 +1416,7 @@ type_t *get_type(env_t *env, ast_t *ast) PUREFUNC bool is_discardable(env_t *env, ast_t *ast) { switch (ast->tag) { - case UpdateAssign: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: + case UPDATE_CASES: case Assign: case Declare: case FunctionDef: case ConvertDef: case StructDef: case EnumDef: case LangDef: case Use: case Extend: return true; default: break; @@ -1610,13 +1570,13 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } case Not: return is_constant(env, Match(ast, Not)->value); case Negative: return is_constant(env, Match(ast, Negative)->value); - case BinaryOp: { - auto binop = Match(ast, BinaryOp); - switch (binop->op) { - case BINOP_UNKNOWN: case BINOP_POWER: case BINOP_CONCAT: case BINOP_MIN: case BINOP_MAX: case BINOP_CMP: + case BINOP_CASES: { + binary_operands_t binop = BINARY_OPERANDS(ast); + switch (ast->tag) { + case Power: case Concat: case Min: case Max: case Compare: return false; default: - return is_constant(env, binop->lhs) && is_constant(env, binop->rhs); + return is_constant(env, binop.lhs) && is_constant(env, binop.rhs); } } case Use: return true; @@ -1626,4 +1586,49 @@ PUREFUNC bool is_constant(env_t *env, ast_t *ast) } } +PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed) +{ + if (needed->tag == OptionalType && ast->tag == None) { + return true; + } + + needed = non_optional(needed); + if (needed->tag == ArrayType && ast->tag == Array) { + type_t *item_type = Match(needed, ArrayType)->item_type; + for (ast_list_t *item = Match(ast, Array)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == SetType && ast->tag == Set) { + type_t *item_type = Match(needed, SetType)->item_type; + for (ast_list_t *item = Match(ast, Set)->items; item; item = item->next) { + if (!can_compile_to_type(env, item->ast, item_type)) + return false; + } + return true; + } else if (needed->tag == TableType && ast->tag == Table) { + type_t *key_type = Match(needed, TableType)->key_type; + type_t *value_type = Match(needed, TableType)->value_type; + for (ast_list_t *entry = Match(ast, Table)->entries; entry; entry = entry->next) { + if (entry->ast->tag != TableEntry) + continue; // TODO: fix this + auto e = Match(entry->ast, TableEntry); + if (!can_compile_to_type(env, e->key, key_type) || !can_compile_to_type(env, e->value, value_type)) + return false; + } + return true; + } else if (needed->tag == PointerType) { + auto ptr = Match(needed, PointerType); + if (ast->tag == HeapAllocate) + return !ptr->is_stack && can_compile_to_type(env, Match(ast, HeapAllocate)->value, ptr->pointed); + else if (ast->tag == StackReference) + return ptr->is_stack && can_compile_to_type(env, Match(ast, StackReference)->value, ptr->pointed); + else + return can_promote(needed, get_type(env, ast)); + } else { + return can_promote(needed, get_type(env, ast)); + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/typecheck.h b/src/typecheck.h index cc5cb18..4342acc 100644 --- a/src/typecheck.h +++ b/src/typecheck.h @@ -29,5 +29,6 @@ type_t *get_method_type(env_t *env, ast_t *self, const char *name); PUREFUNC bool is_constant(env_t *env, ast_t *ast); Table_t *get_arg_bindings(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bool promotion_allowed); bool is_valid_call(env_t *env, arg_t *spec_args, arg_ast_t *call_args, bool promotion_allowed); +PUREFUNC bool can_compile_to_type(env_t *env, ast_t *ast, type_t *needed); // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/types.c b/src/types.c index d8ff377..0b9bc72 100644 --- a/src/types.c +++ b/src/types.c @@ -360,6 +360,16 @@ PUREFUNC bool can_promote(type_t *actual, type_t *needed) return true; } + // Empty literals: + if (actual->tag == ArrayType && needed->tag == ArrayType && Match(actual, ArrayType)->item_type == NULL) + return true; // [] -> [T] + if (actual->tag == SetType && needed->tag == SetType && Match(actual, SetType)->item_type == NULL) + return true; // {/} -> {T} + if (actual->tag == TableType && needed->tag == SetType && Match(actual, TableType)->key_type == NULL && Match(actual, TableType)->value_type == NULL) + return true; // {} -> {T} + if (actual->tag == TableType && needed->tag == TableType && Match(actual, TableType)->key_type == NULL && Match(actual, TableType)->value_type == NULL) + return true; // {} -> {K=V} + // Cross-promotion between tables with default values and without if (needed->tag == TableType && actual->tag == TableType) { auto actual_table = Match(actual, TableType); @@ -708,4 +718,103 @@ PUREFUNC type_t *get_iterated_type(type_t *t) } } +CONSTFUNC bool is_incomplete_type(type_t *t) +{ + if (t == NULL) return true; + switch (t->tag) { + case ReturnType: return is_incomplete_type(Match(t, ReturnType)->ret); + case OptionalType: return is_incomplete_type(Match(t, OptionalType)->type); + case ArrayType: return is_incomplete_type(Match(t, ArrayType)->item_type); + case SetType: return is_incomplete_type(Match(t, SetType)->item_type); + case TableType: { + auto table = Match(t, TableType); + return is_incomplete_type(table->key_type) || is_incomplete_type(table->value_type); + } + case FunctionType: { + auto fn = Match(t, FunctionType); + for (arg_t *arg = fn->args; arg; arg = arg->next) { + if (arg->type == NULL || is_incomplete_type(arg->type)) + return true; + } + return fn->ret ? is_incomplete_type(fn->ret) : false; + } + case ClosureType: return is_incomplete_type(Match(t, ClosureType)->fn); + case PointerType: return is_incomplete_type(Match(t, PointerType)->pointed); + default: return false; + } +} + +CONSTFUNC type_t *most_complete_type(type_t *t1, type_t *t2) +{ + if (!t1) return t2; + if (!t2) return t1; + + if (is_incomplete_type(t1) && is_incomplete_type(t2)) + return NULL; + else if (!is_incomplete_type(t1) && !is_incomplete_type(t2) && type_eq(t1, t2)) + return t1; + + if (t1->tag != t2->tag) + return NULL; + + switch (t1->tag) { + case ReturnType: { + type_t *ret = most_complete_type(Match(t1, ReturnType)->ret, Match(t1, ReturnType)->ret); + return ret ? Type(ReturnType, ret) : NULL; + } + case OptionalType: { + type_t *opt = most_complete_type(Match(t1, OptionalType)->type, Match(t2, OptionalType)->type); + return opt ? Type(OptionalType, opt) : NULL; + } + case ArrayType: { + type_t *item = most_complete_type(Match(t1, ArrayType)->item_type, Match(t2, ArrayType)->item_type); + return item ? Type(ArrayType, item) : NULL; + } + case SetType: { + type_t *item = most_complete_type(Match(t1, SetType)->item_type, Match(t2, SetType)->item_type); + return item ? Type(SetType, item) : NULL; + } + case TableType: { + auto table1 = Match(t1, TableType); + auto table2 = Match(t2, TableType); + type_t *key = most_complete_type(table1->key_type, table2->key_type); + type_t *value = most_complete_type(table1->value_type, table2->value_type); + return (key && value) ? Type(TableType, key, value) : NULL; + } + case FunctionType: { + auto fn1 = Match(t1, FunctionType); + auto fn2 = Match(t2, FunctionType); + arg_t *args = NULL; + for (arg_t *arg1 = fn1->args, *arg2 = fn2->args; arg1 || arg2; arg1 = arg1->next, arg2 = arg2->next) { + if (!arg1 || !arg2) + return NULL; + + type_t *arg_type = most_complete_type(arg1->type, arg2->type); + if (!arg_type) return NULL; + args = new(arg_t, .type=arg_type, .next=args); + } + REVERSE_LIST(args); + type_t *ret = most_complete_type(fn1->ret, fn2->ret); + return ret ? Type(FunctionType, .args=args, .ret=ret) : NULL; + } + case ClosureType: { + type_t *fn = most_complete_type(Match(t1, ClosureType)->fn, Match(t1, ClosureType)->fn); + return fn ? Type(ClosureType, fn) : NULL; + } + case PointerType: { + auto ptr1 = Match(t1, PointerType); + auto ptr2 = Match(t2, PointerType); + if (ptr1->is_stack != ptr2->is_stack) + return NULL; + type_t *pointed = most_complete_type(ptr1->pointed, ptr2->pointed); + return pointed ? Type(PointerType, .is_stack=ptr1->is_stack, .pointed=pointed) : NULL; + } + default: { + if (is_incomplete_type(t1) || is_incomplete_type(t2)) + return NULL; + return type_eq(t1, t2) ? t1 : NULL; + } + } +} + // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 diff --git a/src/types.h b/src/types.h index a5b2ad0..5348858 100644 --- a/src/types.h +++ b/src/types.h @@ -147,6 +147,8 @@ PUREFUNC const char *enum_single_value_tag(type_t *enum_type, type_t *t); PUREFUNC bool is_int_type(type_t *t); PUREFUNC bool is_numeric_type(type_t *t); PUREFUNC bool is_packed_data(type_t *t); +CONSTFUNC bool is_incomplete_type(type_t *t); +CONSTFUNC type_t *most_complete_type(type_t *t1, type_t *t2); PUREFUNC size_t type_size(type_t *t); PUREFUNC size_t type_align(type_t *t); PUREFUNC size_t unpadded_struct_size(type_t *t); diff --git a/test/arrays.tm b/test/arrays.tm index 66819a7..5816c35 100644 --- a/test/arrays.tm +++ b/test/arrays.tm @@ -1,7 +1,7 @@ func main(): do: - >> [:Num32] - = [:Num32] + >> nums : [Num32] = [] + = [] do: >> arr := [10, 20, 30] @@ -104,7 +104,7 @@ func main(): >> heap := @[(i * 1337) mod 37 for i in 10] >> heap:heapify() >> heap - sorted := @[:Int] + sorted : @[Int] = @[] repeat: sorted:insert(heap:heap_pop() or stop) >> sorted == sorted:sorted() @@ -112,7 +112,7 @@ func main(): for i in 10: heap:heap_push((i*13337) mod 37) >> heap - sorted = @[:Int] + sorted = @[] repeat: sorted:insert(heap:heap_pop() or stop) >> sorted == sorted:sorted() @@ -181,6 +181,6 @@ func main(): = &[10, 30, 40] >> nums:clear() >> nums - = &[:Int] + = &[] >> nums:pop() = none:Int diff --git a/test/defer.tm b/test/defer.tm index 911ed67..6657bdc 100644 --- a/test/defer.tm +++ b/test/defer.tm @@ -1,6 +1,6 @@ func main(): x := 123 - nums := @[:Int] + nums : @[Int] = @[] do: defer: nums:insert(x) diff --git a/test/for.tm b/test/for.tm index a67e9d5..e4967e8 100644 --- a/test/for.tm +++ b/test/for.tm @@ -32,18 +32,18 @@ func table_key_str(t:{Text=Text} -> Text): func main(): >> all_nums([10,20,30]) = "10,20,30," - >> all_nums([:Int]) + >> all_nums([]) = "EMPTY" >> labeled_nums([10,20,30]) = "1:10,2:20,3:30," - >> labeled_nums([:Int]) + >> labeled_nums([]) = "EMPTY" >> t := {"key1"="value1", "key2"="value2"} >> table_str(t) = "key1:value1,key2:value2," - >> table_str({:Text=Text}) + >> table_str({}) = "EMPTY" >> table_key_str(t) diff --git a/test/import.tm b/test/import.tm index a7b198b..960bfcb 100644 --- a/test/import.tm +++ b/test/import.tm @@ -8,11 +8,11 @@ func returns_imported_type(->ImportedType): return get_value() # Imported from ./use_import.tm func main(): - >> [:vectors.Vec2] + >> empty : [vectors.Vec2] = [] >> returns_vec() = Vec2(x=1, y=2) - >> [:ImportedType] + >> imported : [ImportedType] = [] >> returns_imported_type() = ImportedType("Hello") diff --git a/test/iterators.tm b/test/iterators.tm index 4a85e6f..0b6c2a8 100644 --- a/test/iterators.tm +++ b/test/iterators.tm @@ -25,7 +25,7 @@ func main(): = ["AB", "BC", "CD"] do: - result := @[:Text] + result : @[Text] = @[] for foo in pairwise(values): result:insert("$(foo.x)$(foo.y)") >> result[] diff --git a/test/optionals.tm b/test/optionals.tm index 0281744..a1b0dcd 100644 --- a/test/optionals.tm +++ b/test/optionals.tm @@ -247,8 +247,8 @@ func main(): = yes >> {none:Int, none:Int} = {none:Int} - >> {:Int? none, none} - = {none:Int} + >> nones : {Int?} = {none, none} + = {none} >> [5?, none:Int, none:Int, 6?]:sorted() = [none:Int, none:Int, 5, 6] diff --git a/test/paths.tm b/test/paths.tm index f224b0f..4f05348 100644 --- a/test/paths.tm +++ b/test/paths.tm @@ -25,7 +25,7 @@ func main(): >> tmpfile:read() = "Hello world!"? >> tmpfile:read_bytes() - = [:Byte, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x77, 0x6F, 0x72, 0x6C, 0x64, 0x21]? + = [0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x77, 0x6F, 0x72, 0x6C, 0x64, 0x21]? >> tmpdir:files():has(tmpfile) = yes diff --git a/test/reductions.tm b/test/reductions.tm index 840a3b3..4f33bb2 100644 --- a/test/reductions.tm +++ b/test/reductions.tm @@ -4,13 +4,14 @@ func main(): >> (+: [10, 20, 30]) = 60? - >> (+: [:Int]) + >> empty_ints : [Int] = [] + >> (+: empty_ints) = none : Int >> (+: [10, 20, 30]) or 0 = 60 - >> (+: [:Int]) or 0 + >> (+: empty_ints) or 0 = 0 >> (_max_: [3, 5, 2, 1, 4]) @@ -36,7 +37,7 @@ func main(): >> (<=: [1, 2, 2, 3, 4])! = yes - >> (<=: [:Int]) + >> (<=: empty_ints) = none : Bool >> (<=: [5, 4, 3, 2, 1])! diff --git a/test/tables.tm b/test/tables.tm index 140fd1c..144b93e 100644 --- a/test/tables.tm +++ b/test/tables.tm @@ -96,11 +96,11 @@ func main(): >> {1=1, 2=2} <> {2=2, 1=1} = Int32(0) - >> [{:Int=Int}, {0=0}, {99=99}, {1=1, 2=2, 3=3}, {1=1, 99=99, 3=3}, {1=1, 2=-99, 3=3}, {1=1, 99=-99, 3=4}]:sorted() - = [{:Int=Int}, {0=0}, {1=1, 2=-99, 3=3}, {1=1, 2=2, 3=3}, {1=1, 99=99, 3=3}, {1=1, 99=-99, 3=4}, {99=99}] + >> ints : [{Int=Int}] = [{}, {0=0}, {99=99}, {1=1, 2=2, 3=3}, {1=1, 99=99, 3=3}, {1=1, 2=-99, 3=3}, {1=1, 99=-99, 3=4}]:sorted() + = [{}, {0=0}, {1=1, 2=-99, 3=3}, {1=1, 2=2, 3=3}, {1=1, 99=99, 3=3}, {1=1, 99=-99, 3=4}, {99=99}] - >> [{:Int}, {1}, {2}, {99}, {0, 3}, {1, 2}, {99}]:sorted() - = [{:Int}, {0, 3}, {1}, {1, 2}, {2}, {99}, {99}] + >> other_ints : [{Int}] = [{}, {1}, {2}, {99}, {0, 3}, {1, 2}, {99}]:sorted() + = [{}, {0, 3}, {1}, {1, 2}, {2}, {99}, {99}] do: # Default values: diff --git a/test/text.tm b/test/text.tm index fe295f9..ae91050 100644 --- a/test/text.tm +++ b/test/text.tm @@ -53,10 +53,10 @@ func main(): >> amelie:split() = ["A", "m", "é", "l", "i", "e"] >> amelie:utf32_codepoints() - = [:Int32, 65, 109, 233, 108, 105, 101] + = [65, 109, 233, 108, 105, 101] >> amelie:bytes() - = [:Byte, 0x41, 0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65] - >> Text.from_bytes([:Byte 0x41, 0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65])! + = [0x41, 0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65] + >> Text.from_bytes([0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65])! = "Amélie" >> Text.from_bytes([Byte(0xFF)]) = none:Text @@ -65,9 +65,9 @@ func main(): >> amelie2:split() = ["A", "m", "é", "l", "i", "e"] >> amelie2:utf32_codepoints() - = [:Int32, 65, 109, 233, 108, 105, 101] + = [65, 109, 233, 108, 105, 101] >> amelie2:bytes() - = [:Byte, 0x41, 0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65] + = [0x41, 0x6D, 0xC3, 0xA9, 0x6C, 0x69, 0x65] >> amelie:codepoint_names() = ["LATIN CAPITAL LETTER A", "LATIN SMALL LETTER M", "LATIN SMALL LETTER E WITH ACUTE", "LATIN SMALL LETTER L", "LATIN SMALL LETTER I", "LATIN SMALL LETTER E"] @@ -136,7 +136,7 @@ func main(): >> "one$(\r\n)two$(\r\n)three$(\r\n)":lines() = ["one", "two", "three"] >> "":lines() - = [:Text] + = [] !! Test splitting and joining text: >> "one,, two,three":split(",") @@ -171,11 +171,11 @@ func main(): >> "+":join(["one"]) = "one" - >> "+":join([:Text]) + >> "+":join([]) = "" >> "":split() - = [:Text] + = [] !! Test text slicing: >> "abcdef":slice() @@ -196,7 +196,7 @@ func main(): >> house:codepoint_names() = ["CJK Unified Ideographs-5BB6"] >> house:utf32_codepoints() - = [:Int32, 23478] + = [23478] >> "🐧":codepoint_names() = ["PENGUIN"]