open ExtLib
open Commons
open Ast

let last_uid = ref 0
let uid () = incr last_uid; !last_uid
let mkname s = Printf.sprintf "%s_%d" s (uid ())

let inssize = function
	| Arith(_, Bin _) -> 4 | Arith(_, Not _) -> 3 | Jmple _ -> 4 | PrChar _ -> 2 | Label _ | Jmp _ | BeginScope _ | EndScope -> 0
	| ins -> failwith (Printf.sprintf "error: getting size of '%s'" (inst_s ins))

let fundefs = Hashtbl.create 11

let shuffle_pairs lst =
  let instarr = Array.of_list lst	in
	Array.sort (fun p1 p2 -> fst p1 - fst p2) instarr;
	Array.map snd instarr |> Array.to_list

let rec expand_inst = function
	| IfLess(a, b, thencode, elsecode) ->
			let then_lab = mkname "then" and end_lab = mkname "endif" in
			Jmple(a, b, then_lab) 
			:: expand_code elsecode
			@ [Jmp end_lab; Label then_lab]
			@ expand_code thencode
			@ [Label end_lab]
	| While(a, b, code) ->
			let wlab = mkname "while" in
			[Label wlab] @ expand_inst (IfLess(a, b, code @ [Jmp wlab], []))
	| Fun(nm, params, code) ->
			BeginScope(Some nm, false) :: expand_fun nm params code
	| DenseFun(nm, params, code) ->
			BeginScope(Some nm, true) :: expand_fun nm params code
	| Prepare(fn, cs) -> 
			let fend = fn ^ "_end" in		 	
			[Arith(Var fend, Bin(Add, Untitled2, RetCode(fend, cs)))]			
	| FastCall(fn, cs) -> [Jmp fn; Label cs]
	| Call(fn, args) ->
			let params = try Hashtbl.find fundefs fn with Not_found -> failwith ("calling unknown function " ^ fn) in
			if List.length params <> List.length args then failwith ("bad number of params for " ^ fn) else
			let setargs = List.map2 (fun parname exp -> Arith(Var (fn ^ "." ^ parname), exp)) params args in
			let callsite = mkname "call" in
			expand_code (setargs @ [Prepare(fn, callsite); FastCall(fn, callsite)])
	| Print s -> String.explode s |> List.mapi (fun j c -> 
			let cc = int_of_char c in
			let exp = if cc mod 2 = 0 then Bin(Mul, Untitled2, Num (cc/2)) else Bin(Add, Untitled2, Num(cc-2)) in
			let var = mkname "cc" in  
			let i1 = Random.int ((j+1)*10) and i2 = (j+1)*10 in
			[i1, Arith(Var var, exp); i2, PrChar (Def(var, Random.int 456))]) |> List.concat	|> shuffle_pairs			
	| PrintOnce s -> String.explode s |> List.mapi (fun j c -> 
			let cc = int_of_char c in
			let var = mkname "cc" in  
			let mul2 = if cc mod 2 = 0 then Some(Bin(Mul, Untitled2, Num (cc/2)), Random.int 432) else None in
			let add2 = Bin(Add, Untitled2, Num(cc-2)), Random.int 234 in
			let addx = let a = Random.int cc in Bin(Add, Var var, Num a), cc - a in
			let orx = Bin(Or, Var var, Num (cc land 0x5D)), (cc land 0xA3) in
			let arr = Option.map_default (fun x -> [x]) [] mul2 @ [add2; addx; orx] |> Array.of_list in
			let exp, v = arr.(Random.int (Array.length arr)) in	
			let i1 = Random.int ((j+1)*10) and i2 = (j+1)*10 in	
			[i1, Arith(Var var, exp); i2, PrChar (Def(var, v))]) |> List.concat |> shuffle_pairs
	| ins -> [ins]

and expand_fun nm params code =
			let rec cutlast tail = function			
				| last :: prev ->
						if inssize last > 0 then (last::tail, prev) else cutlast (last::tail) prev
				| [] -> failwith ("empty function " ^ nm) in			
			let tail, rc = cutlast [] (List.rev (expand_code code)) in
			Hashtbl.add fundefs nm params;											  
			Label nm :: List.rev rc @ Label (nm ^ "_end") :: tail @ EndScope :: []	

and expand_code code = List.map expand_inst code |> List.flatten;;

module StrMap = Map.Make(String)

type scope_t = { vars : int StrMap.t; subscopes : scope_t list; ancest : scope_t list; name : string option }
type env_t = { labels : int StrMap.t; scope : scope_t; ip : int }

let addvar env name addr = 
	if StrMap.mem name env.scope.vars then failwith (name ^ " is already defined") else
	{ env with scope = { env.scope with vars = StrMap.add name addr env.scope.vars }}
	
let getlabel env name = 
	try StrMap.find name env.labels with Not_found -> failwith ("unknown label " ^ name) 

let new_scope anc nm = { vars = StrMap.empty; subscopes = []; ancest = anc; name = nm }

let rev_scopes scp = { scp with subscopes = List.rev scp.subscopes }

let push_new_scope env name_opt =
	let sc = new_scope (env.scope :: env.scope.ancest) name_opt in
	{ env with scope = sc }

let close_scope env =
	match env.scope.ancest with
	| parent :: anc ->
			let direct_order = rev_scopes env.scope in
			let sc = { parent with subscopes = direct_order :: parent.subscopes } in
			{ env with scope = sc }
	| _ -> failwith "closing a root scope"

let find_var env name =
	let rec loop scope =
		if StrMap.mem name scope.vars then StrMap.find name scope.vars else
			match scope.ancest with
			| parent :: anc -> loop parent
			| _ -> try StrMap.find name env.labels with Not_found -> failwith ("unknown variable " ^ name) in
	loop env.scope

let enter_scope env =
	match env.scope.subscopes with
	| sub :: rest ->
			let par = { env.scope with subscopes = rest } in
			let sc = { sub with ancest = par :: par.ancest } in
			{ env with scope = sc }
	| [] -> failwith "trying to enter an empty scope"

let leave_scope env =
	match env.scope.ancest with
	| par :: anc -> { env with scope = par }
	| [] -> failwith "leaving a root scope"

let publish_funvars scope =
	let rec visit sc vs =
		List.fold_left (fun vs scp ->
			let vs' = match scp.name with
				| Some nm -> StrMap.fold (fun var addr vs -> StrMap.add (nm ^ "." ^ var) addr vs) scp.vars vs				
				| None -> vs  in
			visit scp vs') vs sc.subscopes in
	let vs' = visit scope scope.vars in
	{scope with vars = vs'}

let freeints = Std.input_file "intsmap";; 

let calc_spans () = 
	let cells = freeints in
	let spans = Array.create 60000 (-1) in 
	let rec loop i n =
		if i < 0 then spans else
		if cells.[i] = '1' then (spans.(i) <- n; loop (i-1) (n+1)) 
		else (spans.(i) <- -1; loop (i-1) 1) in
	loop 59999 1
			
let position_code code = 
	let spans = calc_spans () in
	let rec find_pos sz pos =
		if pos >= 60000 then failwith "failed to position code" else
		if spans.(pos) >= sz && (Random.int 5 = 2 || pos < 40) then pos else find_pos sz (pos+1) in		
	let lst, lastpos,_ = List.fold_left (fun (res,pos,dense) ins ->
		let dense' = match ins with
			| BeginScope(_, dns) -> dns
			| EndScope -> false
			| _ -> dense in  
		let sz = inssize ins in
		let pos' = if dense' then pos else find_pos sz pos in
		if spans.(pos') < sz then failwith "position code spans fail" else
		(ins, pos') :: res, pos' + sz, dense') ([], 32, false) code in 
	let _,lst2 = List.fold_left (fun (curpos,res) (ins,pos) ->
		if inssize ins = 0 then curpos, ((ins, curpos)::res) 
		else pos, ((ins, pos)::res)) (lastpos,[]) lst in
	lst2, lastpos, spans

let rec learn_names pcode =	
	let step env (ins,pos) =
		learn_names_ins {env with ip = pos} ins in
	let env = List.fold_left step { labels = StrMap.empty; scope = new_scope [] None; ip = 32 } pcode in
	{ env with scope = rev_scopes env.scope |> publish_funvars }

and learn_names_larg d env = function (*d = 1 or 2*)
	| Var _	| Untitled2 -> env
	| Def2 s -> addvar env s (env.ip + d)

and learn_names_ins env = function
	| Arith(l, e) -> learn_names_larg 1 env l |> flip learn_names_exp e
	| Jmple(a1, a2, lab) ->
			let e' = learn_names_rarg 2 env a1 in
			learn_names_rarg 3 e' a2
	| PrChar r -> learn_names_rarg 1 env r
	| Label s -> 
			if StrMap.mem s env.labels then failwith (s ^ " label already exists") else
			{ env with labels = StrMap.add s env.ip env.labels }
	| Jmp _ -> env
	| BeginScope(nm,_) -> push_new_scope env nm
	| EndScope -> close_scope env
	| ins -> failwith (Printf.sprintf "error: learn_names '%s'" (inst_s ins))

and learn_names_exp env = function
	| Bin(op, l, r) ->
			let e = learn_names_larg 2 env l in
			learn_names_rarg 3 e r
	| Not r -> learn_names_rarg 2 env r

and learn_names_rarg d env = function
	| Num _ | RetCode _ -> env
	| Def(s, _) -> addvar env s (env.ip + d)	

let cut_jumps code =
	let rec loop subs = function (* eats backwards *)
		| Jmp trg :: Label lab :: rest ->
				if trg = lab then failwith ("infinite loop at " ^ lab) else
					loop (StrMap.add lab trg subs) (Jmp trg :: rest)
		| a :: rest -> loop subs rest
		| [] -> subs in
	let redirs = loop StrMap.empty (List.rev code) in
	let rec endpoint lab =
		try StrMap.find lab redirs |> endpoint with Not_found -> lab in
	let cutra arg = match arg with
		| Num _ | Def _ -> arg
		| RetCode(fn,cs) -> RetCode(endpoint fn, endpoint cs) in
	let rec cut = function (* goes forward *)
		| Jmp a :: Jmp b :: rest -> cut (Jmp a :: rest)
		| Jmp lab :: rest -> Jmp (endpoint lab) :: cut rest
		| Jmple(a1, a2, lab) :: rest -> Jmple(cutra a1, cutra a2, endpoint lab) :: cut rest
		| Label lab :: rest -> if StrMap.mem lab redirs then cut rest else Label lab :: cut rest
		| Arith(l, Bin(op,l2,r)) :: rest ->Arith(l, Bin(op,l2,cutra r)) :: cut rest
		| a :: rest -> a :: cut rest
		| [] -> [] in
	cut code;;

let binop_c = function
	| Add -> 1 | Sub -> 2 | Mul -> 3 | Div -> 4 | And -> 5 | Or -> 6 | Shl -> 7 | Shr -> 8

let op_c op = binop_c op lsl 16
let opc_Not = 9 lsl 16
let opc_Jmple = 10 lsl 16
let opc_PrChar = 11 lsl 16

let larg_c env d = function
	| Var s -> find_var env s - env.ip
	| Def2 _ | Untitled2 -> d

let rarg_c obj env = function
	| Num n -> n
	| Def(s, n) -> n
	| RetCode(fend, cs) -> 
			let fe = getlabel env fend in
			let oldcode = obj.(fe) in 
			(oldcode land 0x7FFF0000) lor ((getlabel env cs - fe - 2) land 0xFFFF)			

let inst_c obj env = function
	| Arith(ld, Bin(op, l, r)) -> [| op_c op; larg_c env 1 ld; larg_c env 2 l; rarg_c obj env r |]
	| Arith(ld, Not r) -> [| opc_Not; larg_c env 1 ld; rarg_c obj env r |]
	| Jmple(r1, r2, lab) ->
			let addr = getlabel env lab in
			[| opc_Jmple; addr - env.ip; rarg_c obj env r1; rarg_c obj env r2 |]
	| PrChar r -> [| opc_PrChar; rarg_c obj env r |]
	| Label _ | Jmp _ | BeginScope _ | EndScope -> [||]
	| While _ | IfLess _ | Fun _ | DenseFun _ | Prepare _ | FastCall _ | Call _ | Print _ | PrintOnce _ -> 
			failwith "unexpanded code in inst_c"

let write_int_c bmpdata ip v c =
	let j = ip*4 in
	if freeints.[ip] <> '1' then failwith "writing to busy cell!" else
	(freeints.[ip] <- c;
	bmpdata.[j] <- char_of_int (v land 0xFF);
	bmpdata.[j+1] <- char_of_int ((v lsr 8) land 0xFF);
	bmpdata.[j+2] <- char_of_int ((v lsr 16) land 0xFF);
	bmpdata.[j+3] <- char_of_int ((v asr 24) land 0xFF))

let write_int bmpdata ip v = write_int_c bmpdata ip v '2'
let write_int3 bmpdata ip v = write_int_c bmpdata ip v '3'

let encode code bmpdata writer =	
	let obj = Array.create 60000 0 in
	let write env ins dip jmpo =
		let cd = inst_c obj env ins in
		if Array.length cd = 0 then env else
			(cd.(0) <- cd.(0) lor (dip land 0xFFFF);
				let sip = Printf.sprintf "ip=%d " env.ip in
				writer cd (sip ^ inst_s ins ^ (Option.map_default (Printf.sprintf "; jmp %s") "" jmpo));
				Array.iteri (fun i v -> obj.(env.ip+i) <- v; write_int bmpdata (env.ip+i) v) cd; env) in
	let rec loop env = function
		| (BeginScope _, _) :: rest -> loop (enter_scope env) rest
		| (EndScope, _) :: rest -> loop (leave_scope env) rest
		| (ins,pos) :: (Jmp lab, _) :: rest ->
				let dip = try StrMap.find lab env.labels - pos
					with Not_found -> failwith ("jmp: unknown label " ^ lab) in
				let e = {env with ip = pos} in
				loop (write e ins dip (Some (Printf.sprintf "%s (%d, %d)" lab dip (pos+dip)))) rest
		| (ins,pos) :: (ins2,pos2) :: rest -> 
				let e = {env with ip = pos} in
				let dip = pos2 - pos in
				loop (write e ins dip None) ((ins2,pos2) :: rest)
		| (ins,pos) :: [] -> 
				let e = {env with ip = pos} in
				let dip = inssize ins in
				ignore(write e ins dip None)
		| [] -> () in
	let pcode, lastpos, spans = position_code code in	
	let env = learn_names pcode in
	loop { env with ip = 32 } pcode;
	lastpos, spans

let show_map = StrMap.iter (Printf.printf "%s: %d\n");;
let rec show_scope sc =
	print_endline "{"; show_map sc.vars;
	List.iter show_scope sc.subscopes; print_endline "}";;

let fit_data datapos bmpdata spans =
	let imgdata = Std.input_file ~bin:true "img.dat" in
	let nints = String.length imgdata / 4 in
	Printf.printf "lastpos=%d spans[pos]=%d\n" datapos spans.(datapos);	
	let rec find_pos pos =
		if pos >= 60000 then failwith "failed to position data" else
		if spans.(pos) > 2 && Random.int 15 = 1 then pos else find_pos (pos+1) in
	let rec loop left pos = 
		if left <= 0 then () else
		(let p = find_pos (pos+1) in
		write_int3 bmpdata pos (p-pos-1);
		let sz = min (spans.(p)-2) left in
		write_int3 bmpdata p sz;
		String.blit imgdata ((nints - left)*4) bmpdata (p*4+4) (sz*4);
		for i = p+1 to p+sz do
			freeints.[i] <- '3'
		done;
		loop (left - sz) (p+sz+1)) in
	loop nints datapos;;	
		
let rnd () = Random.int 0x3FFFFFFF * Random.int 17 + Random.int 0x3FFFFFFF		
		
let fill_empty_space bmpdata =
	String.enum freeints |> Enum.iteri (fun i c ->
		if c='1' then write_int_c bmpdata i (rnd ()) '4');;		

let compile () =
	let ecode = expand_code Dsl.code |> cut_jumps in
	let bmpdata = Std.input_file ~bin:true "bmpdata" in
	let lastpos, spans = encode ecode bmpdata (fun cd cmt ->
				Array.iter (Printf.printf "%d, ") cd; Printf.printf "// %s\n" cmt) in
	fit_data lastpos bmpdata spans;		
	fill_empty_space bmpdata;		
	let oc = open_out_bin "bmpdata1" in
	output_string oc bmpdata;
	close_out oc;
	let oc2 = open_out "map1" in
	output_string oc2 freeints;
	close_out oc2;;

compile ();;
