1 | {-# LANGUAGE TypeFamilies, MultiParamTypeClasses, FlexibleInstances, |
---|
2 | IncoherentInstances, RankNTypes, ScopedTypeVariables, |
---|
3 | FlexibleContexts,UndecidableInstances #-} |
---|
4 | module Constraints where |
---|
5 | |
---|
6 | import Lambda |
---|
7 | import qualified Data.List as List |
---|
8 | import qualified Data.Set as Set |
---|
9 | import qualified Data.Map as Map |
---|
10 | import qualified Data.Ord as Ord |
---|
11 | import qualified Data.SBV as SBV |
---|
12 | import Data.SBV ( (.==), (.<), (.>=)) |
---|
13 | import Control.Monad |
---|
14 | import Data.IORef |
---|
15 | import Data.Supply as S |
---|
16 | |
---|
17 | asType :: a -> a -> a |
---|
18 | asType a b = a |
---|
19 | |
---|
20 | data D = D { var::Int, cond::[Condition] } |
---|
21 | |
---|
22 | data TypeKind = V | L TypeKind | F TypeKind TypeKind | U |
---|
23 | deriving Show |
---|
24 | |
---|
25 | |
---|
26 | class TS a where |
---|
27 | tk :: a -> TypeKind |
---|
28 | instance (TS a, TS b) => TS (a -> b) where |
---|
29 | tk _ = F (tk (undefined :: a)) (tk (undefined :: b)) |
---|
30 | instance TS a => TS [a] where |
---|
31 | tk _ = L (tk (undefined :: a)) |
---|
32 | instance TS a where |
---|
33 | tk _ = V |
---|
34 | instance TS Int where |
---|
35 | tk _ = U |
---|
36 | |
---|
37 | appall :: Int -> [L] -> L |
---|
38 | appall var l = foldl (App) (Var $ var) $ l |
---|
39 | |
---|
40 | |
---|
41 | data Condition = Condition [Constraint] L L |
---|
42 | |
---|
43 | instance Show Condition where |
---|
44 | showsPrec _ (Condition d a b) = shows d . showString "|- " . shows a . showString " = " . shows b |
---|
45 | showList [] = showString "" |
---|
46 | showList [x] = shows x |
---|
47 | showList (x:xs) = shows x . showString "\n" . shows xs |
---|
48 | |
---|
49 | data Constraint = Zero L | LTC L L | GEC L L |
---|
50 | deriving Eq |
---|
51 | |
---|
52 | instance Show Constraint where |
---|
53 | showList [] = id |
---|
54 | showList [s] = shows s . showChar ' ' |
---|
55 | showList (x:xs) = shows x . showString ", " . shows xs |
---|
56 | showsPrec _ (Zero s) = shows s . showString " = 0" |
---|
57 | showsPrec _ (LTC s1 s2) = shows s1 . showString " < " . shows s2 |
---|
58 | showsPrec _ (GEC s1 s2) = shows s1 . showString " >= " . shows s2 |
---|
59 | |
---|
60 | normalize :: L -> L |
---|
61 | normalize l = delzero $ foldl (\a b -> Op a '-' b) (foldl (\a b -> Op a '+' b) (f1 $ List.sortBy (Ord.comparing snd) a) c) d |
---|
62 | where |
---|
63 | (a,b,c,d) = norm l |
---|
64 | f1 ((0,l):xs) = f1 xs |
---|
65 | f1 ((1,l):xs) = Op (f1 xs) '+' (Var l) |
---|
66 | f1 ((-1,l):xs) = Op (f1 xs) '-' (Var l) |
---|
67 | f1 [] = Num b |
---|
68 | |
---|
69 | delzero (Op (Num 0) '+' b) = b |
---|
70 | delzero (Op a c b) = Op (delzero a) c b |
---|
71 | delzero l = l |
---|
72 | |
---|
73 | gnormalize :: L -> (L, Int) |
---|
74 | gnormalize l = case (c,d) of |
---|
75 | ([], []) -> case nonzero of |
---|
76 | [] -> (Num 0, b) |
---|
77 | _ -> (expr, b) |
---|
78 | (_, _) -> error $ "Expression in condition" ++ (show c) |
---|
79 | where |
---|
80 | nonzero = List.sortBy (Ord.comparing snd) $ filter ((/=0).fst) a |
---|
81 | proc (1, l) = Var l |
---|
82 | proc (-1, l) = Op (Num (-1)) '*' (Var l) |
---|
83 | t :: L -> (Int, Int) -> L |
---|
84 | t l (1, var) = Op l '+' (Var var) |
---|
85 | t l (-1, var) = Op l '-' (Var var) |
---|
86 | expr = foldl t (proc $ head nonzero) (tail nonzero) |
---|
87 | |
---|
88 | (a,b,c,d) = norm l |
---|
89 | |
---|
90 | normalizec :: Constraint -> Constraint |
---|
91 | normalizec (Zero l) = Zero $ normalize l |
---|
92 | normalizec (LTC a b) = LTC (normalize a) (normalize b) |
---|
93 | normalizec (GEC a b) = GEC (normalize a) (normalize b) |
---|
94 | --normalizec (LTC a b) = LTC x (Num (-y)) |
---|
95 | -- where (x,y) = gnormalize $ Op a '-' b |
---|
96 | --normalizec (GEC a b) = GEC x (Num (-y)) |
---|
97 | -- where (x,y) = gnormalize $ Op a '-' b |
---|
98 | |
---|
99 | normalizecs :: [Constraint] -> [Constraint] |
---|
100 | normalizecs = normalizecs' . Prelude.map normalizec |
---|
101 | where |
---|
102 | normalizecs' [] = [] |
---|
103 | normalizecs' (x:xs) = x:normalizecs' (filter (/=x) xs) |
---|
104 | |
---|
105 | {- |
---|
106 | - Takes an extended lambda expression and |
---|
107 | -} |
---|
108 | norm :: L -> ([(Int, Int)], Int, [L], [L]) |
---|
109 | norm (App a b) = ([], 0, [App (normalize a) (normalize b)], []) |
---|
110 | norm (List a b) = ([], 0, [List (normalize a) (normalize b)], []) |
---|
111 | norm (AAbs a b e) = ([], 0, [AAbs a b (normalize e)], []) |
---|
112 | norm (Shift a b c) = ([], 0, [Shift (normalize a) (normalize b) (normalize c)], []) |
---|
113 | norm (Aggr a b c) = ([], 0, [Aggr a (normalize b) (normalize c)], []) |
---|
114 | norm (Unsized) = ([], 0, [Unsized], []) |
---|
115 | norm (Bottom) = ([], 0, [Bottom], []) |
---|
116 | norm (Abs i l) = ([], 0, [Abs i $ normalize l], []) |
---|
117 | norm (Var i) = ([(1, i)], 0, [], []) |
---|
118 | norm (Num i) = ([], i, [], []) |
---|
119 | norm q@(Op a c b) = case c of |
---|
120 | '+' -> (a1++b1, a2+b2, a3++b3, a4++b4) |
---|
121 | '-' -> (sub a1 b1, a2-b2, a3++b4, a4++b3) |
---|
122 | '*' -> case a of |
---|
123 | (Num cnt) -> (map (mul cnt) b1, cnt*b2, map (Op (Num cnt) '*') b3, map (Op (Num cnt) '*') b4) |
---|
124 | _ -> ([], 0, [q], []) |
---|
125 | _ -> ([], 0, [q], []) |
---|
126 | where |
---|
127 | mul c (x,var) = (c*x, var) |
---|
128 | (a1, a2, a3, a4) = norm a |
---|
129 | (b1, b2, b3, b4) = norm b |
---|
130 | sub l1 ((c,i):xs) = sub (sub' l1 c i) xs |
---|
131 | where |
---|
132 | sub' [] c i = [(-c, i)] |
---|
133 | sub' ((cc,ii):xs) c i = if i==ii then (cc-c,i):xs else (cc,ii):sub' xs c i |
---|
134 | sub l1 [] = l1 |
---|
135 | |
---|
136 | tnorm = Op (Num 1) '+' (Op (Var 0) '-' (Num 1)) |
---|
137 | |
---|
138 | checkCond :: Supply Int -> [Condition] -> IO [Condition] |
---|
139 | checkCond v l = do |
---|
140 | ll <- mapM (checkCond1 v) l |
---|
141 | return $ concat ll |
---|
142 | |
---|
143 | where |
---|
144 | -- checkCond1 v p@(Condition d a b) | a==b = do |
---|
145 | -- return [] |
---|
146 | checkCond1 v z@(Condition d (List a b) (List p q)) = do |
---|
147 | -- i <- freshtypevar v |
---|
148 | let (v1,v2,v3) = split3 v |
---|
149 | let i = supplyValue v1 |
---|
150 | let b' = rall $ App b (Var i) |
---|
151 | let q' = rall $ App q (Var i) |
---|
152 | l1 <- checkCond1 v2 (Condition ((Var i `GEC` Num 0):d) a p) |
---|
153 | l2 <- checkCond1 v3 (Condition ((Var i `GEC` Num 0):(Var i `LTC` a):(Var i `LTC` p):d) b' q') |
---|
154 | return $ l1++l2 |
---|
155 | checkCond1 v z@(Condition d (App (Shift e f g) h) x) = do |
---|
156 | let (v1,v2,v3) = split3 v |
---|
157 | let e' = rall $ App e h |
---|
158 | let g' = rall $ App g $ Op h '-' f |
---|
159 | l1 <- checkCond1 v1 (Condition ((h `LTC` f):d) e' x) |
---|
160 | l2 <- checkCond1 v2 (Condition ((h `GEC` f):d) g' x) |
---|
161 | putStrLn $ " -> " ++ (show l1) |
---|
162 | putStrLn $ " -> " ++ (show l2) |
---|
163 | return $ l1 ++ l2 |
---|
164 | checkCond1 v z@(Condition d (App (Var a) x) (App (Var b) y)) | a==b = do |
---|
165 | checkCond1 v (Condition d x y) |
---|
166 | checkCond1 v p@(Condition d x z@(App (Shift e f g) h)) = do |
---|
167 | checkCond1 v (Condition d z x) |
---|
168 | checkCond1 v z@(Condition d a b) = return $ [Condition dd (normalize a) (normalize b)] |
---|
169 | where |
---|
170 | dd = normalizecs d |
---|
171 | |
---|
172 | subst ndl hst (App a b) = App (subst ndl hst a) (subst ndl hst b) |
---|
173 | subst ndl hst (List a b) = List (subst ndl hst a) (subst ndl hst b) |
---|
174 | subst ndl hst (AAbs a b e) = AAbs a b (subst ndl hst e) |
---|
175 | subst ndl hst (Shift a b c) = Shift (subst ndl hst a) (subst ndl hst b) (subst ndl hst c) |
---|
176 | subst ndl hst (Aggr a b c) = Aggr a (subst ndl hst b) (subst ndl hst c) |
---|
177 | subst ndl hst (Unsized) = Unsized |
---|
178 | subst ndl hst (Bottom) = Bottom |
---|
179 | subst ndl hst q@(Abs i l) = if ndl==i then q else Abs i $ subst ndl hst l |
---|
180 | subst ndl hst q@(Var i) = if ndl==i then hst else q |
---|
181 | subst ndl hst (Num i) = Num i |
---|
182 | subst ndl hst (Op a c b) = Op (subst ndl hst a) c (subst ndl hst b) |
---|
183 | |
---|
184 | substc ndl hst (Zero l) = Zero $ subst ndl hst l |
---|
185 | substc ndl hst (LTC a b) = LTC (subst ndl hst a) (subst ndl hst b) |
---|
186 | substc ndl hst (GEC a b) = GEC (subst ndl hst a) (subst ndl hst b) |
---|
187 | |
---|
188 | reorder cs = if any check cs |
---|
189 | then Just (map r cs) |
---|
190 | else Nothing |
---|
191 | where |
---|
192 | check (Zero _) = False |
---|
193 | check (LTC a (Num 0)) = False |
---|
194 | check (GEC a (Num 0)) = False |
---|
195 | check _ = True |
---|
196 | r (LTC a b) = LTC (normalize (Op a '-' b)) (Num 0) |
---|
197 | r (GEC a b) = GEC (normalize (Op a '-' b)) (Num 0) |
---|
198 | r l = l |
---|
199 | |
---|
200 | solve :: [Condition] -> Supply Int -> IO [Condition] |
---|
201 | solve l supply = do |
---|
202 | ll <- forM (zip l $ split supply) (\(c,s) -> do |
---|
203 | putStrLn $ "\nSOLVING "++(show c) |
---|
204 | solve1 s c |
---|
205 | ) |
---|
206 | return $ concat ll |
---|
207 | where |
---|
208 | |
---|
209 | searchzero (Zero (Var a):xs) = Just (a,xs) |
---|
210 | searchzero (x:xs) = do |
---|
211 | (a,l) <- searchzero xs |
---|
212 | return (a,x:l) |
---|
213 | searchzero [] = Nothing |
---|
214 | |
---|
215 | |
---|
216 | searcheq (q@(GEC (Var var) exp):xs) prev = case findeq (List.reverse prev ++ xs) [] of |
---|
217 | Nothing -> searcheq xs (q:prev) |
---|
218 | l -> l |
---|
219 | where |
---|
220 | expinc = normalize $ Op exp '+' $ Num 1 |
---|
221 | findeq [] _ = Nothing |
---|
222 | findeq (LTC (Var var2) exp2:xs) prev2 | var==var2 && exp2==expinc = Just (var, exp, List.reverse prev2 ++ xs) |
---|
223 | findeq (x:xs) prev2 = findeq xs (x:prev2) |
---|
224 | |
---|
225 | searcheq (x:xs) prev = searcheq xs (x:prev) |
---|
226 | searcheq [] _ = Nothing |
---|
227 | |
---|
228 | checkConds [] = False |
---|
229 | checkConds ((LTC a b):xs) | (elem (GEC a b) xs) = True |
---|
230 | checkConds ((GEC a b):xs) | (elem (LTC a b) xs) = True |
---|
231 | checkConds ((LTC a (Num b)):xs) | b<=0 && elem (Zero a) xs = True |
---|
232 | checkConds ((GEC a (Num b)):xs) | b>0 && elem (Zero a) xs = True |
---|
233 | checkConds (_:xs) = checkConds xs |
---|
234 | |
---|
235 | checkConds2 (LTC (Num a) (Num b):xs) | a>=b = Nothing |
---|
236 | checkConds2 (LTC (Num a) (Num b):xs) | a<b = checkConds2 xs |
---|
237 | checkConds2 (GEC (Num a) (Num b):xs) | a<b = Nothing |
---|
238 | checkConds2 (GEC (Num a) (Num b):xs) | a>=b = checkConds2 xs |
---|
239 | checkConds2 (x:xs) = do { y <- checkConds2 xs; return (x:y) } |
---|
240 | checkConds2 [] = Just [] |
---|
241 | |
---|
242 | -- applyList a b d supp = do |
---|
243 | -- let (s1,s2,s3) = split3 supp |
---|
244 | -- let t = fresh (L V) s1 |
---|
245 | -- let dd = (Condition d (rall $ App a t) (rall $ App b t)) |
---|
246 | -- putStrLn $ "Applying a fresh variable: "++(show t) ++"\n"++(show dd) |
---|
247 | -- x <- checkCond s2 [dd] >>= mapM (solve1 s3) |
---|
248 | -- return $ concat x |
---|
249 | |
---|
250 | solve1 supp c@(Condition d a b) = case checkConds2 d of |
---|
251 | Nothing -> do |
---|
252 | putStrLn "Contradiction in preconditions" |
---|
253 | return [] |
---|
254 | Just d' -> solve1' supp $ Condition d a b |
---|
255 | |
---|
256 | solve1' supp c@(Condition d a b) |
---|
257 | | checkConds d = do |
---|
258 | putStrLn "Contradiction in preconditions" |
---|
259 | return [] |
---|
260 | | a==b = do |
---|
261 | putStrLn $ "Equals" |
---|
262 | return [] |
---|
263 | | Just (var, nl) <- searchzero d = do |
---|
264 | let x = (Condition (Prelude.map (normalizec.substc var (Num 0)) nl) |
---|
265 | (normalize$subst var (Num 0) a) |
---|
266 | (normalize$subst var (Num 0) b) |
---|
267 | ) |
---|
268 | putStrLn $ (show $ Var var) ++" is zero:\n" ++ (show x) |
---|
269 | solve1 supp x |
---|
270 | | Just (var, exp, nl) <- searcheq d [] = do |
---|
271 | putStrLn $ "Found equation "++(show $ Var var) ++ " = "++(show exp) |
---|
272 | let x = (Condition (Prelude.map (normalizec.substc var exp) nl ) |
---|
273 | (normalize$subst var exp a) |
---|
274 | (normalize$subst var exp b) |
---|
275 | ) |
---|
276 | putStrLn $ "New equations:\n" ++ (show x) |
---|
277 | solve1 supp x |
---|
278 | | App p q <- a, App r s <- b = do |
---|
279 | putStrLn $ "Branching!" |
---|
280 | nc <- checkCond supp [Condition d p r, Condition d q s] |
---|
281 | solve nc supp |
---|
282 | |
---|
283 | |
---|
284 | -- | Abs _ _ <- a, Abs _ _ <- b = do |
---|
285 | -- let (s1, s2) = split2 supp |
---|
286 | -- let t = fresh (tk a) s1 |
---|
287 | -- let dd= (Condition d (rall $ App a t) (rall $ App b t)) |
---|
288 | -- putStrLn $ "Applying a fresh variable: "++(show t) ++"\n"++(show dd) |
---|
289 | -- solve1 s2 dd |
---|
290 | -- | AAbs _ _ _ <- a, Abs _ _ <- b = applyList a b d supp |
---|
291 | -- | Abs _ _ <- a, AAbs _ _ _ <- b = applyList a b d supp |
---|
292 | -- | AAbs _ _ _ <- a, AAbs _ _ _ <- b = applyList a b d supp |
---|
293 | | Just dd <- reorder d = do |
---|
294 | putStrLn $ "Reorder " ++ (show dd) |
---|
295 | solve1 supp $ Condition dd a b |
---|
296 | | otherwise = do |
---|
297 | putStrLn "Tying to call solver" |
---|
298 | let x = compiletosolver a b d |
---|
299 | y <- SBV.prove x |
---|
300 | print y |
---|
301 | case y of |
---|
302 | (SBV.ThmResult (SBV.Unsatisfiable _)) -> return [] |
---|
303 | otherwise -> return [c] |
---|
304 | |
---|
305 | fvc (Zero a) = fv a |
---|
306 | fvc (LTC a b) = fv a `Set.union` fv b |
---|
307 | fvc (GEC a b) = fv a `Set.union` fv b |
---|
308 | |
---|
309 | compiletosolver :: L -> L -> [Constraint] -> SBV.Symbolic SBV.SBool |
---|
310 | compiletosolver a b d = do |
---|
311 | let fvs = Set.toList $ Set.unions $ (fv a):(fv b):(map fvc d) |
---|
312 | (vars::[SBV.SInteger]) <- mapM SBV.free $ map (flip showVar "") fvs |
---|
313 | let varmap = Map.fromList $ zip fvs vars |
---|
314 | mapM SBV.constrain $ map (compilec varmap) d |
---|
315 | return $ (compilel varmap a) .== (compilel varmap b) |
---|
316 | |
---|
317 | compilec v (Zero a) = compilel v a .== (0::SBV.SInteger) |
---|
318 | compilec v (LTC a b) = compilel v a .< compilel v b |
---|
319 | compilec v (GEC a b) = compilel v a .>= compilel v b |
---|
320 | |
---|
321 | compilel :: Map.Map Int SBV.SInteger -> L -> SBV.SInteger |
---|
322 | compilel v (Op a c b) = case c of |
---|
323 | '+' -> al + bl |
---|
324 | '-' -> al - bl |
---|
325 | '*' -> al * bl |
---|
326 | where |
---|
327 | al = compilel v a |
---|
328 | bl = compilel v b |
---|
329 | compilel v (Var a) | Just x<-Map.lookup a v = x |
---|
330 | compilel v (Num a) = SBV.literal $ fromIntegral a |
---|
331 | compilel v a = error $ "Cannot compile "++(show a) |
---|