// // jwt.go // Copyright (C) 2024 veypi // 2024-09-23 18:28 // Distributed under terms of the MIT license. // package auth import ( "errors" "fmt" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/veypi/OneAuth/cfg" "github.com/vyes-ai/vigo" ) var ( AuthNotFound = vigo.NewError("auth not found").WithCode(40100) AuthFailed = vigo.NewError("auth failed").WithCode(40101) AuthExpired = vigo.NewError("auth expired").WithCode(40102) AuthInvalid = vigo.NewError("auth invalid").WithCode(40103) AuthNoPerm = vigo.NewError("auth no permission").WithCode(40104) ) func GenJwt(claim *Claims) (string, error) { return GenJwtWithKey(claim, cfg.Config.Key) } func GenJwtWithKey(claim *Claims, key string) (string, error) { if claim.ExpiresAt == nil { claim.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) return token.SignedString([]byte(key)) } func ParseJwt(tokenString string, keys ...string) (*Claims, error) { key := cfg.Config.Key if len(keys) > 0 { key = keys[0] } claims := &Claims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(key), nil }) if errors.Is(err, jwt.ErrTokenExpired) { return nil, AuthExpired } if err != nil || !token.Valid { return nil, AuthInvalid } return claims, nil } func checkJWT(x *vigo.X) (*Claims, error) { authHeader := x.Request.Header.Get("Authorization") if authHeader == "" { authHeader = x.Request.URL.Query().Get("Authorization") if authHeader == "" { return nil, AuthNotFound } } // Token is typically in the format "Bearer " tokenString := strings.TrimPrefix(authHeader, "Bearer ") // Parse the token claims, err := ParseJwt(tokenString) if err != nil { return nil, err } x.Set("token", claims) x.Set("uid", claims.UID) return claims, nil } func CheckJWT(x *vigo.X) (any, error) { return checkJWT(x) } type CustomCheckFunc = func(x *vigo.X, data any) bool func Check(target string, pid string, l AuthLevel, funcs ...CustomCheckFunc) func(x *vigo.X, data any) (any, error) { return func(x *vigo.X, data any) (any, error) { var err error claims, ok := x.Get("token").(*Claims) if !ok { claims, err = checkJWT(x) if err != nil { return nil, err } } tid := "" if strings.HasPrefix(pid, "@") { tid, _ = x.Get(pid[1:]).(string) } if strings.HasPrefix(pid, ":") { tid = x.Params.Get(pid[1:]) } if !claims.Access.Check(target, tid, l) { err = AuthNoPerm } for _, fn := range funcs { if fn(x, data) { return data, nil } } return data, err } }