diff options
| -rw-r--r-- | ast.c | 2 | ||||
| -rw-r--r-- | ast.h | 1 | ||||
| -rw-r--r-- | builtins/channel.c | 20 | ||||
| -rw-r--r-- | builtins/channel.h | 2 | ||||
| -rw-r--r-- | builtins/datatypes.h | 1 | ||||
| -rw-r--r-- | compile.c | 12 | ||||
| -rw-r--r-- | parse.c | 11 | ||||
| -rw-r--r-- | test/threads.tm | 20 |
8 files changed, 59 insertions, 10 deletions
@@ -121,7 +121,7 @@ CORD ast_to_xml(ast_t *ast) optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type), ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback)) T(TableEntry, "<TableEntry>%r%r</TableEntry>", ast_to_xml(data.key), ast_to_xml(data.value)) - T(Channel, "<Channel>%r</Channel>", type_ast_to_xml(data.item_type)) + T(Channel, "<Channel>%r%r</Channel>", type_ast_to_xml(data.item_type), optional_tagged("max-size", data.max_size)) T(Comprehension, "<Comprehension>%r%r%r%r%r</Comprehension>", optional_tagged("expr", data.expr), ast_list_to_xml(data.vars), optional_tagged("iter", data.iter), optional_tagged("filter", data.filter)) @@ -184,6 +184,7 @@ struct ast_s { } Array; struct { type_ast_t *item_type; + ast_t *max_size; } Channel; struct { type_ast_t *item_type; diff --git a/builtins/channel.c b/builtins/channel.c index 0b5f7411..cfb398b0 100644 --- a/builtins/channel.c +++ b/builtins/channel.c @@ -17,18 +17,23 @@ #include "types.h" #include "util.h" -public channel_t *Channel$new(void) +public channel_t *Channel$new(int64_t max_size) { + if (max_size <= 0) + fail("Cannot create a channel with a size less than one: %ld", max_size); channel_t *channel = new(channel_t); channel->items = (array_t){}; channel->mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER; channel->cond = (pthread_cond_t)PTHREAD_COND_INITIALIZER; + channel->max_size = max_size; return channel; } public void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size) { (void)pthread_mutex_lock(&channel->mutex); + while (channel->items.length >= channel->max_size) + pthread_cond_wait(&channel->cond, &channel->mutex); Array$insert(&channel->items, item, 0, padded_item_size); (void)pthread_mutex_unlock(&channel->mutex); (void)pthread_cond_signal(&channel->cond); @@ -36,8 +41,17 @@ public void Channel$push(channel_t *channel, const void *item, int64_t padded_it public void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size) { + if (to_push.length == 0) return; (void)pthread_mutex_lock(&channel->mutex); - Array$insert_all(&channel->items, to_push, 0, padded_item_size); + if (channel->items.length + to_push.length >= channel->max_size) { + for (int64_t i = 0; i < to_push.length; i++) { + while (channel->items.length >= channel->max_size) + pthread_cond_wait(&channel->cond, &channel->mutex); + Array$insert(&channel->items, to_push.data + i*to_push.stride, 0, padded_item_size); + } + } else { + Array$insert_all(&channel->items, to_push, 0, padded_item_size); + } (void)pthread_mutex_unlock(&channel->mutex); (void)pthread_cond_signal(&channel->cond); } @@ -50,6 +64,7 @@ public void Channel$pop(channel_t *channel, void *out, int64_t item_size, int64_ memcpy(out, channel->items.data, item_size); Array$remove(&channel->items, 1, 1, padded_item_size); (void)pthread_mutex_unlock(&channel->mutex); + (void)pthread_cond_signal(&channel->cond); } public array_t Channel$view(channel_t *channel) @@ -66,6 +81,7 @@ public void Channel$clear(channel_t *channel) (void)pthread_mutex_lock(&channel->mutex); Array$clear(&channel->items); (void)pthread_mutex_unlock(&channel->mutex); + (void)pthread_cond_signal(&channel->cond); } public uint32_t Channel$hash(const channel_t **channel, const TypeInfo *type) diff --git a/builtins/channel.h b/builtins/channel.h index 2ce71833..6f628726 100644 --- a/builtins/channel.h +++ b/builtins/channel.h @@ -9,7 +9,7 @@ #include "types.h" #include "util.h" -channel_t *Channel$new(void); +channel_t *Channel$new(int64_t max_size); void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size); #define Channel$push_value(channel, item, padded_item_size) ({ __typeof(item) _item = item; Channel$push(channel, &_item, padded_item_size); }) void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size); diff --git a/builtins/datatypes.h b/builtins/datatypes.h index 41a67a5e..a9f28dc1 100644 --- a/builtins/datatypes.h +++ b/builtins/datatypes.h @@ -62,6 +62,7 @@ typedef struct { array_t items; pthread_mutex_t mutex; pthread_cond_t cond; + int64_t max_size; } channel_t; // vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0 @@ -1747,10 +1747,18 @@ CORD compile(env_t *env, ast_t *ast) } } case Channel: { - type_t *item_t = parse_type_ast(env, Match(ast, Channel)->item_type); + auto chan = Match(ast, Channel); + type_t *item_t = parse_type_ast(env, chan->item_type); if (!can_send_over_channel(item_t)) code_err(ast, "This item type can't be sent over a channel because it contains reference to memory that may not be thread-safe."); - return "Channel$new()"; + if (chan->max_size) { + CORD max_size = compile(env, chan->max_size); + if (!promote(env, &max_size, get_type(env, chan->max_size), INT_TYPE)) + code_err(chan->max_size, "This value must be an integer, not %T", get_type(env, chan->max_size)); + return CORD_all("Channel$new(", max_size, ")"); + } else { + return "Channel$new(INT64_MAX)"; + } } case Table: { auto table = Match(ast, Table); @@ -680,8 +680,15 @@ PARSER(parse_channel) { const char *start = pos; if (!match(&pos, "|:")) return NULL; type_ast_t *item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a type for this channel"); + ast_t *max_size = NULL; + if (match(&pos, ";")) { + whitespace(&pos); + const char *attr_start = pos; + if (match_word(&pos, "max_size") && match(&pos, "=")) + max_size = expect(ctx, attr_start, &pos, parse_int, "I expected a maximum size for this channel"); + } expect_closing(ctx, &pos, "|", "I wasn't able to parse the rest of this channel"); - return NewAST(ctx->file, start, pos, Channel, .item_type=item_type); + return NewAST(ctx->file, start, pos, Channel, .item_type=item_type, .max_size=max_size); } PARSER(parse_array) { @@ -770,7 +777,7 @@ PARSER(parse_table) { for (;;) { whitespace(&pos); const char *attr_start = pos; - if (match(&pos, "fallback")) { + if (match_word(&pos, "fallback")) { whitespace(&pos); if (!match(&pos, "=")) parser_err(ctx, attr_start, pos, "I expected an '=' after 'fallback'"); if (fallback) diff --git a/test/threads.tm b/test/threads.tm index 33685103..7024460c 100644 --- a/test/threads.tm +++ b/test/threads.tm @@ -1,8 +1,8 @@ enum Job(Increment(x:Int), Decrement(x:Int)) func main(): - jobs := |:Job| - results := |:Int| + jobs := |:Job; max_size=1| + results := |:Int; max_size=2| >> thread := Thread.new(func(): //! In another thread! while yes: @@ -14,6 +14,11 @@ func main(): >> jobs:push(Increment(5)) >> jobs:push(Decrement(100)) + >> jobs:push(Decrement(100)) + >> jobs:push(Decrement(100)) + >> jobs:push(Decrement(100)) + >> jobs:push(Decrement(100)) + >> jobs:push(Decrement(100)) >> results:pop() = 6 @@ -23,6 +28,17 @@ func main(): = 99 >> results:pop() + = 99 + >> results:pop() + = 99 + >> results:pop() + = 99 + >> results:pop() + = 99 + >> results:pop() + = 99 + + >> results:pop() = 1001 //! Canceling... |
