You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
5.5 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package middleware
import (
"fmt"
"net/http"
"runtime"
"runtime/debug"
"sort"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
// 用户信息类作为生成token的参数
type UserClaims struct {
ID string `json:"userId"`
Name string `json:"name"`
Phone string `json:"phone"`
//jwt-go提供的标准claim
jwt.StandardClaims
}
var (
//自定义的token秘钥
secret = []byte("16849841325189456f487")
//该路由下不校验token
noVerify = []string{
"/",
"/welcome",
"/wstest",
"/ldapi/v1/websocket/ccuwebsocket",
"/user/login",
"/ldapi/v1/web/api/service/auth/login",
"/ldapi/v1/web/api/service/heartbeat/ping",
"/ldapi/v1/app/api/service/auth/login",
"/ldapi/v1/app/api/service/heartbeat/ping",
"/ldapi/v1/web/api/service/ping",
}
//token有效时间纳秒
effectTime = 48 * time.Hour
)
// 生成token
func GenerateToken(claims *UserClaims) string {
//设置token有效期也可不设置有效期采用redis的方式
// 1)将token存储在redis中设置过期时间token如没过期则自动刷新redis过期时间
// 2)通过这种方式可以很方便的为token续期而且也可以实现长时间不登录的话强制登录
//本例只是简单采用 设置token有效期的方式只是提供了刷新token的方法并没有做续期处理的逻辑
claims.ExpiresAt = time.Now().Add(effectTime).Unix()
//生成token
sign, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secret)
if err != nil {
//这里因为项目接入了统一异常处理所以使用panic并不会使程序终止如不接入可使用原始方式处理错误
//接入统一异常可参考 https://blog.csdn.net/u014155085/article/details/106733391
panic(err)
}
return sign
}
// 验证token
func JwtVerify(c *gin.Context) {
//过滤是否验证token
log.Print(c.Request.RequestURI)
if UriinArrayWithSort(c.Request.RequestURI, noVerify) {
return
}
token := c.GetHeader(viper.GetString("app.tokenKey"))
if token == "" {
panic("token not exist !")
}
//验证token并存储在请求中
//c.Set("user", parseToken(token))
}
// 解析Token
func parseToken(tokenString string) *UserClaims {
//解析token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
panic(err)
}
claims, ok := token.Claims.(*UserClaims)
if !ok {
panic("token is valid")
}
return claims
}
// 获取用户ing
func GetUsernameFromToken(tokenString string) string {
//解析token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
// panic(err)
fmt.Println(err)
return ""
}
claims, ok := token.Claims.(*UserClaims)
if !ok {
panic("token is valid")
}
return claims.Name
}
// 更新token
func Refresh(tokenString string) string {
jwt.TimeFunc = func() time.Time {
return time.Unix(0, 0)
}
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
panic(err)
}
claims, ok := token.Claims.(*UserClaims)
if !ok {
panic("token is valid")
}
jwt.TimeFunc = time.Now
claims.StandardClaims.ExpiresAt = time.Now().Add(2 * time.Hour).Unix()
return GenerateToken(claims)
}
func UriinArrayWithSort(target string, str_array []string) bool {
sort.Strings(str_array)
index := sort.SearchStrings(str_array, target)
log.Printf("index: %v\n", index)
//index的取值[0,len(str_array)]
if index < len(str_array) && (str_array[index] == target || strings.HasPrefix(target, "/user/login?redirect=")) {
//需要注意此处的判断,先判断 &&左侧的条件,如果不满足则结束此处判断,不会再进行右侧的判断
return true
}
return false
}
func UriinArray(target string, str_array []string) bool {
for _, element := range str_array {
if target == element {
return true
}
}
return false
}
func Recover(c *gin.Context) {
// 先声明map
var resultmap map[string]interface{}
// 再使用make函数创建一个非nil的mapnil map不能赋值
resultmap = make(map[string]interface{})
defer func() {
if r := recover(); r != nil {
// 获取堆栈信息
stackTrace := debug.Stack()
// 获取调用信息
pc, file, line, ok := runtime.Caller(1)
if ok {
fn := runtime.FuncForPC(pc)
log.Printf("panic: %v\n", r)
log.Printf("occurred in %s, %s:%d\nStack trace:\n%s", fn.Name(), file, line, stackTrace)
} else {
log.Printf("panic: %v\nStack trace:\n%s", r, stackTrace)
}
//打印错误堆栈信息
log.Printf("panic: %v\n", r)
debug.PrintStack()
//封装通用json返回
//c.JSON(http.StatusOK, Result.Fail(errorToString(r)))
//Result.Fail不是本例的重点因此用下面代码代替
resultmap["status"] = http.StatusUnauthorized
resultmap["success"] = true
resultmap["message"] = errorToString(r)
c.JSON(http.StatusUnauthorized, resultmap)
// c.Redirect(http.StatusFound, viper.GetString("app.contextPath"))
//终止后续接口调用不加的话recover到异常后还会继续执行接口里后续代码
c.Abort()
}
}()
//加载完 defer recover继续后续接口调用
c.Next()
}
// recover错误转string
func errorToString(r interface{}) string {
switch v := r.(type) {
case error:
return v.Error()
default:
return r.(string)
}
}