Building Channels in C++: Part 5 - Writing a select statement

So far in this series, we've created a channel with blocking and non-blocking methods for sending and receiving messages. See Building Channels in C++: Part 4 - Going non-blocking for the last post in the series. So far, our channels (with the reader excersises) look like the following (it's getting big enough this will probably be the last time I start a post with the channel code):
template<typename E>
struct Error {
E err;
};
template<typename E>
Error<E> error(const E &err) { return Error<E>{err}; }
template<typename S, typename E>
struct Result {
std::variant<S, Error<E>> val;
Result(const S &s) : val(s) {}
Result(const Error<E> &e) : val(e) {}
bool is_success() const noexcept { return std::holds_alternative<S>(val); }
bool is_error() const noexcept { return std::holds_alternative<Error<E>>(val); }
operator bool() const noexcept { return is_success(); }
S get_value() const noexcept { return std::get<S>(val); }
E get_error() const noexcept { return std::get<Error<E>>(val).err; }
bool copy_value_if_present(S& out) const {
if (is_success()) {
out = std::get<S>(val);
return true;
}
return false;
}
};
template<typename E>
struct Result<void, E> {
std::variant<std::monostate, Error<E>> val = {};
Result() = default;
Result(const Error<E> &e) : val(e) {}
bool is_success() const noexcept { return std::holds_alternative<std::monostate>(val); }
bool is_error() const noexcept { return std::holds_alternative<Error<E>>(val); }
operator bool() const noexcept { return is_success(); }
E get_error() const noexcept { return std::get<Error<E>>(val).err; }
};
enum class ChannelBlockError {
CHANNEL_CLOSED
};
enum class ChannelReceiveError {
CHANNEL_CLOSED,
NO_MESSAGE_READY,
COULD_NOT_LOCK,
};
enum class ChannelSendError {
CHANNEL_CLOSED,
QUEUE_FULL,
COULD_NOT_LOCK,
};
template<typename T, size_t Capacity>
class Channel {
RingBuffer<T, Capacity> buffer = {};
std::mutex mux;
std::condition_variable read_signal;
std::condition_variable write_signal;
bool closed = false;
public:
Result<void, ChannelBlockError> send(const T &elem) {
auto lock = std::unique_lock{mux};
if (!closed && buffer.is_full()) {
write_signal.wait(lock);
}
if (closed) {
return error(ChannelBlockError::CHANNEL_CLOSED);
}
// notify we wrote a message
buffer.push_back(elem);
read_signal.notify_one();
return {};
}
Result<void, ChannelSendError> try_send(const T &elem) {
auto lock = std::unique_lock{mux, std::try_to_lock};
if (!lock.owns_lock()) {
return error(ChannelSendError::COULD_NOT_LOCK);
}
if (closed) {
return error(ChannelSendError::CHANNEL_CLOSED);
}
if (buffer.is_full()) {
return error(ChannelSendError::QUEUE_FULL);
}
buffer.push_back(elem);
return {};
}
Result<T, ChannelReceiveError> try_receive() {
auto lock = std::unique_lock{mux, std::try_to_lock};
if (!lock.owns_lock()) {
return error(ChannelReceiveError::COULD_NOT_LOCK);
}
auto v = buffer.pop_front();
if (v) {
write_signal.notify_one();
return *v;
}
if (closed) {
return error(ChannelReceiveError::CHANNEL_CLOSED);
}
return error(ChannelReceiveError::NO_MESSAGE_READY);
}
Result<T, ChannelBlockError> receive() {
auto lock = std::unique_lock{mux};
if (!closed && buffer.is_empty()) {
// wait for a new message
read_signal.wait(lock);
}
// notify we read a message
auto v = buffer.pop_front();
if (v) {
write_signal.notify_one();
return *v;
}
return error(ChannelBlockError::CHANNEL_CLOSED);
}
void close() {
auto lock = std::unique_lock{mux};
closed = true;
read_signal.notify_all();
write_signal.notify_all();
}
};
Yeah, we've been busy.
So, we have our channels, and they work. What's next?
Select statements.
Intro to select
Select statements in Go allow you to specify multiple channel operations (e.g. send, receive, etc.) and it will attempt all operations until one succeeds. Once an operation succeeds, it will then call a block of code associated to that operation (kind-of like a callback). Here's an example:
package main
import "fmt"
func main() {
c := make(chan int)
quit := make(chan int)
go func() {
x := 1
for {
// This is where we're focused
select {
// try send
case c <- x:
x = x * (x + 1)
// try receive
case y := <- quit:
fmt.Println("Ending with", y)
return
}
}
}()
for i := 0; i < 4; i++ {
fmt.Println(<-c);
}
quit <- 13
}
The select statement above will try sending to a channel and receiving from a channel. It will keep trying until either a send or a receive works. Once one works, it runs the associated line of code (there is no fallthrough, so there's no break). Selects are often wrapped in a loop which will keep retrying the select until told to exit - perfect for workers which process multiple work queues.
Our goal is to build a similar C++ feature. Since we can't make an actual statement, we'll have to make due with functions, structs, and templates.
Creating a Send Case
Before we make our select statement, we're going to start with the case statements first. We have two case statements we're going to make: send and receive.
The send statement will simply hold a channel and then continually try to send to it. We'll also have a callback to call if we successfully sent our message. Here is our case:
template<typename T, size_t Capacity, typename Callback>
struct SendCase {
Channel<T, Capacity>& channel;
const T& message;
Callback callback;
Result<void, ChannelSendError> attempt() {
auto res = channel.try_send(message);
if (res.is_success()) {
Result<void, ChannelBlockError> pass{};
callback(pass);
}
else if (res.get_error() == ChannelSendError::CHANNEL_CLOSED) {
Result<void, ChannelBlockError> pass{error(ChannelBlockError::CHANNEL_CLOSED)};
callback(pass);
}
return res;
}
};
template<typename T, size_t Capacity, typename Callback>
SendCase<T, Capacity, Callback> send_case(Channel<T, Capacity>& chan, const T& message, const Callback& callback) {
return SendCase<T, Capacity, Callback>{chan, message, callback};
}
Super simple!
I can hear some people say "wait, but there's no base class! How are we going to get polymorphism!" To which I say, don't worry. There is a way to do it, and we'll get there. But first, let me cook.
All we need is just a simple templatized struct with a single method. All the method does is call try_send
, and if it works (or the channel is closed) it calls our callback. It then just returns the result. One thing to note is we're narrowing the error space when we call our callback. The callback only needs to worry about handling the channel closed case, not all of the channel receive errors.
So, let's move onto the receive statement - which will be surprisingly similar.
Receiving Case
Here's our receive case:
template<typename T, size_t Capacity, typename Callback>
struct ReceiveCase {
Channel<T, Capacity>& channel;
Callback callback;
Result<void, ChannelReceiveError> attempt() {
auto res = channel.try_receive();
if (res.is_success()) {
Result<T, ChannelBlockError> pass{res.get_value()};
callback(pass);
}
else if (res.get_error() == ChannelReceiveError::CHANNEL_CLOSED) {
Result<T, ChannelBlockError> pass{error(ChannelBlockError::CHANNEL_CLOSED)};
callback(pass);
return {};
}
return error(res.get_error());
}
};
template<typename T, size_t Capacity, typename Callback>
ReceiveCase<T, Capacity, Callback> receive_case(Channel<T, Capacity>& chan, const Callback& callback) {
return ReceiveCase<T, Capacity, Callback>{chan, callback};
}
It's almost identical. We try to receive a message and pass it to our callback. Otherwise we just return the error. With the exception of a closed channel. Since we get a "closed signal", we return a success. Sometimes it's perfectly valid to always receive from a closed channel since a value is being sent - just a "no value" value (those are weird edge cases, but they do exist). Meanwhile, it's not valid to send to a closed channel since nothing gets sent. We call the callback either way to give the code a chance to quit, but we'll handle them differently in the select statement.
Speaking of the select statement - it's now time.
Select Statement
Now is the time we've been waiting for - the actual select statement. We're going to allow multiple channel types, sends and receives, and everything without inheritance - or really any runtime polymorphism. Instead, we're going to use static polymorphism. This will require the use of variadic templates (and template specializations). Don't fret, we're not going to go too nuts, and it should still be pretty straightforward to understand (plus I'll explain things step by step). So, let's get started.
Declaring the Select
The very first thing we're going to do is forward declare our select and define our errors. We will not be setting up every possible template specialization, just the ones we're going to use. This means that our select statement will be forward declared without an implementation - and that's okay. Here's our declaration:
enum class SelectErrors {
NOT_READY_YET,
};
template<typename ...Args> struct Select;
First part done!
We're using variadic templates to take in all of the cases we're going to do. For instance, we could have a Select<SendCase<int, 5>, ReceiveCase<double,10>, ReceiveCase<std::string,2>>
. By having those cases come in as template parameters, we can support multiple types of cases and channels without having to use inheritance.
However, trying to tackle all possible variadic tempaltes is practically impossible. So instead, we'll create specializations for the first element (send or receive), and then have a list of everything else which we forward to the "next" select - which will then reuse our specializations. We will then have a very special terminator specialization which is empty. With this setup, we'll create a list of selects which we can then iterate over at runtime.
While there are compiler limits to variadic template length and recursive depths, for our use case we generally aren't going to run into those without some actually awful code since we're going to be typing every case. Forcing developers to type their cases is totally fine since we're trying to emulate a statement (i.e. language construct) and not a scheduler or runtime capability. And letting the compiler enforce some sort of constraint on how crazy developers get is often a good thing (we don't want 10,000 line select statements). Let's move on to actually defining logic for our select.
Selecting sends
Our select should be able to handle our send case. It should call attempt's send, and if it works it should quit. Otherwise, we should try the next statement. The type of the next select will be determined by a variadic template. We, the developer, don't need to know what that type will be. That's the compiler's job. We just have to ensure all of our specializations have a step method with the same signature. So, here's our send case.
template<typename T, size_t Capacity, typename Callback, typename ...Args>
struct Select<SendCase<T, Capacity, Callback>, Args...> {
using Case = SendCase<T, Capacity, Callback>;
using Next = Select<Args...>;
Case sendCase;
Next next;
Select(Case sendCase, Args... nextVals)
: sendCase(sendCase), next({nextVals...}) {}
Result<void, SelectErrors> step() {
auto res = sendCase.attempt();
if (res) {
return {};
}
return next.step();
}
void operator()() {
while (!step()) {
std::this_thread::yield();
}
}
};
In addition to our step, we have an operator overload to make our select callable. This operator overload is simply a loop that keeps attempting from that select onwards until we have a success. Each failure we yield the thread so we're not hogging the CPU constantly. We have the operator overload at every step of the select since we have no way to tell which step is the first one at this level. That's okay though, since we'll just end up calling the first level in our outer code anyways.
Now on to receiving messages.
Select receives
The code is going to look almost identical, with the only difference being a few names. Which poses the question - why do we specialize on the type of the case rather than just use a generic type since it's all the same code structure anyways? The answer is simple: at some point we're going to add timeouts, and those aren't going to work the same way as sends and receives. And since that's our very next feature, there's no point in refactoring code until we write it and see what's actually refactorable. Until then, we're just going to duplicate the code.
Here's our receive handler:
template<typename T, size_t Capacity, typename Callback, typename ...Args>
struct Select<ReceiveCase<T, Capacity, Callback>, Args...> {
using Case = ReceiveCase<T, Capacity, Callback>;
using Next = Select<Args...>;
Case receiveCase;
Next next;
Select(Case receiveCase, Args... nextVals)
: receiveCase(receiveCase), next({nextVals...}) {}
Result<void, SelectErrors> step() {
auto res = receiveCase.attempt();
if (res) {
return {};
}
return next.step();
}
void operator()() {
while (!step()) {
std::this_thread::yield();
}
}
};
Ending a select
We've gotten our cases handled so far. Now we just need to add our terminating select. Even though our last select isn't going to do anything special, we still need to make sure we have the same function signatures. Including our operator overload (just in case someone creates a blank select of all things). Here's our ending select:
template<>
struct Select<> {
Result<void, SelectErrors> step() { return error(SelectErrors::NOT_READY_YET); }
void operator()() {}
};
Almost done! We're just going to add a convenience wrapper method to give us type deductions and to automatically call our select statement. Here it is:
template<typename ...Args>
void select(Args... args) {
Select<Args...>{args...}();
}
Using the select
Now let's see what our developer experience is like! Here's a sample program using the select. In fact, its a C++ reproduction of the example Go program I showed earlier!
int main() {
Channel<int, 1> c;
Channel<int, 1> quit;
auto t1 = std::thread([&]{
int x = 1;
bool done = false;
while (!done) {
// our select statement
select(
send_case(c, x, [&](const auto&){
x *= x + 1;
}),
receive_case(quit, [&](const Result<int, ChannelBlockError>& v){
auto val = v.get_value();
std::cout << "Ending with " << val << "\n";
done = true;
})
);
}
});
for (int i = 0; i < 4; ++i) {
std::cout << c.receive().get_value() << "\n";
}
quit.send(13);
t1.join();
return 0;
}
Whew! That's a long post!
Fortunately, this post lays a lot of the groundwork for the rest of the series. The next post will be shorter as well. Up next we're going to add a feature to our select statement which Go is lacking - timeouts. I cover the details in Building Channels in C++: Part 6 - Adding timeouts.