Support channels with maximum size

This commit is contained in:
Bruce Hill 2024-08-11 15:04:22 -04:00
parent 2ecb5fe885
commit d2f4d07585
8 changed files with 59 additions and 10 deletions

2
ast.c
View File

@ -121,7 +121,7 @@ CORD ast_to_xml(ast_t *ast)
optional_tagged_type("key-type", data.key_type), optional_tagged_type("value-type", data.value_type),
ast_list_to_xml(data.entries), optional_tagged("fallback", data.fallback))
T(TableEntry, "<TableEntry>%r%r</TableEntry>", ast_to_xml(data.key), ast_to_xml(data.value))
T(Channel, "<Channel>%r</Channel>", type_ast_to_xml(data.item_type))
T(Channel, "<Channel>%r%r</Channel>", type_ast_to_xml(data.item_type), optional_tagged("max-size", data.max_size))
T(Comprehension, "<Comprehension>%r%r%r%r%r</Comprehension>", optional_tagged("expr", data.expr),
ast_list_to_xml(data.vars), optional_tagged("iter", data.iter),
optional_tagged("filter", data.filter))

1
ast.h
View File

@ -184,6 +184,7 @@ struct ast_s {
} Array;
struct {
type_ast_t *item_type;
ast_t *max_size;
} Channel;
struct {
type_ast_t *item_type;

View File

@ -17,18 +17,23 @@
#include "types.h"
#include "util.h"
public channel_t *Channel$new(void)
public channel_t *Channel$new(int64_t max_size)
{
if (max_size <= 0)
fail("Cannot create a channel with a size less than one: %ld", max_size);
channel_t *channel = new(channel_t);
channel->items = (array_t){};
channel->mutex = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER;
channel->cond = (pthread_cond_t)PTHREAD_COND_INITIALIZER;
channel->max_size = max_size;
return channel;
}
public void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size)
{
(void)pthread_mutex_lock(&channel->mutex);
while (channel->items.length >= channel->max_size)
pthread_cond_wait(&channel->cond, &channel->mutex);
Array$insert(&channel->items, item, 0, padded_item_size);
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
@ -36,8 +41,17 @@ public void Channel$push(channel_t *channel, const void *item, int64_t padded_it
public void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size)
{
if (to_push.length == 0) return;
(void)pthread_mutex_lock(&channel->mutex);
Array$insert_all(&channel->items, to_push, 0, padded_item_size);
if (channel->items.length + to_push.length >= channel->max_size) {
for (int64_t i = 0; i < to_push.length; i++) {
while (channel->items.length >= channel->max_size)
pthread_cond_wait(&channel->cond, &channel->mutex);
Array$insert(&channel->items, to_push.data + i*to_push.stride, 0, padded_item_size);
}
} else {
Array$insert_all(&channel->items, to_push, 0, padded_item_size);
}
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
}
@ -50,6 +64,7 @@ public void Channel$pop(channel_t *channel, void *out, int64_t item_size, int64_
memcpy(out, channel->items.data, item_size);
Array$remove(&channel->items, 1, 1, padded_item_size);
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
}
public array_t Channel$view(channel_t *channel)
@ -66,6 +81,7 @@ public void Channel$clear(channel_t *channel)
(void)pthread_mutex_lock(&channel->mutex);
Array$clear(&channel->items);
(void)pthread_mutex_unlock(&channel->mutex);
(void)pthread_cond_signal(&channel->cond);
}
public uint32_t Channel$hash(const channel_t **channel, const TypeInfo *type)

View File

@ -9,7 +9,7 @@
#include "types.h"
#include "util.h"
channel_t *Channel$new(void);
channel_t *Channel$new(int64_t max_size);
void Channel$push(channel_t *channel, const void *item, int64_t padded_item_size);
#define Channel$push_value(channel, item, padded_item_size) ({ __typeof(item) _item = item; Channel$push(channel, &_item, padded_item_size); })
void Channel$push_all(channel_t *channel, array_t to_push, int64_t padded_item_size);

View File

@ -62,6 +62,7 @@ typedef struct {
array_t items;
pthread_mutex_t mutex;
pthread_cond_t cond;
int64_t max_size;
} channel_t;
// vim: ts=4 sw=0 et cino=L2,l1,(0,W4,m1,\:0

View File

@ -1747,10 +1747,18 @@ CORD compile(env_t *env, ast_t *ast)
}
}
case Channel: {
type_t *item_t = parse_type_ast(env, Match(ast, Channel)->item_type);
auto chan = Match(ast, Channel);
type_t *item_t = parse_type_ast(env, chan->item_type);
if (!can_send_over_channel(item_t))
code_err(ast, "This item type can't be sent over a channel because it contains reference to memory that may not be thread-safe.");
return "Channel$new()";
if (chan->max_size) {
CORD max_size = compile(env, chan->max_size);
if (!promote(env, &max_size, get_type(env, chan->max_size), INT_TYPE))
code_err(chan->max_size, "This value must be an integer, not %T", get_type(env, chan->max_size));
return CORD_all("Channel$new(", max_size, ")");
} else {
return "Channel$new(INT64_MAX)";
}
}
case Table: {
auto table = Match(ast, Table);

11
parse.c
View File

@ -680,8 +680,15 @@ PARSER(parse_channel) {
const char *start = pos;
if (!match(&pos, "|:")) return NULL;
type_ast_t *item_type = expect(ctx, pos-1, &pos, parse_type, "I couldn't parse a type for this channel");
ast_t *max_size = NULL;
if (match(&pos, ";")) {
whitespace(&pos);
const char *attr_start = pos;
if (match_word(&pos, "max_size") && match(&pos, "="))
max_size = expect(ctx, attr_start, &pos, parse_int, "I expected a maximum size for this channel");
}
expect_closing(ctx, &pos, "|", "I wasn't able to parse the rest of this channel");
return NewAST(ctx->file, start, pos, Channel, .item_type=item_type);
return NewAST(ctx->file, start, pos, Channel, .item_type=item_type, .max_size=max_size);
}
PARSER(parse_array) {
@ -770,7 +777,7 @@ PARSER(parse_table) {
for (;;) {
whitespace(&pos);
const char *attr_start = pos;
if (match(&pos, "fallback")) {
if (match_word(&pos, "fallback")) {
whitespace(&pos);
if (!match(&pos, "=")) parser_err(ctx, attr_start, pos, "I expected an '=' after 'fallback'");
if (fallback)

View File

@ -1,8 +1,8 @@
enum Job(Increment(x:Int), Decrement(x:Int))
func main():
jobs := |:Job|
results := |:Int|
jobs := |:Job; max_size=1|
results := |:Int; max_size=2|
>> thread := Thread.new(func():
//! In another thread!
while yes:
@ -14,6 +14,11 @@ func main():
>> jobs:push(Increment(5))
>> jobs:push(Decrement(100))
>> jobs:push(Decrement(100))
>> jobs:push(Decrement(100))
>> jobs:push(Decrement(100))
>> jobs:push(Decrement(100))
>> jobs:push(Decrement(100))
>> results:pop()
= 6
@ -22,6 +27,17 @@ func main():
>> results:pop()
= 99
>> results:pop()
= 99
>> results:pop()
= 99
>> results:pop()
= 99
>> results:pop()
= 99
>> results:pop()
= 99
>> results:pop()
= 1001