Files
mercury/compiler/par_loop_control.m
Zoltan Somogyi 6d1bc24d0b Make vartypes an abstract data type, in preparation for exploring
Estimated hours taken: 4
Branches: main

compiler/prog_data.m:
	Make vartypes an abstract data type, in preparation for exploring
	better representations for it.

compiler/mode_util.m:
	Provide two different versions of a predicate. The generic version
	continues to use map lookups. The other version knows it works on
	prog_vars, so it can use the abstract operations on them provided
	by prog_data.m.

compiler/accumulator.m:
compiler/add_class.m:
compiler/add_heap_ops.m:
compiler/add_pragma.m:
compiler/add_pred.m:
compiler/add_trail_ops.m:
compiler/arg_info.m:
compiler/builtin_lib_types.m:
compiler/bytecode_gen.m:
compiler/call_gen.m:
compiler/clause_to_proc.m:
compiler/closure_analysis.m:
compiler/code_info.m:
compiler/common.m:
compiler/complexity.m:
compiler/const_prop.m:
compiler/constraint.m:
compiler/continuation_info.m:
compiler/cse_detection.m:
compiler/ctgc.datastruct.m:
compiler/ctgc.util.m:
compiler/deep_profiling.m:
compiler/deforest.m:
compiler/dep_par_conj.m:
compiler/det_analysis.m:
compiler/det_report.m:
compiler/det_util.m:
compiler/disj_gen.m:
compiler/equiv_type_hlds.m:
compiler/erl_call_gen.m:
compiler/erl_code_gen.m:
compiler/erl_code_util.m:
compiler/exception_analysis.m:
compiler/float_regs.m:
compiler/follow_vars.m:
compiler/format_call.m:
compiler/goal_path.m:
compiler/goal_util.m:
compiler/hhf.m:
compiler/higher_order.m:
compiler/hlds_clauses.m:
compiler/hlds_goal.m:
compiler/hlds_out_goal.m:
compiler/hlds_out_pred.m:
compiler/hlds_pred.m:
compiler/hlds_rtti.m:
compiler/inlining.m:
compiler/instmap.m:
compiler/intermod.m:
compiler/interval.m:
compiler/lambda.m:
compiler/lco.m:
compiler/live_vars.m:
compiler/liveness.m:
compiler/lookup_switch.m:
compiler/mercury_to_mercury.m:
compiler/ml_accurate_gc.m:
compiler/ml_closure_gen.m:
compiler/ml_code_gen.m:
compiler/ml_code_util.m:
compiler/ml_disj_gen.m:
compiler/ml_lookup_switch.m:
compiler/ml_proc_gen.m:
compiler/ml_unify_gen.m:
compiler/mode_info.m:
compiler/modecheck_call.m:
compiler/modecheck_conj.m:
compiler/modecheck_goal.m:
compiler/modecheck_unify.m:
compiler/modecheck_util.m:
compiler/modes.m:
compiler/par_loop_control.m:
compiler/pd_info.m:
compiler/pd_util.m:
compiler/polymorphism.m:
compiler/post_typecheck.m:
compiler/prog_type_subst.m:
compiler/prop_mode_constraints.m:
compiler/purity.m:
compiler/qual_info.m:
compiler/rbmm.points_to_info.m:
compiler/rbmm.region_liveness_info.m:
compiler/rbmm.region_transformation.m:
compiler/saved_vars.m:
compiler/simplify.m:
compiler/size_prof.m:
compiler/ssdebug.m:
compiler/stack_alloc.m:
compiler/stack_opt.m:
compiler/store_alloc.m:
compiler/structure_reuse.analysis.m:
compiler/structure_reuse.direct.choose_reuse.m:
compiler/structure_reuse.direct.detect_garbage.m:
compiler/structure_reuse.indirect.m:
compiler/structure_sharing.analysis.m:
compiler/structure_sharing.domain.m:
compiler/switch_detection.m:
compiler/table_gen.m:
compiler/term_constr_build.m:
compiler/term_constr_util.m:
compiler/term_traversal.m:
compiler/term_util.m:
compiler/trace_gen.m:
compiler/trailing_analysis.m:
compiler/try_expand.m:
compiler/tupling.m:
compiler/type_constraints.m:
compiler/type_util.m:
compiler/typecheck.m:
compiler/typecheck_errors.m:
compiler/typecheck_info.m:
compiler/unify_gen.m:
compiler/unify_proc.m:
compiler/unique_modes.m:
compiler/untupling.m:
compiler/unused_args.m:
compiler/var_locn.m:
	Conform to the above.

compiler/prog_type.m:
compiler/rbmm.points_to_graph.m:
	Conform to the above.

	Move some comments where they belong.

compiler/stm_expand.m:
	Conform to the above.

	Do not export a predicate that is not used outside this module.

	Disable some debugging output unless it is asked for.

	Remove unnecessary prefixes on variable names.

library/version_array.m:
	Instead writing code for field access lookalike functions and defining
	lookup, set etc in terms of them, write code for lookup, set etc,
	and define the field access lookalike functions in terms of them.

	Change argument orders of some internal predicates to be
	more state variable friendly.

	Fix typos in comments.

tests/hard_coded/version_array_test.exp:
	Conform to the change to version_array.m.
2012-07-02 01:16:39 +00:00

1550 lines
64 KiB
Mathematica

%-----------------------------------------------------------------------------%
% vim: ft=mercury ts=4 sw=4 et
%-----------------------------------------------------------------------------%
% Copyright (C) 2011-2012 The University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%-----------------------------------------------------------------------------%
%
% File: par_loop_control.m.
% Author: pbone.
%
% This module implements the parallel loop control transformation.
% This transformation operates on procedure bodies that contain exactly one
% recursive call which occurs in the second and last conjunct of a parallel
% conjunction.
%
% Normally, this parallel conjunction would spawn off their second conjunct,
% execute the first conjunct, and then block waiting for the completion of
% the second. When the second conjunct contains a recursive call, the blocked
% first computation will be a context, and will thus have a high memory
% footprint due to its stacks. The objective of our transformation is to reduce
% this memory footprint.
%
% The way we do this is by spawning off the first parallel conjunct, and
% continuing execution of the second. Since the second conjunct is a recursive
% call, we will continue spawning off first conjuncts until we reach a limit
% that is imposed by a loop control structure. This limit prevents us from
% swamping the available CPUs with too much work. It also allows us to
% use one barrier for ALL the loop iterations, rather than one barrier
% for EACH loop iteration.
%
% Consider this loop:
%
% map(M, [], []).
% map(M, [X | Xs], [Y | Ys]) :-
% (
% M(X, Y)
% &
% map(M, Xs, Ys)
% ).
%
% It would be transformed to:
%
% map(M, Xs, Ys) :-
% create_loop_control(LC, P), % P is the number of contexts to use.
% map_lc(LC, M, Xs, Ys).
%
% map(LC, _, [], []) :-
% finish_loop_control(LC).
% map(LC, M, [X | Xs], [Y | Ys) :-
% wait_free_slot(LC, LCS) ->
% spawn_off(LCS, (
% M(X, Y),
% join_and_terminate(LC, LCS)
% ),
% map(LC, M, Xs, Ys). % May not use tail recursion.
%
% The parallel conjunction is replaced with a wait_free_slot goal and a
% spawn_off goal for each conjunct except for the last. The last is re-written
% to call the loop control version of the predicate.
%
% Rules:
%
% 1. This transformation works when there are multiple parallel conjunctions in
% different branches. It also works when the parallel conjunction has more
% than two conjuncts, in which case all but the right most branch are
% replaced with the call to spawn_off.
%
% 2. There may be code _after_ the recursive call that consumes variables
% produced in the first conjunct. This is safe because we get to such code
% only *after* the barrier in the base case has been executed. Any
% consumption before the recursive call will already be using a future,
% and is therefore safe.
%
% 3. The predicate must be singly recursive, i.e. its body cannot have
% more than one recursive call along any execution path. We need this
% to ensure that the base case (and its barrier) is executed exactly once.
%
% 4. Multiple parallel conjunctions may exist within the body, but due to rule
% 3, only one of them may contain a recursive call.
%
%----------------------------------------------------------------------------%
:- module transform_hlds.par_loop_control.
:- interface.
:- import_module hlds.hlds_module.
%----------------------------------------------------------------------------%
:- pred maybe_par_loop_control_module(module_info::in, module_info::out)
is det.
%----------------------------------------------------------------------------%
%----------------------------------------------------------------------------%
:- implementation.
:- import_module hlds.goal_path.
:- import_module hlds.goal_util.
:- import_module hlds.hlds_goal.
:- import_module hlds.hlds_pred.
:- import_module hlds.hlds_rtti.
:- import_module hlds.instmap.
:- import_module hlds.passes_aux.
:- import_module hlds.pred_table.
:- import_module libs.globals.
:- import_module libs.options.
:- import_module mdbcomp.goal_path.
:- import_module mdbcomp.prim_data.
:- import_module parse_tree.prog_data.
:- import_module parse_tree.prog_util.
:- import_module parse_tree.set_of_var.
:- import_module transform_hlds.dependency_graph.
:- import_module bool.
:- import_module digraph.
:- import_module list.
:- import_module map.
:- import_module maybe.
:- import_module pair.
:- import_module require.
:- import_module set.
:- import_module string.
:- import_module varset.
%----------------------------------------------------------------------------%
maybe_par_loop_control_module(!ModuleInfo) :-
module_info_rebuild_dependency_info(!ModuleInfo, DepInfo),
process_all_nonimported_procs(
update_module(maybe_par_loop_control_proc(DepInfo)),
!ModuleInfo).
:- pred maybe_par_loop_control_proc(dependency_info::in, pred_proc_id::in,
proc_info::in, proc_info::out, module_info::in, module_info::out) is det.
maybe_par_loop_control_proc(DepInfo, PredProcId, !ProcInfo, !ModuleInfo) :-
( loop_control_is_applicable(DepInfo, PredProcId, !.ProcInfo) ->
proc_info_get_goal(!.ProcInfo, Body0),
% Re-calculate goal ids.
proc_info_get_vartypes(!.ProcInfo, VarTypes),
fill_goal_id_slots_in_proc_body(!.ModuleInfo, VarTypes,
ContainingGoalMap, Body0, Body),
proc_info_set_goal(Body, !ProcInfo),
goal_get_loop_control_par_conjs(Body, PredProcId,
RecursiveParConjIds),
(
( RecursiveParConjIds = have_not_seen_recursive_call
; RecursiveParConjIds = seen_one_recursive_call_on_every_branch
; RecursiveParConjIds = seen_unusable_recursion
)
;
RecursiveParConjIds = seen_usable_recursion_in_par_conj(GoalIds),
% Go ahead and perform the transformation.
create_inner_proc(GoalIds, PredProcId, !.ProcInfo,
ContainingGoalMap, InnerPredProcId, InnerPredName,
!ModuleInfo),
update_outer_proc(PredProcId, InnerPredProcId, InnerPredName,
!.ModuleInfo, !ProcInfo)
)
;
true
).
%----------------------------------------------------------------------------%
% Loop control is applicable if the procedure contains a parallel
% conjunction with exactly two conjuncts whose right conjunct contains a
% recursive call.
%
:- pred loop_control_is_applicable(dependency_info::in, pred_proc_id::in,
proc_info::in) is semidet.
loop_control_is_applicable(DepInfo, PredProcId, ProcInfo) :-
proc_info_get_has_parallel_conj(ProcInfo, yes),
proc_info_get_inferred_determinism(ProcInfo, Detism),
% If the predicate itself is not deterministic then its recursive call
% will not be deterministic and therefore will not be found in a parallel
% conjunction.
( Detism = detism_det
; Detism = detism_cc_multi
),
proc_is_self_recursive(DepInfo, PredProcId).
:- pred proc_is_self_recursive(dependency_info::in, pred_proc_id::in)
is semidet.
proc_is_self_recursive(DepInfo, PredProcId) :-
hlds_dependency_info_get_dependency_graph(DepInfo, DepGraph),
% There must be a directly recursive call.
digraph.lookup_key(DepGraph, PredProcId, SelfKey),
digraph.is_edge(DepGraph, SelfKey, SelfKey),
% There must not be a indirectly recursive call.
% Note: we could handle this in the future by inlining one call within
% another, but recursion analysis in the deep profiler should support this
% first.
digraph.delete_edge(SelfKey, SelfKey, DepGraph, DepGraphWOSelfEdge),
digraph.tc(DepGraphWOSelfEdge, TCDepGraphWOSelfEdge),
not digraph.is_edge(TCDepGraphWOSelfEdge, SelfKey, SelfKey).
%----------------------------------------------------------------------------%
:- type seen_usable_recursion
---> have_not_seen_recursive_call
% There is no reachable recursive call in this goal.
; seen_one_recursive_call_on_every_branch
% There is exactly one recursive call on every reachable branch.
% Therefore this single recursion can be used if it is within
% a parallel conjunction.
; seen_unusable_recursion
% There is recursion, but we cannot use it. There may be several
% reasons for why we cannot use the transformation, including
% + Multiple recursion.
% + Recursion on some but not all branches or in code that is
% not det/cc_multi.
% + Usable recursion inside a parallel conjunction that is
% inside _another_ parallel conjunction.
; seen_usable_recursion_in_par_conj(list(goal_id)).
% There is recursion within the right-most conjunct of a
% parallel conjunction. There may be multiple cases of this
% (different parallel conjunctions in different branches).
% This subtype of seen usable recursion is the set of values for which we
% should keep searching.
%
:- inst seen_usable_recursion_continue
---> have_not_seen_recursive_call
; seen_one_recursive_call_on_every_branch
; seen_usable_recursion_in_par_conj(ground).
:- pred goal_get_loop_control_par_conjs(hlds_goal::in, pred_proc_id::in,
seen_usable_recursion::out) is det.
goal_get_loop_control_par_conjs(Goal, SelfPredProcId, SeenUsableRecursion) :-
Goal = hlds_goal(GoalExpr, GoalInfo),
Detism = goal_info_get_determinism(GoalInfo),
InstmapDelta = goal_info_get_instmap_delta(GoalInfo),
( instmap_delta_is_reachable(InstmapDelta) ->
(
GoalExpr = unify(_, _, _, _, _),
SeenUsableRecursion0 = have_not_seen_recursive_call
;
GoalExpr = plain_call(PredId, ProcId, _, _, _, _),
( SelfPredProcId = proc(PredId, ProcId) ->
SeenUsableRecursion0 =
seen_one_recursive_call_on_every_branch
;
SeenUsableRecursion0 = have_not_seen_recursive_call
)
;
GoalExpr = generic_call(_, _, _, _, _),
% We cannot determine if a generic call is recursive or not,
% however it most likely is not. In either case, we cannot perform
% the loop control transformation.
SeenUsableRecursion0 = have_not_seen_recursive_call
;
GoalExpr = call_foreign_proc(_, _, _, _, _, _, _),
SeenUsableRecursion0 = have_not_seen_recursive_call
;
GoalExpr = conj(ConjType, Conjs),
(
ConjType = plain_conj,
conj_get_loop_control_par_conjs(Conjs, SelfPredProcId,
have_not_seen_recursive_call, SeenUsableRecursion0)
;
ConjType = parallel_conj,
GoalId = goal_info_get_goal_id(GoalInfo),
par_conj_get_loop_control_par_conjs(Conjs, SelfPredProcId,
GoalId, SeenUsableRecursion0)
)
;
GoalExpr = disj(_),
% If the disjunction contains a recursive call at all, then the
% recursive call is in an unusable context.
( goal_calls(Goal, SelfPredProcId) ->
SeenUsableRecursion0 = seen_unusable_recursion
;
SeenUsableRecursion0 = have_not_seen_recursive_call
)
;
GoalExpr = switch(_, _CanFail, Cases),
list.map(case_get_loop_control_par_conjs(SelfPredProcId), Cases,
SeenUsableRecursionCases),
% If the switch can fail, then there is effectively another branch
% that has no recursive call. However, we do not need to test for
% this here, as checking the determinism of the goal will detect
% such cases.
merge_loop_control_par_conjs_between_branches_list(
SeenUsableRecursionCases, SeenUsableRecursion0)
;
GoalExpr = negation(SubGoal),
goal_get_loop_control_par_conjs(SubGoal, SelfPredProcId,
SeenUsableRecursion0)
% If the negation can fail (I don't see how it could possibly be
% 'det'); then code that checks the determinism below will ensure
% that any recursion found here is unusable (Like for can-fail
% switches).
;
GoalExpr = scope(_, SubGoal),
goal_get_loop_control_par_conjs(SubGoal, SelfPredProcId,
SeenUsableRecursion0)
% If the scope does a cut, then any recursion inside SubGoal
% is unusable, but the determinism check below will catch that.
;
GoalExpr = if_then_else(_, Cond, Then, Else),
goal_get_loop_control_par_conjs(Cond, SelfPredProcId,
SeenUsableRecursionCond),
(
SeenUsableRecursionCond = have_not_seen_recursive_call,
goal_get_loop_control_par_conjs(Then, SelfPredProcId,
SeenUsableRecursionThen),
goal_get_loop_control_par_conjs(Else, SelfPredProcId,
SeenUsableRecursionElse),
merge_loop_control_par_conjs_between_branches(
SeenUsableRecursionThen, SeenUsableRecursionElse,
SeenUsableRecursion0)
;
% We cannot make use of any recursion found in the condition
% of an if-then-else.
( SeenUsableRecursionCond =
seen_one_recursive_call_on_every_branch
; SeenUsableRecursionCond = seen_unusable_recursion
; SeenUsableRecursionCond =
seen_usable_recursion_in_par_conj(_)
),
SeenUsableRecursion0 = seen_unusable_recursion
)
;
GoalExpr = shorthand(_),
unexpected($module, $pred, "shorthand")
),
% If the goal might fail or might succeed more than once, then any
% recursion in it is unusable for loop control.
(
( SeenUsableRecursion0 = have_not_seen_recursive_call
; SeenUsableRecursion0 = seen_unusable_recursion
),
SeenUsableRecursion = SeenUsableRecursion0
;
( SeenUsableRecursion0 = seen_one_recursive_call_on_every_branch
; SeenUsableRecursion0 = seen_usable_recursion_in_par_conj(_)
),
(
( Detism = detism_det
; Detism = detism_cc_multi
),
SeenUsableRecursion = SeenUsableRecursion0
;
( Detism = detism_semi
; Detism = detism_multi
; Detism = detism_non
; Detism = detism_cc_non
; Detism = detism_erroneous
; Detism = detism_failure
),
SeenUsableRecursion = seen_unusable_recursion
)
)
;
% InstmapDelta is unreachable.
SeenUsableRecursion = have_not_seen_recursive_call
).
% Analyze the parallel conjunction for a usable recursive call.
%
% If any but the last conjunct contains a recursive call, then that
% recursive call is unusable. If only the last conjunct contains
% a recursive call, then that recursion is usable.
%
:- pred par_conj_get_loop_control_par_conjs(list(hlds_goal)::in,
pred_proc_id::in, goal_id::in, seen_usable_recursion::out) is det.
par_conj_get_loop_control_par_conjs(Conjs, SelfPredProcId,
GoalId, SeenUsableRecursion) :-
(
Conjs = [],
unexpected($module, $pred, "Empty parallel conjunction")
;
Conjs = [Head | Tail],
par_conj_get_loop_control_par_conjs_lag(Head, Tail, SelfPredProcId,
SeenUsableRecursion0),
(
SeenUsableRecursion0 = have_not_seen_recursive_call,
SeenUsableRecursion = SeenUsableRecursion0
;
SeenUsableRecursion0 = seen_one_recursive_call_on_every_branch,
SeenUsableRecursion = seen_usable_recursion_in_par_conj([GoalId])
;
( SeenUsableRecursion0 = seen_unusable_recursion
; SeenUsableRecursion0 = seen_usable_recursion_in_par_conj(_)
),
SeenUsableRecursion = seen_unusable_recursion
)
).
:- pred par_conj_get_loop_control_par_conjs_lag(hlds_goal::in,
list(hlds_goal)::in, pred_proc_id::in, seen_usable_recursion::out) is det.
par_conj_get_loop_control_par_conjs_lag(Conj, Conjs, SelfPredProcId,
SeenUsableRecursion) :-
goal_get_loop_control_par_conjs(Conj, SelfPredProcId,
SeenUsableRecursion0),
(
% This is the last conjunct. Therefore, if it contains a recursive
% call it is a the recursion we're looking for.
Conjs = [],
SeenUsableRecursion = SeenUsableRecursion0
;
Conjs = [Head | Tail],
% This is not the last conjunct. Therefore any recursion it contains
% is unusable.
(
( SeenUsableRecursion0 = seen_one_recursive_call_on_every_branch
; SeenUsableRecursion0 = seen_unusable_recursion
; SeenUsableRecursion0 = seen_usable_recursion_in_par_conj(_)
),
SeenUsableRecursion = seen_unusable_recursion
;
SeenUsableRecursion0 = have_not_seen_recursive_call,
% Analyze the rest of the conjunction.
par_conj_get_loop_control_par_conjs_lag(Head, Tail, SelfPredProcId,
SeenUsableRecursion)
)
).
:- pred conj_get_loop_control_par_conjs(hlds_goals::in, pred_proc_id::in,
seen_usable_recursion::in(seen_usable_recursion_continue),
seen_usable_recursion::out) is det.
conj_get_loop_control_par_conjs([], _, !SeenUsableRecursion).
conj_get_loop_control_par_conjs([Conj | Conjs], SelfPredProcId,
!SeenUsableRecursion) :-
goal_get_loop_control_par_conjs(Conj, SelfPredProcId,
SeenUsableRecursionConj),
merge_loop_control_par_conjs_sequential(SeenUsableRecursionConj,
!SeenUsableRecursion),
(
!.SeenUsableRecursion = seen_unusable_recursion
;
( !.SeenUsableRecursion = seen_one_recursive_call_on_every_branch
; !.SeenUsableRecursion = seen_usable_recursion_in_par_conj(_)
; !.SeenUsableRecursion = have_not_seen_recursive_call
),
conj_get_loop_control_par_conjs(Conjs, SelfPredProcId,
!SeenUsableRecursion)
).
:- pred case_get_loop_control_par_conjs(pred_proc_id::in, case::in,
seen_usable_recursion::out) is det.
case_get_loop_control_par_conjs(SelfPredProcId, case(_, _, Goal),
SeenUsableRecursion) :-
goal_get_loop_control_par_conjs(Goal, SelfPredProcId,
SeenUsableRecursion).
:- pred merge_loop_control_par_conjs_sequential(seen_usable_recursion::in,
seen_usable_recursion::in, seen_usable_recursion::out) is det.
merge_loop_control_par_conjs_sequential(have_not_seen_recursive_call,
Seen, Seen).
merge_loop_control_par_conjs_sequential(seen_unusable_recursion,
_, seen_unusable_recursion).
merge_loop_control_par_conjs_sequential(
seen_one_recursive_call_on_every_branch, Seen0, Seen) :-
(
Seen0 = have_not_seen_recursive_call,
Seen = seen_one_recursive_call_on_every_branch
;
( Seen0 = seen_one_recursive_call_on_every_branch
; Seen0 = seen_unusable_recursion
; Seen0 = seen_usable_recursion_in_par_conj(_)
),
Seen = seen_unusable_recursion
).
merge_loop_control_par_conjs_sequential(
seen_usable_recursion_in_par_conj(GoalIds), Seen0, Seen) :-
(
Seen0 = have_not_seen_recursive_call,
Seen = seen_usable_recursion_in_par_conj(GoalIds)
;
( Seen0 = seen_one_recursive_call_on_every_branch
; Seen0 = seen_unusable_recursion
; Seen0 = seen_usable_recursion_in_par_conj(_)
),
Seen = seen_unusable_recursion
).
:- pred merge_loop_control_par_conjs_between_branches_list(
list(seen_usable_recursion)::in, seen_usable_recursion::out) is det.
merge_loop_control_par_conjs_between_branches_list([],
have_not_seen_recursive_call).
merge_loop_control_par_conjs_between_branches_list([Seen | Seens], Result) :-
list.foldl(merge_loop_control_par_conjs_between_branches, Seens,
Seen, Result).
:- pred merge_loop_control_par_conjs_between_branches(
seen_usable_recursion::in, seen_usable_recursion::in,
seen_usable_recursion::out) is det.
merge_loop_control_par_conjs_between_branches(have_not_seen_recursive_call,
Seen0, Seen) :-
(
Seen0 = have_not_seen_recursive_call,
Seen = have_not_seen_recursive_call
;
( Seen0 = seen_one_recursive_call_on_every_branch
; Seen0 = seen_unusable_recursion
),
Seen = seen_unusable_recursion
;
Seen0 = seen_usable_recursion_in_par_conj(_),
Seen = Seen0
).
merge_loop_control_par_conjs_between_branches(
seen_one_recursive_call_on_every_branch, Seen0, Seen) :-
(
Seen0 = seen_one_recursive_call_on_every_branch,
Seen = Seen0
;
( Seen0 = have_not_seen_recursive_call
; Seen0 = seen_unusable_recursion
; Seen0 = seen_usable_recursion_in_par_conj(_)
),
Seen = seen_unusable_recursion
).
merge_loop_control_par_conjs_between_branches(seen_unusable_recursion, _,
seen_unusable_recursion).
merge_loop_control_par_conjs_between_branches(
seen_usable_recursion_in_par_conj(GoalIdsA), Seen0, Seen) :-
(
Seen0 = have_not_seen_recursive_call,
Seen = seen_usable_recursion_in_par_conj(GoalIdsA)
;
( Seen0 = seen_one_recursive_call_on_every_branch
; Seen0 = seen_unusable_recursion
),
Seen = seen_unusable_recursion
;
Seen0 = seen_usable_recursion_in_par_conj(GoalIdsB),
% We do the concatenation in this order so that it is not quadratic
% when called from merge_loop_control_par_conjs_between_branches_list.
GoalIds = GoalIdsA ++ GoalIdsB,
Seen = seen_usable_recursion_in_par_conj(GoalIds)
).
%----------------------------------------------------------------------------%
:- pred create_inner_proc(list(goal_id)::in, pred_proc_id::in, proc_info::in,
containing_goal_map::in, pred_proc_id::out, sym_name::out,
module_info::in, module_info::out) is det.
create_inner_proc(RecParConjIds, OldPredProcId, OldProcInfo,
ContainingGoalMap, PredProcId, PredSym, !ModuleInfo) :-
proc(OldPredId, OldProcId) = OldPredProcId,
module_info_pred_info(!.ModuleInfo, OldPredId, OldPredInfo),
% Gather data to build the new pred/proc.
module_info_get_name(!.ModuleInfo, ModuleName),
PredOrFunc = pred_info_is_pred_or_func(OldPredInfo),
make_pred_name(ModuleName, "LoopControl", yes(PredOrFunc),
pred_info_name(OldPredInfo), newpred_parallel_loop_control, PredSym0),
% The mode number is included because we want to avoid the creation of
% more than one predicate with the same name if more than one mode of
% a predicate is parallelised. Since the names of e.g. deep profiling
% proc_static structures are derived from the names of predicates,
% duplicate predicate names lead to duplicate global variable names
% and hence to link errors.
proc_id_to_int(OldProcId, OldProcInt),
add_sym_name_suffix(PredSym0, "_" ++ int_to_string(OldProcInt), PredSym),
pred_info_get_context(OldPredInfo, Context),
pred_info_get_origin(OldPredInfo, OldOrigin),
Origin = origin_transformed(transform_parallel_loop_control, OldOrigin,
OldPredId),
some [!Markers] (
init_markers(!:Markers),
add_marker(marker_is_impure, !Markers),
add_marker(marker_calls_are_fully_qualified, !Markers),
Markers = !.Markers
),
pred_info_get_typevarset(OldPredInfo, TypeVarSet),
pred_info_get_exist_quant_tvars(OldPredInfo, ExistQVars),
pred_info_get_class_context(OldPredInfo, ClassConstraints),
pred_info_get_arg_types(OldPredInfo, ArgTypes0),
some [!PredInfo, !Body, !VarSet, !VarTypes] (
% Construct the pred info structure. We initially construct it with
% the old proc info which will be replaced below.
pred_info_create(ModuleName, PredSym, PredOrFunc, Context, Origin,
status_local, Markers, ArgTypes0, TypeVarSet, ExistQVars,
ClassConstraints, set.init, map.init, OldProcInfo, ProcId,
!:PredInfo),
% Add the new predicate to the module.
some [!PredTable] (
module_info_get_predicate_table(!.ModuleInfo, !:PredTable),
predicate_table_insert(!.PredInfo, PredId, !PredTable),
module_info_set_predicate_table(!.PredTable, !ModuleInfo)
),
PredProcId = proc(PredId, ProcId),
% Now transform the predicate. This could not be done earlier because
% we needed to know the new PredProcId to re-write the recursive calls
% in the body.
proc_info_get_argmodes(OldProcInfo, ArgModes0),
proc_info_get_headvars(OldProcInfo, HeadVars0),
proc_info_get_varset(OldProcInfo, !:VarSet),
proc_info_get_vartypes(OldProcInfo, !:VarTypes),
proc_info_get_goal(OldProcInfo, !:Body),
varset.new_named_var("LC", LCVar, !VarSet),
add_var_type(LCVar, loop_control_var_type, !VarTypes),
should_preserve_tail_recursion(!.ModuleInfo, PreserveTailRecursion),
get_lc_wait_free_slot_proc(!.ModuleInfo, WaitFreeSlotProc),
get_lc_join_and_terminate_proc(!.ModuleInfo, JoinAndTerminateProc),
Info = loop_control_info(!.ModuleInfo, LCVar, OldPredProcId,
PredProcId, PredSym, PreserveTailRecursion, WaitFreeSlotProc,
lc_wait_free_slot_name, JoinAndTerminateProc,
lc_join_and_terminate_name),
goal_loop_control_all_recursive_paths(Info, RecParConjIds,
ContainingGoalMap, !Body, !VarSet, !VarTypes),
% Fixup the remaining recursive calls, and add barriers in the base
% cases.
goal_update_non_loop_control_paths(Info, RecParConjIds, _, !Body),
% Now create the new proc_info structure.
HeadVars = [LCVar | HeadVars0],
ArgTypes = [loop_control_var_type | ArgTypes0],
Ground = ground(shared, none),
In = (Ground -> Ground),
ArgModes = [In | ArgModes0],
proc_info_get_inst_varset(OldProcInfo, InstVarSet),
proc_info_get_rtti_varmaps(OldProcInfo, RttiVarMaps),
proc_info_get_inferred_determinism(OldProcInfo, Detism),
proc_info_create(Context, !.VarSet, !.VarTypes, HeadVars, InstVarSet,
ArgModes, detism_decl_none, Detism, !.Body, RttiVarMaps,
address_is_not_taken, map.init, ProcInfo),
% Update the other structures
pred_info_set_arg_types(TypeVarSet, ExistQVars, ArgTypes, !PredInfo),
pred_info_set_proc_info(ProcId, ProcInfo, !PredInfo),
module_info_set_pred_info(PredId, !.PredInfo, !ModuleInfo)
).
:- pred should_preserve_tail_recursion(module_info::in,
preserve_tail_recursion::out) is det.
should_preserve_tail_recursion(ModuleInfo, PreserveTailRecursion) :-
module_info_get_globals(ModuleInfo, Globals),
globals.lookup_bool_option(Globals,
par_loop_control_preserve_tail_recursion, PreserveTailRecursionBool),
(
PreserveTailRecursionBool = yes,
PreserveTailRecursion = preserve_tail_recursion
;
PreserveTailRecursionBool = no,
PreserveTailRecursion = do_not_preserve_tail_recursion
).
:- type loop_control_info
---> loop_control_info(
lci_module_info :: module_info,
lci_lc_var :: prog_var,
lci_rec_pred_proc_id :: pred_proc_id,
lci_inner_pred_proc_id :: pred_proc_id,
lci_inner_pred_name :: sym_name,
lci_preserve_tail_recursion :: preserve_tail_recursion,
lci_wait_free_slot_proc :: pred_proc_id,
lci_wait_free_slot_proc_name :: sym_name,
lci_join_and_terminate_proc :: pred_proc_id,
lci_join_and_terminate_proc_name :: sym_name
).
:- type preserve_tail_recursion
---> preserve_tail_recursion
; do_not_preserve_tail_recursion.
% Is the current goal the last goal on an execution path through the
% procedure. In other words, can the last goal within the current goal use
% a tailcall?
%
:- type goal_is_last_goal_on_path
---> goal_is_last_goal_on_path
; goal_is_not_last_goal_on_path.
:- pred goal_loop_control_all_recursive_paths(loop_control_info::in,
list(goal_id)::in, containing_goal_map::in, hlds_goal::in, hlds_goal::out,
prog_varset::in, prog_varset::out, vartypes::in, vartypes::out) is det.
goal_loop_control_all_recursive_paths(Info, GoalIds, ContainingGoalMap, !Goal,
!VarSet, !VarTypes) :-
GoalPaths = list.map(goal_id_to_forward_path(ContainingGoalMap), GoalIds),
list.foldl3(goal_loop_control_one_recursive_path(Info,
goal_is_last_goal_on_path),
GoalPaths, !Goal, !VarSet, !VarTypes).
:- pred goal_loop_control_one_recursive_path(loop_control_info::in,
goal_is_last_goal_on_path::in, forward_goal_path::in,
hlds_goal::in, hlds_goal::out, prog_varset::in, prog_varset::out,
vartypes::in, vartypes::out) is det.
goal_loop_control_one_recursive_path(Info, IsLastGoal, GoalPath0, !Goal,
!VarSet, !VarTypes) :-
!.Goal = hlds_goal(GoalExpr0, GoalInfo),
( goal_path_remove_first(GoalPath0, GoalPath, Step) ->
format("Couldn't follow goal path step: \"%s\"", [s(string(Step))],
ErrorString),
(
Step = step_conj(N),
(
GoalExpr0 = conj(plain_conj, Conjs0),
list.index1(Conjs0, N, Conj0)
->
(
IsLastGoal = goal_is_last_goal_on_path,
( N = length(Conjs0) ->
IsLastGoalConj = goal_is_last_goal_on_path
;
IsLastGoalConj = goal_is_not_last_goal_on_path
)
;
IsLastGoal = goal_is_not_last_goal_on_path,
IsLastGoalConj = IsLastGoal
),
goal_loop_control_one_recursive_path(Info, IsLastGoalConj,
GoalPath, Conj0, Conj, !VarSet, !VarTypes),
det_replace_nth(Conjs0, N, Conj, Conjs),
GoalExpr = conj(plain_conj, Conjs)
;
unexpected($module, $pred, ErrorString)
)
;
Step = step_switch(N, _),
(
GoalExpr0 = switch(Var, CanFail, Cases0),
list.index1(Cases0, N, Case0)
->
Goal0 = Case0 ^ case_goal,
goal_loop_control_one_recursive_path(Info, IsLastGoal,
GoalPath, Goal0, Goal, !VarSet, !VarTypes),
Case = Case0 ^ case_goal := Goal,
det_replace_nth(Cases0, N, Case, Cases),
GoalExpr = switch(Var, CanFail, Cases)
;
unexpected($module, $pred, ErrorString)
)
;
Step = step_ite_then,
( GoalExpr0 = if_then_else(Vars, Cond, Then0, Else) ->
goal_loop_control_one_recursive_path(Info, IsLastGoal,
GoalPath, Then0, Then, !VarSet, !VarTypes),
GoalExpr = if_then_else(Vars, Cond, Then, Else)
;
unexpected($module, $pred, ErrorString)
)
;
Step = step_ite_else,
( GoalExpr0 = if_then_else(Vars, Cond, Then, Else0) ->
goal_loop_control_one_recursive_path(Info, IsLastGoal,
GoalPath, Else0, Else, !VarSet, !VarTypes),
GoalExpr = if_then_else(Vars, Cond, Then, Else)
;
unexpected($module, $pred, ErrorString)
)
;
Step = step_scope(_),
( GoalExpr0 = scope(Reason, SubGoal0) ->
goal_loop_control_one_recursive_path(Info, IsLastGoal,
GoalPath, SubGoal0, SubGoal, !VarSet, !VarTypes),
GoalExpr = scope(Reason, SubGoal)
;
unexpected($module, $pred, ErrorString)
)
;
( Step = step_ite_cond
; Step = step_disj(_)
; Step = step_neg
; Step = step_lambda
; Step = step_try
; Step = step_atomic_main
; Step = step_atomic_orelse(_)
),
unexpected($module, $pred,
format("Unexpected step in goal path \"%s\"",
[s(string(Step))]))
),
!:Goal = hlds_goal(GoalExpr, GoalInfo),
fixup_goal_info(Info, !Goal)
;
( GoalExpr0 = conj(parallel_conj, Conjs) ->
par_conj_loop_control(Info, Conjs, IsLastGoal, GoalInfo, !:Goal,
!VarSet, !VarTypes)
;
unexpected($module, $pred, "expected parallel conjunction")
)
).
:- pred par_conj_loop_control(loop_control_info::in, list(hlds_goal)::in,
goal_is_last_goal_on_path::in, hlds_goal_info::in, hlds_goal::out,
prog_varset::in, prog_varset::out, vartypes::in, vartypes::out) is det.
par_conj_loop_control(Info, Conjuncts0, IsLastGoal, GoalInfo, Goal, !VarSet,
!VarTypes) :-
list.det_split_last(Conjuncts0, EarlierConjuncts0, LastConjunct0),
% Re-write the recursive call in the last conjunct.
goal_rewrite_recursive_call(Info, IsLastGoal, LastConjunct0, LastConjunct,
UseParentStack, _),
goal_to_conj_list(LastConjunct, LastConjGoals),
% Process the remaining conjuncts.
rewrite_nonrecursive_par_conjuncts(Info, UseParentStack,
EarlierConjuncts0, EarlierConjuncts, !VarSet, !VarTypes),
Conjuncts = EarlierConjuncts ++ LastConjGoals,
% XXX The point of calling create_conj_from_list is that it sets up
% the goal_info of Goal0 appropriately. Why call it if we then immediately
% overwrite the goal_info?
create_conj_from_list(Conjuncts, plain_conj, Goal0),
Goal1 = Goal0 ^ hlds_goal_info := GoalInfo,
fixup_goal_info(Info, Goal1, Goal).
% Process each of the conjuncts, building the new expression from them.
%
:- pred rewrite_nonrecursive_par_conjuncts(loop_control_info::in,
lc_use_parent_stack::in, list(hlds_goal)::in, list(hlds_goal)::out,
prog_varset::in, prog_varset::out, vartypes::in, vartypes::out) is det.
rewrite_nonrecursive_par_conjuncts(_, _, [], [], !VarSet, !VarTypes).
rewrite_nonrecursive_par_conjuncts(Info, UseParentStack,
[Conjunct0 | Conjuncts0], Goals, !VarSet, !VarTypes) :-
% Create the "get free slot" call..
create_get_free_slot_goal(Info, LCSVar, GetFreeSlotGoal,
!VarSet, !VarTypes),
% Add a join_and_terminate goal to the end of Conj0 forming Conj.
create_join_and_terminate_goal(Info, LCVar, LCSVar, JoinAndTerminateGoal),
Conjunct0GoalInfo = Conjunct0 ^ hlds_goal_info,
goal_to_conj_list(Conjunct0, Conjunct0Goals),
ConjunctGoals = Conjunct0Goals ++ [JoinAndTerminateGoal],
some [!NonLocals] (
!:NonLocals = goal_info_get_nonlocals(Conjunct0GoalInfo),
set_of_var.insert(LCSVar, !NonLocals),
set_of_var.insert(LCVar, !NonLocals),
goal_info_set_nonlocals(!.NonLocals,
Conjunct0GoalInfo, ConjunctGoalInfo)
),
conj_list_to_goal(ConjunctGoals, ConjunctGoalInfo, Conjunct),
% Wrap Conjunct in the loop control scope.
LCVar = Info ^ lci_lc_var,
ScopeGoalInfo = ConjunctGoalInfo,
ScopeGoalExpr = scope(
loop_control(LCVar, LCSVar, UseParentStack), Conjunct),
ScopeGoal = hlds_goal(ScopeGoalExpr, ScopeGoalInfo),
rewrite_nonrecursive_par_conjuncts(Info, UseParentStack, Conjuncts0,
TailGoals, !VarSet, !VarTypes),
Goals = [GetFreeSlotGoal, ScopeGoal | TailGoals].
% Re-write any recursive calls in this goal.
%
% This predicate's argument order does not conform to the Mercury coding
% standards, this is deliberate as it makes it easier to call from
% list.map2.
%
% UseParentStack is lc_use_parent_stack_frame if, from this goal's
% perspective it is save to use the parent stack in any spawned off code
% running in parallel with this goal. Otherwise it is
% lc_create_frame_on_child_stack.
:- pred goal_rewrite_recursive_call(loop_control_info::in,
goal_is_last_goal_on_path::in, hlds_goal::in, hlds_goal::out,
lc_use_parent_stack::out, fixup_goal_info::out) is det.
goal_rewrite_recursive_call(Info, IsLastGoal, !Goal, UseParentStack,
FixupGoalInfo) :-
!.Goal = hlds_goal(GoalExpr0, GoalInfo),
(
GoalExpr0 = plain_call(CallPredId0, CallProcId0, Args0, Builtin,
MaybeUnify, _Name0),
RecPredProcId = Info ^ lci_rec_pred_proc_id,
( RecPredProcId = proc(CallPredId0, CallProcId0) ->
NewPredProcId = Info ^ lci_inner_pred_proc_id,
proc(CallPredId, CallProcId) = NewPredProcId,
LCVar = Info ^ lci_lc_var,
Args = [LCVar | Args0],
Name = Info ^ lci_inner_pred_name,
GoalExpr = plain_call(CallPredId, CallProcId, Args, Builtin,
MaybeUnify, Name),
PreserveTailRecursion = Info ^ lci_preserve_tail_recursion,
!:Goal = hlds_goal(GoalExpr, GoalInfo),
(
IsLastGoal = goal_is_last_goal_on_path,
PreserveTailRecursion = preserve_tail_recursion
->
% Create a frame on the child's stack so that the parent can
% tail-call.
UseParentStack = lc_create_frame_on_child_stack
;
UseParentStack = lc_use_parent_stack_frame,
% Inform the code generator that this call may not be a tail
% call.
goal_add_feature(feature_do_not_tailcall, !Goal)
),
fixup_goal_info(Info, !Goal),
FixupGoalInfo = fixup_goal_info
;
UseParentStack = lc_use_parent_stack_frame,
FixupGoalInfo = do_not_fixup_goal_info
)
;
( GoalExpr0 = unify(_, _, _, _, _)
; GoalExpr0 = generic_call(_, _, _, _, _)
; GoalExpr0 = call_foreign_proc(_, _, _, _, _, _, _)
; GoalExpr0 = conj(_, _)
; GoalExpr0 = disj(_)
; GoalExpr0 = switch(_, _, _)
; GoalExpr0 = negation(_)
; GoalExpr0 = scope(_, _)
; GoalExpr0 = if_then_else(_, _, _, _)
),
(
( GoalExpr0 = unify(_, _, _, _, _)
; GoalExpr0 = generic_call(_, _, _, _, _)
; GoalExpr0 = call_foreign_proc(_, _, _, _, _, _, _)
),
GoalExpr = GoalExpr0,
% lc_use_parent_stack_frame is the most indifferent option.
UseParentStack = lc_use_parent_stack_frame,
FixupGoalInfo = do_not_fixup_goal_info
;
GoalExpr0 = conj(ConjType, Conjs0),
list.det_split_last(Conjs0, EarlierConjs0, LastConj0),
goal_rewrite_recursive_call(Info, IsLastGoal, LastConj0, LastConj,
UseParentStackLastConj, FixupGoalInfoLastConj),
list.map3(goal_rewrite_recursive_call(Info,
goal_is_not_last_goal_on_path),
EarlierConjs0, EarlierConjs, UseParentStackEarlierConjs,
FixupGoalInfoEarlierConjs),
FixupGoalInfoConjs =
[FixupGoalInfoLastConj | FixupGoalInfoEarlierConjs],
goals_fixup_goal_info(FixupGoalInfoConjs, FixupGoalInfo),
goals_use_parent_stack(UseParentStackEarlierConjs, UseParentStack0),
combine_use_parent_stack(UseParentStackLastConj, UseParentStack0,
UseParentStack),
Conjs = EarlierConjs ++ [LastConj],
GoalExpr = conj(ConjType, Conjs)
;
GoalExpr0 = disj(Disjs0),
% I don't care about disjunctions enough to try to preserve tail
% calls in them,
list.map3(goal_rewrite_recursive_call(Info,
goal_is_not_last_goal_on_path),
Disjs0, Disjs, UseParentStackDisjs, FixupGoalInfoDisjs),
goals_use_parent_stack(UseParentStackDisjs, UseParentStack),
goals_fixup_goal_info(FixupGoalInfoDisjs, FixupGoalInfo),
GoalExpr = disj(Disjs)
;
GoalExpr0 = switch(Var, CanFail, Cases0),
list.map3(case_rewrite_recursive_call(Info, IsLastGoal),
Cases0, Cases, UseParentStackCases, FixupGoalInfoCases),
goals_use_parent_stack(UseParentStackCases, UseParentStack),
goals_fixup_goal_info(FixupGoalInfoCases, FixupGoalInfo),
GoalExpr = switch(Var, CanFail, Cases)
;
GoalExpr0 = negation(SubGoal0),
goal_rewrite_recursive_call(Info, IsLastGoal, SubGoal0, SubGoal,
UseParentStack, FixupGoalInfo),
GoalExpr = negation(SubGoal)
;
GoalExpr0 = scope(Reason, SubGoal0),
goal_rewrite_recursive_call(Info, IsLastGoal, SubGoal0, SubGoal,
UseParentStack, FixupGoalInfo),
GoalExpr = scope(Reason, SubGoal)
;
GoalExpr0 = if_then_else(Vars, Cond0, Then0, Else0),
goal_rewrite_recursive_call(Info, goal_is_last_goal_on_path,
Cond0, Cond, UseParentStackCond, FixupGoalInfoCond),
goal_rewrite_recursive_call(Info, IsLastGoal, Then0, Then,
UseParentStackThen, FixupGoalInfoThen),
goal_rewrite_recursive_call(Info, IsLastGoal, Else0, Else,
UseParentStackElse, FixupGoalInfoElse),
goals_fixup_goal_info([FixupGoalInfoCond, FixupGoalInfoThen,
FixupGoalInfoElse], FixupGoalInfo),
goals_use_parent_stack([UseParentStackCond, UseParentStackThen,
UseParentStackElse], UseParentStack),
GoalExpr = if_then_else(Vars, Cond, Then, Else)
),
!:Goal = hlds_goal(GoalExpr, GoalInfo),
(
FixupGoalInfo = fixup_goal_info,
fixup_goal_info(Info, !Goal)
;
FixupGoalInfo = do_not_fixup_goal_info
)
;
GoalExpr0 = shorthand(_),
unexpected($module, $pred, "shorthand")
).
:- pred case_rewrite_recursive_call(loop_control_info::in,
goal_is_last_goal_on_path::in, case::in, case::out,
lc_use_parent_stack::out, fixup_goal_info::out) is det.
case_rewrite_recursive_call(Info, IsLastGoal, !Case, UseParentStack,
FixupGoalInfo) :-
some [!Goal] (
!:Goal = !.Case ^ case_goal,
goal_rewrite_recursive_call(Info, IsLastGoal, !Goal, UseParentStack,
FixupGoalInfo),
!Case ^ case_goal := !.Goal
).
:- pred goals_fixup_goal_info(list(fixup_goal_info)::in, fixup_goal_info::out)
is det.
goals_fixup_goal_info(List, Fixup) :-
( list.contains(List, fixup_goal_info) ->
Fixup = fixup_goal_info
;
Fixup = do_not_fixup_goal_info
).
:- pred goals_use_parent_stack(list(lc_use_parent_stack)::in,
lc_use_parent_stack::out) is det.
goals_use_parent_stack([], lc_use_parent_stack_frame).
goals_use_parent_stack([X | Xs], UseParentStack) :-
goals_use_parent_stack(Xs, UseParentStack0),
combine_use_parent_stack(X, UseParentStack0, UseParentStack).
:- pred combine_use_parent_stack(lc_use_parent_stack::in,
lc_use_parent_stack::in, lc_use_parent_stack::out) is det.
combine_use_parent_stack(lc_use_parent_stack_frame,
lc_use_parent_stack_frame, lc_use_parent_stack_frame).
combine_use_parent_stack(lc_use_parent_stack_frame,
lc_create_frame_on_child_stack, lc_create_frame_on_child_stack).
combine_use_parent_stack(lc_create_frame_on_child_stack,
lc_use_parent_stack_frame, lc_create_frame_on_child_stack).
combine_use_parent_stack(lc_create_frame_on_child_stack,
lc_create_frame_on_child_stack, lc_create_frame_on_child_stack).
%----------------------------------------------------------------------------%
% This predicate does two things:
% + It inserts a loop control barrier into the base case(s) of the
% predicate.
% + It re-writes the recursive calls that aren't part of parallel
% conjunctions so that they call the inner procedure and pass the loop
% control variable.
%
:- pred goal_update_non_loop_control_paths(loop_control_info::in,
list(goal_id)::in, fixup_goal_info::out,
hlds_goal::in, hlds_goal::out) is det.
goal_update_non_loop_control_paths(Info, RecParConjIds, FixupGoalInfo,
!Goal) :-
GoalInfo0 = !.Goal ^ hlds_goal_info,
GoalId = goal_info_get_goal_id(GoalInfo0),
(
% This goal is one of the transformed parallel conjunctions,
% nothing needs to be done.
% The last conjunct always recurses, this is inforced by
% merge_loop_control_par_conjs_between_branches, but we should check
% to see how often this happens and if we should handle it.
% XXX This may not work, I don't know if the goal ID is maintained.
list.member(GoalId, RecParConjIds)
->
FixupGoalInfo = do_not_fixup_goal_info
;
% This goal is a base case, insert the barrier.
not ( some [Callee] (
goal_calls(!.Goal, Callee),
(
Callee = Info ^ lci_rec_pred_proc_id
;
Callee = Info ^ lci_inner_pred_proc_id
)
) )
->
goal_to_conj_list(!.Goal, Conjs0),
create_finish_loop_control_goal(Info, FinishLCGoal),
Conjs = Conjs0 ++ [FinishLCGoal],
conj_list_to_goal(Conjs, GoalInfo0, !:Goal),
fixup_goal_info(Info, !Goal),
FixupGoalInfo = fixup_goal_info
;
!.Goal = hlds_goal(GoalExpr0, _),
(
( GoalExpr0 = unify(_, _, _, _, _)
; GoalExpr0 = generic_call(_, _, _, _, _)
; GoalExpr0 = call_foreign_proc(_, _, _, _, _, _, _)
),
% These cannot be a recursive call and they cannot be a base case
% since base cases are detected above.
unexpected($module, $pred, "Non-recursive atomic goal")
;
GoalExpr0 = plain_call(PredId, ProcId, Args0, Builtin,
MaybeContext, _SymName0),
% This can only be a recursive call, it must be re-written
RecPredProcId = Info ^ lci_rec_pred_proc_id,
expect(unify(RecPredProcId, proc(PredId, ProcId)), $module, $pred,
"Expected recursive call"),
proc(InnerPredId, InnerProcId) = Info ^ lci_inner_pred_proc_id,
LCVar = Info ^ lci_lc_var,
Args = [LCVar | Args0],
SymName = Info ^ lci_inner_pred_name,
GoalExpr = plain_call(InnerPredId, InnerProcId, Args, Builtin,
MaybeContext, SymName),
FixupGoalInfo = fixup_goal_info
;
GoalExpr0 = conj(ConjType, Conjs0),
expect(unify(ConjType, plain_conj), $module, $pred,
"parallel conjunction"),
conj_update_non_loop_control_paths(Info, RecParConjIds,
FixupGoalInfo, Conjs0, Conjs),
GoalExpr = conj(ConjType, Conjs)
;
GoalExpr0 = disj(_),
sorry($module, $pred, "disjunction")
;
GoalExpr0 = switch(Var, CanFail, Cases0),
list.map2(case_update_non_loop_control_paths(Info, RecParConjIds),
Cases0, Cases, FixupGoalInfos),
goals_fixup_goal_info(FixupGoalInfos, FixupGoalInfo),
GoalExpr = switch(Var, CanFail, Cases)
;
GoalExpr0 = negation(_),
sorry($module, $pred, "negation")
;
GoalExpr0 = scope(Reason, SubGoal0),
goal_update_non_loop_control_paths(Info, RecParConjIds,
FixupGoalInfo, SubGoal0, SubGoal),
GoalExpr = scope(Reason, SubGoal)
;
GoalExpr0 = if_then_else(ExistVars, Cond, Then0, Else0),
% There may not be any recursive calls in Cond; if there are,
% we don not apply the transformation.
goal_update_non_loop_control_paths(Info, RecParConjIds,
FixupGoalInfoThen, Then0, Then),
goal_update_non_loop_control_paths(Info, RecParConjIds,
FixupGoalInfoElse, Else0, Else),
goals_fixup_goal_info([FixupGoalInfoThen, FixupGoalInfoElse],
FixupGoalInfo),
GoalExpr = if_then_else(ExistVars, Cond, Then, Else)
;
GoalExpr0 = shorthand(_),
unexpected($module, $pred, "shorthand")
),
!Goal ^ hlds_goal_expr := GoalExpr,
(
FixupGoalInfo = fixup_goal_info,
some [!NonLocals, !GoalInfo] (
!:GoalInfo = !.Goal ^ hlds_goal_info,
!:NonLocals = goal_info_get_nonlocals(!.GoalInfo),
set_of_var.insert(Info ^ lci_lc_var, !NonLocals),
goal_info_set_nonlocals(!.NonLocals, !GoalInfo),
goal_info_set_purity(purity_impure, !GoalInfo),
!Goal ^ hlds_goal_info := !.GoalInfo
)
;
FixupGoalInfo = do_not_fixup_goal_info
)
).
% As goal_update_non_loop_control_paths, but for a conjunction.
%
:- pred conj_update_non_loop_control_paths(loop_control_info::in,
list(goal_id)::in, fixup_goal_info::out,
list(hlds_goal)::in, list(hlds_goal)::out) is det.
conj_update_non_loop_control_paths(_Info, _RecGoalIds, do_not_fixup_goal_info,
[], []).
conj_update_non_loop_control_paths(Info, RecGoalIds, FixupGoalInfo,
[Conj0 | Conjs0], [Conj | Conjs]) :-
(
not goal_calls(Conj0, Callee),
(
% XXX At the moment, we require that all recursive calls be
% inside parallel conjunctions, and all those recursive calls
% have by now been transformed to call the inner procedure instead.
% So the first part of this disjunction cannot succeed.
Callee = Info ^ lci_rec_pred_proc_id
;
Callee = Info ^ lci_inner_pred_proc_id
)
->
% Conj0 does not make a recursive call or contain a recursive
% parallel conjunction. We don't need to transform it.
Conj = Conj0,
conj_update_non_loop_control_paths(Info, RecGoalIds, FixupGoalInfo,
Conjs0, Conjs)
;
% This Conj has something that needs to be transformed.
goal_update_non_loop_control_paths(Info, RecGoalIds, FixupGoalInfo,
Conj0, Conj),
% There is not going to be anything else in this conjunct
% that needs to be transformed, we don't make a recursive call.
Conjs = Conjs0
).
% As goal_update_non_loop_control_paths, but for a case.
% Note that this argument order is needed by a higher order call above.
%
:- pred case_update_non_loop_control_paths(loop_control_info::in,
list(goal_id)::in, case::in, case::out, fixup_goal_info::out) is det.
case_update_non_loop_control_paths(Info, RecParConjIds, !Case,
FixupGoalInfo) :-
some [!Goal] (
!:Goal = !.Case ^ case_goal,
goal_update_non_loop_control_paths(Info, RecParConjIds,
FixupGoalInfo, !Goal),
!Case ^ case_goal := !.Goal
).
%----------------------------------------------------------------------------%
:- pred create_get_free_slot_goal(loop_control_info::in, prog_var::out,
hlds_goal::out, prog_varset::in, prog_varset::out,
vartypes::in, vartypes::out) is det.
create_get_free_slot_goal(Info, LCSVar, Goal, !VarSet,
!VarTypes) :-
varset.new_named_var("LCS", LCSVar, !VarSet),
add_var_type(LCSVar, loop_control_slot_var_type, !VarTypes),
LCVar = Info ^ lci_lc_var,
proc(PredId, ProcId) = Info ^ lci_wait_free_slot_proc,
SymName = Info ^ lci_wait_free_slot_proc_name,
GoalExpr = plain_call(PredId, ProcId, [LCVar, LCSVar], not_builtin, no,
SymName),
NonLocals = list_to_set([LCVar, LCSVar]),
InstmapDelta = instmap_delta_bind_var(LCSVar),
GoalInfo = impure_init_goal_info(NonLocals, InstmapDelta, detism_det),
Goal = hlds_goal(GoalExpr, GoalInfo).
%----------------------------------------------------------------------------%
:- pred create_create_loop_control_goal(module_info::in, prog_var::in,
prog_var::out, hlds_goal::out, prog_varset::in, prog_varset::out,
vartypes::in, vartypes::out) is det.
create_create_loop_control_goal(ModuleInfo, NumContextsVar, LCVar, Goal,
!VarSet, !VarTypes) :-
varset.new_named_var("LC", LCVar, !VarSet),
add_var_type(LCVar, loop_control_var_type, !VarTypes),
get_lc_create_proc(ModuleInfo, LCCreatePredId, LCCreateProcId),
GoalExpr = plain_call(LCCreatePredId, LCCreateProcId,
[NumContextsVar, LCVar], not_builtin, no, lc_create_name),
goal_info_init(set_of_var.list_to_set([NumContextsVar, LCVar]),
instmap_delta_bind_var(LCVar), detism_det, purity_pure, GoalInfo),
Goal = hlds_goal(GoalExpr, GoalInfo).
%----------------------------------------------------------------------------%
:- pred create_join_and_terminate_goal(loop_control_info::in, prog_var::in,
prog_var::in, hlds_goal::out) is det.
create_join_and_terminate_goal(Info, LCVar, LCSVar, Goal) :-
proc(PredId, ProcId) = Info ^ lci_join_and_terminate_proc,
SymName = Info ^ lci_join_and_terminate_proc_name,
GoalExpr = plain_call(PredId, ProcId, [LCVar, LCSVar], not_builtin, no,
SymName),
NonLocals = list_to_set([LCVar, LCSVar]),
instmap_delta_init_reachable(InstmapDelta),
GoalInfo = impure_init_goal_info(NonLocals, InstmapDelta, detism_det),
Goal = hlds_goal(GoalExpr, GoalInfo).
%----------------------------------------------------------------------------%
:- pred create_finish_loop_control_goal(loop_control_info::in, hlds_goal::out)
is det.
create_finish_loop_control_goal(Info, Goal) :-
get_lc_finish_loop_control_proc(Info ^ lci_module_info, PredId, ProcId),
LCVar = Info ^ lci_lc_var,
GoalExpr = plain_call(PredId, ProcId, [LCVar], not_builtin, no,
lc_finish_loop_control_name),
NonLocals = list_to_set([LCVar]),
instmap_delta_init_reachable(InstmapDelta),
GoalInfo = impure_init_goal_info(NonLocals, InstmapDelta, detism_det),
Goal = hlds_goal(GoalExpr, GoalInfo).
%----------------------------------------------------------------------------%
:- type fixup_goal_info
---> fixup_goal_info
; do_not_fixup_goal_info.
% Fixup goalinfo after performing the loop control transformation.
%
:- pred fixup_goal_info(loop_control_info::in, hlds_goal::in, hlds_goal::out)
is det.
fixup_goal_info(Info, Goal0, Goal) :-
some [!GoalInfo, !NonLocals] (
Goal0 = hlds_goal(GoalExpr, !:GoalInfo),
LCVar = Info ^ lci_lc_var,
!:NonLocals = goal_info_get_nonlocals(!.GoalInfo),
set_of_var.insert(LCVar, !NonLocals),
goal_info_set_nonlocals(!.NonLocals, !GoalInfo),
goal_info_set_purity(purity_impure, !GoalInfo),
Goal = hlds_goal(GoalExpr, !.GoalInfo)
).
%----------------------------------------------------------------------------%
:- pred update_outer_proc(pred_proc_id::in, pred_proc_id::in, sym_name::in,
module_info::in, proc_info::in, proc_info::out) is det.
update_outer_proc(PredProcId, InnerPredProcId, InnerPredName, ModuleInfo,
!ProcInfo) :-
proc(PredId, _) = PredProcId,
module_info_pred_info(ModuleInfo, PredId, PredInfo),
pred_info_get_arg_types(PredInfo, HeadVarTypes),
proc_info_get_headvars(!.ProcInfo, HeadVars0),
proc_info_get_inferred_determinism(!.ProcInfo, Detism),
proc_info_get_goal(!.ProcInfo, OrigGoal),
OrigInstmapDelta = goal_info_get_instmap_delta(OrigGoal ^ hlds_goal_info),
some [!VarSet, !VarTypes] (
% Re-build the variables in the procedure with smaller sets.
varset.init(!:VarSet),
init_vartypes(!:VarTypes),
proc_info_get_varset(!.ProcInfo, OldVarSet),
foldl3_corresponding(add_old_var_to_sets(OldVarSet), HeadVars0,
HeadVarTypes, !VarSet, !VarTypes, map.init, Remap),
map(map.lookup(Remap), HeadVars0, HeadVars),
proc_info_set_headvars(HeadVars, !ProcInfo),
% Fix rtti varmaps.
proc_info_get_rtti_varmaps(!.ProcInfo, RttiVarmaps0),
apply_substitutions_to_rtti_varmaps(map.init, map.init, Remap,
RttiVarmaps0, RttiVarmaps),
proc_info_set_rtti_varmaps(RttiVarmaps, !ProcInfo),
% Create a variable for the number of worker contexts, we control this
% in the compiler so that it can be adjusted using profiler feedback
% (for auto-parallelisation), but for now we just set it using
% a runtime call so that it can be tuned.
varset.new_named_var("NumContexts", NumContextsVar, !VarSet),
add_var_type(NumContextsVar, builtin_type(builtin_type_int),
!VarTypes),
get_lc_default_num_contexts_proc(ModuleInfo,
LCDefaultNumContextsPredId, LCDefaultNumContextsProcId),
goal_info_init(list_to_set([NumContextsVar]),
instmap_delta_bind_var(NumContextsVar),
detism_det, purity_pure, GetNumContextsGoalInfo),
GetNumContextsGoal = hlds_goal(plain_call(LCDefaultNumContextsPredId,
LCDefaultNumContextsProcId, [NumContextsVar],
not_builtin, no, lc_default_num_contexts_name),
GetNumContextsGoalInfo),
% Create the call to lc_create
create_create_loop_control_goal(ModuleInfo, NumContextsVar, LCVar,
LCCreateGoal, !VarSet, !VarTypes),
% Create the inner call.
InnerCallArgs = [LCVar | HeadVars],
NonLocals = list_to_set(InnerCallArgs),
% The instmap of the call to the transformed body has the same instmap
% delta as the original body.
remap_instmap(Remap, OrigInstmapDelta, InstmapDelta),
goal_info_init(NonLocals, InstmapDelta, Detism, purity_impure,
InnerProcCallGoalInfo),
proc(InnerPredId, InnerProcId) = InnerPredProcId,
InnerProcCallGoal = hlds_goal(plain_call(InnerPredId, InnerProcId,
InnerCallArgs, not_builtin, no, InnerPredName),
InnerProcCallGoalInfo),
% Build a conjunction of these goals.
goal_info_init(list_to_set(HeadVars), InstmapDelta, Detism,
purity_impure, ConjGoalInfo),
ConjGoal = hlds_goal(conj(plain_conj,
[GetNumContextsGoal, LCCreateGoal, InnerProcCallGoal]),
ConjGoalInfo),
OrigPurity = goal_info_get_purity(OrigGoal ^ hlds_goal_info),
(
OrigPurity = purity_impure,
% The impurity introduced by this transformation does not need
% to be promised away.
Body = ConjGoal
;
( OrigPurity = purity_pure
; OrigPurity = purity_semipure
),
% Wrap the body in a scope to promise away the impurity.
goal_info_set_purity(purity_pure, ConjGoalInfo, ScopeGoalInfo),
Body = hlds_goal(scope(promise_purity(OrigPurity), ConjGoal),
ScopeGoalInfo)
),
proc_info_set_goal(Body, !ProcInfo),
proc_info_set_varset(!.VarSet, !ProcInfo),
proc_info_set_vartypes(!.VarTypes, !ProcInfo)
).
:- pred add_old_var_to_sets(prog_varset::in, prog_var::in, mer_type::in,
prog_varset::in, prog_varset::out, vartypes::in, vartypes::out,
prog_var_renaming::in, prog_var_renaming::out) is det.
add_old_var_to_sets(OldVarSet, OldVar, VarType, !VarSet, !VarTypes,
!Remap) :-
( varset.search_name(OldVarSet, OldVar, Name) ->
varset.new_named_var(Name, Var, !VarSet)
;
varset.new_var(Var, !VarSet)
),
add_var_type(Var, VarType, !VarTypes),
map.det_insert(OldVar, Var, !Remap).
:- pred remap_instmap(map(prog_var, prog_var)::in,
instmap_delta::in, instmap_delta::out) is det.
remap_instmap(Remap, OldInstmapDelta, !:InstmapDelta) :-
instmap_delta_to_assoc_list(OldInstmapDelta, VarInsts),
instmap_delta_init_reachable(!:InstmapDelta),
foldl((pred((OldVar - Inst)::in, IMD0::in, IMD::out) is det :-
map.lookup(Remap, OldVar, Var),
instmap_delta_set_var(Var, Inst, IMD0, IMD)
), VarInsts, !InstmapDelta).
%--------------------------------------------------------------------%
:- func loop_control_var_type = mer_type.
loop_control_var_type = defined_type(Sym, [], kind_star) :-
Sym = qualified(par_builtin_module_sym, "loop_control").
:- func loop_control_slot_var_type = mer_type.
loop_control_slot_var_type = builtin_type(builtin_type_int).
%----------------------------------------------------------------------------%
:- func lc_wait_free_slot_name = sym_name.
lc_wait_free_slot_name =
qualified(par_builtin_module_sym, lc_wait_free_slot_name_unqualified).
:- func lc_wait_free_slot_name_unqualified = string.
lc_wait_free_slot_name_unqualified = "lc_wait_free_slot".
:- pred get_lc_wait_free_slot_proc(module_info::in, pred_proc_id::out) is det.
get_lc_wait_free_slot_proc(ModuleInfo, proc(PredId, ProcId)) :-
lookup_lc_pred_proc(ModuleInfo, lc_wait_free_slot_name_unqualified, 2,
PredId, ProcId).
:- func lc_default_num_contexts_name_unqualified = string.
lc_default_num_contexts_name_unqualified = "lc_default_num_contexts".
:- func lc_default_num_contexts_name = sym_name.
lc_default_num_contexts_name =
qualified(par_builtin_module_sym,
lc_default_num_contexts_name_unqualified).
:- pred get_lc_default_num_contexts_proc(module_info::in, pred_id::out,
proc_id::out) is det.
get_lc_default_num_contexts_proc(ModuleInfo, PredId, ProcId) :-
lookup_lc_pred_proc(ModuleInfo, lc_default_num_contexts_name_unqualified,
1, PredId, ProcId).
:- func lc_create_name_unqualified = string.
lc_create_name_unqualified = "lc_create".
:- func lc_create_name = sym_name.
lc_create_name =
qualified(par_builtin_module_sym, lc_create_name_unqualified).
:- pred get_lc_create_proc(module_info::in, pred_id::out, proc_id::out) is det.
get_lc_create_proc(ModuleInfo, PredId, ProcId) :-
lookup_lc_pred_proc(ModuleInfo, lc_create_name_unqualified, 2, PredId,
ProcId).
:- func lc_join_and_terminate_name_unqualified = string.
lc_join_and_terminate_name_unqualified = "lc_join_and_terminate".
:- func lc_join_and_terminate_name = sym_name.
lc_join_and_terminate_name =
qualified(par_builtin_module_sym, lc_join_and_terminate_name_unqualified).
:- pred get_lc_join_and_terminate_proc(module_info::in, pred_proc_id::out)
is det.
get_lc_join_and_terminate_proc(ModuleInfo, proc(PredId, ProcId)) :-
lookup_lc_pred_proc(ModuleInfo, lc_join_and_terminate_name_unqualified, 2,
PredId, ProcId).
:- func lc_finish_loop_control_name_unqualified = string.
lc_finish_loop_control_name_unqualified = "lc_finish".
:- func lc_finish_loop_control_name = sym_name.
lc_finish_loop_control_name =
qualified(par_builtin_module_sym, lc_finish_loop_control_name_unqualified).
:- pred get_lc_finish_loop_control_proc(module_info::in,
pred_id::out, proc_id::out) is det.
get_lc_finish_loop_control_proc(ModuleInfo, PredId, ProcId) :-
lookup_lc_pred_proc(ModuleInfo, lc_finish_loop_control_name_unqualified, 1,
PredId, ProcId).
%----------------------------------------------------------------------------%
:- pred lookup_lc_pred_proc(module_info::in, string::in, arity::in,
pred_id::out, proc_id::out) is det.
lookup_lc_pred_proc(ModuleInfo, Sym, Arity, PredId, ProcId) :-
lookup_builtin_pred_proc_id(ModuleInfo, par_builtin_module_sym,
Sym, pf_predicate, Arity, only_mode, PredId, ProcId).
%----------------------------------------------------------------------------%
:- func par_builtin_module_sym = sym_name.
par_builtin_module_sym = unqualified("par_builtin").
%----------------------------------------------------------------------------%
:- end_module transform_hlds.par_loop_control.
%----------------------------------------------------------------------------%