session.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. package client
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/jcmturner/gokrb5/v8/iana/nametype"
  10. "github.com/jcmturner/gokrb5/v8/krberror"
  11. "github.com/jcmturner/gokrb5/v8/messages"
  12. "github.com/jcmturner/gokrb5/v8/types"
  13. )
  14. // sessions hold TGTs and are keyed on the realm name
  15. type sessions struct {
  16. Entries map[string]*session
  17. mux sync.RWMutex
  18. }
  19. // destroy erases all sessions
  20. func (s *sessions) destroy() {
  21. s.mux.Lock()
  22. defer s.mux.Unlock()
  23. for k, e := range s.Entries {
  24. e.destroy()
  25. delete(s.Entries, k)
  26. }
  27. }
  28. // update replaces a session with the one provided or adds it as a new one
  29. func (s *sessions) update(sess *session) {
  30. s.mux.Lock()
  31. defer s.mux.Unlock()
  32. // if a session already exists for this, cancel its auto renew.
  33. if i, ok := s.Entries[sess.realm]; ok {
  34. if i != sess {
  35. // Session in the sessions cache is not the same as one provided.
  36. // Cancel the one in the cache and add this one.
  37. i.mux.Lock()
  38. defer i.mux.Unlock()
  39. if i.cancel != nil {
  40. i.cancel <- true
  41. }
  42. s.Entries[sess.realm] = sess
  43. return
  44. }
  45. }
  46. // No session for this realm was found so just add it
  47. s.Entries[sess.realm] = sess
  48. }
  49. // get returns the session for the realm specified
  50. func (s *sessions) get(realm string) (*session, bool) {
  51. s.mux.RLock()
  52. defer s.mux.RUnlock()
  53. sess, ok := s.Entries[realm]
  54. return sess, ok
  55. }
  56. // session holds the TGT details for a realm
  57. type session struct {
  58. realm string
  59. authTime time.Time
  60. endTime time.Time
  61. renewTill time.Time
  62. tgt messages.Ticket
  63. sessionKey types.EncryptionKey
  64. sessionKeyExpiration time.Time
  65. cancel chan bool
  66. mux sync.RWMutex
  67. }
  68. // jsonSession is used to enable marshaling some information of a session in a JSON format
  69. type jsonSession struct {
  70. Realm string
  71. AuthTime time.Time
  72. EndTime time.Time
  73. RenewTill time.Time
  74. SessionKeyExpiration time.Time
  75. }
  76. // AddSession adds a session for a realm with a TGT to the client's session cache.
  77. // A goroutine is started to automatically renew the TGT before expiry.
  78. func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  79. if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
  80. // Not a TGT
  81. return
  82. }
  83. realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
  84. s := &session{
  85. realm: realm,
  86. authTime: dep.AuthTime,
  87. endTime: dep.EndTime,
  88. renewTill: dep.RenewTill,
  89. tgt: tgt,
  90. sessionKey: dep.Key,
  91. sessionKeyExpiration: dep.KeyExpiration,
  92. }
  93. cl.sessions.update(s)
  94. cl.enableAutoSessionRenewal(s)
  95. cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
  96. }
  97. // update overwrites the session details with those from the TGT and decrypted encPart
  98. func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  99. s.mux.Lock()
  100. defer s.mux.Unlock()
  101. s.authTime = dep.AuthTime
  102. s.endTime = dep.EndTime
  103. s.renewTill = dep.RenewTill
  104. s.tgt = tgt
  105. s.sessionKey = dep.Key
  106. s.sessionKeyExpiration = dep.KeyExpiration
  107. }
  108. // destroy will cancel any auto renewal of the session and set the expiration times to the current time
  109. func (s *session) destroy() {
  110. s.mux.Lock()
  111. defer s.mux.Unlock()
  112. if s.cancel != nil {
  113. s.cancel <- true
  114. }
  115. s.endTime = time.Now().UTC()
  116. s.renewTill = s.endTime
  117. s.sessionKeyExpiration = s.endTime
  118. }
  119. // valid informs if the TGT is still within the valid time window
  120. func (s *session) valid() bool {
  121. s.mux.RLock()
  122. defer s.mux.RUnlock()
  123. t := time.Now().UTC()
  124. if t.Before(s.endTime) && s.authTime.Before(t) {
  125. return true
  126. }
  127. return false
  128. }
  129. // tgtDetails is a thread safe way to get the session's realm, TGT and session key values
  130. func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
  131. s.mux.RLock()
  132. defer s.mux.RUnlock()
  133. return s.realm, s.tgt, s.sessionKey
  134. }
  135. // timeDetails is a thread safe way to get the session's validity time values
  136. func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
  137. s.mux.RLock()
  138. defer s.mux.RUnlock()
  139. return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
  140. }
  141. // JSON return information about the held sessions in a JSON format.
  142. func (s *sessions) JSON() (string, error) {
  143. s.mux.RLock()
  144. defer s.mux.RUnlock()
  145. var js []jsonSession
  146. keys := make([]string, 0, len(s.Entries))
  147. for k := range s.Entries {
  148. keys = append(keys, k)
  149. }
  150. sort.Strings(keys)
  151. for _, k := range keys {
  152. r, at, et, rt, kt := s.Entries[k].timeDetails()
  153. j := jsonSession{
  154. Realm: r,
  155. AuthTime: at,
  156. EndTime: et,
  157. RenewTill: rt,
  158. SessionKeyExpiration: kt,
  159. }
  160. js = append(js, j)
  161. }
  162. b, err := json.MarshalIndent(js, "", " ")
  163. if err != nil {
  164. return "", err
  165. }
  166. return string(b), nil
  167. }
  168. // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
  169. func (cl *Client) enableAutoSessionRenewal(s *session) {
  170. var timer *time.Timer
  171. s.mux.Lock()
  172. s.cancel = make(chan bool, 1)
  173. s.mux.Unlock()
  174. go func(s *session) {
  175. for {
  176. s.mux.RLock()
  177. w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
  178. s.mux.RUnlock()
  179. if w < 0 {
  180. return
  181. }
  182. timer = time.NewTimer(w)
  183. select {
  184. case <-timer.C:
  185. renewal, err := cl.refreshSession(s)
  186. if err != nil {
  187. cl.Log("error refreshing session: %v", err)
  188. }
  189. if !renewal && err == nil {
  190. // end this goroutine as there will have been a new login and new auto renewal goroutine created.
  191. return
  192. }
  193. case <-s.cancel:
  194. // cancel has been called. Stop the timer and exit.
  195. timer.Stop()
  196. return
  197. }
  198. }
  199. }(s)
  200. }
  201. // renewTGT renews the client's TGT session.
  202. func (cl *Client) renewTGT(s *session) error {
  203. realm, tgt, skey := s.tgtDetails()
  204. spn := types.PrincipalName{
  205. NameType: nametype.KRB_NT_SRV_INST,
  206. NameString: []string{"krbtgt", realm},
  207. }
  208. _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
  209. if err != nil {
  210. return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
  211. }
  212. s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  213. cl.sessions.update(s)
  214. cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
  215. return nil
  216. }
  217. // refreshSession updates either through renewal or creating a new login.
  218. // The boolean indicates if the update was a renewal.
  219. func (cl *Client) refreshSession(s *session) (bool, error) {
  220. s.mux.RLock()
  221. realm := s.realm
  222. renewTill := s.renewTill
  223. s.mux.RUnlock()
  224. cl.Log("refreshing TGT session for %s", realm)
  225. if time.Now().UTC().Before(renewTill) {
  226. err := cl.renewTGT(s)
  227. return true, err
  228. }
  229. err := cl.realmLogin(realm)
  230. return false, err
  231. }
  232. // ensureValidSession makes sure there is a valid session for the realm
  233. func (cl *Client) ensureValidSession(realm string) error {
  234. s, ok := cl.sessions.get(realm)
  235. if ok {
  236. s.mux.RLock()
  237. d := s.endTime.Sub(s.authTime) / 6
  238. if s.endTime.Sub(time.Now().UTC()) > d {
  239. s.mux.RUnlock()
  240. return nil
  241. }
  242. s.mux.RUnlock()
  243. _, err := cl.refreshSession(s)
  244. return err
  245. }
  246. return cl.realmLogin(realm)
  247. }
  248. // sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
  249. func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
  250. err = cl.ensureValidSession(realm)
  251. if err != nil {
  252. return
  253. }
  254. s, ok := cl.sessions.get(realm)
  255. if !ok {
  256. err = fmt.Errorf("could not find TGT session for %s", realm)
  257. return
  258. }
  259. _, tgt, sessionKey = s.tgtDetails()
  260. return
  261. }
  262. // sessionTimes provides the timing information with regards to a session for the realm specified.
  263. func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
  264. s, ok := cl.sessions.get(realm)
  265. if !ok {
  266. err = fmt.Errorf("could not find TGT session for %s", realm)
  267. return
  268. }
  269. _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
  270. return
  271. }
  272. // spnRealm resolves the realm name of a service principal name
  273. func (cl *Client) spnRealm(spn types.PrincipalName) string {
  274. return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
  275. }