Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/Traq/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ import Traq.Prelude
import qualified Traq.ProtoLang as P
import qualified Traq.Utils.Printing as PP

import Traq.Primitives.Search.DetSearch
import Traq.Primitives.Search.Prelude
import Traq.Primitives.Search.QSearchCFNW
import Traq.Primitives.Search.RandomSearch

data DefaultPrims = QAny QSearchCFNW | RAny RandomSearch
data DefaultPrims
= QAny QSearchCFNW
| RAny RandomSearch
| DAny DetSearch
deriving (Eq, Show, Read)

instance HasPrimAny DefaultPrims where
Expand All @@ -40,17 +44,20 @@ instance HasPrimSearch DefaultPrims where
instance PP.ToCodeString DefaultPrims where
build (QAny prim) = PP.build prim
build (RAny RandomSearch{predicate}) = PP.putWord $ printf "@any_rand[%s]" predicate
build (DAny DetSearch{predicate}) = PP.putWord $ printf "@any_det[%s]" predicate

-- Parsing
instance P.CanParsePrimitive DefaultPrims where
primitiveParser tp =
(QAny <$> P.primitiveParser tp)
<|> (RAny <$> parsePrimAny "any_rand" tp)
<|> (DAny <$> parsePrimAny "any_det" tp)

-- Type Checking
instance P.TypeCheckablePrimitive DefaultPrims sizeT where
typeCheckPrimitive (QAny prim) = P.typeCheckPrimitive prim
typeCheckPrimitive (RAny prim) = P.typeCheckPrimitive prim
typeCheckPrimitive (DAny prim) = P.typeCheckPrimitive prim

-- Evaluation
instance
Expand All @@ -59,6 +66,7 @@ instance
where
evalPrimitive (QAny prim) = P.evalPrimitive prim
evalPrimitive (RAny prim) = P.evalPrimitive prim
evalPrimitive (DAny prim) = P.evalPrimitive prim

-- Costs
instance
Expand All @@ -71,6 +79,7 @@ instance
where
unitaryQueryCostPrimitive delta (QAny prim) = P.unitaryQueryCostPrimitive delta prim
unitaryQueryCostPrimitive delta (RAny prim) = P.unitaryQueryCostPrimitive delta prim
unitaryQueryCostPrimitive delta (DAny prim) = P.unitaryQueryCostPrimitive delta prim

instance
( Integral sizeT
Expand All @@ -82,6 +91,7 @@ instance
where
quantumMaxQueryCostPrimitive delta (QAny prim) = P.quantumMaxQueryCostPrimitive delta prim
quantumMaxQueryCostPrimitive delta (RAny prim) = P.quantumMaxQueryCostPrimitive delta prim
quantumMaxQueryCostPrimitive delta (DAny prim) = P.quantumMaxQueryCostPrimitive delta prim

instance
( Integral sizeT
Expand All @@ -95,6 +105,7 @@ instance
where
quantumQueryCostPrimitive delta (QAny prim) = P.quantumQueryCostPrimitive delta prim
quantumQueryCostPrimitive delta (RAny prim) = P.quantumQueryCostPrimitive delta prim
quantumQueryCostPrimitive delta (DAny prim) = P.quantumQueryCostPrimitive delta prim

-- Lowering
instance
Expand Down
142 changes: 142 additions & 0 deletions src/Traq/Primitives/Search/DetSearch.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

module Traq.Primitives.Search.DetSearch (
DetSearch (..),
) where

import Control.Monad (forM)
import Control.Monad.Reader (runReaderT)
import Lens.Micro.GHC
import Lens.Micro.Mtl
import Text.Printf (printf)

import Traq.Control.Monad
import qualified Traq.Data.Context as Ctx
import qualified Traq.Data.Tree as Tree

import Traq.Prelude
import qualified Traq.ProtoLang as P
import qualified Traq.Utils.Printing as PP

import Traq.Primitives.Search.Prelude

-- ================================================================================
-- Primitive Implementation
-- ================================================================================

-- | Primitive implementing brute-force classical search.
newtype DetSearch = DetSearch {predicate :: Ident}
deriving (Eq, Show, Read)

instance HasPrimAny DetSearch where
mkAny = DetSearch
getPredicateOfAny = predicate

instance PP.ToCodeString DetSearch where
build DetSearch{predicate} = PP.putWord $ printf "@any[%s]" predicate

instance P.CanParsePrimitive DetSearch where
primitiveParser = parsePrimAny "any"

instance P.TypeCheckablePrimitive DetSearch sizeT where
typeCheckPrimitive = typeCheckPrimAny

instance (P.EvaluatablePrimitive primsT primsT) => P.EvaluatablePrimitive primsT DetSearch where
evalPrimitive = evaluatePrimAny

-- ================================================================================
-- Abstract Costs
-- ================================================================================

instance
( Integral sizeT
, Floating costT
, Show costT
, P.UnitaryCostablePrimitive primsT primsT sizeT costT
) =>
P.UnitaryCostablePrimitive primsT DetSearch sizeT costT
where
unitaryQueryCostPrimitive delta DetSearch{predicate} _ = do
P.FunDef{P.param_types} <- view $ P._funCtx . Ctx.at predicate . singular _Just
let P.Fin n = last param_types

-- precision per predicate call
let delta_per_pred_call = delta / fromIntegral n

-- cost of each predicate call
cost_pred <-
P.unitaryQueryCostE delta_per_pred_call $
P.FunCallE{P.fun_kind = P.FunctionCall predicate, P.args = undefined}

return $ fromIntegral n * cost_pred

instance
( Integral sizeT
, Floating costT
, Ord costT
, P.QuantumMaxCostablePrimitive primsT primsT sizeT costT
) =>
P.QuantumMaxCostablePrimitive primsT DetSearch sizeT costT
where
quantumMaxQueryCostPrimitive eps DetSearch{predicate} = do
P.FunDef{P.param_types} <- view $ P._funCtx . Ctx.at predicate . singular _Just
let P.Fin n = last param_types

-- fail prob per predicate call
let eps_per_pred_call = eps / fromIntegral n

-- cost of each predicate call
cost_pred_call <-
P.quantumMaxQueryCostE eps_per_pred_call $
P.FunCallE{P.fun_kind = P.FunctionCall predicate, P.args = undefined}

return $ fromIntegral n * cost_pred_call

instance
( Integral sizeT
, Floating costT
, Ord costT
, P.QuantumCostablePrimitive primsT primsT sizeT costT
, sizeT ~ SizeT
) =>
P.QuantumCostablePrimitive primsT DetSearch sizeT costT
where
quantumQueryCostPrimitive eps DetSearch{predicate} args = do
P.FunDef{P.param_types} <- view $ P._funCtx . Ctx.at predicate . singular _Just
let ty@(P.Fin n) = last param_types

-- fail prob per predicate call
let eps_per_pred_call = eps / fromIntegral n

let sigma = Ctx.fromList $ zip ["in" ++ show i | i <- [1 .. length args]] args

costs <- forM (P.range ty) $ \v -> do
let sigma' = sigma & Ctx.ins "x_s" .~ v
let pred_call_expr =
P.FunCallE
{ P.fun_kind = P.FunctionCall predicate
, P.args = Ctx.keys sigma'
}

-- cost of predicate on input `v`
cost_v <- P.quantumQueryCostE eps_per_pred_call sigma' pred_call_expr

-- evaluate predicate on `v` to check if it is a solution
eval_env <- view P._evaluationEnv
let is_sol =
P.evalExpr pred_call_expr sigma'
& (runReaderT ?? eval_env)
& Tree.detExtract
& head
& P.valueToBool

return (is_sol, cost_v)

-- average costs of a solution and a non-solution respectively
let (non_sols, sol_and_rest) = span fst costs & (each %~ map snd)
let sol_cost = case sol_and_rest of [] -> 0; (c : _) -> c

return $ sum non_sols + sol_cost
75 changes: 56 additions & 19 deletions tools/matrixsearchqcost.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-unused-top-binds #-}

module Main (main) where

Expand All @@ -20,10 +22,41 @@ import Traq.Primitives (DefaultPrims)
import qualified Traq.ProtoLang as P

import Traq.Examples.MatrixSearch
import Traq.Primitives.Search.DetSearch (DetSearch)
import Traq.Primitives.Search.Prelude (HasPrimAny)
import Traq.Primitives.Search.QSearchCFNW (QSearchCFNW)
import Traq.Primitives.Search.RandomSearch (RandomSearch)

printDivider :: IO ()
printDivider = putStrLn $ replicate 80 '='

-- | Data to box a type @p@
data Phantom p = Phantom

class
( P.CanParsePrimitive p
, P.QuantumCostablePrimitive p p SizeT Double
, HasPrimAny p
) =>
MyPrim p

instance MyPrim DefaultPrims
instance MyPrim RandomSearch
instance MyPrim QSearchCFNW
instance MyPrim DetSearch

defPrims :: Phantom DefaultPrims
defPrims = Phantom

randSearchP :: Phantom RandomSearch
randSearchP = Phantom

qSearchP :: Phantom QSearchCFNW
qSearchP = Phantom

detSearchP :: Phantom DetSearch
detSearchP = Phantom

type Value = P.Value SizeT

class MatrixType t where
Expand All @@ -39,26 +72,28 @@ instance MatrixType (Value, Value, Value -> Value -> Bool) where
toValueFun _ _ = error "unsupported"

instance MatrixType [[Value]] where
nRows mat = length mat
nCols mat = length $ head mat
nRows = length
nCols = length . head

toValueFun mat [P.FinV i, P.FinV j] = [mat !! fromIntegral i !! fromIntegral j]
toValueFun _ _ = error "unsupported"

-- | Get the input-dependent quantum query cost.
qcost ::
(MatrixType matT) =>
forall primsT matT.
(MatrixType matT, MyPrim primsT) =>
Phantom primsT ->
-- | eps (max. fail probability)
Double ->
-- | matrix
matT ->
Double
qcost eps mat = cost
qcost _ eps mat = cost
where
n = nRows mat
m = nCols mat

ex = matrixExampleS n m
ex = matrixExample @primsT n m P.tbool

dataCtx = Ctx.singleton "Oracle" (toValueFun mat)
ticks = mempty & at "Oracle" ?~ 1.0
Expand Down Expand Up @@ -88,24 +123,24 @@ randomMatrixWith (n, m) g z = do
let rows = bad_rows ++ replicate g (replicate m (P.FinV 1))
shuffleM rows

randomStat :: SizeT -> Double -> Int -> IO [Double]
randomStat nruns eps n =
randomStat :: (MyPrim primsT) => Phantom primsT -> SizeT -> Double -> Int -> IO [Double]
randomStat phantom nruns eps n =
replicateM nruns $ do
mat <- randomMatrix n n
return $ qcost eps mat
return $ qcost phantom eps mat

computeStatsForRandomMatrices :: IO ()
computeStatsForRandomMatrices =
computeStatsForRandomMatrices :: (MyPrim primsT) => Phantom primsT -> IO ()
computeStatsForRandomMatrices phantom =
withFile "examples/matrix_search/stats/qcost.csv" WriteMode $ \h -> do
hPutStrLn h "eps,n,cost"
forM_ [0.001, 0.0005, 0.0001] $ \eps -> do
forM_ [10, 20 .. 100] $ \n -> do
cs <- randomStat 20 eps n
cs <- randomStat phantom 20 eps n
forM cs $ \c -> do
hPutStrLn h $ printf "%f,%d,%.2f" eps n c

computeStatsForPlantedRandomMatrices :: IO ()
computeStatsForPlantedRandomMatrices =
computeStatsForPlantedRandomMatrices :: (MyPrim primsT) => Phantom primsT -> IO ()
computeStatsForPlantedRandomMatrices phantom =
withFile "examples/matrix_search/stats/datadep.csv" WriteMode $ \h -> do
hPutStrLn h "eps,n,m,good,zeros,cost"
let eps = 0.001
Expand All @@ -114,18 +149,19 @@ computeStatsForPlantedRandomMatrices =
forM_ [0 .. n `div` 4 + 1] $ \g -> do
forM_ [1] $ \z -> do
mat <- randomMatrixWith (n, m) g z
let c = qcost eps mat
let c = qcost phantom eps mat
hPutStrLn h $ printf "%f,%d,%d,%d,%d,%.2f" eps n m g z c

computeStatsForWorstCaseMatrices :: IO ()
computeStatsForWorstCaseMatrices =
computeStatsForWorstCaseMatrices :: (MyPrim primsT) => Phantom primsT -> IO ()
computeStatsForWorstCaseMatrices phantom =
withFile "examples/matrix_search/stats/worstcase.csv" WriteMode $ \h -> do
hPutStrLn h "n,cost"
let eps = 0.001
forM_ ((10 :: Int) : [500, 1000 .. 4000]) $ \n -> do
-- forM_ ((10 :: Int) : [500, 1000 .. 4000]) $ \n -> do
forM_ ((10 :: Int) : [20, 40, 100]) $ \n -> do
putStrLn $ ">> n = " <> show n
let m = n
let c = qcost eps (P.FinV n, P.FinV m, matfun)
let c = qcost phantom eps (P.FinV n, P.FinV m, matfun)
hPutStrLn h $ printf "%d,%.2f" n c
where
matfun :: Value -> Value -> Bool
Expand Down Expand Up @@ -176,7 +212,8 @@ main = do

-- computeStatsForRandomMatrices
-- computeStatsForPlantedRandomMatrices
timeIt computeStatsForWorstCaseMatrices
timeIt $ computeStatsForWorstCaseMatrices detSearchP
timeIt $ computeStatsForWorstCaseMatrices qSearchP
-- timeIt computeStatsForWorstCaseExample
-- timeIt triangular
putStrLn "done"
1 change: 1 addition & 0 deletions traq.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ library
Traq.Examples.Search
Traq.Prelude
Traq.Primitives
Traq.Primitives.Search.DetSearch
Traq.Primitives.Search.Prelude
Traq.Primitives.Search.QCount
Traq.Primitives.Search.QMax
Expand Down
Loading