%-----------------------------------------------------------------------------% % Copyright (C) 1999-2003 The University of Melbourne. % This file may only be copied under the terms of the GNU General % Public License - see the file COPYING in the Mercury distribution. %-----------------------------------------------------------------------------% % File: ml_tailcall.m % Main author: fjh % This module is an MLDS-to-MLDS transformation % that marks function calls as tail calls whenever % it is safe to do so, based on the assumptions described below. % This module also contains a pass over the MLDS that detects functions % which are directly recursive, but not tail-recursive, % and warns about them. % A function call can safely be marked as a tail call if % (1) it occurs in a position which would fall through into the % end of the function body or to a `return' statement, % and % (2) the lvalues in which the return value(s) from the `call' % will be placed are the same as the value(s) returned % by the `return', and these lvalues are all local variables. % and % (3) the function's local variables do not need to be live % for that call. % % For (2), we just assume (rather than checking) that % any variables returned by the `return' statement are % local variables. This assumption is true for the MLDS code % generated by ml_code_gen.m. % % For (3), we assume that the addresses of local variables % and nested functions are only ever passed down to other functions % (and used to assign to the local variable or to call the nested % function), so that here we only need to check if the potential % tail call uses such addresses, not whether such addresses were % taken in earlier calls. That is, if the addresses % of locals were taken in earlier calls from the same function, % we assume that these addresses will not be saved (on the heap, % or in global variables, etc.) and used after those earlier calls % have returned. This assumption is true for the MLDS code generated % by ml_code_gen.m. % % We just mark tailcalls in this module here. The actual tailcall % optimization (turn self-tailcalls into loops) is done in ml_optimize. % Individual backends may wish to treat tailcalls separately if there is % any backend support for them. % % Note that ml_call_gen.m will also mark calls to procedures with determinism % `erroneous' as `no_return_call's (a special case of tail calls) % when it generates them. %-----------------------------------------------------------------------------% :- module ml_backend__ml_tailcall. :- interface. :- import_module ml_backend__mlds. :- import_module io. % Traverse the MLDS, marking all optimizable tail calls % as tail calls. % :- pred ml_mark_tailcalls(mlds, mlds, io__state, io__state). :- mode ml_mark_tailcalls(in, out, di, uo) is det. % Traverse the MLDS, warning about all directly recursive calls % that are not marked as tail calls. % :- pred ml_warn_tailcalls(mlds, io__state, io__state). :- mode ml_warn_tailcalls(in, di, uo) is det. %-----------------------------------------------------------------------------% %-----------------------------------------------------------------------------% :- implementation. :- import_module hlds__error_util. :- import_module hlds__hlds_out. :- import_module hlds__hlds_pred. :- import_module ml_backend__ml_util. :- import_module parse_tree__prog_data. :- import_module string, int, list, std_util. ml_mark_tailcalls(MLDS0, MLDS) --> { MLDS0 = mlds(ModuleName, ForeignCode, Imports, Defns0) }, { Defns = mark_tailcalls_in_defns(Defns0) }, { MLDS = mlds(ModuleName, ForeignCode, Imports, Defns) }. %-----------------------------------------------------------------------------% % The `at_tail' type indicates whether or not a subgoal % is at a tail position, i.e. is followed by a return % statement or the end of the function, and if so, % specifies the return values (if any) in the return statement. :- type at_tail == maybe(list(mlds__rval)). % The `locals' type contains a list of local definitions % which are in scope. :- type locals == list(local_defns). :- type local_defns ---> params(mlds__arguments) ; defns(mlds__defns) . %-----------------------------------------------------------------------------% % % mark_tailcalls_in_defns: % mark_tailcalls_in_defn: % Recursively process the definition(s), % marking each optimizable tail call in them as a tail call. % % mark_tailcalls_in_maybe_statement: % mark_tailcalls_in_statements: % mark_tailcalls_in_statement: % mark_tailcalls_in_stmt: % mark_tailcalls_in_case: % mark_tailcalls_in_default: % Recursively process the statement(s), % marking each optimizable tail call in them as a tail call. % The `AtTail' argument indicates whether or not this % construct is in a tail call position. % The `Locals' argument contains a list of the % local definitions which are in scope at this point. % :- func mark_tailcalls_in_defns(mlds__defns) = mlds__defns. mark_tailcalls_in_defns(Defns) = list__map(mark_tailcalls_in_defn, Defns). :- func mark_tailcalls_in_defn(mlds__defn) = mlds__defn. mark_tailcalls_in_defn(Defn0) = Defn :- Defn0 = mlds__defn(Name, Context, Flags, DefnBody0), ( DefnBody0 = mlds__function(PredProcId, Params, FuncBody0, Attributes), % % Compute the initial value of the `Locals' and % `AtTail' arguments. % Params = mlds__func_params(Args, RetTypes), Locals = [params(Args)], ( RetTypes = [] -> AtTail = yes([]) ; AtTail = no ), FuncBody = mark_tailcalls_in_function_body(FuncBody0, AtTail, Locals), DefnBody = mlds__function(PredProcId, Params, FuncBody, Attributes), Defn = mlds__defn(Name, Context, Flags, DefnBody) ; DefnBody0 = mlds__data(_, _, _), Defn = Defn0 ; DefnBody0 = mlds__class(ClassDefn0), ClassDefn0 = class_defn(Kind, Imports, BaseClasses, Implements, CtorDefns0, MemberDefns0), CtorDefns = mark_tailcalls_in_defns(CtorDefns0), MemberDefns = mark_tailcalls_in_defns(MemberDefns0), ClassDefn = class_defn(Kind, Imports, BaseClasses, Implements, CtorDefns, MemberDefns), DefnBody = mlds__class(ClassDefn), Defn = mlds__defn(Name, Context, Flags, DefnBody) ). :- func mark_tailcalls_in_function_body(function_body, at_tail, locals) = function_body. mark_tailcalls_in_function_body(external, _, _) = external. mark_tailcalls_in_function_body(defined_here(Statement0), AtTail, Locals) = defined_here(Statement) :- Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals). :- func mark_tailcalls_in_maybe_statement(maybe(mlds__statement), at_tail, locals) = maybe(mlds__statement). mark_tailcalls_in_maybe_statement(no, _, _) = no. mark_tailcalls_in_maybe_statement(yes(Statement0), AtTail, Locals) = yes(Statement) :- Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals). :- func mark_tailcalls_in_statements(mlds__statements, at_tail, locals) = mlds__statements. mark_tailcalls_in_statements([], _, _) = []. mark_tailcalls_in_statements([First0 | Rest0], AtTail, Locals) = [First | Rest] :- % % If the First statement is followed by a `return' % statement, then it is in a tailcall position. % If there are no statements after the first, then % the first statement is in a tail call position % iff the statement list is in a tail call position. % Otherwise, i.e. if the first statement is followed % by anything other than a `return' statement, then % the first statement is not in a tail call position. % ( Rest = [mlds__statement(return(ReturnVals), _) | _] -> FirstAtTail = yes(ReturnVals) ; Rest = [] -> FirstAtTail = AtTail ; FirstAtTail = no ), First = mark_tailcalls_in_statement(First0, FirstAtTail, Locals), Rest = mark_tailcalls_in_statements(Rest0, AtTail, Locals). :- func mark_tailcalls_in_statement(mlds__statement, at_tail, locals) = mlds__statement. mark_tailcalls_in_statement(Statement0, AtTail, Locals) = Statement :- Statement0 = mlds__statement(Stmt0, Context), Stmt = mark_tailcalls_in_stmt(Stmt0, AtTail, Locals), Statement = mlds__statement(Stmt, Context). :- func mark_tailcalls_in_stmt(mlds__stmt, at_tail, locals) = mlds__stmt. mark_tailcalls_in_stmt(Stmt0, AtTail, Locals) = Stmt :- ( % % Whenever we encounter a block statement, % we recursively mark tailcalls in any nested % functions defined in that block. % We also need to add any local definitions in that % block to the list of currently visible local % declarations before processing the statements % in that block. The statement list will be in a % tail position iff the block is in a tail position. % Stmt0 = block(Defns0, Statements0), Defns = mark_tailcalls_in_defns(Defns0), NewLocals = [defns(Defns) | Locals], Statements = mark_tailcalls_in_statements(Statements0, AtTail, NewLocals), Stmt = block(Defns, Statements) ; % % The statement in the body of a while loop is never % in a tail position. % Stmt0 = while(Rval, Statement0, Once), Statement = mark_tailcalls_in_statement(Statement0, no, Locals), Stmt = while(Rval, Statement, Once) ; % % Both the `then' and the `else' parts of an if-then-else % are in a tail position iff the if-then-else is in a % tail position. % Stmt0 = if_then_else(Cond, Then0, MaybeElse0), Then = mark_tailcalls_in_statement(Then0, AtTail, Locals), MaybeElse = mark_tailcalls_in_maybe_statement(MaybeElse0, AtTail, Locals), Stmt = if_then_else(Cond, Then, MaybeElse) ; % % All of the cases of a switch (including the default) % are in a tail position iff the switch is in a % tail position. % Stmt0 = switch(Type, Val, Range, Cases0, Default0), Cases = mark_tailcalls_in_cases(Cases0, AtTail, Locals), Default = mark_tailcalls_in_default(Default0, AtTail, Locals), Stmt = switch(Type, Val, Range, Cases, Default) ; Stmt0 = label(_), Stmt = Stmt0 ; Stmt0 = goto(_), Stmt = Stmt0 ; Stmt0 = computed_goto(_, _), Stmt = Stmt0 ; Stmt0 = call(Sig, Func, Obj, Args, ReturnLvals, CallKind0), % % check if we can mark this call as a tail call % ( CallKind0 = ordinary_call, % % we must be in a tail position % AtTail = yes(ReturnRvals), % % the values returned in this call must match % those returned by the `return' statement that % follows % match_return_vals(ReturnRvals, ReturnLvals), % % the call must not take the address of any % local variables or nested functions % check_maybe_rval(Obj, Locals), check_rvals(Args, Locals), % % the call must not be to a function nested within % this function % check_rval(Func, Locals) -> % mark this call as a tail call CallKind = tail_call, Stmt = call(Sig, Func, Obj, Args, ReturnLvals, CallKind) ; % leave this call unchanged Stmt = Stmt0 ) ; Stmt0 = return(_Rvals), Stmt = Stmt0 ; Stmt0 = do_commit(_Ref), Stmt = Stmt0 ; Stmt0 = try_commit(Ref, Statement0, Handler0), % % Both the statement inside a `try_commit' and the % handler are in tail call position iff the % `try_commit' statement is in a tail call position. % Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals), Handler = mark_tailcalls_in_statement(Handler0, AtTail, Locals), Stmt = try_commit(Ref, Statement, Handler) ; Stmt0 = atomic(_), Stmt = Stmt0 ). :- func mark_tailcalls_in_cases(list(mlds__switch_case), at_tail, locals) = list(mlds__switch_case). mark_tailcalls_in_cases([], _, _) = []. mark_tailcalls_in_cases([Case0 | Cases0], AtTail, Locals) = [Case | Cases] :- Case = mark_tailcalls_in_case(Case0, AtTail, Locals), Cases = mark_tailcalls_in_cases(Cases0, AtTail, Locals). :- func mark_tailcalls_in_case(mlds__switch_case, at_tail, locals) = mlds__switch_case. mark_tailcalls_in_case(Cond - Statement0, AtTail, Locals) = Cond - Statement :- Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals). :- func mark_tailcalls_in_default(mlds__switch_default, at_tail, locals) = mlds__switch_default. mark_tailcalls_in_default(default_do_nothing, _, _) = default_do_nothing. mark_tailcalls_in_default(default_is_unreachable, _, _) = default_is_unreachable. mark_tailcalls_in_default(default_case(Statement0), AtTail, Locals) = default_case(Statement) :- Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals). %-----------------------------------------------------------------------------% % % match_return_vals(Rvals, Lvals): % match_return_val(Rval, Lval): % Check that the Lval(s) returned by a call match % the Rval(s) in the `return' statement that follows, % and those Lvals are local variables % (so that assignments to them won't have any side effects), % so that we can optimize the call into a tailcall. % :- pred match_return_vals(list(mlds__rval), list(mlds__lval)). :- mode match_return_vals(in, in) is semidet. match_return_vals([], []). match_return_vals([Rval|Rvals], [Lval|Lvals]) :- match_return_val(Rval, Lval), match_return_vals(Rvals, Lvals). :- pred match_return_val(mlds__rval, mlds__lval). :- mode match_return_val(in, in) is semidet. match_return_val(lval(Lval), Lval) :- lval_is_local(Lval). :- pred lval_is_local(mlds__lval). :- mode lval_is_local(in) is semidet. lval_is_local(var(_, _)) :- % We just assume it is local. (This assumption is % true for the code generated by ml_code_gen.m.) true. lval_is_local(field(_Tag, Rval, _Field, _, _)) :- % a field of a local variable is local ( Rval = mem_addr(Lval) -> lval_is_local(Lval) ; fail ). lval_is_local(mem_ref(_Rval, _Type)) :- fail. %-----------------------------------------------------------------------------% % % check_rvals: % check_maybe_rval: % check_rval: % Fail if the specified rval(s) might evaluate to the addresses of % local variables (or fields of local variables) or nested functions. % :- pred check_rvals(list(mlds__rval), locals). :- mode check_rvals(in, in) is semidet. check_rvals([], _). check_rvals([Rval|Rvals], Locals) :- check_rval(Rval, Locals), check_rvals(Rvals, Locals). :- pred check_maybe_rval(maybe(mlds__rval), locals). :- mode check_maybe_rval(in, in) is semidet. check_maybe_rval(no, _). check_maybe_rval(yes(Rval), Locals) :- check_rval(Rval, Locals). :- pred check_rval(mlds__rval, locals). :- mode check_rval(in, in) is semidet. check_rval(lval(_Lval), _) :- % Passing the _value_ of an lval is fine. true. check_rval(mkword(_Tag, Rval), Locals) :- check_rval(Rval, Locals). check_rval(const(Const), Locals) :- check_const(Const, Locals). check_rval(unop(_Op, Rval), Locals) :- check_rval(Rval, Locals). check_rval(binop(_Op, X, Y), Locals) :- check_rval(X, Locals), check_rval(Y, Locals). check_rval(mem_addr(Lval), Locals) :- % Passing the address of an lval is a problem, % if that lval names a local variable. check_lval(Lval, Locals). % % check_lval: % Fail if the specified lval might be a local variable % (or a field of a local variable). % :- pred check_lval(mlds__lval, locals). :- mode check_lval(in, in) is semidet. check_lval(field(_MaybeTag, Rval, _FieldId, _, _), Locals) :- check_rval(Rval, Locals). check_lval(mem_ref(_, _), _) :- % We assume that the addresses of local variables are only % ever passed down to other functions, or assigned to, % so a mem_ref lval can never refer to a local variable. true. check_lval(var(Var0, _), Locals) :- \+ var_is_local(Var0, Locals). % % check_const: % Fail if the specified const might be the address of a % local variable or nested function. % % The addresses of local variables are probably % not consts, at least not unless those variables are % declared as static (i.e. `one_copy'), % so it might be safe to allow all data_addr_consts here, % but currently we just take a conservative approach. % :- pred check_const(mlds__rval_const, locals). :- mode check_const(in, in) is semidet. check_const(Const, Locals) :- ( Const = code_addr_const(CodeAddr) -> \+ function_is_local(CodeAddr, Locals) ; Const = data_addr_const(DataAddr) -> DataAddr = data_addr(ModuleName, DataName), ( DataName = var(VarName) -> \+ var_is_local(qual(ModuleName, VarName), Locals) ; true ) ; true ). % % var_is_local: % Check whether the specified variable is defined locally, % i.e. in storage that might no longer exist when the function % returns or does a tail call. % % It would be safe to fail for variables declared static % (i.e. `one_copy'), but currently we just take a conservative % approach. % :- pred var_is_local(mlds__var, locals). :- mode var_is_local(in, in) is semidet. var_is_local(Var, Locals) :- % XXX we ignore the ModuleName -- % that is safe, but overly conservative Var = qual(_ModuleName, VarName), some [Local] ( locals_member(Local, Locals), Local = data(var(VarName)) ). % % function_is_local: % Check whether the specified function is defined locally % (i.e. as a nested function). % :- pred function_is_local(mlds__code_addr, locals). :- mode function_is_local(in, in) is semidet. function_is_local(CodeAddr, Locals) :- ( CodeAddr = proc(QualifiedProcLabel, _Sig), MaybeSeqNum = no ; CodeAddr = internal(QualifiedProcLabel, SeqNum, _Sig), MaybeSeqNum = yes(SeqNum) ), % XXX we ignore the ModuleName -- % that is safe, but might be overly conservative QualifiedProcLabel = qual(_ModuleName, ProcLabel), ProcLabel = PredLabel - ProcId, some [Local] ( locals_member(Local, Locals), Local = function(PredLabel, ProcId, MaybeSeqNum, _PredId) ). % % locals_member(Name, Locals): % Nondeterministically enumerates the names of all the entities % in Locals. % :- pred locals_member(mlds__entity_name, locals). :- mode locals_member(out, in) is nondet. locals_member(Name, LocalsList) :- list__member(Locals, LocalsList), ( Locals = defns(Defns), list__member(Defn, Defns), Defn = mlds__defn(Name, _, _, _) ; Locals = params(Params), list__member(Param, Params), Param = mlds__argument(Name, _, _) ). %-----------------------------------------------------------------------------% ml_warn_tailcalls(MLDS) --> { solutions(nontailcall_in_mlds(MLDS), Warnings) }, list__foldl(report_nontailcall_warning, Warnings). :- type tailcall_warning ---> tailcall_warning( mlds__pred_label, proc_id, mlds__context ). :- pred nontailcall_in_mlds(mlds::in, tailcall_warning::out) is nondet. nontailcall_in_mlds(MLDS, Warning) :- MLDS = mlds(ModuleName, _ForeignCode, _Imports, Defns), MLDS_ModuleName = mercury_module_name_to_mlds(ModuleName), nontailcall_in_defns(MLDS_ModuleName, Defns, Warning). :- pred nontailcall_in_defns(mlds_module_name::in, mlds__defns::in, tailcall_warning::out) is nondet. nontailcall_in_defns(ModuleName, Defns, Warning) :- list__member(Defn, Defns), nontailcall_in_defn(ModuleName, Defn, Warning). :- pred nontailcall_in_defn(mlds_module_name::in, mlds__defn::in, tailcall_warning::out) is nondet. nontailcall_in_defn(ModuleName, Defn, Warning) :- Defn = mlds__defn(Name, _Context, _Flags, DefnBody), ( DefnBody = mlds__function(_PredProcId, _Params, FuncBody, _Attributes), FuncBody = defined_here(Body), nontailcall_in_statement(ModuleName, Name, Body, Warning) ; DefnBody = mlds__class(ClassDefn), ClassDefn = class_defn(_Kind, _Imports, _BaseClasses, _Implements, CtorDefns, MemberDefns), ( nontailcall_in_defns(ModuleName, CtorDefns, Warning) ; nontailcall_in_defns(ModuleName, MemberDefns, Warning) ) ). :- pred nontailcall_in_statement(mlds_module_name::in, mlds__entity_name::in, mlds__statement::in, tailcall_warning::out) is nondet. nontailcall_in_statement(CallerModule, CallerFuncName, Statement, Warning) :- % nondeterministically find a non-tail call statement_contains_statement(Statement, SubStatement), SubStatement = mlds__statement(SubStmt, Context), SubStmt = call(_CallSig, Func, _This, _Args, _RetVals, CallKind), CallKind = ordinary_call, % check if this call is a directly recursive call Func = const(code_addr_const(CodeAddr)), ( CodeAddr = proc(QualProcLabel, _Sig), MaybeSeqNum = no ; CodeAddr = internal(QualProcLabel, SeqNum, _Sig), MaybeSeqNum = yes(SeqNum) ), QualProcLabel = qual(CallerModule, PredLabel - ProcId), CallerFuncName = function(PredLabel, ProcId, MaybeSeqNum, _PredId), % if so, construct an appropriate warning Warning = tailcall_warning(PredLabel, ProcId, Context). :- pred report_nontailcall_warning(tailcall_warning::in, io__state::di, io__state::uo) is det. report_nontailcall_warning(tailcall_warning(PredLabel, ProcId, Context)) --> ( { PredLabel = pred(PredOrFunc, _MaybeModule, Name, Arity, _CodeModel, _NonOutputFunc) }, { hlds_out__simple_call_id_to_string(PredOrFunc - unqualified(Name) / Arity, CallId) }, { proc_id_to_int(ProcId, ProcNumber0) }, { ProcNumber = ProcNumber0 + 1 }, { ProcNumberStr = string__int_to_string(ProcNumber) }, report_warning(mlds__get_prog_context(Context), 0, [ words("In mode number"), words(ProcNumberStr), words("of"), fixed(CallId ++ ":"), nl, words(" warning: recursive call is not tail recursive.") ]) ; { PredLabel = special_pred(_, _, _, _) } % don't warn about these ). %-----------------------------------------------------------------------------%