aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruce Hill <bruce@bruce-hill.com>2024-08-11 15:04:22 -0400
committerBruce Hill <bruce@bruce-hill.com>2024-08-11 15:04:22 -0400
commitd2f4d07585d1e915365f3aaea6fc696e00a9e26d (patch)
tree12cf2a0f978835bc55db572df8e0d114c9b89494
parent2ecb5fe885042ca6c25ee0a3e3da070ddec9e07e (diff)
Support channels with maximum size
-rw-r--r--ast.c2
-rw-r--r--ast.h1
-rw-r--r--builtins/channel.c20
-rw-r--r--builtins/channel.h2
-rw-r--r--builtins/datatypes.h1
-rw-r--r--compile.c12
-rw-r--r--parse.c11
-rw-r--r--test/threads.tm20
8 files changed, 59 insertions, 10 deletions
diff --git a/ast.c b/ast.c
index 000c5f18..d3b5b3c4 100644
--- a/ast.c
+++ b/ast.c
@@ -121,7 +121,7 @@ CORD ast_to_xml(ast_t *ast)
optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type),
ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback))
T(TableEntry, "<TableEntry>%r%r</TableEntry>", ast_to_xml(data.key), ast_to_xml(data.value))
- T(Channel, "<Channel>%r</Channel>", type_ast_to_xml(data.item_type))
+ T(Channel, "<Channel>%r%r</Channel>", type_ast_to_xml(data.item_type), optional_tagged("max-size", data.max_size))
T(Comprehension, "<Comprehension>%r%r%r%r%r</Comprehension>", optional_tagged("expr", data.expr),
ast_list_to_xml(data.vars), optional_tagged("iter", data.iter),
optional_tagged("filter", data.filter))
diff --git a/ast.h b/ast.h
index 15c41d81..b7f715b8 100644
--- a/ast.h
+++ b/ast.h
@@ -184,6 +184,7 @@ struct ast_s {
} Array;
struct {
type_ast_t *item_type;
+ ast_t *max_size;
} Channel;
struct {
type_ast_t *item_type;
diff --git a/builtins/channel.c b/builtins/channel.c
index 0b5f7411..cfb398b0 100644
--- a/builtins/channel.c
+++ b/builtins/channel.c
@@ -17,18 +17,23 @@
#include "types.h"
#include "util.h"
-public channel_t *Channel$new(void)
+public channel_t *Channel$new(int64_t max_size)
{
+ if (max_size <= 0)
+ fail("Cannot create a channel with a size less than one: %ld", max_size);
channel_t *channel = new(channel_t);
channel->items = (array_t){};
channel->mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER;
channel->cond = (pthread_cond_t)PTHREAD_COND_INITIALIZER;
+ channel->max_size = max_size;
return channel;
}
public void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size)
{
(void)pthread_mutex_lock(&channel->mutex);
+ while (channel->items.length >= channel->max_size)
+ pthread_cond_wait(&channel->cond, &channel->mutex);
Array$insert(&channel->items, item, 0, padded_item_size);
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
@@ -36,8 +41,17 @@ public void Channel$push(channel_t *channel, const void *item, int64_t padded_it
public void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size)
{
+ if (to_push.length == 0) return;
(void)pthread_mutex_lock(&channel->mutex);
- Array$insert_all(&channel->items, to_push, 0, padded_item_size);
+ if (channel->items.length + to_push.length >= channel->max_size) {
+ for (int64_t i = 0; i < to_push.length; i++) {
+ while (channel->items.length >= channel->max_size)
+ pthread_cond_wait(&channel->cond, &channel->mutex);
+ Array$insert(&channel->items, to_push.data + i*to_push.stride, 0, padded_item_size);
+ }
+ } else {
+ Array$insert_all(&channel->items, to_push, 0, padded_item_size);
+ }
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
}
@@ -50,6 +64,7 @@ public void Channel$pop(channel_t *channel, void *out, int64_t item_size, int64_
memcpy(out, channel->items.data, item_size);
Array$remove(&channel->items, 1, 1, padded_item_size);
(void)pthread_mutex_unlock(&channel->mutex);
+ (void)pthread_cond_signal(&channel->cond);
}
public array_t Channel$view(channel_t *channel)
@@ -66,6 +81,7 @@ public void Channel$clear(channel_t *channel)
(void)pthread_mutex_lock(&channel->mutex);
Array$clear(&channel->items);
(void)pthread_mutex_unlock(&channel->mutex);
+ (void)pthread_cond_signal(&channel->cond);
}
public uint32_t Channel$hash(const channel_t **channel, const TypeInfo *type)
diff --git a/builtins/channel.h b/builtins/channel.h
index 2ce71833..6f628726 100644
--- a/builtins/channel.h
+++ b/builtins/channel.h
@@ -9,7 +9,7 @@
#include "types.h"
#include "util.h"
-channel_t *Channel$new(void);
+channel_t *Channel$new(int64_t max_size);
void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size);
#define Channel$push_value(channel, item, padded_item_size) ({ __typeof(item) _item = item; Channel$push(channel, &_item, padded_item_size); })
void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size);
diff --git a/builtins/datatypes.h b/builtins/datatypes.h
index 41a67a5e..a9f28dc1 100644
--- a/builtins/datatypes.h
+++ b/builtins/datatypes.h
@@ -62,6 +62,7 @@ typedef struct {
array_t items;
pthread_mutex_t mutex;
pthread_cond_t cond;
+ int64_t max_size;
} channel_t;
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0
diff --git a/compile.c b/compile.c
index 8e8f7c6e..df830a9b 100644
--- a/compile.c
+++ b/compile.c
@@ -1747,10 +1747,18 @@ CORD compile(env_t *env, ast_t *ast)
}
}
case Channel: {
- type_t *item_t = parse_type_ast(env, Match(ast, Channel)->item_type);
+ auto chan = Match(ast, Channel);
+ type_t *item_t = parse_type_ast(env, chan->item_type);
if (!can_send_over_channel(item_t))
code_err(ast, "This item type can't be sent over a channel because it contains reference to memory that may not be thread-safe.");
- return "Channel$new()";
+ if (chan->max_size) {
+ CORD max_size = compile(env, chan->max_size);
+ if (!promote(env, &max_size, get_type(env, chan->max_size), INT_TYPE))
+ code_err(chan->max_size, "This value must be an integer, not %T", get_type(env, chan->max_size));
+ return CORD_all("Channel$new(", max_size, ")");
+ } else {
+ return "Channel$new(INT64_MAX)";
+ }
}
case Table: {
auto table = Match(ast, Table);
diff --git a/parse.c b/parse.c
index ac24297e..f58a0928 100644
--- a/parse.c
+++ b/parse.c
@@ -680,8 +680,15 @@ PARSER(parse_channel) {
const char *start = pos;
if (!match(&pos, "|:")) return NULL;
type_ast_t *item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a type for this channel");
+ ast_t *max_size = NULL;
+ if (match(&pos, ";")) {
+ whitespace(&pos);
+ const char *attr_start = pos;
+ if (match_word(&pos, "max_size") && match(&pos, "="))
+ max_size = expect(ctx, attr_start, &pos, parse_int, "I expected a maximum size for this channel");
+ }
expect_closing(ctx, &pos, "|", "I wasn't able to parse the rest of this channel");
- return NewAST(ctx->file, start, pos, Channel, .item_type=item_type);
+ return NewAST(ctx->file, start, pos, Channel, .item_type=item_type, .max_size=max_size);
}
PARSER(parse_array) {
@@ -770,7 +777,7 @@ PARSER(parse_table) {
for (;;) {
whitespace(&pos);
const char *attr_start = pos;
- if (match(&pos, "fallback")) {
+ if (match_word(&pos, "fallback")) {
whitespace(&pos);
if (!match(&pos, "=")) parser_err(ctx, attr_start, pos, "I expected an '=' after 'fallback'");
if (fallback)
diff --git a/test/threads.tm b/test/threads.tm
index 33685103..7024460c 100644
--- a/test/threads.tm
+++ b/test/threads.tm
@@ -1,8 +1,8 @@
enum Job(Increment(x:Int), Decrement(x:Int))
func main():
- jobs := |:Job|
- results := |:Int|
+ jobs := |:Job; max_size=1|
+ results := |:Int; max_size=2|
>> thread := Thread.new(func():
//! In another thread!
while yes:
@@ -14,6 +14,11 @@ func main():
>> jobs:push(Increment(5))
>> jobs:push(Decrement(100))
+ >> jobs:push(Decrement(100))
+ >> jobs:push(Decrement(100))
+ >> jobs:push(Decrement(100))
+ >> jobs:push(Decrement(100))
+ >> jobs:push(Decrement(100))
>> results:pop()
= 6
@@ -23,6 +28,17 @@ func main():
= 99
>> results:pop()
+ = 99
+ >> results:pop()
+ = 99
+ >> results:pop()
+ = 99
+ >> results:pop()
+ = 99
+ >> results:pop()
+ = 99
+
+ >> results:pop()
= 1001
//! Canceling...