:- module recombstack. % keeps and prunes stack % elements with low score *given the ordering* are dropped :- interface. :- import_module ordered, list, io, string. :- typeclass stackable(T, Score, K) <= (ordered(Score), (T->Score), (T->K)) where [ func value(T) = Score, func to_key(T) = K, pred recombine(T::in, T::in, T::out) is semidet, % failure prohibits recombination pred write_elem(string::in, T::in, io::di, io::uo) is det ]. :- type stack(T, Score, K). % <= stackable(T, Score, K). :- func new(int) = stack(T, Score, K). % create a new stack with stacklimit N :- func size(stack(T, Score, K)) = int. % how many elements? :- func to_list(stack(T, Score, K)) = list(T). % unsorted list of elements :- func unsorted_elems(stack(T, Score, K)) = list(T). % unsorted list of elements :- pred peek_top(stack(T, Score, K)::in, T::out) is semidet <= stackable(T, S, K). % (prune and) return the top element :- pred peek_worst_score(stack(T, Score, K)::in, Score::out) is semidet <= stackable(T, S, K). % (don't prune and) return the score of the worst element :- type insert_status(T) ---> kept ; too_bad_right_away ; inserted_but_maybe_pruned ; recombined_with(T) ; recombined_with_but_maybe_pruned(T) . :- pred insert_prune(T::in, stack(T, S, K)::in, stack(T, S, K)::out, insert_status(T)::out) is det <= stackable(T, S, K). :- pred insert_prune(T::in, stack(T, S, K)::in, stack(T, S, K)::out) is det <= stackable(T, S, K). :- func insert_prune(T, stack(T, S, K)) = stack(T, S, K) <= stackable(T, S, K). % force pruning to stack_limit :- func prune(stack(T, S, K)) = stack(T, S, K) <= stackable(T, S, K). :- pred fold(pred(T, Aku, Aku), stack(T, S, K), Aku, Aku) <= stackable(T, S, K). :- mode fold(pred(in, in, out) is cc_multi, in, in, out) is cc_multi. % unsorted walk over all (unpruned) elements :- pred foldl(pred(T, Aku, Aku), stack(T, S, K), Aku, Aku) <= stackable(T, S, K). :- mode foldl(pred(in, in, out) is det, in, in, out) is det. :- mode foldl(pred(in, di, uo) is det, in, di, uo) is det. :- mode foldl(pred(in, di, uo) is cc_multi, in, di, uo) is cc_multi. % sorted walk over all (unpruned) elements :- pred foldl2(pred(T, Aku, Aku, IO, IO), stack(T, S, K), Aku, Aku, IO, IO) <= stackable(T, S, K). :- mode foldl2(pred(in, in, out, di, uo) is det, in, in, out, di, uo) is det. % sorted walk over all (unpruned) elements :- func sorted_elems(stack(T, S, K)) = list(T) <= stackable(T, S, K). :- pred aggregate_to_stack(int, pred(T), stack(T, S, K)) <= stackable(T, S, K). :- mode aggregate_to_stack(in, pred(out) is nondet, out) is cc_multi. % given a stacklimit and a generator predicate, scan all solutions and fill % up the stack :- implementation. :- import_module int, binheap, require, map, maybe, map_set, set. :- import_module list_tools. :- import_module debugstr, type_desc. :- type stack(T, Score, K) ---> blank( sl::int, pl::int) ; stack( worst :: Score, elems :: map_set(K, T), stack_limit :: int, prune_limit :: int, cnt :: int ). new(StackLimit) = blank(StackLimit, StackLimit*2). size(Stack) = Size :- Stack = blank(_,_), Size = 0 ; Stack = stack(_, _, _, _, Size). to_list(Stack) = Out :- Stack = blank(_,_), Out = [] ; Stack = stack(_, Map, _, _, _), Out = map.foldl(func(_k, Hyps, Tail)=set.to_sorted_list(Hyps)++Tail, Map, []). % map__values(Map). unsorted_elems(Stack) = to_list(Stack). insert_prune(H, IS) = OS :- insert_prune(H, IS, OS). insert_prune(H, InStack, OutStack) :- insert_prune(H, InStack, OutStack, _Status). insert_prune(H, InStack, OutStack, OutStatus):- ( InStack = blank(StackLimit, PruneLimit), % adding the first hypo OutStack = stack(value(H), map_set.add(map__init, to_key(H), H), StackLimit, PruneLimit, 1), OutStatus = kept ; InStack = stack(Worst, Elems, StackLimit, PruneLimit, Cnt), % extending the stack ( if ordered.ordering(value(H), Worst) = (<) `with_type` comparison_result , Cnt > StackLimit then % not adding, this is too bad anyway % trace [io(!IO)] (debugstr("DROPPING", !IO), write_elem("Dropping: ", H, !IO)), OutStack = InStack, OutStatus = too_bad_right_away else % adding, but might prune NewWorst = ordered.min(Worst, value(H)), K = to_key(H), % trace[io(!IO)]write_elem("Key "++string(K)++" from: ", H, !IO), ( if map__search(Elems, K, OldHyps), list_tools.map_first_matching( (pred(OH::in, NH::out) is semidet:- trace[io(!IO)]debugstr("Recombinable? ", type_of(H), !IO), % trace[io(!IO)]write_elem(" with: ", type_of(OH), !IO), recombine(H, OH, NH) ), set.to_sorted_list(OldHyps), HToReplace, HReplacement) then % just recombine set.insert(set.delete(OldHyps, HToReplace), HReplacement, NewHyps), map__det_update(Elems, K, NewHyps, NewElems), NewCnt = Cnt, Recombined = yes(HToReplace) else % properly add NewElems = map_set.add(Elems, K, H), NewCnt = Cnt+1, Recombined = no ), ( if NewCnt > PruneLimit then % add and prune restrict_stack(StackLimit, NewElems, PrunedElems, PrunedWorst), PrunedCnt = StackLimit, ( Recombined = no, OutStatus = inserted_but_maybe_pruned ; Recombined = yes(Rec), OutStatus = recombined_with_but_maybe_pruned(Rec) ) else % just add PrunedWorst = NewWorst, PrunedElems = NewElems, PrunedCnt = NewCnt, ( Recombined = no, OutStatus = kept ; Recombined = yes(Rec), OutStatus = recombined_with(Rec) ) ), OutStack = stack(PrunedWorst, PrunedElems, StackLimit, PruneLimit, PrunedCnt) ) ). prune(InStack) = OutStack :- InStack = blank(_StackLimit, _PruneLimit), % nothing to prune OutStack = InStack ; InStack = stack(_Worst, Elems, StackLimit, PruneLimit, Cnt), ( if Cnt > StackLimit then % prune restrict_stack(StackLimit, Elems, PrunedElems, PrunedWorst), PrunedCnt = StackLimit, OutStack = stack(PrunedWorst, PrunedElems, StackLimit, PruneLimit, PrunedCnt) else % already under limit OutStack = InStack ) . :- type invscore(T) ---> invscore(T). :- instance ordered(invscore(T)) <= ordered(T) where [ ordering(invscore(A), invscore(B)) = ordered.ordering(B, A) ]. :- pred restrict_stack(int::in, map_set(K, T)::in, map_set(K, T)::out, Score::out) is det <= stackable(T, Score, K). restrict_stack(Lim, InList, OutList, WorstScore) :- map_set__foldl((pred(K::in, H::in, InHeap::in, OutHeap::out) is det:- OutHeap = binheap__insert(InHeap, invscore(value(H)), {K,H}) ), InList, binheap__init, Heap), fold_up( (pred(_I::in, {_, InHeap, InMap}::in, {InHeap, OutHeap, OutMap}::out) is det:- if getmin(InHeap, _, {TK,TH}, TOutHeap) then OutHeap = TOutHeap, % trace [io(!IO)] (debugstr("KEEPING", !IO), write_elem("Keeping: ", TH, !IO)), OutMap = map_set__add(InMap, TK, TH) else error("Hit stack ground!") ), 1, Lim, {Heap, Heap, map__init}, {LastHeap, _RestHeap, OutList}), /* trace [io(!IO)] ( DroppedElems = map(func({_K,H})=H, sorted_values(RestHeap)), list__foldl(write_elem("Dropped: "), DroppedElems, !IO) ), */ ( if getmin(LastHeap, invscore(TWorstScore), _, _) then WorstScore = TWorstScore else error("No worst hypo!") ). peek_top(Stack, OutHyp) :- Stack = blank(_,_), fail ; Stack = stack(_, InList, _, _, _), map_set__foldl((pred(__K::in, H::in, InHeap::in, OutHeap::out) is det:- OutHeap = binheap__insert(InHeap, invscore(value(H)), H) ), InList, binheap__init, Heap), ( if getmin(Heap, _K, TH, _TOutHeap) then OutHyp = TH else error("Hit stack ground!") ) . peek_worst_score(Stack, OutScore) :- Stack = blank(_,_), fail ; Stack = stack(OutScore, _InList, _, _, _). sorted_elems(InStack) = OutList :- InStack = blank(_, _), OutList = [] ; InStack = stack(_, InList, _, _, _), map_set__foldl((pred(__K::in, H::in, InHeap::in, OutHeap::out) is det:- OutHeap = binheap__insert(InHeap, invscore(value(H)), H) ), InList, binheap__init, Heap), OutList = sorted_values(Heap). foldl(Pred, InStack, InAku, OutAku) :- list__foldl(Pred, sorted_elems(InStack), InAku, OutAku). foldl2(Pred, InStack, !Aku, !IO) :- list__foldl2(Pred, sorted_elems(InStack), !Aku, !IO). fold(Pred, InStack, InAku, OutAku) :- InStack = blank(_, _), InAku = OutAku ; InStack = stack(_, Elems, _, _, _), list__foldl(Pred, unsorted_elems(InStack), InAku, OutAku). :- import_module solutions. aggregate_to_stack(Limit, Generator, prune(Stack)) :- unsorted_aggregate(Generator, insert_prune, new(Limit), Stack).