1- {-# LANGUAGE DataKinds #-}
2- {-# LANGUAGE DerivingStrategies #-}
3- {-# LANGUAGE OverloadedLabels #-}
4- {-# LANGUAGE OverloadedRecordDot #-}
5- {-# LANGUAGE OverloadedStrings #-}
6- {-# LANGUAGE PatternSynonyms #-}
7- {-# LANGUAGE RecordWildCards #-}
8- {-# LANGUAGE TemplateHaskell #-}
9- {-# LANGUAGE TypeFamilies #-}
10- {-# LANGUAGE UnicodeSyntax #-}
1+ {-# LANGUAGE BlockArguments #-}
2+ {-# LANGUAGE DataKinds #-}
3+ {-# LANGUAGE DerivingStrategies #-}
4+ {-# LANGUAGE ImpredicativeTypes #-}
5+ {-# LANGUAGE LiberalTypeSynonyms #-}
6+ {-# LANGUAGE MultiWayIf #-}
7+ {-# LANGUAGE OverloadedLabels #-}
8+ {-# LANGUAGE OverloadedRecordDot #-}
9+ {-# LANGUAGE OverloadedStrings #-}
10+ {-# LANGUAGE PatternSynonyms #-}
11+ {-# LANGUAGE QuantifiedConstraints #-}
12+ {-# LANGUAGE RecordWildCards #-}
13+ {-# LANGUAGE TemplateHaskell #-}
14+ {-# LANGUAGE TypeFamilies #-}
15+ {-# LANGUAGE UnicodeSyntax #-}
16+ {-# LANGUAGE ViewPatterns #-}
1117
1218-- |
1319-- This module provides the core functionality of the plugin.
14- module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull , getSemanticTokensRule , semanticConfigProperties , semanticTokensFullDelta ) where
20+ module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull , getSemanticTokensRule , getSyntacticTokensRule , semanticConfigProperties , semanticTokensFullDelta ) where
1521
1622import Control.Concurrent.STM (stateTVar )
1723import Control.Concurrent.STM.Stats (atomically )
@@ -21,20 +27,28 @@ import Control.Monad.Except (ExceptT, liftEither,
2127import Control.Monad.IO.Class (MonadIO (.. ))
2228import Control.Monad.Trans (lift )
2329import Control.Monad.Trans.Except (runExceptT )
30+ import Control.Monad.Trans.Maybe
31+ import Data.Data (Data (.. ))
32+ import Data.List
2433import qualified Data.Map.Strict as M
34+ import Data.Maybe
35+ import Data.Semigroup (First (.. ))
2536import Data.Text (Text )
2637import qualified Data.Text as T
2738import Development.IDE (Action ,
2839 GetDocMap (GetDocMap ),
2940 GetHieAst (GetHieAst ),
41+ GetParsedModuleWithComments (.. ),
3042 HieAstResult (HAR , hieAst , hieModule , refMap ),
3143 IdeResult , IdeState ,
3244 Priority (.. ),
3345 Recorder , Rules ,
3446 WithPriority ,
3547 cmapWithPrio , define ,
36- fromNormalizedFilePath ,
37- hieKind )
48+ hieKind ,
49+ srcSpanToRange ,
50+ toNormalizedUri ,
51+ useWithStale )
3852import Development.IDE.Core.PluginUtils (runActionE , useE ,
3953 useWithStaleE )
4054import Development.IDE.Core.Rules (toIdeResult )
@@ -44,10 +58,11 @@ import Development.IDE.Core.Shake (ShakeExtras (..),
4458 getVirtualFile )
4559import Development.IDE.GHC.Compat hiding (Warning )
4660import Development.IDE.GHC.Compat.Util (mkFastString )
61+ import GHC.Parser.Annotation
4762import GHC.Iface.Ext.Types (HieASTs (getAsts ),
4863 pattern HiePath )
4964import Ide.Logger (logWith )
50- import Ide.Plugin.Error (PluginError (PluginInternalError ),
65+ import Ide.Plugin.Error (PluginError (PluginInternalError , PluginRuleFailed ),
5166 getNormalizedFilePathE ,
5267 handleMaybe ,
5368 handleMaybeM )
@@ -61,10 +76,17 @@ import qualified Language.LSP.Protocol.Lens as L
6176import Language.LSP.Protocol.Message (MessageResult ,
6277 Method (Method_TextDocumentSemanticTokensFull , Method_TextDocumentSemanticTokensFullDelta ))
6378import Language.LSP.Protocol.Types (NormalizedFilePath ,
79+ Range ,
6480 SemanticTokens ,
81+ fromNormalizedFilePath ,
6582 type (|? ) (InL , InR ))
6683import Prelude hiding (span )
6784import qualified StmContainers.Map as STM
85+ import Type.Reflection (Typeable , eqTypeRep ,
86+ pattern App ,
87+ type (:~~: ) (HRefl ),
88+ typeOf , typeRep ,
89+ withTypeable )
6890
6991
7092$ mkSemanticConfigFunctions
@@ -78,8 +100,17 @@ computeSemanticTokens recorder pid _ nfp = do
78100 config <- lift $ useSemanticConfigAction pid
79101 logWith recorder Debug (LogConfig config)
80102 semanticId <- lift getAndIncreaseSemanticTokensId
81- (RangeHsSemanticTokenTypes {rangeSemanticList}, mapping) <- useWithStaleE GetSemanticTokens nfp
82- withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config mapping rangeSemanticList
103+
104+ (sortOn fst -> tokenList, First mapping) <- do
105+ rangesyntacticTypes <- lift $ useWithStale GetSyntacticTokens nfp
106+ rangesemanticTypes <- lift $ useWithStale GetSemanticTokens nfp
107+ let mk w u (toks, mapping) = (map (fmap w) $ u toks, First mapping)
108+ maybeToExceptT (PluginRuleFailed " no syntactic nor semantic tokens" ) $ hoistMaybe $
109+ (mk HsSyntacticTokenType rangeSyntacticList <$> rangesyntacticTypes)
110+ <> (mk HsSemanticTokenType rangeSemanticList <$> rangesemanticTypes)
111+
112+ -- NOTE: rangeSemanticsSemanticTokens actually assumes that the tokesn are in order. that means they have to be sorted by position
113+ withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config mapping tokenList
83114
84115semanticTokensFull :: Recorder (WithPriority SemanticLog ) -> PluginMethodHandler IdeState 'Method_TextDocumentSemanticTokensFull
85116semanticTokensFull recorder state pid param = runActionE " SemanticTokens.semanticTokensFull" state computeSemanticTokensFull
@@ -133,6 +164,87 @@ getSemanticTokensRule recorder =
133164 let hsFinder = idSemantic getTyThingMap (hieKindFunMasksKind hieKind) refMap
134165 return $ computeRangeHsSemanticTokenTypeList hsFinder virtualFile ast
135166
167+ getSyntacticTokensRule :: Recorder (WithPriority SemanticLog ) -> Rules ()
168+ getSyntacticTokensRule recorder =
169+ define (cmapWithPrio LogShake recorder) $ \ GetSyntacticTokens nfp -> handleError recorder $ do
170+ (parsedModule, _) <- withExceptT LogDependencyError $ useWithStaleE GetParsedModuleWithComments nfp
171+ let tokList = computeRangeHsSyntacticTokenTypeList parsedModule
172+ logWith recorder Debug $ LogSyntacticTokens tokList
173+ pure tokList
174+
175+ astTraversalWith :: forall b r . Data b => b -> (forall a . Data a => a -> [r ]) -> [r ]
176+ astTraversalWith ast f = mconcat $ flip gmapQ ast \ y -> f y <> astTraversalWith y f
177+
178+ {-# inline extractTyToTy #-}
179+ extractTyToTy :: forall f a . (Typeable f , Data a ) => a -> Maybe (forall r . (forall b . Typeable b => f b -> r ) -> r )
180+ extractTyToTy node
181+ | App conRep argRep <- typeOf node
182+ , Just HRefl <- eqTypeRep conRep (typeRep @ f )
183+ = Just $ withTypeable argRep $ (\ k -> k node)
184+ | otherwise = Nothing
185+
186+ {-# inline extractTy #-}
187+ extractTy :: forall b a . (Typeable b , Data a ) => a -> Maybe b
188+ extractTy node
189+ | Just HRefl <- eqTypeRep (typeRep @ b ) (typeOf node)
190+ = Just node
191+ | otherwise = Nothing
192+
193+ computeRangeHsSyntacticTokenTypeList :: ParsedModule -> RangeHsSyntacticTokenTypes
194+ computeRangeHsSyntacticTokenTypeList ParsedModule {pm_parsed_source} =
195+ let toks = astTraversalWith pm_parsed_source \ node -> mconcat
196+ [ maybeToList $ mkFromLocatable TKeyword . (\ k -> k \ x k' -> k' x) =<< extractTyToTy @ EpToken node
197+ -- FIXME: probably needs to be commented out for ghc > 9.10
198+ , maybeToList $ mkFromLocatable TKeyword . (\ x k -> k x) =<< extractTy @ AddEpAnn node
199+ , do
200+ EpAnnImportDecl i p s q pkg a <- maybeToList $ extractTy @ EpAnnImportDecl node
201+
202+ mapMaybe (mkFromLocatable TKeyword . (\ x k -> k x)) $ catMaybes $ [Just i, s, q, pkg, a] <> foldMap (\ (l, l') -> [Just l, Just l']) p
203+ , maybeToList $ mkFromLocatable TComment . (\ x k -> k x) =<< extractTy @ LEpaComment node
204+ , do
205+ L loc expr <- maybeToList $ extractTy @ (LHsExpr GhcPs ) node
206+ let fromSimple = maybeToList . flip mkFromLocatable \ k -> k loc
207+ case expr of
208+ HsOverLabel {} -> fromSimple TStringLit
209+ HsOverLit _ (OverLit _ lit) -> fromSimple case lit of
210+ HsIntegral {} -> TNumberLit
211+ HsFractional {} -> TNumberLit
212+
213+ HsIsString {} -> TStringLit
214+ HsLit _ lit -> fromSimple case lit of
215+ HsChar {} -> TCharLit
216+ HsCharPrim {} -> TCharLit
217+
218+ HsInt {} -> TNumberLit
219+ HsInteger {} -> TNumberLit
220+ HsIntPrim {} -> TNumberLit
221+ HsWordPrim {} -> TNumberLit
222+ HsWord8Prim {} -> TNumberLit
223+ HsWord16Prim {} -> TNumberLit
224+ HsWord32Prim {} -> TNumberLit
225+ HsWord64Prim {} -> TNumberLit
226+ HsInt8Prim {} -> TNumberLit
227+ HsInt16Prim {} -> TNumberLit
228+ HsInt32Prim {} -> TNumberLit
229+ HsInt64Prim {} -> TNumberLit
230+ HsFloatPrim {} -> TNumberLit
231+ HsDoublePrim {} -> TNumberLit
232+ HsRat {} -> TNumberLit
233+
234+ HsString {} -> TStringLit
235+ HsStringPrim {} -> TStringLit
236+ HsGetField _ _ field -> maybeToList $ mkFromLocatable TRecordSelector \ k -> k field
237+ HsProjection _ projs -> foldMap (\ proj -> maybeToList $ mkFromLocatable TRecordSelector \ k -> k proj) projs
238+ _ -> []
239+ ]
240+ in RangeHsSyntacticTokenTypes toks
241+
242+ {-# inline mkFromLocatable #-}
243+ mkFromLocatable
244+ :: HsSyntacticTokenType
245+ -> (forall r . (forall a . HasSrcSpan a => a -> r ) -> r )
246+ -> Maybe (Range , HsSyntacticTokenType )
247+ mkFromLocatable tt w = w \ tok -> let mrange = srcSpanToRange $ getLoc tok in fmap (, tt) mrange
136248
137249-- taken from /haskell-language-server/plugins/hls-code-range-plugin/src/Ide/Plugin/CodeRange/Rules.hs
138250
0 commit comments