// // jwt.go // Copyright (C) 2024 veypi // 2024-09-23 18:28 // Distributed under terms of the MIT license. // package auth import ( "errors" "fmt" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/veypi/OneAuth/cfg" "github.com/vyes/vigo" ) var ( AuthNotFound = vigo.NewError("auth not found").WithCode(40100) AuthFailed = vigo.NewError("auth failed").WithCode(40101) AuthExpired = vigo.NewError("auth expired").WithCode(40102) AuthInvalid = vigo.NewError("auth invalid").WithCode(40103) AuthNoPerm = vigo.NewError("auth no permission").WithCode(40104) ) func GenJwt(claim *Claims) (string, error) { return GenJwtWithKey(claim, cfg.Config.Key) } func GenJwtWithKey(claim *Claims, key string) (string, error) { if claim.ExpiresAt == nil { claim.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) return token.SignedString([]byte(key)) } func ParseJwt(tokenString string, keys ...string) (*Claims, error) { key := cfg.Config.Key if len(keys) > 0 { key = keys[0] } claims := &Claims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(key), nil }) if errors.Is(err, jwt.ErrTokenExpired) { return nil, AuthExpired } if err != nil || !token.Valid { return nil, AuthInvalid } return claims, nil } func checkJWT(x *vigo.X) (*Claims, error) { authHeader := x.Request.Header.Get("Authorization") if authHeader == "" { authHeader = x.Request.URL.Query().Get("Authorization") if authHeader == "" { return nil, AuthNotFound } } // Token is typically in the format "Bearer " tokenString := strings.TrimPrefix(authHeader, "Bearer ") // Parse the token claims, err := ParseJwt(tokenString) if err != nil { return nil, err } x.Set("token", claims) x.Set("uid", claims.UID) return claims, nil } func CheckJWT(x *vigo.X) (any, error) { return checkJWT(x) } func Check(target string, pid string, l AuthLevel) func(x *vigo.X) (any, error) { return func(x *vigo.X) (any, error) { claims, err := checkJWT(x) if err != nil { return nil, err // return nil, err } tid := "" if pid != "" { tid = x.Params.Get(pid) } if !claims.Access.Check(target, tid, l) { return nil, AuthNoPerm } return claims, nil } }