staload _ = "prelude/DATS/array.dats"

dataprop LEON(int, int) = //LEON(n,x) means L[n] = x
| {n:nat | n <= 1} LEONbase(n, 1)
| {n:nat | n > 1; a,b:nat | a > 0; b > 0} LEONind(n, a+b+1) of (LEON(n-2, a), LEON(n-1,b))

prfn leon_positive {n,x:nat}  (pf : LEON(n,x)): [x > 0] void = 
  case+ pf of
  | LEONbase () => ()
  | LEONind(pf1, pf2) => ()
    
prfun leon_isfun {n,x1,x2:nat} .<n>. (pf1 : LEON(n,x1), pf2 : LEON(n,x2)): [x1==x2] void =
  case+ (pf1, pf2) of
  | (LEONbase (), LEONbase ()) => ()
  | (LEONind(p1a,p1b), LEONind(p2a,p2b)) => let
      prval () = leon_isfun(p1a, p2a)
      prval () = leon_isfun(p1b, p2b)
    in () end

prfn leon_mono {k1,sz1,sz2:nat | k1 > 0} 
  (pf1 : LEON(k1, sz1), pf2 : LEON(k1 + 1,sz2)): [sz1 < sz2-1] void = let
    prval LEONind(pf2_2, pf2_1) = pf2
    prval () = leon_isfun(pf1, pf2_1)
  in () end

prfn leon_base_is1 {k,sz:nat | k < 2} (pf : LEON(k,sz)): [sz==1] void =
  case+ pf of LEONbase () => ()

prfn leon_size_dif {k,sz, sz1, sz2 : nat |  k > 1; sz2 == sz - 1 - sz1}    
  (pfsz : LEON(k,sz), pfsz1 : LEON(k-1,sz1)) : LEON(k-2, sz2) =
  case+ pfsz of 
  | LEONind(pf_2, pf_1) => let prval () = leon_isfun(pf_1, pfsz1) in pf_2 end

typedef L(n:int) = [x:nat | x > 0] (LEON(n,x) | int x)
typedef L0 = [n:nat] L(n)

extern fun leon_cache_get {n:nat} (n: int n): Option(L(n)) = "leon_cache_get"
extern fun leon_cache_set {n:nat; x:int} (n: int n, v: L(n)): void = "leon_cache_set"

extern fun leon_cache_get0 {n:nat} (n: int n): Option(L0) = "leon_cache_get"
extern fun leon_cache_set0 {n,x:nat} (n: int n, v: L0): void = "leon_cache_set"

#define LSIZE 43
val Lcache : array(Option(L0), LSIZE) = array_make_elt(LSIZE, None ()) 
implement leon_cache_get0(n) = if n < LSIZE then Lcache[n] else None 
implement leon_cache_set0(n, v) = if n < LSIZE then Lcache[n] := Some v else ()

extern fun leon {n:nat} (n: int n): L(n) 
implement leon(n) = 
  if n < 2 then (LEONbase | 1)
  else 
    case+ leon_cache_get n of 
    | Some x => x
    | None () => let
        val (pa | a) = leon(n-2)
        val (pb | b) = leon(n-1)
        val res = (LEONind(pa, pb) | a + b + 1)        
      in
        leon_cache_set(n, res);        
        res
      end    

dataprop LH(int, int) = //LH(r,sz) means r >= sz - 1, r>=0, sz>=1
  {r,sz:nat | sz >= 1; r >= sz - 1} LHpf(r,sz)

prfn lh_use {r,sz:int} (pf : LH(r, sz)): [r >= sz-1; sz >= 1] void =
  case pf of LHpf () => ()

// LTE(i,j) means A[i] <= A[j]
dataprop LTE(int, int) = 
  | {i,j:nat} LTEcompared(i,j)
  | {k:nat} LTErefl(k,k)
  
extern praxi lte_trans {a,b,c:nat} (ab : LTE(a,b), bc : LTE(b,c)): LTE(a,c)

// MAXRIGHT(i,j) means A[j] is the max of A[i..j]
dataprop MAXRIGHT(int, int) =
  | {k:nat} MRsingle(k,k)
  | {i,j:nat | i <= j; i > 0} MRgrow_l(i-1,j) of (LTE(i-1,j), MAXRIGHT(i,j))

prfun lte_from_maxright {i,j,k:nat | i <= k; k <= j} .<j-i>. 
  (mr: MAXRIGHT(i,j), k: int k, i: int i, j: int j): LTE(k,j) =
    case+ mr of
    | MRsingle () => LTErefl
    | MRgrow_l(lte_ij, mr_i1j) => 
        if k = i then lte_ij else lte_from_maxright(mr_i1j, k, i+1, j)

prfn mr_join {a,b,c:nat | a <= b; b < c}
  (ab: MAXRIGHT(a,b), bc: MAXRIGHT(b+1, c), lte_bc: LTE(b,c), 
    a: int a, b: int b): MAXRIGHT(a,c) = let 
  prfun loop {i:nat | i >= a; i <= b+1} .<i-a>. (i: int i, mr_ic: MAXRIGHT(i,c)): MAXRIGHT(a,c) =
    if i=a then mr_ic else
    loop(i-1, MRgrow_l(lte_trans(lte_from_maxright(ab, i-1, a, b), lte_bc), mr_ic))    
in
  loop(b+1, bc)
end

prfn mr_grow_r {a,b:nat | a <= b}
  (ab: MAXRIGHT(a,b), pf_lte: LTE(b,b+1), a: int a, b: int b): MAXRIGHT(a,b+1) = let
  prval b1 : MAXRIGHT(b+1,b+1) = MRsingle
in
  mr_join(ab, b1, pf_lte, a, b)
end    

//SORTED(i,j) means A[i..j] is sorted
dataprop SORTED(int, int) =
  | {k:nat} SORTEDsingle(k,k)
  | {i,j:nat | i > 0; i <= j} SORTEDjoin(i-1, j) of (LTE(i-1,i), SORTED(i,j))

// a heap is either a single element or contains proper sub-heaps
dataprop GOODCHILDREN(int) =
  | {k:two} GCsmall(k)
  | {k:nat | k > 1; a,b,c:nat} GCbig(k) of (MAXRIGHT(a,b), MAXRIGHT(b+1,c))

extern praxi good_children_mr {a,b,c,d:nat} {k:nat | k > 1} 
  (pf : GOODCHILDREN(k)): (MAXRIGHT(a,b), MAXRIGHT(c,d))

extern praxi good_child {d:nat | d < 3} {k:nat | k > 1} (pf : GOODCHILDREN(k)): GOODCHILDREN(k-d)

typedef gt (a:t@ype) = (a,a) -> bool // x > y

//returns true if a swap occured
fn {a:t@ype} order_elements {n,i,j:nat | i < n; j < n; i <= j}
  (A: array(a,n), i : int i, j : int j, gt : gt(a)): (LTE(i,j) | bool) = 
    if A[i] \gt A[j] then 
      let val tmp = A[i] in A[i] := A[j]; A[j] := tmp;      
      (LTEcompared | true)
    end 
    else (LTEcompared | false)  
    
datatype compare_res(int,int) =    
  | {i,j:nat} LeftGr(i,j) of (LTE(j,i) | int)
  | {i,j:nat} RightGr(i,j) of (LTE(i,j) | int)
  
fn {a:t@ype} compare_elements {n,i,j:nat | j < n; i <= j}
  (A: array(a,n), i : int i, j : int j, gt : gt(a)) : compare_res(i,j) =
    if A[i] \gt A[j] then LeftGr (LTEcompared | 0)
    else RightGr (LTEcompared | 1)

typedef heap(r:int, k:int, sz:int) = 
  @{root = int r, k = int k, sz = int sz, pf_sz = LEON(k,sz), pf_r = LH(r,sz), pf_gc = GOODCHILDREN(k)}

typedef heap1(r:int, k:int, sz:int) = 
  @{ hp = heap(r,k,sz),  pf_mr = MAXRIGHT(r-sz+1, r) }

typedef heap2(r:int, k:int, sz:int) = 
  @{ hp = heap(r,k,sz),  pf_mr = MAXRIGHT(r-sz+1, r), pf_totalmr = MAXRIGHT(0,r) }

datatype heaps(int) = 
  | heaps_nil (0)
  | {m,r,k,sz:nat | m + sz - 1 == r} heaps_cons (r+1) of (heap2(r,k,sz), heaps(m))

#define :: heaps_cons
#define nil heaps_nil

datatype order_res(r1:int,k1:int,sz1:int) =
  | Swap(r1,k1,sz1) of heap(r1,k1,sz1)
  | Noswap(r1,k1,sz1) 
  
fn {a:t@ype} smoothsort {n:nat | n > 0} 
 (A : array(a,n), n : int n, gt : gt(a)): (SORTED(0,n-1) | void) = let
  fn top {m:int | m > 0}(hs : heaps(m)):<> [r,k,sz:nat | r + 1 == m] heap2(r,k,sz) =  
    case+ hs of h :: rest => h

  fn small_heap {r:nat} (r:int r, prev_k : int):<> [k:two] heap(r,k,1) =
    if prev_k = 1 then #[.. | @{root=r, k=0, sz=1, pf_sz=LEONbase, pf_r = LHpf, pf_gc = GCsmall}]
    else #[.. | @{root=r, k=1, sz=1, pf_sz=LEONbase, pf_r = LHpf, pf_gc = GCsmall}]

  fn heap2to1 {r,k,sz:nat} (h : heap2(r,k,sz)):<> heap1(r,k,sz) =
    @{hp = h.hp, pf_mr = h.pf_mr}
  
  fn split_heap {r,k,sz:nat | k > 1} (h: heap(r,k,sz)):
      [cr1,csz1, cr2,csz2:nat | cr1 - csz1 == r - sz; cr2 == r-1; cr2 == cr1 + csz2; csz1 > 0; csz2 > 0 ] 
      (heap1(cr1, k-1, csz1), heap1(cr2, k-2, csz2)) = let
    val (pfsz1 | child_size1) = leon(h.k - 1)
    prval () = lh_use h.pf_r
    prval () = leon_mono(pfsz1, h.pf_sz)    
    
    val child_size2 = h.sz - 1 - child_size1
    prval pfsz2 = leon_size_dif(h.pf_sz, pfsz1)
    
    prval pfgc1 = good_child {1} h.pf_gc
    prval pfgc2 = good_child {2} h.pf_gc    
    
    val left_h =  @{root = h.root - h.sz + child_size1, k = h.k - 1, sz = child_size1, 
              pf_sz = pfsz1, pf_r = LHpf, pf_gc = pfgc1}
    
    val right_h = @{root = h.root - 1, k = h.k - 2, sz = child_size2,
              pf_sz = pfsz2, pf_r = LHpf, pf_gc = pfgc2}   
    
    prval (pfmr1, pfmr2) = good_children_mr h.pf_gc
  in #[.. | (@{hp=left_h, pf_mr=pfmr1}, @{hp=right_h, pf_mr=pfmr2})]
  end
  
  fn order_roots {r1,k1,sz1, r2:nat | r1 < r2; r2 < n}
    (lh : heap2(r1,k1,sz1), rroot : int r2):<cloref1> (LTE(r1,r2) | order_res(r1,k1,sz1)) = let
      val (pf_lte | swapped) = order_elements(A, lh.hp.root, rroot, gt)
  in 
    if swapped then (pf_lte | Swap(lh.hp))
    else (pf_lte | Noswap)
  end
  
  fun restore_heap_prop {r,k,sz:nat | r < n} (h: heap(r,k,sz)):<cloref1> heap1(r,k,sz) = 
    if h.k < 2 then 
      let prval () = leon_base_is1 h.pf_sz in @{hp = h, pf_mr = MRsingle} end
    else let
      val (left, right) = split_heap h in
      restore_big_heap_prop(left, right, h)
    end
  
  and restore_big_heap_prop  {r,k,sz, cr1,csz1, cr2,csz2:nat | r < n; k > 1; 
    cr1 - csz1 == r - sz; cr2 == r-1; cr2 == cr1 + csz2} 
    (left: heap1(cr1, k-1, csz1), right: heap1(cr2, k-2, csz2), h: heap(r,k,sz)):<cloref1> heap1(r,k,sz) = let       
      val start = left.hp.root - left.hp.sz + 1
      prval () = lh_use left.hp.pf_r
      prval () = leon_positive right.hp.pf_sz 
    in
      case+ compare_elements(A, left.hp.root, right.hp.root, gt) of
      | LeftGr(pf_lte_rl | _) => let
          val (pf_lte_lh | swapped) = order_elements(A, left.hp.root, h.root, gt)          
          val rstart = right.hp.root - right.hp.sz + 1
          prval pf_lte_rh = lte_trans(pf_lte_rl, pf_lte_lh)          
          prval pf_mrr = mr_grow_r(right.pf_mr, pf_lte_rh, rstart, right.hp.root)  
        in
          if swapped then let
              val left1 = restore_heap_prop left.hp
              prval pf_mrheap = mr_join(left1.pf_mr, pf_mrr, pf_lte_lh, start, left1.hp.root)
            in @{hp = h, pf_mr = pf_mrheap}
          end else let
              prval pf_mrheap = mr_join(left.pf_mr, pf_mrr, pf_lte_lh, start, left.hp.root)
            in @{hp = h, pf_mr = pf_mrheap}
          end
        end
      | RightGr(pf_lte_lr | _) => let
          val (pf_lte_rh | swapped) = order_elements(A, right.hp.root, h.root, gt)
        in        
          if swapped then let
              val right1 = restore_heap_prop right.hp 
              prval pf_mr_leftright = mr_join(left.pf_mr, right1.pf_mr, pf_lte_lr, start, left.hp.root)
              prval pf_mrheap = mr_grow_r( pf_mr_leftright, pf_lte_rh, start, right1.hp.root)
            in @{hp = h, pf_mr = pf_mrheap} 
          end else let
              prval pf_mr_leftright = mr_join(left.pf_mr, right.pf_mr, pf_lte_lr, start, left.hp.root)
              prval pf_mrheap = mr_grow_r( pf_mr_leftright, pf_lte_rh, start, right.hp.root)
            in @{hp = h, pf_mr = pf_mrheap}
          end          
        end      
    end // end of restore_big_heap_prop

  fun restore_heapstring_prop {m,r,k,sz:nat | m + sz - 1 == r; r < n} 
    (hs : heaps(m), h:heap(r,k,sz)):<cloref1> heaps(r+1) =
    case+ hs of
    | nil () => let val h1 = restore_heap_prop h in
        @{hp = h1.hp, pf_mr = h1.pf_mr, pf_totalmr = h1.pf_mr} :: nil 
      end
    | ph :: rest =>
      if h.k < 2 then let        
        prval () = leon_base_is1 h.pf_sz in //prove that h.root = ph.root+1
        restore_heapstring_prop_small(rest, ph, h)
      end else let
        val (left, right) = split_heap h in
        restore_heapstring_prop_big(rest, ph, left, right, h)
      end
   
  and restore_heapstring_prop_small {m,r,k,sz,r1,k1,sz1:nat | m + sz - 1 == r; r1 < n; r1 == r+1; sz1 == 1}
        (hs: heaps(m), h: heap2(r,k,sz), small: heap(r1,k1,sz1)):<cloref1> heaps(r+2) = let 
    val (pf_lte | orderres) = order_roots(h, small.root) 
  in
    case+ orderres of
    | Noswap () => @{hp = small, pf_mr = MRsingle, 
        pf_totalmr = mr_grow_r(h.pf_totalmr, pf_lte, 0, h.hp.root)} :: h :: hs
    | Swap(hp) => let
        val hsr = restore_heapstring_prop(hs, hp)
        val htop = top hsr
      in @{hp = small, pf_mr = MRsingle, 
           pf_totalmr = mr_grow_r(htop.pf_totalmr, pf_lte, 0, htop.hp.root)} :: hsr
      end  
  end
  
  and restore_heapstring_prop_big {m,phr,phk,phsz,lr,lsz,rr,rsz,hr,hk,hsz:nat | 
        m + phsz - 1 == phr; hr == phr + hsz;  hsz == lsz + rsz + 1; rr + 1 == hr; rr == lr + rsz; 
        lsz > 0; rsz > 0; hr < n;  hk > 1} 
        (hs: heaps(m), prev: heap2(phr,phk,phsz), 
         left: heap1(lr,hk-1,lsz), right: heap1(rr,hk-2,rsz), h: heap(hr,hk,hsz)):<cloref1>  heaps(hr+1) = let  
    val cmp_prev_left = compare_elements(A, prev.hp.root, left.hp.root, gt)
    val cmp_prev_right = compare_elements(A, prev.hp.root, right.hp.root, gt)
  in  
    case (cmp_prev_left, cmp_prev_right) of
    | (LeftGr(lte_left_prev |_), LeftGr(lte_right_prev |_)) => let
        val (lte_prevbig | orderres) = order_roots(prev, h.root) 
        prval lte_leftbig = lte_trans(lte_left_prev, lte_prevbig)
        prval lte_rightbig = lte_trans(lte_right_prev, lte_prevbig)        
        prval pf_mrr = mr_grow_r(right.pf_mr, lte_rightbig, right.hp.root-right.hp.sz+1, right.hp.root)
        prval pf_mrbig = mr_join(left.pf_mr, pf_mrr, lte_leftbig, left.hp.root-left.hp.sz+1, left.hp.root)
        in
        case+ orderres of
        | Noswap () => let //previous root is bigger than both children but not the new root
            prval pf_mrtotal = mr_join(prev.pf_totalmr, pf_mrbig, lte_prevbig, 0, prev.hp.root)
            in @{hp = h, pf_mr = pf_mrbig, pf_totalmr = pf_mrtotal} :: prev :: hs
            end              
        | Swap(hp) => let //previous root was bigger than all others, hp is now prev as heap
            val hsr = restore_heapstring_prop(hs, hp)
            val prev_r = top hsr // prev_r is prev with restored properties
            prval pf_mrtotal = mr_join(prev_r.pf_totalmr, pf_mrbig, lte_prevbig, 0, prev_r.hp.root)
            in @{hp = h, pf_mr = pf_mrbig, pf_totalmr = pf_mrtotal} :: hsr
            end
        end
    // one of children is bigger than previous root, just restore the heap property    
    | (_, RightGr(lte_prev_right |_)) => let 
          val h1 = restore_big_heap_prop(left, right, h)
          prval lte_right_big = lte_from_maxright(h1.pf_mr, right.hp.root, h1.hp.root-h1.hp.sz+1, h1.hp.root)
          prval lte_prevbig = lte_trans(lte_prev_right, lte_right_big)
          prval pf_mrtotal = mr_join(prev.pf_totalmr, h1.pf_mr, lte_prevbig, 0, prev.hp.root)
        in @{hp = h, pf_mr = h1.pf_mr, pf_totalmr = pf_mrtotal} :: prev :: hs
        end
    | (RightGr(lte_prev_left |_), _) => let 
          val h1 = restore_big_heap_prop(left, right, h)
          prval lte_left_big = lte_from_maxright(h1.pf_mr, left.hp.root, h1.hp.root-h1.hp.sz+1, h1.hp.root)
          prval lte_prevbig = lte_trans(lte_prev_left, lte_left_big)
          prval pf_mrtotal = mr_join(prev.pf_totalmr, h1.pf_mr, lte_prevbig, 0, prev.hp.root)
        in @{hp = h, pf_mr = h1.pf_mr, pf_totalmr = pf_mrtotal} :: prev :: hs
        end
  end // of restore_heapstring_prop_big

  fun grow {m:nat; m < n} (hs : heaps(m)):<cloref1> heaps(m+1) =
    case+ hs of
    | nil () => @{hp = small_heap(0, ~1), pf_mr = MRsingle, pf_totalmr = MRsingle} :: nil        
    | h :: nil () => let 
        val small = small_heap(h.hp.root + 1, h.hp.k)   
      in restore_heapstring_prop_small(nil, h, small)
      end      
    | h0 :: (h0rest as h1 :: rest) => 
        if h0.hp.k + 1 = h1.hp.k then let //join two top heaps into a bigger one
          prval () = lh_use h1.hp.pf_r 
          prval () = leon_positive h0.hp.pf_sz  //prove that child sizes > 0 to use LEONind
          prval pfgc = GCbig(h1.pf_mr, h0.pf_mr)
          val bighp = @{root = h0.hp.root + 1, k = h1.hp.k + 1, sz = h1.hp.sz + h0.hp.sz + 1, 
              pf_sz = LEONind(h0.hp.pf_sz, h1.hp.pf_sz), pf_r = LHpf, pf_gc = pfgc }
          in
          case+ rest of
          | prev :: hps => restore_heapstring_prop_big(hps, prev, heap2to1 h1, heap2to1 h0, bighp)
          | nil () => let 
              val bighp1 = restore_big_heap_prop (heap2to1 h1, heap2to1 h0, bighp) 
            in @{hp = bighp, pf_mr = bighp1.pf_mr, pf_totalmr = bighp1.pf_mr} :: nil
            end          
        end else let //add a small heap
          val small = small_heap(h0.hp.root + 1, h0.hp.k)
          in restore_heapstring_prop_small(h0rest, h0, small)
        end
  
  fun grow_loop {i:nat | i <= n} (i : int i, hs : heaps(i)) :<cloref1> heaps(n) =
    if i = n then hs else grow_loop(i+1, grow hs)    
  
  fun shrink {m:nat | m > 0; m <= n} (hs : heaps(m)) :<cloref1> heaps(m-1) =
    case+ hs of h :: rest => 
      if h.hp.k < 2 then let prval () = leon_base_is1 h.hp.pf_sz in rest end
      else let   
        val (left,right) = split_heap h.hp
        val hs1 = restore_heapstring_prop(rest, left.hp)          
        in        restore_heapstring_prop(hs1, right.hp)
      end

  fun shrink_loop{i:nat | i < n; i > 0} 
    (pf_sorted : SORTED(i-1,n-1) | i : int i, hs : heaps(i)) :<cloref1> (SORTED(0,n-1) | void) =
    if i = 1 then (pf_sorted | ()) 
    else let
      val htop = top hs
      prval pf_lte = lte_from_maxright(htop.pf_totalmr, i-2, 0, i-1)
    in     
      shrink_loop(SORTEDjoin(pf_lte, pf_sorted) | i-1, shrink hs)
    end  
in
  if n > 1 then let
    val hps = grow_loop(0, nil) //hps : heaps(n)
    val htop = top hps
    prval pf_lte21 = lte_from_maxright(htop.pf_totalmr, n-2, 0, n-1) // LTE(n-2, n-1)
    prval pf_sorted0 : SORTED(n-1, n-1) = SORTEDsingle
    prval pf_sorted1 : SORTED(n-2, n-1) = SORTEDjoin(pf_lte21, pf_sorted0)
    val hps1 = shrink hps
    in shrink_loop(pf_sorted1 | n-1, hps1)
  end else (SORTEDsingle | ())
end // of smoothsort

fun{a:t@ype} prarr {n:nat} 
  (pr: a -> void, A: array (a, n), sz: int n): void = let
  fun loop {i:nat | i <= n} (n: int n, i: int i):<cloptr1> void =
    if i < n then (if i > 0 then print ", "; pr A[i]; loop (n, i+1))
in
  loop (sz, 0); print_newline ()
end 

fn pr_int (x: int): void = print x

implement main () = let
  val A = array $arrsz{int}( 1, 10, 3, 2, 9, 7, 4, 5, 8, 6)
  val _ = prarr(pr_int, A, 10)
  val (pf_sorted | ()) = smoothsort(A, 10, gt_int_int)
in  
  prarr(pr_int, A, 10)
end