Implemented some form of LM optimization, but it's proper fucked somehow.
This commit is contained in:
@@ -1,12 +1,25 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
module Petzval.Optimization where
|
||||
|
||||
import Linear
|
||||
import Control.Monad
|
||||
import Control.Monad.Loops
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer (MonadWriter, tell)
|
||||
import Control.Lens
|
||||
import Control.Lens.Unsound (adjoin)
|
||||
import Numeric.AD.Mode
|
||||
import Numeric.AD.Mode (auto)
|
||||
import Numeric.AD.Mode.Reverse.Double
|
||||
import Petzval.Optics
|
||||
import Petzval.Types
|
||||
import Control.Monad.Writer.Class
|
||||
|
||||
import Numeric.LinearAlgebra hiding (Element, dot)
|
||||
import qualified Numeric.LinearAlgebra as L
|
||||
|
||||
|
||||
-- | A set of modifiable parts of a lens system, expressed as a traversal.
|
||||
-- The recommended way of generating a variable set is with the following construction:
|
||||
@@ -28,6 +41,73 @@ gradientAt :: VariableSet -- ^ The set of independent variables
|
||||
-> (Double, [Double]) -- ^ The gradient
|
||||
gradientAt vars merit system = grad' (merit . setVars vars (system & each.liftFp %~ auto)) (extractVars vars system)
|
||||
|
||||
data DLSState = DLSState { _damping :: Double
|
||||
, _lastSum :: Double
|
||||
, _curPt :: Vector Double
|
||||
, _dampScale :: Double
|
||||
, _found :: Bool
|
||||
, _curIter :: Integer
|
||||
}
|
||||
|
||||
data DLSCfg = DLSCfg { eps1 :: Double -- ^ Cutoff for the derivative of the metric function
|
||||
, eps2 :: Double -- ^ Cutoff for no longer making progress
|
||||
, maxSteps :: Integer -- ^ Maximum number of steps to iterate
|
||||
}
|
||||
makeLenses ''DLSState
|
||||
|
||||
optimizeDLS :: (MonadWriter [[Double]] m -- ^ This yields a list of intermediate values of the merit function
|
||||
)
|
||||
=> DLSCfg
|
||||
-> VariableSet -- ^ The set of independent variables
|
||||
-> (forall a. Calcuable a => [Element mat a] -> [a]) -- ^ merit function
|
||||
-> [Element mat Double] -- ^ The system
|
||||
-> m [Element mat Double]
|
||||
|
||||
optimizeDLS cfg vars merit system = fmap (setVars vars system . toList) $ evalStateT doOptimize initialState
|
||||
where
|
||||
initialState = let pt = fromList $ extractVars vars system
|
||||
(y,j) = jacobianAt pt
|
||||
a :: Matrix Double = tr j L.<> j
|
||||
damping :: Double = 1e-3 * (maximum . toList . takeDiag) a
|
||||
lastSum = sum . map (^2) . merit $ system
|
||||
in DLSState { _damping=damping
|
||||
, _lastSum=lastSum
|
||||
, _curPt=pt
|
||||
, _dampScale = 2
|
||||
, _found = False
|
||||
, _curIter = maxSteps cfg
|
||||
}
|
||||
isDone = orM [use found, uses curIter (<0)]
|
||||
doOptimize = (untilM_ (curIter -= 1 >> optimizeStep) isDone) >> use curPt
|
||||
-- optimizeStep :: m Double, where the return value is the change in merit
|
||||
optimizeStep = do
|
||||
mu <- use damping
|
||||
lastPt <- use curPt
|
||||
let (y, a) = jacobianAt lastPt
|
||||
let g = tr a #> y
|
||||
let dPt :: Vector Double = -(inv $ tr a L.<> a + (scalar mu * ident (cols a)) ) #> g
|
||||
let newPt = lastPt + dPt
|
||||
let oldMerit = sumSq y
|
||||
let newMerit = sumSq . fromList . merit . setVars vars system . toList $ newPt
|
||||
let dL = (dPt `L.dot` (scalar mu * dPt - g)) / 2
|
||||
let gain = (oldMerit - newMerit) / dL
|
||||
if gain > 0
|
||||
then do curPt .= lastPt `add` dPt
|
||||
damping .= mu * max (1/3) (1 - (2*gain - 1) ^ 3)
|
||||
dampScale .= 2
|
||||
found ||= (norm_2 g <= eps1 cfg || norm_2 dPt <= eps2 cfg * (norm_2 lastPt + eps2 cfg))
|
||||
curPt .= newPt
|
||||
else do scale <- use dampScale
|
||||
dampScale *= 2
|
||||
damping *= scale
|
||||
|
||||
tell [toList y]
|
||||
|
||||
return $ newMerit - oldMerit
|
||||
jacobianAt :: Vector Double -> (Vector Double, Matrix Double)
|
||||
jacobianAt pt = let (y,a) = unzip . jacobian' (merit . setVars vars (system & each.liftFp %~ auto)) $ toList pt
|
||||
in (fromList y, fromLists a)
|
||||
|
||||
sumSq :: Vector Double -> Double
|
||||
sumSq x = L.dot x x
|
||||
-- instances
|
||||
|
||||
Reference in New Issue
Block a user