114 lines
4.6 KiB
Haskell
114 lines
4.6 KiB
Haskell
{-# LANGUAGE TemplateHaskell #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
module Petzval.Optimization where
|
|
|
|
import Linear
|
|
import Control.Monad
|
|
import Control.Monad.Loops
|
|
import Control.Monad.State
|
|
import Control.Monad.Writer (MonadWriter, tell)
|
|
import Control.Lens
|
|
import Control.Lens.Unsound (adjoin)
|
|
import Numeric.AD.Mode (auto)
|
|
import Numeric.AD.Mode.Reverse.Double
|
|
import Petzval.Optics
|
|
import Petzval.Types
|
|
import Control.Monad.Writer.Class
|
|
|
|
import Numeric.LinearAlgebra hiding (Element, dot)
|
|
import qualified Numeric.LinearAlgebra as L
|
|
|
|
|
|
-- | A set of modifiable parts of a lens system, expressed as a traversal.
|
|
-- The recommended way of generating a variable set is with the following construction:
|
|
-- @
|
|
-- vars = ix 1 . roc `adjoin` ix 2 . thickness
|
|
-- @
|
|
type VariableSet = forall mat a. Traversal' [Element mat a] a
|
|
type AdMode s = ReverseDouble s
|
|
|
|
extractVars :: VariableSet -> [Element mat a] -> [a]
|
|
extractVars vars system = system ^. partsOf vars
|
|
|
|
setVars :: VariableSet -> [Element mat a] -> [a] -> [Element mat a]
|
|
setVars vars system vals = system & partsOf vars .~ vals
|
|
|
|
gradientAt :: VariableSet -- ^ The set of independent variables
|
|
-> (forall a. Calcuable a => [Element mat a] -> a) -- ^ merit function
|
|
-> [Element mat Double] -- ^ The system
|
|
-> (Double, [Double]) -- ^ The gradient
|
|
gradientAt vars merit system = grad' (merit . setVars vars (system & each.liftFp %~ auto)) (extractVars vars system)
|
|
|
|
data DLSState = DLSState { _damping :: Double
|
|
, _lastSum :: Double
|
|
, _curPt :: Vector Double
|
|
, _dampScale :: Double
|
|
, _found :: Bool
|
|
, _curIter :: Integer
|
|
}
|
|
|
|
data DLSCfg = DLSCfg { eps1 :: Double -- ^ Cutoff for the derivative of the metric function
|
|
, eps2 :: Double -- ^ Cutoff for no longer making progress
|
|
, maxSteps :: Integer -- ^ Maximum number of steps to iterate
|
|
}
|
|
makeLenses ''DLSState
|
|
|
|
optimizeDLS :: (MonadWriter [[Double]] m -- ^ This yields a list of intermediate values of the merit function
|
|
)
|
|
=> DLSCfg
|
|
-> VariableSet -- ^ The set of independent variables
|
|
-> (forall a. Calcuable a => [Element mat a] -> [a]) -- ^ merit function
|
|
-> [Element mat Double] -- ^ The system
|
|
-> m [Element mat Double]
|
|
|
|
optimizeDLS cfg vars merit system = fmap (setVars vars system . toList) $ evalStateT doOptimize initialState
|
|
where
|
|
initialState = let pt = fromList $ extractVars vars system
|
|
(y,j) = jacobianAt pt
|
|
a :: Matrix Double = tr j L.<> j
|
|
damping :: Double = 1e-3 * (maximum . toList . takeDiag) a
|
|
lastSum = sum . map (^2) . merit $ system
|
|
in DLSState { _damping=damping
|
|
, _lastSum=lastSum
|
|
, _curPt=pt
|
|
, _dampScale = 2
|
|
, _found = False
|
|
, _curIter = maxSteps cfg
|
|
}
|
|
isDone = orM [use found, uses curIter (<0)]
|
|
doOptimize = (untilM_ (curIter -= 1 >> optimizeStep) isDone) >> use curPt
|
|
-- optimizeStep :: m Double, where the return value is the change in merit
|
|
optimizeStep = do
|
|
mu <- use damping
|
|
lastPt <- use curPt
|
|
let (y, a) = jacobianAt lastPt
|
|
let g = tr a #> y
|
|
let dPt :: Vector Double = -(inv $ tr a L.<> a + (scalar mu * ident (cols a)) ) #> g
|
|
let newPt = lastPt + dPt
|
|
let oldMerit = sumSq y
|
|
let newMerit = sumSq . fromList . merit . setVars vars system . toList $ newPt
|
|
let dL = (dPt `L.dot` (scalar mu * dPt - g)) / 2
|
|
let gain = (oldMerit - newMerit) / dL
|
|
if gain > 0
|
|
then do curPt .= lastPt `add` dPt
|
|
damping .= mu * max (1/3) (1 - (2*gain - 1) ^ 3)
|
|
dampScale .= 2
|
|
found ||= (norm_2 g <= eps1 cfg || norm_2 dPt <= eps2 cfg * (norm_2 lastPt + eps2 cfg))
|
|
curPt .= newPt
|
|
else do scale <- use dampScale
|
|
dampScale *= 2
|
|
damping *= scale
|
|
|
|
tell [toList y]
|
|
|
|
return $ newMerit - oldMerit
|
|
jacobianAt :: Vector Double -> (Vector Double, Matrix Double)
|
|
jacobianAt pt = let (y,a) = unzip . jacobian' (merit . setVars vars (system & each.liftFp %~ auto)) $ toList pt
|
|
in (fromList y, fromLists a)
|
|
|
|
sumSq :: Vector Double -> Double
|
|
sumSq x = L.dot x x
|
|
-- instances
|