(* BigO equational reasoner *) (* Kevin Donnelly, August 2004 *) (* basic structure for function symbols, equations, and symbol sets *) structure Equation = struct (* type of function symbols and its comparator structure *) type symbol = int structure SymbolKey = struct type ord_key = symbol val compare = Int.compare end structure SymbolSet = SplaySetFn(SymbolKey) (* Big O sets are just sets of symbols. * The set {0, 2, 3} is used to represent the big O equation: * O(f_0 + f_2 + f_3) * since O(f_0 + f_0) is equivalent to O(f_0), representing * O equations with sets of symbols is sufficient *) type oset = SymbolSet.set (* An equation is represented by a vector of integers. * The vector [3, 0, 4, ~8] represents the equation * 3x_0 + 0x_1 + 4x_2 + (~8)x_3 = 0 * or, equivalently, * 3x_0 + 4x_2 = 8x_3 *) type equation = int Vector.vector (* Putting these together, we will represent a big O equation * using an equation and an oset. For example, the equation * [0,2,0,1] with the oset {0,2} represents the big O equation * 2f_1 + f_3 = 0 + O(f_0 + f_2) *) (* equality of equations *) fun eq_equal(x,y) = Vector.foldli (fn (i,n,res) => res andalso n = Vector.sub(y,i)) true (x,0,NONE) (* vectors of equations *) type eqVector = equation Vector.vector fun is_zero t = Vector.foldl(fn (n, res) => res andalso n = 0) true t (* printing functions *) val print = TextIO.print fun print_eq t = Vector.appi(fn (i,x) => (if not (i = 0) then print " + " else (); print ((Int.toString x) ^"x_"^(Int.toString i)))) (t,0,NONE) fun print_eq_vector eqv : unit = Vector.app (fn eq => (print "|\t"; print_eq eq; print "\n")) eqv exception EqError end (* Gaussian Elimination module *) structure GaussElim = struct open Equation (* find gcd, first taking absolute value *) fun gcd (m, n) = let fun gcd' (m, n) = if m = 0 then n else if n = 0 then m else gcd' (n mod m, m) in gcd'(Int.abs m, Int.abs n) end (* find least common multiple *) fun lcm(n, m) = (n * m) div (gcd(n,m)) (* find common divisor of a vector of integers *) fun gcd_vector v = let open Vector fun gcd' (n,x) = if x = 1 then 1 else if length v > n then gcd'(n+1,gcd(x,sub(v, n))) else x in gcd' (0,0) end (* divide any common factors out of a vector of coefficients *) fun normalize_eq eq = let (* first find common divisors *) val divisor = gcd_vector eq in (* then divide out *) if (divisor > 1) then Vector.map (fn x => x div divisor) eq else eq end (* Vector.update is not in this version of the Vector module *) fun update (v:'a Vector.vector, i:int, a:'a) = Vector.mapi (fn (j,x) => if (i = j) then a else x) (v,0,NONE) (* * Perform Gaussian elimination on a vector of equations (rows) using at most * the first 'max' many equations *) fun gauss_elim (E:eqVector) max = let open Vector (* find an equation in eqv with non-zero entry at index i *) fun find_eq eqv i min max = if min >= max then ~1 else if not(sub(sub(eqv,min),i) = 0) then min else find_eq eqv i (min+1) max (* swap the positions of two equations in the vector *) fun swap_rows eqv n m = let open Vector val row1 = sub(eqv,n) val row2 = sub(eqv,m) in update(update(eqv,n,row2),m,row1) end (* try to eliminate the given variable from equations numbered min * to max *) fun elim eqv var min max = if (var >= max) then (* elimination complete and only the first 'min' rows are necessary *) Vector.tabulate(min, fn i => Vector.sub(eqv,i)) else if (min >= Vector.length eqv) then (* elimination complete *) eqv else (* go ahead and try to eliminate 'var' from row 'min' *) let (* find a row with non-zero coefficient for 'var' *) val index = find_eq eqv var min (Vector.length eqv) (* eliminate var from all but the 'index' equation *) fun elim_var cnt eqv = if (cnt >= Vector.length eqv) then eqv else if (index = cnt) then elim_var (cnt+1) eqv else let val this_coef = sub(sub(eqv, cnt), var) in if (this_coef = 0) then elim_var (cnt+1) eqv else let val elim_row = sub(eqv, index) val elim_coef = sub(elim_row, var) val this_mult = lcm(elim_coef, this_coef) div elim_coef val elim_mult = lcm(elim_coef, this_coef) div this_coef (* use linear combination of this row and the index row to zero the coefficient*) fun zero_it (i, x) = (x * elim_mult) - (sub(elim_row, i) * this_mult) in (* update this row, and recurse to eliminate 'var' from the next row *) elim_var (cnt+1) (update(eqv, cnt, normalize_eq (Vector.mapi zero_it (sub(eqv, cnt), 0, NONE)))) end end in if (index = ~1) then (* column 'var' is all zeros *) elim eqv (var + 1) min max else if (index > min) then (* found row with non-zero 'var' coefficient, swap with first row *) elim (swap_rows eqv min index) var min max else (* first row has non-zero 'var' coefficient *) elim (elim_var 0 eqv) (var + 1) (min + 1) max end in elim E 0 0 max end (* check if 'eq' is a linear combination of the equations in eqv *) fun check_eq eqv eq = let val newv = Vector.concat([eqv, Vector.fromList([eq])]) val newelimv = gauss_elim newv (Vector.length eq) in Vector.foldli(fn (i,x,res) => res andalso is_zero x orelse ((i < Vector.length eqv) andalso (eq_equal(x, Vector.sub(eqv, i))))) true (newelimv, 0, NONE) end end (* a test for Gaussian Elimination structure test = struct open Equation val eq1 = Vector.tabulate(6,fn i => case i of 0 => 2 | 1 => 1 | 2 => 3 | 3 => 0 | 4 => 2 | 5 => 3 | x => raise EqError) val eq2 = Vector.tabulate(6,fn i => case i of 0 => 0 | 1 => 3 | 2 => 8 | 3 => ~3 | 4 => 2 | 5 => 3 | x => raise EqError) val eq3 = Vector.tabulate(6,fn i => case i of 0 => 1 | 1 => 12 | 2 => 4 | 3 => ~7 | 4 => 2 | 5 => 3 | x => raise EqError) val eq4 = Vector.tabulate(6,fn i => case i of 0 => 5 | 1 => 3 | 2 => 5 | 3 => ~2 | 4 => 2 | 5 => 3 | x => raise EqError) val eqv1 = Vector.tabulate(4,fn i => case i of 0 => eq1 | 1 => eq2 | 2 => eq3 | 3 => eq4 | x => raise EqError) val _ = print_eq_vector eqv1 val _ = print "Performing Gaussian Elimination\n\n" val _ = print_eq_vector (GaussElim.gauss_elim eqv1 4) end *) (* Structure for Fourier-Motzkin elimination procedure. This is used to check the * satisfiability of a system of linear inequalities. *) structure fourier_motzkin = struct open Equation fun zero n = Vector.tabulate(n, fn x => 0) fun get_coef(sym, eq) = Vector.sub(eq, sym) fun lcm(x,y) = GaussElim.lcm(x,y) (* which type of inequality *) datatype ineq_type = LT | LE type ineq = equation * ineq_type * equation exception BadComposition fun eq_mult_n t n = Vector.map (fn x => n * x) t fun print_ineq (t1,it,t2) = case it of LE => (print_eq t1;print " <= ";print_eq t2) | LT => (print_eq t1;print " < ";print_eq t2) fun print_ineqlist [] = () | print_ineqlist (ineq::ineqtail) = (print_ineq ineq; print "\n"; print_ineqlist ineqtail) (* given inequalities t1 LTE mx and nx LTE' t2 (where LTE and LTE' are < or <=), * compose will return * n't1 LTE'' m't2 * where n'm = m'n = LCM(m,n), and LTE'' is < if either of LTE or LTE is < and <= otherwise *) fun compose (t1, it, t2) (t1', it', t2') x = let val _ = (print "\n---\nComposing "; print_ineq (t1,it,t2); print " with\n"; print_ineq (t1',it',t2'); print"\n") val n = lcm(get_coef(x, t2), get_coef(x, t1')) val m = n div get_coef(x, t2) val m' = n div get_coef(x, t1') in case (it,it') of ((LT,x) | (x, LT)) => (eq_mult_n t1 m, LT, eq_mult_n t2' m') | _ => (eq_mult_n t1 m, LE, eq_mult_n t2' m') end exception Unsatisfiable (* when used with map_partial, check will * 1. eliminate trivially true inequalities (0 <= 0 in this case) * 2. raise an exception if there are any unsatisfiable inequalities (0 < 0 in this case) *) fun check (t1,it,t2) = if is_zero(t1) then if is_zero(t2) then case it of LT => raise Unsatisfiable | _ => NONE else SOME(t1,it,t2) else SOME(t1,it,t2) (* ineq_split partitions a list of inequalities into a triple (xless, lessx, nonx) such that * xless is of the form nx LTE t * lessx is of the form t LTE nx * nonx has a 0 coefficient for x *) fun ineq_split x ineqlist = let fun ineq_split' x (lessx,xless,nonx) [] = (lessx,xless,nonx) | ineq_split' x (lessx,xless,nonx) ((ineq as (t1,it,t2))::ineqtail) = let (* first, we move everything but x do the lhs of the inequality *) val x_coef = get_coef(x,t2) val t2' = Vector.mapi (fn (i,n) => if i = x then 0 else n - get_coef(i,t1)) (t2,0,NONE) val t1' = Vector.mapi (fn (i,n) => if i = x then n - x_coef else 0) (t1,0,NONE) val acc = if (get_coef(x,t1')) > 0 then (* if x has a positive coefficient then it is an xless *) (lessx,(t1',it,t2')::xless,nonx) else if (get_coef(x,t1')) < 0 then (* if x has a negative coefficient then multiplying through * by ~1 and flipping the sides gives us a lessx *) ((Vector.map (fn n => ~n) t2',it,Vector.map (fn n => ~n) t1')::lessx,xless,nonx) else (lessx,xless,(t1',it,t2')::nonx) in ineq_split' x acc ineqtail end in ineq_split' x ([], [], []) ineqlist end (* eliminate x from a list of inequalities *) fun eliminate x ineqlist = let val ineqlist = List.mapPartial check ineqlist val (lessx,xless,nonx) = ineq_split x ineqlist val comp = List.foldl (fn (ineq,res) => List.foldl (fn (ineq',res) => (compose ineq ineq' x)::res) res xless) [] lessx in comp@nonx end (* eliminate each variable in turn, if we are left with only trivial * inqualities then succeed with true, otherwise fail with false *) fun fourier_motzkin [] = (print "Satisfiable\n\n";true) | fourier_motzkin (ineqlist as (t1,it,t2)::ineqtail) = let val size = Vector.length t1 val final = List.mapPartial check (List.foldl (fn (var,ineqlist) => (print ("--- eliminating "^ (Int.toString var)^" ---\n");print_ineqlist ineqlist;eliminate var ineqlist) ) ineqlist (List.tabulate(size,fn x => x))) in if final = [] then (print "Satisfiable\n\n";true) else (* this should not happen *) (print "Strange! Left with non-trivial inequalities after Fourier-Motzkin elimination"; print_ineqlist final; false) end handle Unsatisfiable => (print "Unsatisfiable\n\n";false) end (* tests for the Fourier-Motzkin procedure structure test2 = struct open fourier_motzkin val eq1 = Vector.tabulate(6,fn i => case i of 0 => 2 | 1 => 1 | 2 => 3 | 3 => 0 | 4 => 2 | 5 => 3 | x => raise EqError) val eq2 = Vector.tabulate(6,fn i => case i of 0 => 0 | 1 => 3 | 2 => 8 | 3 => ~3 | 4 => 2 | 5 => 3 | x => raise EqError) val eq3 = Vector.tabulate(6,fn i => case i of 0 => 1 | 1 => 12 | 2 => 4 | 3 => ~7 | 4 => 2 | 5 => 3 | x => raise EqError) val eq4 = Vector.tabulate(6,fn i => case i of 0 => 5 | 1 => 3 | 2 => 5 | 3 => ~2 | 4 => 2 | 5 => 3 | x => raise EqError) val eq5 = Vector.tabulate(6,fn i => case i of 0 => ~2 | 1 => 5 | 2 => 0 | 3 => 1 | 4 => ~3 | 5 => 3 | x => raise EqError) val eq6 = Vector.tabulate(6,fn i => case i of 0 => 3 | 1 => 9 | 2 => 2 | 3 => ~3 | 4 => 2 | 5 => 8 | x => raise EqError) val eq7 = Vector.tabulate(6,fn i => case i of 0 => 1 | 1 => ~5 | 2 => 3 | 3 => 12 | 4 => ~2 | 5 => 3 | x => raise EqError) val eq8 = Vector.tabulate(6,fn i => case i of 0 => 2 | 1 => ~3 | 2 => 0 | 3 => ~2 | 4 => 4 | 5 => 0 | x => raise EqError) val zero = zero 6 val ineq1 = (eq1,LE,zero) val ineq2 = (eq2,LT,zero) val ineq3 = (zero,LE,eq3) val ineq4 = (zero,LT,eq4) val ineq5 = (eq5,LE,zero) val ineq6 = (eq6, LE, zero) val ineq7 = (eq4, LT, zero) val ineq8 = (zero, LE, eq8) val ineqlist = [ineq1,ineq2,ineq3,ineq4,ineq5,ineq6,ineq7,ineq8] val eqv1 = Vector.tabulate(4,fn i => case i of 0 => eq1 | 1 => eq2 | 2 => eq3 | 3 => eq4 | x => raise EqError) val _ = print_ineqlist ineqlist val _ = print "Performing Fourier-Motzkin Elimination\n\n" val _ = (fourier_motzkin ineqlist) end *) structure BigOAlg = struct open Equation open GaussElim (* a big O equation is an equation and a big O set *) type oEquation = equation * oset type oEqList = oEquation list (* a search is a triple (contos, goal, oeql) of a big O set, a goal and a list of equations * contos is the largest big O set that we have determined to be equivalent to the big O set * and oeql is the list of remaining non-trivial big O equations that we can use *) type search = oset * oEquation * oEqList fun getEQ (eq,os) = eq fun getOS (eq,os) = os exception negExpError fun power x y = let fun pow x y = if y = 0 then 1 else x * (pow x (y - 1)) in if (y >= 0) then pow x y else raise negExpError end (* create initial search *) fun init (oeql:oEquation list) (goal:oEquation) = (getOS goal,goal,oeql):search fun current_eqs os (inil as ((eq',os')::initl)) = if (SymbolSet.isSubset(os',os)) then eq'::(current_eqs os initl) else (current_eqs os initl) | current_eqs goal [] = [] (* covers eq os returns true if eq is trivial modulo os *) fun covers eq os = Vector.foldl (fn (n,res) => res andalso (n = 0 orelse SymbolSet.member(os,n))) true eq fun check_goal_trivial (contos,goal,iniv) = covers (getEQ goal) contos (* remove from eq coefficients of symbols contained in os then divide out common factors *) fun elim_normalize os eq = normalize_eq(Vector.mapi (fn (i,n) => if SymbolSet.member(os,i) then 0 else n) (eq,0,NONE)) (* simplify the current equations and check if the goal is a linear combination of them *) fun check_goal (contos,goal,iniv) = let open GaussElim val eqv = Vector.fromList(List.map (elim_normalize contos) (current_eqs contos iniv)) val elimv = gauss_elim eqv (Vector.length (getEQ goal)) in check_eq elimv (elim_normalize contos (getEQ goal)) end exception noNewContainments of search (* Take a step towards the goal by finding new symbols to add to the current big O set, * raise noNewContainments if no progress can be made *) fun generate_containments (contos,goal,iniv) = let open fourier_motzkin val eqlen = Vector.length (getEQ goal) val eql = List.map (elim_normalize contos) (current_eqs contos iniv) val equationl = List.tabulate(eqlen, fn i => Vector.tabulate(List.length eql, fn j => Vector.sub(List.nth(eql,j),i))) (* attempt to derive a big O equation in which symbol x is positive *) fun make_pos x = let fun make_ineqlist x (equation::equationtl) = if (x = 0) then ((zero (List.length eql)),LT,equation)::(make_ineqlist (~1) equationtl) else if x > 0 then ((zero (List.length eql)),LE,equation)::(make_ineqlist (x-1) equationtl) else ((zero (List.length eql)),LE,equation)::(make_ineqlist (~1) equationtl) | make_ineqlist x [] = [] in fourier_motzkin (make_ineqlist x equationl) end (* generate new containments by deriving new equations, e.g. if we * can determine that x_0 + 3x_2 = 0 + O(x_1) then we know * O(x_0 + x_1 + x_2) = O(x_1) * and we expand the current big O set appropriately * (by adding x_0 and x_2 for the example) *) val newcontos = List.foldl (fn (x,res) => if SymbolSet.member(res,x) then res else if (make_pos x) then SymbolSet.add(res,x) else res ) contos (List.tabulate(eqlen, fn x => x)) in if not(SymbolSet.equal(contos,newcontos)) then (newcontos,goal,iniv) else raise noNewContainments (contos,goal,iniv) end fun print_oset os max = let fun print0 0 = () | print0 n = (print "0";print0 (n-1)) fun printItems (n::ntail) last = (print0 ((n - last) - 1);print "1";printItems (ntail) n) | printItems nil last = last in (print "[";print0 ((max - (printItems (SymbolSet.listItems os) (~1))) - 1);print "]") end fun print_search (contos,goal,iniv) = (print "Current State:\n"; print "Goal: "; print_eq (getEQ goal); print " + O("; print_oset (getOS goal) (Vector.length (getEQ goal)); print ")\n"; print "Containments: "; print_oset contos (Vector.length (getEQ goal)); print "\n") fun solve goal oeql = let fun loop search : bool = (print_search search; loop (generate_containments search)) in (print "generating bigo containments\n"; loop (init oeql goal)) handle noNewContainments(search) => (print "All containments generated\n"; if check_goal_trivial search orelse check_goal search then (print "Success\n"; true) else (print "Failure\n"; false)) end end structure test3 = struct open BigOAlg (* f = g + O(k) h = g + O(k) ------------ f = h + O(k) *) val prems1 = [([1,~1,0,0],[3]),([0,~1,1,0],[3])] val goal1 = ([1,0,~1,0],[3]) (* f + g = h + O(k) g = l + O(k) ---------------- f + l = h + O(k) *) val prems2 = [([1,1,~1,0,0],[3]),([0,1,0,~1,0],[3])] val goal2 = ([1,0,~1,1,0],[3]) (* f + g = h + O(m) g = l + O(k) -------------------- f + l = h + O(k + m) *) val prems3 = [([1,1,~1,0,0,0],[5]),([0,1,0,~1,0,0],[4])] val goal3 = ([1,0,~1,1,0,0],[4,5]) (* f + f = g + g + O(k) -------------------- f = g + O(k) *) val prems4 = [([2,~2,0],[2])] val goal4 = ([1,~1,0],[2]) (* f + f + g = 0 + O(k) --------------------- f = 0 + O(k) *) val prems5 = [([2,1,0],[2])] val goal5 = ([1,0,0],[2]) (* f + g = h + O(k) g + l = h + O(k) -------------------- f = l + O(k) *) val prems6 = [([1,1,~1,0,0],[3]), ([0,1,~1,0,1],[3])] val goal6 = ([1,0,0,0,~1],[3]) (* f + g = h + O(k) g = 0 + O(l) k = 0 + O(l) ------------------ f = h + O(l) *) val prems7 = [([1,1,~1,0,0],[3]), ([0,1,0,0,0],[4]), ([0,0,0,1,0],[4])] val goal7 = ([1,0,~1,0,0],[4]) exception InconsistentVectors exception TestFailed fun run_test prems (geq,gos) = let val start_t = Timer.startRealTimer() val len = ref (~1) val oeql = List.map (fn (eq,os) => ((if !len = (~1) then len := List.length(eq) else if not ((!len) = List.length(eq)) then raise InconsistentVectors else () ); (Vector.fromList eq, SymbolSet.addList(SymbolSet.empty,os)) )) prems val goal = (Vector.fromList geq, SymbolSet.addList(SymbolSet.empty,gos)) val res = solve goal oeql val run_t = Timer.checkRealTimer start_t in print ("Completed test in: " ^ (Time.fmt 3 run_t) ^ "\n"); res end fun succ_test prems (geq,gos) = if run_test prems (geq, gos) then () else raise TestFailed fun fail_test prems (geq,gos) = if not (run_test prems (geq, gos)) then () else raise TestFailed val _ = succ_test prems1 goal1 val _ = succ_test prems2 goal2 val _ = succ_test prems3 goal3 val _ = succ_test prems4 goal4 val _ = succ_test prems5 goal5 val _ = succ_test prems6 goal6 val _ = succ_test prems7 goal7 val _ = fail_test prems4 goal5 end