Files
mercury/tests/hard_coded/thread_test_utils.m
Julien Fischer 2a096e69c9 Update programming style in a test.
tests/hard_coded/thread_barrier_test.m:
tests/hard_coded/thread_test_utils.m:
     As above.
2022-04-14 20:33:29 +10:00

177 lines
5.8 KiB
Mathematica

%---------------------------------------------------------------------------%
% vim: ft=mercury ts=4 sw=4 et
%---------------------------------------------------------------------------%
% Copyright (C) 2014 Mission Critical IT.
% Copyright (C) 2014 The Mercury team.
% This file may only be copied under the terms of the GNU Library General
% Public License - see the file COPYING.LIB in the Mercury distribution.
%---------------------------------------------------------------------------%
%
% File: thread_test_utils.m
% Author: Paul Bone
%
% These utilities make it easier to test concurrent code. In particular a
% concurrent program's IO actions may occur in a nondeterminsic order. This
% module provides alternative IO operations that provide some order so that
% program output matches test output (when the test passes).
%
%---------------------------------------------------------------------------%
:- module thread_test_utils.
:- interface.
:- import_module io.
:- import_module string.
%---------------------------------------------------------------------------%
% This type represents all the output of all the threads in the program.
%
:- type all_threads_output.
% This type represents the output of an individual thread.
%
:- type thread_output.
:- pred init_all_thread_output(all_threads_output::out, io::di, io::uo)
is det.
% new_thread_output(N) = Output
%
% Create a new thread output object for thread number N.
%
:- pred init_thread_output(all_threads_output::in, int::in,
thread_output::out, io::di, io::uo) is det.
% Save some output into the buffer.
%
:- pred t_write_string(thread_output::in, string::in, io::di, io::uo) is det.
% Close this threads output stream. All streams must be closed as
% write_all_thread_output/3 will use this to ensure it has received all
% the messages it should have.
%
:- pred close_thread_output(thread_output::in, io::di, io::uo) is det.
% Write out this set of buffers.
%
:- pred write_all_thread_output(all_threads_output::in, io::di, io::uo) is det.
%---------------------------------------------------------------------------%
%---------------------------------------------------------------------------%
:- implementation.
:- import_module int.
:- import_module list.
:- import_module map.
:- import_module maybe.
:- import_module thread.
:- import_module thread.channel.
:- import_module set.
%---------------------------------------------------------------------------%
:- type all_threads_output
---> all_threads_output(
channel(message)
).
:- type thread_output
---> thread_output(
to_thread :: int,
to_chan :: channel(message)
).
:- type message
---> message_output(
om_thread :: int,
om_string :: string
)
; message_open(
mopen_thread :: int
)
; message_close(
mc_thread :: int
).
%---------------------------------------------------------------------------%
init_all_thread_output(all_threads_output(Chan), !IO) :-
channel.init(Chan, !IO).
init_thread_output(AllOutput, Thread, Output, !IO) :-
AllOutput = all_threads_output(Chan),
Output = thread_output(Thread, Chan),
% Put a message that will be used later to show that this Output has
% been opened.
put(Chan, message_open(Thread), !IO).
t_write_string(Output, String, !IO) :-
put(Output ^ to_chan, message_output(Output ^ to_thread, String), !IO).
write_all_thread_output(AllOutput, !IO) :-
get_all_messages(AllOutput, set.init, map.init, Messages, !IO),
foldl(write_out_thread_messages, Messages, !IO).
% Messages indexed by thread. Each list is stored in reverse order.
%
:- type messages == map(int, list(string)).
:- pred get_all_messages(all_threads_output::in, set(int)::in,
messages::in, messages::out, io::di, io::uo) is det.
get_all_messages(AllOutput, OpenThreads, !Messages, !IO) :-
AllOutput = all_threads_output(Chan),
( if is_empty(OpenThreads) then
% If this might be the end of the messages then we only try and
% take, so we know if we should exit.
try_take(Chan, MaybeMessage, !IO)
else
% OTOH, if there may be threads that have not finished sending our
% messages, then we use a blocking take to ensure that we don't miss
% their messages.
take(Chan, Message0, !IO),
MaybeMessage = yes(Message0)
),
(
MaybeMessage = yes(Message),
(
Message = message_output(Thread, String),
( if map.search(!.Messages, Thread, TMessages0) then
TMessages = [String | TMessages0]
else
TMessages = [String]
),
map.set(Thread, TMessages, !Messages),
get_all_messages(AllOutput, OpenThreads, !Messages, !IO)
;
Message = message_open(Thread),
get_all_messages(AllOutput, insert(OpenThreads, Thread),
!Messages, !IO)
;
Message = message_close(Thread),
get_all_messages(AllOutput, delete(OpenThreads, Thread),
!Messages, !IO)
)
;
MaybeMessage = no
).
:- pred write_out_thread_messages(int::in, list(string)::in, io::di, io::uo)
is det.
write_out_thread_messages(Thread, Messages, !IO) :-
io.format("Messages from thread %d:\n", [i(Thread)], !IO),
foldr(write_out_message, Messages, !IO).
close_thread_output(Output, !IO) :-
put(Output ^ to_chan, message_close(Output ^ to_thread), !IO).
:- pred write_out_message(string::in, io::di, io::uo) is det.
write_out_message(String, !IO) :-
io.format("\t%s\n", [s(String)], !IO).