package redis import ( "context" "encoding/json" "fmt" "git.thequux.com/thequux/ipasso/backend" "github.com/redis/go-redis/v9" "net/url" "time" ) var redisHelp = ` Usage: redis://:@:/ redisu://:@?db= rediss://:@:/?addr=&addr=... The redis scheme connects to a normal Redis server over TCP the redisu scheme connects to a normal Redis server over a Unix domain socket the rediss scheme connects to a redis cluster ` func init() { backend.RegisterBackendFactory("redis", "Redis TCP", redisHelp, DialRedis) backend.RegisterBackendFactory("redisu", "Redis Unix", redisHelp, DialRedis) backend.RegisterBackendFactory("rediss", "Redis Cluster", redisHelp, DialRedis) } func DialRedis(url *url.URL) (be backend.Backend, err error) { var urlString = url.String() var client redis.UniversalClient if url.Scheme == "redis" || url.Scheme == "redisu" { var options *redis.Options options, err = redis.ParseURL(urlString) client = redis.NewClient(options) } else if url.Scheme == "rediss" { var options *redis.ClusterOptions options, err = redis.ParseClusterURL(urlString) client = redis.NewClusterClient(options) } if err == nil { be = &RedisBackend{ rdb: client, } } return } func sessionKey(id string) string { return "session:" + id } func scacheKey(id string) string { return "scache:" + id } type RedisBackend struct { rdb redis.UniversalClient } var putSessionScript = redis.NewScript(` local skey = KEYS[1] local ckey = KEYS[2] local session = ARGV[1] local sexp = ARGV[2] local cache = ARGV[3] local clifetime = ARGV[4] if redis.call("GET", skey) ~= "" then return false else redis.call("SET", skey, session, "EXAT", sexp) redis.call("SET", ckey, cache, "EX", 30) return true end `) func (r *RedisBackend) PutSession(ctx context.Context, session backend.Session, cache backend.SessionCache) error { jsonSession, err := json.Marshal(session) if err != nil { return err } jsonCache, err := json.Marshal(cache) if err != nil { return err } result, err := putSessionScript.Run(ctx, r.rdb, []string{sessionKey(session.SessionID), scacheKey(session.SessionID)}, jsonSession, session.Expiration.Unix(), jsonCache, 30, ).Bool() if err != nil { return err } else if result { return nil } else { return backend.ErrReservationSniped } } func (r *RedisBackend) GetSession(ctx context.Context, id string) (backend.Session, *backend.SessionCache, error) { var session backend.Session var cache backend.SessionCache var cachep = &cache result, err := r.rdb.MGet(ctx, sessionKey(id), scacheKey(id)).Result() if err != nil { return session, nil, err } fmt.Printf("Result: %#v\n", result) v, ok := result[0].(string) if !ok { return backend.Session{}, nil, backend.ErrBackendData } if err = json.Unmarshal([]byte(v), &session); err != nil { return backend.Session{}, nil, err } v, ok = result[1].(string) if ok { if err = json.Unmarshal([]byte(v), &cache); err != nil { return backend.Session{}, nil, err } } else { cachep = nil } return session, cachep, nil } func (r *RedisBackend) EndSession(ctx context.Context, id string) { r.rdb.Del(ctx, sessionKey(id), scacheKey(id)) } func (r *RedisBackend) ReserveSessionID(ctx context.Context, id string) (bool, error) { return r.rdb.SetNX(ctx, sessionKey(id), "", time.Second*60).Result() } func (r *RedisBackend) DoMaintenance(ctx context.Context) { // Redis handles cleaning up expired keys itself } func (r *RedisBackend) Ping(ctx context.Context) error { return r.rdb.Ping(ctx).Err() }