Files
mercury/compiler/ml_tailcall.m
Zoltan Somogyi 9551640f55 Import only one compiler module per line. Sort the blocks of imports.
Estimated hours taken: 2
Branches: main

compiler/*.m:
	Import only one compiler module per line. Sort the blocks of imports.
	This makes it easier to merge in changes.

	In a couple of places, remove unnecessary imports.
2003-03-15 03:09:14 +00:00

659 lines
21 KiB
Mathematica

%-----------------------------------------------------------------------------%
% Copyright (C) 1999-2003 The University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%-----------------------------------------------------------------------------%
% File: ml_tailcall.m
% Main author: fjh
% This module is an MLDS-to-MLDS transformation
% that marks function calls as tail calls whenever
% it is safe to do so, based on the assumptions described below.
% This module also contains a pass over the MLDS that detects functions
% which are directly recursive, but not tail-recursive,
% and warns about them.
% A function call can safely be marked as a tail call if
% (1) it occurs in a position which would fall through into the
% end of the function body or to a `return' statement,
% and
% (2) the lvalues in which the return value(s) from the `call'
% will be placed are the same as the value(s) returned
% by the `return', and these lvalues are all local variables.
% and
% (3) the function's local variables do not need to be live
% for that call.
%
% For (2), we just assume (rather than checking) that
% any variables returned by the `return' statement are
% local variables. This assumption is true for the MLDS code
% generated by ml_code_gen.m.
%
% For (3), we assume that the addresses of local variables
% and nested functions are only ever passed down to other functions
% (and used to assign to the local variable or to call the nested
% function), so that here we only need to check if the potential
% tail call uses such addresses, not whether such addresses were
% taken in earlier calls. That is, if the addresses
% of locals were taken in earlier calls from the same function,
% we assume that these addresses will not be saved (on the heap,
% or in global variables, etc.) and used after those earlier calls
% have returned. This assumption is true for the MLDS code generated
% by ml_code_gen.m.
%
% We just mark tailcalls in this module here. The actual tailcall
% optimization (turn self-tailcalls into loops) is done in ml_optimize.
% Individual backends may wish to treat tailcalls separately if there is
% any backend support for them.
%
% Note that ml_call_gen.m will also mark calls to procedures with determinism
% `erroneous' as `no_return_call's (a special case of tail calls)
% when it generates them.
%-----------------------------------------------------------------------------%
:- module ml_backend__ml_tailcall.
:- interface.
:- import_module ml_backend__mlds.
:- import_module io.
% Traverse the MLDS, marking all optimizable tail calls
% as tail calls.
%
:- pred ml_mark_tailcalls(mlds, mlds, io__state, io__state).
:- mode ml_mark_tailcalls(in, out, di, uo) is det.
% Traverse the MLDS, warning about all directly recursive calls
% that are not marked as tail calls.
%
:- pred ml_warn_tailcalls(mlds, io__state, io__state).
:- mode ml_warn_tailcalls(in, di, uo) is det.
%-----------------------------------------------------------------------------%
%-----------------------------------------------------------------------------%
:- implementation.
:- import_module hlds__error_util.
:- import_module hlds__hlds_out.
:- import_module hlds__hlds_pred.
:- import_module ml_backend__ml_util.
:- import_module parse_tree__prog_data.
:- import_module string, int, list, std_util.
ml_mark_tailcalls(MLDS0, MLDS) -->
{ MLDS0 = mlds(ModuleName, ForeignCode, Imports, Defns0) },
{ Defns = mark_tailcalls_in_defns(Defns0) },
{ MLDS = mlds(ModuleName, ForeignCode, Imports, Defns) }.
%-----------------------------------------------------------------------------%
% The `at_tail' type indicates whether or not a subgoal
% is at a tail position, i.e. is followed by a return
% statement or the end of the function, and if so,
% specifies the return values (if any) in the return statement.
:- type at_tail == maybe(list(mlds__rval)).
% The `locals' type contains a list of local definitions
% which are in scope.
:- type locals == list(local_defns).
:- type local_defns
---> params(mlds__arguments)
; defns(mlds__defns)
.
%-----------------------------------------------------------------------------%
%
% mark_tailcalls_in_defns:
% mark_tailcalls_in_defn:
% Recursively process the definition(s),
% marking each optimizable tail call in them as a tail call.
%
% mark_tailcalls_in_maybe_statement:
% mark_tailcalls_in_statements:
% mark_tailcalls_in_statement:
% mark_tailcalls_in_stmt:
% mark_tailcalls_in_case:
% mark_tailcalls_in_default:
% Recursively process the statement(s),
% marking each optimizable tail call in them as a tail call.
% The `AtTail' argument indicates whether or not this
% construct is in a tail call position.
% The `Locals' argument contains a list of the
% local definitions which are in scope at this point.
%
:- func mark_tailcalls_in_defns(mlds__defns) = mlds__defns.
mark_tailcalls_in_defns(Defns) = list__map(mark_tailcalls_in_defn, Defns).
:- func mark_tailcalls_in_defn(mlds__defn) = mlds__defn.
mark_tailcalls_in_defn(Defn0) = Defn :-
Defn0 = mlds__defn(Name, Context, Flags, DefnBody0),
(
DefnBody0 = mlds__function(PredProcId, Params, FuncBody0,
Attributes),
%
% Compute the initial value of the `Locals' and
% `AtTail' arguments.
%
Params = mlds__func_params(Args, RetTypes),
Locals = [params(Args)],
( RetTypes = [] ->
AtTail = yes([])
;
AtTail = no
),
FuncBody = mark_tailcalls_in_function_body(FuncBody0,
AtTail, Locals),
DefnBody = mlds__function(PredProcId, Params, FuncBody,
Attributes),
Defn = mlds__defn(Name, Context, Flags, DefnBody)
;
DefnBody0 = mlds__data(_, _, _),
Defn = Defn0
;
DefnBody0 = mlds__class(ClassDefn0),
ClassDefn0 = class_defn(Kind, Imports, BaseClasses, Implements,
CtorDefns0, MemberDefns0),
CtorDefns = mark_tailcalls_in_defns(CtorDefns0),
MemberDefns = mark_tailcalls_in_defns(MemberDefns0),
ClassDefn = class_defn(Kind, Imports, BaseClasses, Implements,
CtorDefns, MemberDefns),
DefnBody = mlds__class(ClassDefn),
Defn = mlds__defn(Name, Context, Flags, DefnBody)
).
:- func mark_tailcalls_in_function_body(function_body, at_tail, locals)
= function_body.
mark_tailcalls_in_function_body(external, _, _) = external.
mark_tailcalls_in_function_body(defined_here(Statement0), AtTail, Locals) =
defined_here(Statement) :-
Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals).
:- func mark_tailcalls_in_maybe_statement(maybe(mlds__statement),
at_tail, locals) = maybe(mlds__statement).
mark_tailcalls_in_maybe_statement(no, _, _) = no.
mark_tailcalls_in_maybe_statement(yes(Statement0), AtTail, Locals) =
yes(Statement) :-
Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals).
:- func mark_tailcalls_in_statements(mlds__statements, at_tail, locals) =
mlds__statements.
mark_tailcalls_in_statements([], _, _) = [].
mark_tailcalls_in_statements([First0 | Rest0], AtTail, Locals) =
[First | Rest] :-
%
% If the First statement is followed by a `return'
% statement, then it is in a tailcall position.
% If there are no statements after the first, then
% the first statement is in a tail call position
% iff the statement list is in a tail call position.
% Otherwise, i.e. if the first statement is followed
% by anything other than a `return' statement, then
% the first statement is not in a tail call position.
%
(
Rest = [mlds__statement(return(ReturnVals), _) | _]
->
FirstAtTail = yes(ReturnVals)
;
Rest = []
->
FirstAtTail = AtTail
;
FirstAtTail = no
),
First = mark_tailcalls_in_statement(First0, FirstAtTail, Locals),
Rest = mark_tailcalls_in_statements(Rest0, AtTail, Locals).
:- func mark_tailcalls_in_statement(mlds__statement, at_tail, locals) =
mlds__statement.
mark_tailcalls_in_statement(Statement0, AtTail, Locals) = Statement :-
Statement0 = mlds__statement(Stmt0, Context),
Stmt = mark_tailcalls_in_stmt(Stmt0, AtTail, Locals),
Statement = mlds__statement(Stmt, Context).
:- func mark_tailcalls_in_stmt(mlds__stmt, at_tail, locals) = mlds__stmt.
mark_tailcalls_in_stmt(Stmt0, AtTail, Locals) = Stmt :-
(
%
% Whenever we encounter a block statement,
% we recursively mark tailcalls in any nested
% functions defined in that block.
% We also need to add any local definitions in that
% block to the list of currently visible local
% declarations before processing the statements
% in that block. The statement list will be in a
% tail position iff the block is in a tail position.
%
Stmt0 = block(Defns0, Statements0),
Defns = mark_tailcalls_in_defns(Defns0),
NewLocals = [defns(Defns) | Locals],
Statements = mark_tailcalls_in_statements(Statements0,
AtTail, NewLocals),
Stmt = block(Defns, Statements)
;
%
% The statement in the body of a while loop is never
% in a tail position.
%
Stmt0 = while(Rval, Statement0, Once),
Statement = mark_tailcalls_in_statement(Statement0, no, Locals),
Stmt = while(Rval, Statement, Once)
;
%
% Both the `then' and the `else' parts of an if-then-else
% are in a tail position iff the if-then-else is in a
% tail position.
%
Stmt0 = if_then_else(Cond, Then0, MaybeElse0),
Then = mark_tailcalls_in_statement(Then0, AtTail, Locals),
MaybeElse = mark_tailcalls_in_maybe_statement(MaybeElse0,
AtTail, Locals),
Stmt = if_then_else(Cond, Then, MaybeElse)
;
%
% All of the cases of a switch (including the default)
% are in a tail position iff the switch is in a
% tail position.
%
Stmt0 = switch(Type, Val, Range, Cases0, Default0),
Cases = mark_tailcalls_in_cases(Cases0, AtTail, Locals),
Default = mark_tailcalls_in_default(Default0, AtTail, Locals),
Stmt = switch(Type, Val, Range, Cases, Default)
;
Stmt0 = label(_),
Stmt = Stmt0
;
Stmt0 = goto(_),
Stmt = Stmt0
;
Stmt0 = computed_goto(_, _),
Stmt = Stmt0
;
Stmt0 = call(Sig, Func, Obj, Args, ReturnLvals, CallKind0),
%
% check if we can mark this call as a tail call
%
(
CallKind0 = ordinary_call,
%
% we must be in a tail position
%
AtTail = yes(ReturnRvals),
%
% the values returned in this call must match
% those returned by the `return' statement that
% follows
%
match_return_vals(ReturnRvals, ReturnLvals),
%
% the call must not take the address of any
% local variables or nested functions
%
check_maybe_rval(Obj, Locals),
check_rvals(Args, Locals),
%
% the call must not be to a function nested within
% this function
%
check_rval(Func, Locals)
->
% mark this call as a tail call
CallKind = tail_call,
Stmt = call(Sig, Func, Obj, Args, ReturnLvals,
CallKind)
;
% leave this call unchanged
Stmt = Stmt0
)
;
Stmt0 = return(_Rvals),
Stmt = Stmt0
;
Stmt0 = do_commit(_Ref),
Stmt = Stmt0
;
Stmt0 = try_commit(Ref, Statement0, Handler0),
%
% Both the statement inside a `try_commit' and the
% handler are in tail call position iff the
% `try_commit' statement is in a tail call position.
%
Statement = mark_tailcalls_in_statement(Statement0, AtTail,
Locals),
Handler = mark_tailcalls_in_statement(Handler0, AtTail, Locals),
Stmt = try_commit(Ref, Statement, Handler)
;
Stmt0 = atomic(_),
Stmt = Stmt0
).
:- func mark_tailcalls_in_cases(list(mlds__switch_case), at_tail, locals) =
list(mlds__switch_case).
mark_tailcalls_in_cases([], _, _) = [].
mark_tailcalls_in_cases([Case0 | Cases0], AtTail, Locals) =
[Case | Cases] :-
Case = mark_tailcalls_in_case(Case0, AtTail, Locals),
Cases = mark_tailcalls_in_cases(Cases0, AtTail, Locals).
:- func mark_tailcalls_in_case(mlds__switch_case, at_tail, locals) =
mlds__switch_case.
mark_tailcalls_in_case(Cond - Statement0, AtTail, Locals) =
Cond - Statement :-
Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals).
:- func mark_tailcalls_in_default(mlds__switch_default, at_tail, locals) =
mlds__switch_default.
mark_tailcalls_in_default(default_do_nothing, _, _) = default_do_nothing.
mark_tailcalls_in_default(default_is_unreachable, _, _) =
default_is_unreachable.
mark_tailcalls_in_default(default_case(Statement0), AtTail, Locals) =
default_case(Statement) :-
Statement = mark_tailcalls_in_statement(Statement0, AtTail, Locals).
%-----------------------------------------------------------------------------%
%
% match_return_vals(Rvals, Lvals):
% match_return_val(Rval, Lval):
% Check that the Lval(s) returned by a call match
% the Rval(s) in the `return' statement that follows,
% and those Lvals are local variables
% (so that assignments to them won't have any side effects),
% so that we can optimize the call into a tailcall.
%
:- pred match_return_vals(list(mlds__rval), list(mlds__lval)).
:- mode match_return_vals(in, in) is semidet.
match_return_vals([], []).
match_return_vals([Rval|Rvals], [Lval|Lvals]) :-
match_return_val(Rval, Lval),
match_return_vals(Rvals, Lvals).
:- pred match_return_val(mlds__rval, mlds__lval).
:- mode match_return_val(in, in) is semidet.
match_return_val(lval(Lval), Lval) :-
lval_is_local(Lval).
:- pred lval_is_local(mlds__lval).
:- mode lval_is_local(in) is semidet.
lval_is_local(var(_, _)) :-
% We just assume it is local. (This assumption is
% true for the code generated by ml_code_gen.m.)
true.
lval_is_local(field(_Tag, Rval, _Field, _, _)) :-
% a field of a local variable is local
( Rval = mem_addr(Lval) ->
lval_is_local(Lval)
;
fail
).
lval_is_local(mem_ref(_Rval, _Type)) :-
fail.
%-----------------------------------------------------------------------------%
%
% check_rvals:
% check_maybe_rval:
% check_rval:
% Fail if the specified rval(s) might evaluate to the addresses of
% local variables (or fields of local variables) or nested functions.
%
:- pred check_rvals(list(mlds__rval), locals).
:- mode check_rvals(in, in) is semidet.
check_rvals([], _).
check_rvals([Rval|Rvals], Locals) :-
check_rval(Rval, Locals),
check_rvals(Rvals, Locals).
:- pred check_maybe_rval(maybe(mlds__rval), locals).
:- mode check_maybe_rval(in, in) is semidet.
check_maybe_rval(no, _).
check_maybe_rval(yes(Rval), Locals) :-
check_rval(Rval, Locals).
:- pred check_rval(mlds__rval, locals).
:- mode check_rval(in, in) is semidet.
check_rval(lval(_Lval), _) :-
% Passing the _value_ of an lval is fine.
true.
check_rval(mkword(_Tag, Rval), Locals) :-
check_rval(Rval, Locals).
check_rval(const(Const), Locals) :-
check_const(Const, Locals).
check_rval(unop(_Op, Rval), Locals) :-
check_rval(Rval, Locals).
check_rval(binop(_Op, X, Y), Locals) :-
check_rval(X, Locals),
check_rval(Y, Locals).
check_rval(mem_addr(Lval), Locals) :-
% Passing the address of an lval is a problem,
% if that lval names a local variable.
check_lval(Lval, Locals).
%
% check_lval:
% Fail if the specified lval might be a local variable
% (or a field of a local variable).
%
:- pred check_lval(mlds__lval, locals).
:- mode check_lval(in, in) is semidet.
check_lval(field(_MaybeTag, Rval, _FieldId, _, _), Locals) :-
check_rval(Rval, Locals).
check_lval(mem_ref(_, _), _) :-
% We assume that the addresses of local variables are only
% ever passed down to other functions, or assigned to,
% so a mem_ref lval can never refer to a local variable.
true.
check_lval(var(Var0, _), Locals) :-
\+ var_is_local(Var0, Locals).
%
% check_const:
% Fail if the specified const might be the address of a
% local variable or nested function.
%
% The addresses of local variables are probably
% not consts, at least not unless those variables are
% declared as static (i.e. `one_copy'),
% so it might be safe to allow all data_addr_consts here,
% but currently we just take a conservative approach.
%
:- pred check_const(mlds__rval_const, locals).
:- mode check_const(in, in) is semidet.
check_const(Const, Locals) :-
( Const = code_addr_const(CodeAddr) ->
\+ function_is_local(CodeAddr, Locals)
; Const = data_addr_const(DataAddr) ->
DataAddr = data_addr(ModuleName, DataName),
( DataName = var(VarName) ->
\+ var_is_local(qual(ModuleName, VarName), Locals)
;
true
)
;
true
).
%
% var_is_local:
% Check whether the specified variable is defined locally,
% i.e. in storage that might no longer exist when the function
% returns or does a tail call.
%
% It would be safe to fail for variables declared static
% (i.e. `one_copy'), but currently we just take a conservative
% approach.
%
:- pred var_is_local(mlds__var, locals).
:- mode var_is_local(in, in) is semidet.
var_is_local(Var, Locals) :-
% XXX we ignore the ModuleName --
% that is safe, but overly conservative
Var = qual(_ModuleName, VarName),
some [Local] (
locals_member(Local, Locals),
Local = data(var(VarName))
).
%
% function_is_local:
% Check whether the specified function is defined locally
% (i.e. as a nested function).
%
:- pred function_is_local(mlds__code_addr, locals).
:- mode function_is_local(in, in) is semidet.
function_is_local(CodeAddr, Locals) :-
(
CodeAddr = proc(QualifiedProcLabel, _Sig),
MaybeSeqNum = no
;
CodeAddr = internal(QualifiedProcLabel, SeqNum, _Sig),
MaybeSeqNum = yes(SeqNum)
),
% XXX we ignore the ModuleName --
% that is safe, but might be overly conservative
QualifiedProcLabel = qual(_ModuleName, ProcLabel),
ProcLabel = PredLabel - ProcId,
some [Local] (
locals_member(Local, Locals),
Local = function(PredLabel, ProcId, MaybeSeqNum, _PredId)
).
%
% locals_member(Name, Locals):
% Nondeterministically enumerates the names of all the entities
% in Locals.
%
:- pred locals_member(mlds__entity_name, locals).
:- mode locals_member(out, in) is nondet.
locals_member(Name, LocalsList) :-
list__member(Locals, LocalsList),
(
Locals = defns(Defns),
list__member(Defn, Defns),
Defn = mlds__defn(Name, _, _, _)
;
Locals = params(Params),
list__member(Param, Params),
Param = mlds__argument(Name, _, _)
).
%-----------------------------------------------------------------------------%
ml_warn_tailcalls(MLDS) -->
{ solutions(nontailcall_in_mlds(MLDS), Warnings) },
list__foldl(report_nontailcall_warning, Warnings).
:- type tailcall_warning ---> tailcall_warning(
mlds__pred_label,
proc_id,
mlds__context
).
:- pred nontailcall_in_mlds(mlds::in, tailcall_warning::out) is nondet.
nontailcall_in_mlds(MLDS, Warning) :-
MLDS = mlds(ModuleName, _ForeignCode, _Imports, Defns),
MLDS_ModuleName = mercury_module_name_to_mlds(ModuleName),
nontailcall_in_defns(MLDS_ModuleName, Defns, Warning).
:- pred nontailcall_in_defns(mlds_module_name::in, mlds__defns::in,
tailcall_warning::out) is nondet.
nontailcall_in_defns(ModuleName, Defns, Warning) :-
list__member(Defn, Defns),
nontailcall_in_defn(ModuleName, Defn, Warning).
:- pred nontailcall_in_defn(mlds_module_name::in, mlds__defn::in,
tailcall_warning::out) is nondet.
nontailcall_in_defn(ModuleName, Defn, Warning) :-
Defn = mlds__defn(Name, _Context, _Flags, DefnBody),
(
DefnBody = mlds__function(_PredProcId, _Params, FuncBody,
_Attributes),
FuncBody = defined_here(Body),
nontailcall_in_statement(ModuleName, Name, Body, Warning)
;
DefnBody = mlds__class(ClassDefn),
ClassDefn = class_defn(_Kind, _Imports, _BaseClasses,
_Implements, CtorDefns, MemberDefns),
( nontailcall_in_defns(ModuleName, CtorDefns, Warning)
; nontailcall_in_defns(ModuleName, MemberDefns, Warning)
)
).
:- pred nontailcall_in_statement(mlds_module_name::in, mlds__entity_name::in,
mlds__statement::in, tailcall_warning::out) is nondet.
nontailcall_in_statement(CallerModule, CallerFuncName, Statement, Warning) :-
% nondeterministically find a non-tail call
statement_contains_statement(Statement, SubStatement),
SubStatement = mlds__statement(SubStmt, Context),
SubStmt = call(_CallSig, Func, _This, _Args, _RetVals, CallKind),
CallKind = ordinary_call,
% check if this call is a directly recursive call
Func = const(code_addr_const(CodeAddr)),
( CodeAddr = proc(QualProcLabel, _Sig), MaybeSeqNum = no
; CodeAddr = internal(QualProcLabel, SeqNum, _Sig),
MaybeSeqNum = yes(SeqNum)
),
QualProcLabel = qual(CallerModule, PredLabel - ProcId),
CallerFuncName = function(PredLabel, ProcId, MaybeSeqNum, _PredId),
% if so, construct an appropriate warning
Warning = tailcall_warning(PredLabel, ProcId, Context).
:- pred report_nontailcall_warning(tailcall_warning::in,
io__state::di, io__state::uo) is det.
report_nontailcall_warning(tailcall_warning(PredLabel, ProcId, Context)) -->
(
{ PredLabel = pred(PredOrFunc, _MaybeModule, Name, Arity,
_CodeModel, _NonOutputFunc) },
{ hlds_out__simple_call_id_to_string(PredOrFunc -
unqualified(Name) / Arity, CallId) },
{ proc_id_to_int(ProcId, ProcNumber0) },
{ ProcNumber = ProcNumber0 + 1 },
{ ProcNumberStr = string__int_to_string(ProcNumber) },
report_warning(mlds__get_prog_context(Context), 0, [
words("In mode number"), words(ProcNumberStr),
words("of"), fixed(CallId ++ ":"), nl,
words(" warning: recursive call is not tail recursive.")
])
;
{ PredLabel = special_pred(_, _, _, _) }
% don't warn about these
).
%-----------------------------------------------------------------------------%