You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
OneAuth/libs/auth/jwt.go

119 lines
2.7 KiB
Go

//
// jwt.go
// Copyright (C) 2024 veypi <i@veypi.com>
// 2024-09-23 18:28
// Distributed under terms of the MIT license.
//
package auth
import (
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/veypi/OneAuth/cfg"
"github.com/vyes/vigo"
)
var (
AuthNotFound = vigo.NewError("auth not found").WithCode(40100)
AuthFailed = vigo.NewError("auth failed").WithCode(40101)
AuthExpired = vigo.NewError("auth expired").WithCode(40102)
AuthInvalid = vigo.NewError("auth invalid").WithCode(40103)
AuthNoPerm = vigo.NewError("auth no permission").WithCode(40104)
)
func GenJwt(claim *Claims) (string, error) {
return GenJwtWithKey(claim, cfg.Config.Key)
}
func GenJwtWithKey(claim *Claims, key string) (string, error) {
if claim.ExpiresAt == nil {
claim.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour))
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
return token.SignedString([]byte(key))
}
func ParseJwt(tokenString string, keys ...string) (*Claims, error) {
key := cfg.Config.Key
if len(keys) > 0 {
key = keys[0]
}
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(key), nil
})
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, AuthExpired
}
if err != nil || !token.Valid {
return nil, AuthInvalid
}
return claims, nil
}
func checkJWT(x *vigo.X) (*Claims, error) {
authHeader := x.Request.Header.Get("Authorization")
if authHeader == "" {
authHeader = x.Request.URL.Query().Get("Authorization")
if authHeader == "" {
return nil, AuthNotFound
}
}
// Token is typically in the format "Bearer <token>"
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
// Parse the token
claims, err := ParseJwt(tokenString)
if err != nil {
return nil, err
}
x.Set("token", claims)
x.Set("uid", claims.UID)
return claims, nil
}
func CheckJWT(x *vigo.X) (any, error) {
return checkJWT(x)
}
type CustomCheckFunc = func(x *vigo.X, data any) bool
func Check(target string, pid string, l AuthLevel, funcs ...CustomCheckFunc) func(x *vigo.X, data any) (any, error) {
return func(x *vigo.X, data any) (any, error) {
var err error
claims, ok := x.Get("token").(*Claims)
if !ok {
claims, err = checkJWT(x)
if err != nil {
return nil, err
}
}
tid := ""
if strings.HasPrefix(pid, "@") {
tid, _ = x.Get(pid[1:]).(string)
}
if strings.HasPrefix(pid, ":") {
tid = x.Params.Get(pid[1:])
}
if !claims.Access.Check(target, tid, l) {
err = AuthNoPerm
}
for _, fn := range funcs {
if fn(x, data) {
return data, nil
}
}
return data, err
}
}