Rust - Simple threading

Previously published

This article was previously published on len-learns-rust.com. A full index of these articles can be found here.

The simplest threading is already covered by most Rust books. Starting up a thread, passing stuff to it, letting it run and waiting for it to finish.

Something like this is the basic thread example in Rust.

This work’s and is easy to understand and reason about. The spawned thread clearly runs for less time than the main thread as we join with it before the main thread completes but we rely on the spawned thread to decide when to shut down, this isn’t that important here as the spawned thread has a finite amount of work to do, but for threads that do a potentially infinite amount of work I will need a way to ask the thread to stop…

fn simple_thread()
{
    println!("Simple thread");

    let handle = thread::spawn(|| {
        println!("Spawned thread has started");

        for i in 1..10 {
            println!("hi #{} from the spawned thread!", i);
            thread::sleep(Duration::from_millis(1));
        }
        println!("Spawned thread is done");
    });

    for i in 1..15 {
       println!("hi #{} from the main thread!", i);
       thread::sleep(Duration::from_millis(1));
    }

    handle.join().unwrap();

    println!("thread has ended...");
}

Generally I use one of two ways to shut a thread down in a controlled manner. If the thread is being fed work from a queue then a special piece of work can ask the thread to shut down, if the thread is just ‘doing stuff’ then, on a Windows platform, I’d possibly use an Event to signal the thread and the thread would periodically check the event, possibly waiting on multiple events, one of which was the one that requests it to shut down.

In Rust we might use a channel to send messages to a thread and we can then either send a special message to trigger the thread shutdown or we can simply close the sending end of the channel and shut the thread down when we detect the channel closure in the thread.

Since this is the more idiomatic way to interact with threads in Rust I’ll start with this approach and then look at signalling a thread next.

Code to use a multi-producer, single-consumer, channel to send strings to a thread is as simple as this:

use std::{mem, thread};
use std::sync::mpsc;

fn main() {
    let (to_thread, from_main) = mpsc::channel::<String>();

    let thread_handle = thread::spawn(move || {
        println!("spawned thread has started");

        let mut done = false;

        while !done {
            match from_main.recv() {
                Ok(message) => {
                    println!("got {} from main", message);
                }
                Err(reason) => {
                    println!("thread recv error {}", reason);
                    done = true;
                }
            }
        }
        println!("spawned thread is done");
    });

    for i in 1..15 {
        println!("sending {} to thread", i);
        to_thread.send(i.to_string()).expect("failed to send");
    }

    println!("close channel, signal thread we're done");

    mem::drop(to_thread);

    println!("wait for thread to end");

    thread_handle.join().expect("failed to join with thread");

    println!("thread has ended");

    println!("all done...");
}

We create a channel and get given two halves, the sender and receiver. We give the receiver half to the thread and keep the sender in our main function. We spawn a thread, and it blocks on the channel until messages arrive or the channel is closed. We then send some messages and drop the sender end of the channel to close it. The thread processes the queued messages and then exits when it detects that the channel is closed.

As is my way, I like to wrap up code I’m likely to want to use into something reusable. In this case I’m possibly a bit premature but, here’s a generic ChannelThread struct that lets me start a thread than takes a closure to process messages and which allows me to send messages to the thread and shut the thread down cleanly…

struct ChannelThread<T> {
    channel: Option<Sender<T>>,
    handle: Option<thread::JoinHandle<()>>,
}

impl<T: Send + 'static> ChannelThread<T> {
    fn new<F>(mut f: F) -> Self
    where
        F: FnMut(T) -> bool + Send + 'static,
    {
        let (to_thread, from_controller) = mpsc::channel::<T>();

        let handle = Some(thread::spawn(move || {
            println!("Spawned thread has started");

            println!("Spawned is running");

            loop {
                match from_controller.recv() {
                    Ok(message) => {
                        if !f(message) {
                            break;
                        }
                    }
                    Err(reason) => {
                        println!("thread recv error {}", reason);
                        break;
                    }
                }
            }
            println!("Spawned thread is done");
        }));

        ChannelThread {
            channel: Some(to_thread),
            handle,
        }
    }

    fn send(&self, message: T) {
        self.channel
            .as_ref()
            .expect("Too late to send")
            .send(message)
            .expect("failed to send");
    }

    fn shutdown(&mut self) {
        if let Some(sender) = self.channel.take() {
            self.channel = None;

            drop(sender);
        }
    }

    fn join(&mut self) {
        if let Some(handle) = self.handle.take() {
            self.handle = None;

            handle.join().expect("failed to join with thread");
        } else {
            panic!("already joined");
        }
    }
}

This could be used like this:

    #[test]
    fn test_channel_thread_with_closure() {
        let mut thread = ChannelThread::new(|_message| {
            return true;
        });

        for i in 1..15 {
            println!("sending {} to thread", i);
            thread.send(i);
        }

        println!("close channel, signal thread we're done");

        thread.shutdown();

        println!("wait for thread to end");

        thread.join();

        println!("all done...");
    }
Join in

The code can be found here on GitHub each step on the journey will have one or more separate directories of code, so this article’s code is here and here this allows for easy comparison of changes at each stage.

Of course, there may be a better way; leave comments if you’d like to help me learn.