|
|
|
|
//
|
|
|
|
|
// jwt.go
|
|
|
|
|
// Copyright (C) 2024 veypi <i@veypi.com>
|
|
|
|
|
// 2024-09-23 18:28
|
|
|
|
|
// Distributed under terms of the MIT license.
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
package auth
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
|
"github.com/veypi/OneAuth/cfg"
|
|
|
|
|
"github.com/vyes/vigo"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
AuthNotFound = vigo.NewError("auth not found").WithCode(40100)
|
|
|
|
|
AuthFailed = vigo.NewError("auth failed").WithCode(40101)
|
|
|
|
|
AuthExpired = vigo.NewError("auth expired").WithCode(40102)
|
|
|
|
|
AuthInvalid = vigo.NewError("auth invalid").WithCode(40103)
|
|
|
|
|
AuthNoPerm = vigo.NewError("auth no permission").WithCode(40104)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func GenJwt(claim *Claims) (string, error) {
|
|
|
|
|
return GenJwtWithKey(claim, cfg.Config.Key)
|
|
|
|
|
}
|
|
|
|
|
func GenJwtWithKey(claim *Claims, key string) (string, error) {
|
|
|
|
|
if claim.ExpiresAt == nil {
|
|
|
|
|
claim.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour))
|
|
|
|
|
}
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
|
|
|
|
|
return token.SignedString([]byte(key))
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func ParseJwt(tokenString string, keys ...string) (*Claims, error) {
|
|
|
|
|
key := cfg.Config.Key
|
|
|
|
|
if len(keys) > 0 {
|
|
|
|
|
key = keys[0]
|
|
|
|
|
}
|
|
|
|
|
claims := &Claims{}
|
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
return []byte(key), nil
|
|
|
|
|
})
|
|
|
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
|
|
|
return nil, AuthExpired
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil || !token.Valid {
|
|
|
|
|
return nil, AuthInvalid
|
|
|
|
|
}
|
|
|
|
|
return claims, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func checkJWT(x *vigo.X) (*Claims, error) {
|
|
|
|
|
authHeader := x.Request.Header.Get("Authorization")
|
|
|
|
|
if authHeader == "" {
|
|
|
|
|
authHeader = x.Request.URL.Query().Get("Authorization")
|
|
|
|
|
if authHeader == "" {
|
|
|
|
|
return nil, AuthNotFound
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Token is typically in the format "Bearer <token>"
|
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
|
|
|
|
|
|
// Parse the token
|
|
|
|
|
claims, err := ParseJwt(tokenString)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
x.Set("token", claims)
|
|
|
|
|
x.Set("uid", claims.UID)
|
|
|
|
|
|
|
|
|
|
return claims, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CheckJWT(x *vigo.X) (any, error) {
|
|
|
|
|
return checkJWT(x)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type CustomCheckFunc = func(x *vigo.X, data any) bool
|
|
|
|
|
|
|
|
|
|
func Check(target string, pid string, l AuthLevel, funcs ...CustomCheckFunc) func(x *vigo.X, data any) (any, error) {
|
|
|
|
|
return func(x *vigo.X, data any) (any, error) {
|
|
|
|
|
var err error
|
|
|
|
|
claims, ok := x.Get("token").(*Claims)
|
|
|
|
|
if !ok {
|
|
|
|
|
claims, err = checkJWT(x)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tid := ""
|
|
|
|
|
if strings.HasPrefix(pid, "@") {
|
|
|
|
|
tid, _ = x.Get(pid[1:]).(string)
|
|
|
|
|
}
|
|
|
|
|
if strings.HasPrefix(pid, ":") {
|
|
|
|
|
tid = x.Params.Get(pid[1:])
|
|
|
|
|
}
|
|
|
|
|
if !claims.Access.Check(target, tid, l) {
|
|
|
|
|
err = AuthNoPerm
|
|
|
|
|
}
|
|
|
|
|
for _, fn := range funcs {
|
|
|
|
|
if fn(x, data) {
|
|
|
|
|
return data, nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return data, err
|
|
|
|
|
}
|
|
|
|
|
}
|