%----------------------------------------------------------------------------% % vim: ft=mercury ts=4 sw=4 et %----------------------------------------------------------------------------% % Copyright (C) 2023-2024 The Mercury team. % This file may only be copied under the terms of the GNU General % Public License - see the file COPYING in the Mercury distribution. %----------------------------------------------------------------------------% % % File: split_switch_arms.m. % % This module implements a program transformation that replaces code like this: % % ( % ( X = f1 % ; X = f2 % ; X = f3 % ; X = f4 % ), % , % ( % X = f1, % % ; % ( X = f2 % ; X = f3 % ), % % ; % X = f4, % % ), % % ; % X = f5, % % ) % % with code like this: % ( % X = f1, % , % , % % ; % ( X = f2 % ; X = f3 % ), % , % , % % ; % X = f4, % , % , % % ; % X = f5, % % ) % % The idea is that if inside a switch on a variable, there are other switches % on that same variable that subdivide the set of cons_ids even further, then % we do the following. % % - We partition the cons_ids of the top-level switch arm to make all the % distinctions between its cons_ids that any of the (one or more) switches % on the same variable inside the top-level switch arm make, % % In this case, we would partition {f1,f2,f3,f4} into three sets: % {f1}, {f2,f3} and {f4}. % % - We replace the original switch arm with a separate switch arm for each % partition. % % In this case, this would replace the arm for {f1,f2,f3,f4} % with three arms for {f1}, {f2,f3} and {f4} respectively. % The three arms would initially contain the same code as the original % switch arm. This would mean that each would contain the whole of % nested switch on X with the three arms for {f1}, {f2,f3} and {f4}. % % - We restrict the goal in each of the resulting switch arms to refer to % only the cons_ids of its own partition in any nested switches on the % original switched-on variable, and we then simplify away any trivial % switches generated by this process. % % In this case, in the new {f1}-only arm of the top switch, would replace % the inner switch on X with the arms for {f1}, {f2,f3} and {f4} % with just the arm for {f1}, and then optimize away the now-unnecessary % switch wrapper around the goal inside that arm. We would do the same % with the other two arms. % % The code resulting from the above steps will include some code duplication % (most of the pieces of code denoted by in the example % above would be duplicated), but it will need to execute fewer transfers % of control. This is worthwhile because % % - the branch instructions used to implement switches are hard to predict % unless most paths through the nested switches are rarely if ever taken, and % % - the pipeline breaks caused by branches that are not correctly predicted % are one of the two major contributors to the runtime of Mercury programs. % (The other major contributors are data cache misses.) % % Nevertheless, the main point of this transformation is not performance, % since the transformation performed by this module could also be done % manually by humans. Instead, the main point is that having the transformation % done by a machine is better from a software engineering point of view. % One software engineering advantage is that the compiler can do the % transformation more quickly, more cheaply, and more reliably than % programmers can. Another advantage is that it allows programmers to maintain % the pre-transform version of the code. The pre-transform version may be % clearer and easier to maintain than the post-transform version, due to % the code duplication required by the transform. % %----------------------------------------------------------------------------% :- module check_hlds.simplify.split_switch_arms. :- interface. :- import_module check_hlds.simplify.simplify_info. :- import_module hlds. :- import_module hlds.hlds_goal. :- import_module set. % split_switch_arms_in_goal(ToSplitArms, Goal0, Goal): % % This predicate looks for the top-level switch arms listed in ToSplitArms, % and when it finds them, it calls % % - the predicates that partition the cons_ids of the arm % as described above, and % % - the predicates that replace the one original arm with several arms, % one for each partition. % :- pred split_switch_arms_in_goal(set(switch_arm)::in, hlds_goal::in, hlds_goal::out) is det. %----------------------------------------------------------------------------% :- implementation. :- import_module hlds.make_goal. :- import_module parse_tree. :- import_module parse_tree.prog_data. :- import_module cord. :- import_module list. :- import_module require. split_switch_arms_in_goal(ToSplitArms, Goal0, Goal) :- Goal0 = hlds_goal(GoalExpr0, GoalInfo0), ( ( GoalExpr0 = unify(_, _, _, _, _) ; GoalExpr0 = plain_call(_, _, _, _, _, _) ; GoalExpr0 = generic_call(_, _, _, _, _) ; GoalExpr0 = call_foreign_proc(_, _, _, _, _, _, _) ), GoalExpr = GoalExpr0 ; GoalExpr0 = switch(Var, CanFail, Cases0), split_switch_arms_in_cases(ToSplitArms, Var, Cases0, cord.init, CasesCord), Cases = cord.list(CasesCord), GoalExpr = switch(Var, CanFail, Cases) ; GoalExpr0 = conj(ConjType, Goals0), split_switch_arms_in_goals(ToSplitArms, Goals0, Goals), GoalExpr = conj(ConjType, Goals) ; GoalExpr0 = disj(Goals0), split_switch_arms_in_goals(ToSplitArms, Goals0, Goals), GoalExpr = disj(Goals) ; GoalExpr0 = if_then_else(Vars, Cond0, Then0, Else0), split_switch_arms_in_goal(ToSplitArms, Cond0, Cond), split_switch_arms_in_goal(ToSplitArms, Then0, Then), split_switch_arms_in_goal(ToSplitArms, Else0, Else), GoalExpr = if_then_else(Vars, Cond, Then, Else) ; GoalExpr0 = negation(SubGoal0), split_switch_arms_in_goal(ToSplitArms, SubGoal0, SubGoal), GoalExpr = negation(SubGoal) ; GoalExpr0 = scope(Reason, SubGoal0), split_switch_arms_in_goal(ToSplitArms, SubGoal0, SubGoal), GoalExpr = scope(Reason, SubGoal) ; GoalExpr0 = shorthand(ShortHand0), ( ShortHand0 = atomic_goal(GoalType, Outer, Inner, MaybeOutputVars, MainGoal0, OrElseGoals0, OrElseInners), split_switch_arms_in_goal(ToSplitArms, MainGoal0, MainGoal), split_switch_arms_in_goals(ToSplitArms, OrElseGoals0, OrElseGoals), ShortHand = atomic_goal(GoalType, Outer, Inner, MaybeOutputVars, MainGoal, OrElseGoals, OrElseInners) ; ShortHand0 = try_goal(_, _, _), % These should have been expanded out by now. unexpected($pred, "try_goal") ; ShortHand0 = bi_implication(_, _), % These should have been expanded out by now. unexpected($pred, "bi_implication") ), GoalExpr = shorthand(ShortHand) ), Goal = hlds_goal(GoalExpr, GoalInfo0). :- pred split_switch_arms_in_goals(set(switch_arm)::in, list(hlds_goal)::in, list(hlds_goal)::out) is det. split_switch_arms_in_goals(_ToSplitArms, [], []). split_switch_arms_in_goals(ToSplitArms, [Goal0 | Goals0], [Goal | Goals]) :- split_switch_arms_in_goal(ToSplitArms, Goal0, Goal), split_switch_arms_in_goals(ToSplitArms, Goals0, Goals). :- pred split_switch_arms_in_cases(set(switch_arm)::in, prog_var::in, list(case)::in, cord(case)::in, cord(case)::out) is det. split_switch_arms_in_cases(_ToSplitArms, _Var, [], !CasesCord). split_switch_arms_in_cases(ToSplitArms, Var, [Case0 | Cases0], !CasesCord) :- split_switch_arms_in_case(ToSplitArms, Var, Case0, !CasesCord), split_switch_arms_in_cases(ToSplitArms, Var, Cases0, !CasesCord). %---------------------------------------------------------------------------% % split_switch_arms_in_case(ToSplitArms, Var, Case0, !CasesCord): % % This is the active part of the split_switch_arms_in_X family % of predicates; the rest of the code in that family is needed % only to let us arrive here. % % If this case represents a switch arm that we should split, then % % - find out how its cons_ids should be partitioned, % - perform the partitioning itself into several switch arms, and % - restrict the code each resulting switch arms to refer to only % the cons_id in its own partition. % % The first step is done by partition_cons_ids_for_var_in_goal, % the second and third are done by gather_partition_restricted_cases. % :- pred split_switch_arms_in_case(set(switch_arm)::in, prog_var::in, case::in, cord(case)::in, cord(case)::out) is det. split_switch_arms_in_case(ToSplitArms, Var, Case0, !CasesCord) :- Case0 = case(MainConsId, OtherConsIds, Goal0), set.list_to_set([MainConsId | OtherConsIds], ConsIds), Arm = switch_arm(Var, ConsIds), ( if set.contains(ToSplitArms, Arm) then % Split up any switch arms that we can split up on *other* variables. % It is better to do so before this arm is split up, because % we can do so once and copy the results, rather than having to % repeat the same work N times if Partitions below has N elements. % % The switch arms on Var itself inside Goal0 will be broken up % as finely as possibly by gather_partition_restricted_cases below. set.filter(switch_arm_is_on_var(Var), ToSplitArms, _ToSplitArmsOnVar, ToSplitArmsNotOnVar), ( if set.is_empty(ToSplitArmsNotOnVar) then Goal = Goal0 else split_switch_arms_in_goal(ToSplitArmsNotOnVar, Goal0, Goal) ), % Start with the initial set of partitions containing only one set, % the set of all cons_ids in this arm. Then refine this partitioning % based on the switches on Var found in Goal0. Partitions0 = set.make_singleton_set(ConsIds), partition_cons_ids_for_var_in_goal(Var, Goal, Partitions0, Partitions), ( if set.count(Partitions, 1) then % Normally, an inner switch on Var inside an arm of a outer switch % on Var makes distinctions that the outer switch does not, % but it does not *have* to, and in this case, it doesn't. % In this case, there is no basis on which to split up this arm. % % This code gets the same result as the else part would, % only faster. Case = case(MainConsId, OtherConsIds, Goal), cord.snoc(Case, !CasesCord) else % Replace the original Case with one case for each partition in % Partitions. set.fold(gather_partition_restricted_cases(Var, Goal), Partitions, !CasesCord) ) else split_switch_arms_in_goal(ToSplitArms, Goal0, Goal), Case = case(MainConsId, OtherConsIds, Goal), cord.snoc(Case, !CasesCord) ). :- pred switch_arm_is_on_var(prog_var::in, switch_arm::in) is semidet. switch_arm_is_on_var(Var, Arm) :- Arm = switch_arm(ArmVar, _), ArmVar = Var. %---------------------------------------------------------------------------% % partition_cons_ids_for_var_in_goal(Var, Goal, !Partitions): % % Refine !Partitions based on all the distinctions made between its % cons_ids by switches on Var inside Goal. % :- pred partition_cons_ids_for_var_in_goal(prog_var::in, hlds_goal::in, set(set(cons_id))::in, set(set(cons_id))::out) is det. partition_cons_ids_for_var_in_goal(Var, Goal, !Partitions) :- Goal = hlds_goal(GoalExpr, _GoalInfo), ( ( GoalExpr = unify(_, _, _, _, _) ; GoalExpr = plain_call(_, _, _, _, _, _) ; GoalExpr = generic_call(_, _, _, _, _) ; GoalExpr = call_foreign_proc(_, _, _, _, _, _, _) ) ; GoalExpr = switch(SwitchVar, _CanFail, Cases), ( if Var = SwitchVar then partition_cons_ids_of_cases(Cases, !Partitions) else true ), partition_cons_ids_for_var_in_cases(Var, Cases, !Partitions) ; GoalExpr = conj(_ConjType, Goals), partition_cons_ids_for_var_in_goals(Var, Goals, !Partitions) ; GoalExpr = disj(Goals), partition_cons_ids_for_var_in_goals(Var, Goals, !Partitions) ; GoalExpr = if_then_else(_Vars, Cond, Then, Else), partition_cons_ids_for_var_in_goal(Var, Cond, !Partitions), partition_cons_ids_for_var_in_goal(Var, Then, !Partitions), partition_cons_ids_for_var_in_goal(Var, Else, !Partitions) ; GoalExpr = negation(SubGoal), partition_cons_ids_for_var_in_goal(Var, SubGoal, !Partitions) ; GoalExpr = scope(_Reason, SubGoal), partition_cons_ids_for_var_in_goal(Var, SubGoal, !Partitions) ; GoalExpr = shorthand(ShortHand), ( ShortHand = atomic_goal(_GoalType, _Outer, _Inner, _MaybeOutputVars, MainGoal, OrElseGoals, _OrElseInners), partition_cons_ids_for_var_in_goal(Var, MainGoal, !Partitions), partition_cons_ids_for_var_in_goals(Var, OrElseGoals, !Partitions) ; ShortHand = try_goal(_, _, _), % These should have been expanded out by now. unexpected($pred, "try_goal") ; ShortHand = bi_implication(_, _), % These should have been expanded out by now. unexpected($pred, "bi_implication") ) ). :- pred partition_cons_ids_for_var_in_goals(prog_var::in, list(hlds_goal)::in, set(set(cons_id))::in, set(set(cons_id))::out) is det. partition_cons_ids_for_var_in_goals(_Var, [], !Partitions). partition_cons_ids_for_var_in_goals(Var, [Goal | Goals], !Partitions) :- partition_cons_ids_for_var_in_goal(Var, Goal, !Partitions), partition_cons_ids_for_var_in_goals(Var, Goals, !Partitions). :- pred partition_cons_ids_for_var_in_cases(prog_var::in, list(case)::in, set(set(cons_id))::in, set(set(cons_id))::out) is det. partition_cons_ids_for_var_in_cases(_Var, [], !Partitions). partition_cons_ids_for_var_in_cases(Var, [Case | Cases], !Partitions) :- Case = case(_MainConsId, _OtherConsIds, Goal), partition_cons_ids_for_var_in_goal(Var, Goal, !Partitions), partition_cons_ids_for_var_in_cases(Var, Cases, !Partitions). %---------------------------------------------------------------------------% % partition_cons_ids_of_cases(Cases, !Partitions): % % Refine !Partitions by making all the distinctions between its cons_ids % that Cases makes. In other words, if two cons_ids are in different cases, % then ensure that those two cons_ids are in different partitions in % !:Partitions, even if they were in the same partition in !.Partitions. % :- pred partition_cons_ids_of_cases(list(case)::in, set(set(cons_id))::in, set(set(cons_id))::out) is det. partition_cons_ids_of_cases([], !Partitions). partition_cons_ids_of_cases([Case | Cases], !Partitions) :- Case = case(MainConsId, OtherConsIds, _Goal), set.list_to_set([MainConsId | OtherConsIds], ArmConsIds), set.fold(add_in_and_or_out_cons_ids(ArmConsIds), !.Partitions, set.init, !:Partitions), partition_cons_ids_of_cases(Cases, !Partitions). :- pred add_in_and_or_out_cons_ids(set(cons_id)::in,set(cons_id)::in, set(set(cons_id))::in, set(set(cons_id))::out) is det. add_in_and_or_out_cons_ids(ArmConsIds, Partition0, !Partitions) :- divide_by_set(ArmConsIds, Partition0, PartitionIn, PartitionOut), % Since Partition0 has to be nonempty, at least one of Partition{In,Out} % also has to be nonempty, but both can be nonempty, and the partitioning % process makes progress only if they both *are* nonempty. ( if set.is_empty(PartitionIn) then true else set.insert(PartitionIn, !Partitions) ), ( if set.is_empty(PartitionOut) then true else set.insert(PartitionOut, !Partitions) ). %---------------------------------------------------------------------------% % gather_partition_restricted_cases(Var, Goal0, Partition, !CasesCord): % % Restrict Goal0 to refer only to the cons_ids in Partition in its % switches on Var, and add the resulting goal as a switch arm % to !CasesCord for the cons_ids in Partition. % :- pred gather_partition_restricted_cases(prog_var::in, hlds_goal::in, set(cons_id)::in, cord(case)::in, cord(case)::out) is det. gather_partition_restricted_cases(Var, Goal0, Partition, !CasesCord) :- set.to_sorted_list(Partition, ConsIds), ( ConsIds = [], unexpected($pred, "Partition is empty") ; ConsIds = [MainConsId | OtherConsIds], restrict_switches_on_var_in_goal(Var, Partition, Goal0, Goal), Case = case(MainConsId, OtherConsIds, Goal), cord.snoc(Case, !CasesCord) ). % restrict_switches_on_var_in_goal(Var, ConsIds, !Goal): % % Restrict any switches on Var in !.Goal to refer to only the cons_ids % in ConsIds in it cases. % :- pred restrict_switches_on_var_in_goal(prog_var::in, set(cons_id)::in, hlds_goal::in, hlds_goal::out) is det. restrict_switches_on_var_in_goal(Var, Partition, Goal0, Goal) :- Goal0 = hlds_goal(GoalExpr0, GoalInfo0), ( ( GoalExpr0 = unify(_, _, _, _, _) ; GoalExpr0 = plain_call(_, _, _, _, _, _) ; GoalExpr0 = generic_call(_, _, _, _, _) ; GoalExpr0 = call_foreign_proc(_, _, _, _, _, _, _) ), GoalExpr = GoalExpr0 ; GoalExpr0 = switch(SwitchVar, CanFail, Cases0), ( if Var = SwitchVar then restrict_switches_on_var_in_restricted_cases(Var, Partition, Cases0, cord.init, CasesCord), Cases = cord.list(CasesCord), % This optimization effectively deletes the switch construct, % leaving either one of its arms, or none. ( Cases = [], GoalExpr = fail_goal_expr ; Cases = [Case], Case = case(_, _, CaseGoal), CaseGoal = hlds_goal(GoalExpr, _) ; Cases = [_, _ | _], % Every distinction made between the cons_ids of Cases % should have been made by the top-level switch on Var % found by split_switch_arms_in_goal, and each case % that replaced that original top-level switch arm % should be for a set of cons_ids that wholly fits % within the cons_ids of ONE case. If execution gets here, % then something in the above process must have gone wrong. unexpected($pred, "Cases = [_, _ | _]") ) else restrict_switches_on_var_in_cases(Var, Partition, Cases0, Cases), GoalExpr = switch(SwitchVar, CanFail, Cases) ) ; GoalExpr0 = conj(ConjType, Goals0), ( ConjType = plain_conj, restrict_switches_on_var_in_conj(Var, Partition, Goals0, Goals) ; ConjType = parallel_conj, restrict_switches_on_var_in_goals(Var, Partition, Goals0, Goals) ), GoalExpr = conj(ConjType, Goals) ; GoalExpr0 = disj(Goals0), restrict_switches_on_var_in_goals(Var, Partition, Goals0, Goals), GoalExpr = disj(Goals) ; GoalExpr0 = if_then_else(Vars, Cond0, Then0, Else0), restrict_switches_on_var_in_goal(Var, Partition, Cond0, Cond), restrict_switches_on_var_in_goal(Var, Partition, Then0, Then), restrict_switches_on_var_in_goal(Var, Partition, Else0, Else), GoalExpr = if_then_else(Vars, Cond, Then, Else) ; GoalExpr0 = negation(SubGoal0), restrict_switches_on_var_in_goal(Var, Partition, SubGoal0, SubGoal), GoalExpr = negation(SubGoal) ; GoalExpr0 = scope(Reason, SubGoal0), restrict_switches_on_var_in_goal(Var, Partition, SubGoal0, SubGoal), GoalExpr = scope(Reason, SubGoal) ; GoalExpr0 = shorthand(ShortHand0), ( ShortHand0 = atomic_goal(GoalType, Outer, Inner, MaybeOutputVars, MainGoal0, OrElseGoals0, OrElseInners), restrict_switches_on_var_in_goal(Var, Partition, MainGoal0, MainGoal), restrict_switches_on_var_in_goals(Var, Partition, OrElseGoals0, OrElseGoals), ShortHand = atomic_goal(GoalType, Outer, Inner, MaybeOutputVars, MainGoal, OrElseGoals, OrElseInners) ; ShortHand0 = try_goal(_, _, _), % These should have been expanded out by now. unexpected($pred, "try_goal") ; ShortHand0 = bi_implication(_, _), % These should have been expanded out by now. unexpected($pred, "bi_implication") ), GoalExpr = shorthand(ShortHand) ), Goal = hlds_goal(GoalExpr, GoalInfo0). :- pred restrict_switches_on_var_in_conj(prog_var::in, set(cons_id)::in, list(hlds_goal)::in, list(hlds_goal)::out) is det. restrict_switches_on_var_in_conj(_Var, _Partition, [], []). restrict_switches_on_var_in_conj(Var, Partition, [Goal0 | Goals0], Goals) :- restrict_switches_on_var_in_goal(Var, Partition, Goal0, HeadGoal), restrict_switches_on_var_in_conj(Var, Partition, Goals0, TailGoals), % Flatten sequential conjunctions. ( if HeadGoal = hlds_goal(conj(plain_conj, HeadConjuncts), _GoalInfo) then Goals = HeadConjuncts ++ TailGoals else Goals = [HeadGoal | TailGoals] ). :- pred restrict_switches_on_var_in_goals(prog_var::in, set(cons_id)::in, list(hlds_goal)::in, list(hlds_goal)::out) is det. restrict_switches_on_var_in_goals(_Var, _Partition, [], []). restrict_switches_on_var_in_goals(Var, Partition, [Goal0 | Goals0], [Goal | Goals]) :- restrict_switches_on_var_in_goal(Var, Partition, Goal0, Goal), restrict_switches_on_var_in_goals(Var, Partition, Goals0, Goals). :- pred restrict_switches_on_var_in_cases(prog_var::in, set(cons_id)::in, list(case)::in, list(case)::out) is det. restrict_switches_on_var_in_cases(_Var, _Partition, [], []). restrict_switches_on_var_in_cases(Var, Partition, [Case0 | Cases0], [Case | Cases]) :- Case0 = case(MainConsId, OtherPartition, Goal0), restrict_switches_on_var_in_goal(Var, Partition, Goal0, Goal), Case = case(MainConsId, OtherPartition, Goal), restrict_switches_on_var_in_cases(Var, Partition, Cases0, Cases). :- pred restrict_switches_on_var_in_restricted_cases(prog_var::in, set(cons_id)::in, list(case)::in, cord(case)::in, cord(case)::out) is det. restrict_switches_on_var_in_restricted_cases(_Var, _Partition, [], !CasesCord). restrict_switches_on_var_in_restricted_cases(Var, Partition, [Case0 | Cases0], !CasesCord) :- Case0 = case(MainConsId0, OtherConsIds0, Goal0), set.list_to_set([MainConsId0 | OtherConsIds0], ArmConsIds0), set.intersect(Partition, ArmConsIds0, ArmConsIds), set.to_sorted_list(ArmConsIds, ArmConsIdsList), ( ArmConsIdsList = [] % Since none of the cons_ids of this case are allowed by the partition % of the top switch arm for Var that we are in, delete this case % by not adding it to !CasesCord. ; ArmConsIdsList = [MainConsId | OtherConsIds], restrict_switches_on_var_in_goal(Var, Partition, Goal0, Goal), Case = case(MainConsId, OtherConsIds, Goal), cord.snoc(Case, !CasesCord) ), restrict_switches_on_var_in_restricted_cases(Var, Partition, Cases0, !CasesCord). %---------------------------------------------------------------------------% :- end_module check_hlds.simplify.split_switch_arms. %---------------------------------------------------------------------------%