add external redis username config to support redis6 ACL (#18364)

add external redis username o support redis6 ACL

Signed-off-by: yminer <yminer@vmware.com>
This commit is contained in:
MinerYang 2023-03-17 14:16:19 +08:00 committed by GitHub
parent 53d86f872e
commit e76aff6a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 914 additions and 310 deletions

View File

@ -178,6 +178,8 @@ _version: 2.7.0
# # <host_sentinel1>:<port_sentinel1>,<host_sentinel2>:<port_sentinel2>,<host_sentinel3>:<port_sentinel3>
# host: redis:6379
# password:
# # Redis AUTH command was extended in Redis 6, it is possible to use it in the two-arguments AUTH <username> <password> form.
# # username:
# # sentinel_master_set must be set to support redis+sentinel
# #sentinel_master_set:
# # db_index 0 is for core, it's unchangeable

View File

@ -384,8 +384,9 @@ def get_redis_url(db, redis=None):
kwargs['db_part'] = db and ("/%s" % db) or ""
kwargs['sentinel_part'] = kwargs.get('sentinel_master_set', None) and ("/" + kwargs['sentinel_master_set']) or ''
kwargs['password_part'] = kwargs.get('password', None) and (':%s@' % kwargs['password']) or ''
kwargs['username_part'] = kwargs.get('username', None) or ''
return "{scheme}://{password_part}{host}{sentinel_part}{db_part}".format(**kwargs) + get_redis_url_param(kwargs)
return "{scheme}://{username_part}{password_part}{host}{sentinel_part}{db_part}".format(**kwargs) + get_redis_url_param(kwargs)
def get_redis_url_param(redis=None):

View File

@ -201,5 +201,6 @@ replace (
github.com/Azure/go-autorest => github.com/Azure/go-autorest v14.2.0+incompatible
github.com/docker/distribution => github.com/distribution/distribution v2.8.1+incompatible
github.com/goharbor/harbor => ../
github.com/gomodule/redigo => github.com/gomodule/redigo v1.8.8
google.golang.org/api => google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff
)

View File

@ -652,8 +652,8 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8l
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/gomodule/redigo v1.8.8 h1:f6cXq6RRfiyrOJEV7p3JhLDlmawGBVBBP1MggY8Mo4E=
github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/certificate-transparency-go v1.0.21 h1:Yf1aXowfZ2nuboBsg7iYGLmwsOARdV86pfH3g95wXmE=

View File

@ -173,3 +173,5 @@
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS

View File

@ -12,32 +12,33 @@
// License for the specific language governing permissions and limitations
// under the License.
package internal // import "github.com/gomodule/redigo/internal"
package redis
import (
"strings"
)
const (
WatchState = 1 << iota
MultiState
SubscribeState
MonitorState
connectionWatchState = 1 << iota
connectionMultiState
connectionSubscribeState
connectionMonitorState
)
type CommandInfo struct {
type commandInfo struct {
// Set or Clear these states on connection.
Set, Clear int
}
var commandInfos = map[string]CommandInfo{
"WATCH": {Set: WatchState},
"UNWATCH": {Clear: WatchState},
"MULTI": {Set: MultiState},
"EXEC": {Clear: WatchState | MultiState},
"DISCARD": {Clear: WatchState | MultiState},
"PSUBSCRIBE": {Set: SubscribeState},
"SUBSCRIBE": {Set: SubscribeState},
"MONITOR": {Set: MonitorState},
var commandInfos = map[string]commandInfo{
"WATCH": {Set: connectionWatchState},
"UNWATCH": {Clear: connectionWatchState},
"MULTI": {Set: connectionMultiState},
"EXEC": {Clear: connectionWatchState | connectionMultiState},
"DISCARD": {Clear: connectionWatchState | connectionMultiState},
"PSUBSCRIBE": {Set: connectionSubscribeState},
"SUBSCRIBE": {Set: connectionSubscribeState},
"MONITOR": {Set: connectionMonitorState},
}
func init() {
@ -46,7 +47,7 @@ func init() {
}
}
func LookupCommandInfo(commandName string) CommandInfo {
func lookupCommandInfo(commandName string) commandInfo {
if ci, ok := commandInfos[commandName]; ok {
return ci
}

View File

@ -17,6 +17,7 @@ package redis
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -74,15 +75,27 @@ type DialOption struct {
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error)
db int
password string
useTLS bool
skipVerify bool
tlsConfig *tls.Config
readTimeout time.Duration
writeTimeout time.Duration
tlsHandshakeTimeout time.Duration
dialer *net.Dialer
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
db int
username string
password string
clientName string
useTLS bool
skipVerify bool
tlsConfig *tls.Config
}
// DialTLSHandshakeTimeout specifies the maximum amount of time waiting to
// wait for a TLS handshake. Zero means no timeout.
// If no DialTLSHandshakeTimeout option is specified then the default is 30 seconds.
func DialTLSHandshakeTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.tlsHandshakeTimeout = d
}}
}
// DialReadTimeout specifies the timeout for reading a single command reply.
@ -101,6 +114,7 @@ func DialWriteTimeout(d time.Duration) DialOption {
// DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
// If no DialConnectTimeout option is specified then the default is 30 seconds.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.Timeout = d
@ -122,7 +136,18 @@ func DialKeepAlive(d time.Duration) DialOption {
// DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
do.dialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dial(network, addr)
}
}}
}
// DialContextFunc specifies a custom dial function with context for creating TCP
// connections, otherwise a net.Dialer customized via the other options is used.
// DialContextFunc overrides DialConnectTimeout and DialKeepAlive.
func DialContextFunc(f func(ctx context.Context, network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dialContext = f
}}
}
@ -141,6 +166,23 @@ func DialPassword(password string) DialOption {
}}
}
// DialUsername specifies the username to use when connecting to
// the Redis server when Redis ACLs are used.
// A DialPassword must also be passed otherwise this option will have no effect.
func DialUsername(username string) DialOption {
return DialOption{func(do *dialOptions) {
do.username = username
}}
}
// DialClientName specifies a client name to be used
// by the Redis server connection.
func DialClientName(name string) DialOption {
return DialOption{func(do *dialOptions) {
do.clientName = name
}}
}
// DialTLSConfig specifies the config to use when a TLS connection is dialed.
// Has no effect when not dialing a TLS connection.
func DialTLSConfig(c *tls.Config) DialOption {
@ -168,19 +210,33 @@ func DialUseTLS(useTLS bool) DialOption {
// Dial connects to the Redis server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
return DialContext(context.Background(), network, address, options...)
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "TLS handshake timeout" }
// DialContext connects to the Redis server at the given network and
// address using the specified options and context.
func DialContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dialer: &net.Dialer{
Timeout: time.Second * 30,
KeepAlive: time.Minute * 5,
},
tlsHandshakeTimeout: time.Second * 10,
}
for _, option := range options {
option.f(&do)
}
if do.dial == nil {
do.dial = do.dialer.Dial
if do.dialContext == nil {
do.dialContext = do.dialer.DialContext
}
netConn, err := do.dial(network, address)
netConn, err := do.dialContext(ctx, network, address)
if err != nil {
return nil, err
}
@ -202,10 +258,22 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
}
tlsConn := tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
errc := make(chan error, 2) // buffered so we don't block timeout or Handshake
if d := do.tlsHandshakeTimeout; d != 0 {
timer := time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
defer timer.Stop()
}
go func() {
errc <- tlsConn.Handshake()
}()
if err := <-errc; err != nil {
// Timeout or Handshake error.
netConn.Close() // nolint: errcheck
return nil, err
}
netConn = tlsConn
}
@ -218,7 +286,19 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
}
if do.password != "" {
if _, err := c.Do("AUTH", do.password); err != nil {
authArgs := make([]interface{}, 0, 2)
if do.username != "" {
authArgs = append(authArgs, do.username)
}
authArgs = append(authArgs, do.password)
if _, err := c.Do("AUTH", authArgs...); err != nil {
netConn.Close()
return nil, err
}
}
if do.clientName != "" {
if _, err := c.Do("CLIENT", "SETNAME", do.clientName); err != nil {
netConn.Close()
return nil, err
}
@ -236,10 +316,17 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
// DialURL connects to a Redis server at the given URL using the Redis
// DialURL wraps DialURLContext using context.Background.
func DialURL(rawurl string, options ...DialOption) (Conn, error) {
ctx := context.Background()
return DialURLContext(ctx, rawurl, options...)
}
// DialURLContext connects to a Redis server at the given URL using the Redis
// URI scheme. URLs should follow the draft IANA specification for the
// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
func DialURL(rawurl string, options ...DialOption) (Conn, error) {
func DialURLContext(ctx context.Context, rawurl string, options ...DialOption) (Conn, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
@ -249,6 +336,10 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
}
if u.Opaque != "" {
return nil, fmt.Errorf("invalid redis URL, url is opaque: %s", rawurl)
}
// As per the IANA draft spec, the host defaults to localhost and
// the port defaults to 6379.
host, port, err := net.SplitHostPort(u.Host)
@ -264,8 +355,18 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
if u.User != nil {
password, isSet := u.User.Password()
username := u.User.Username()
if isSet {
options = append(options, DialPassword(password))
if username != "" {
// ACL
options = append(options, DialUsername(username), DialPassword(password))
} else {
// requirepass - user-info username:password with blank username
options = append(options, DialPassword(password))
}
} else if username != "" {
// requirepass - redis-cli compatibility which treats as single arg in user-info as a password
options = append(options, DialPassword(username))
}
}
@ -287,7 +388,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
options = append(options, DialUseTLS(u.Scheme == "rediss"))
return Dial("tcp", address, options...)
return DialContext(ctx, "tcp", address, options...)
}
// NewConn returns a new Redigo connection for the given net connection.
@ -349,15 +450,23 @@ func (c *conn) writeLen(prefix byte, n int) error {
}
func (c *conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
if err := c.writeLen('$', len(s)); err != nil {
return err
}
if _, err := c.bw.WriteString(s); err != nil {
return err
}
_, err := c.bw.WriteString("\r\n")
return err
}
func (c *conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
if err := c.writeLen('$', len(p)); err != nil {
return err
}
if _, err := c.bw.Write(p); err != nil {
return err
}
_, err := c.bw.WriteString("\r\n")
return err
}
@ -371,7 +480,9 @@ func (c *conn) writeFloat64(n float64) error {
}
func (c *conn) writeCommand(cmd string, args []interface{}) error {
c.writeLen('*', 1+len(args))
if err := c.writeLen('*', 1+len(args)); err != nil {
return err
}
if err := c.writeString(cmd); err != nil {
return err
}
@ -427,10 +538,21 @@ func (pe protocolError) Error() string {
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
// readLine reads a line of input from the RESP stream.
func (c *conn) readLine() ([]byte, error) {
// To avoid allocations, attempt to read the line using ReadSlice. This
// call typically succeeds. The known case where the call fails is when
// reading the output from the MONITOR command.
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, protocolError("long response line")
// The line does not fit in the bufio.Reader's buffer. Fall back to
// allocating a buffer for the line.
buf := append([]byte{}, p...)
for err == bufio.ErrBufferFull {
p, err = c.br.ReadSlice('\n')
buf = append(buf, p...)
}
p = buf
}
if err != nil {
return nil, err
@ -510,18 +632,18 @@ func (c *conn) readReply() (interface{}, error) {
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
switch string(line[1:]) {
case "OK":
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
case "PONG":
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
return Error(line[1:]), nil
case ':':
return parseInt(line[1:])
case '$':
@ -562,7 +684,9 @@ func (c *conn) Send(cmd string, args ...interface{}) error {
c.pending += 1
c.mu.Unlock()
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return c.fatal(err)
}
}
if err := c.writeCommand(cmd, args); err != nil {
return c.fatal(err)
@ -572,7 +696,9 @@ func (c *conn) Send(cmd string, args ...interface{}) error {
func (c *conn) Flush() error {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return c.fatal(err)
}
}
if err := c.bw.Flush(); err != nil {
return c.fatal(err)
@ -584,12 +710,44 @@ func (c *conn) Receive() (interface{}, error) {
return c.ReceiveWithTimeout(c.readTimeout)
}
func (c *conn) ReceiveContext(ctx context.Context) (interface{}, error) {
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)
r, e = c.ReceiveWithTimeout(realTimeout)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}
func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
}
c.conn.SetReadDeadline(deadline)
if err := c.conn.SetReadDeadline(deadline); err != nil {
return nil, c.fatal(err)
}
if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err)
@ -616,6 +774,36 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}
func (c *conn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)
r, e = c.DoWithTimeout(realTimeout, cmd, args...)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}
func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
@ -627,7 +815,9 @@ func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...inte
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return nil, c.fatal(err)
}
}
if cmd != "" {
@ -644,7 +834,9 @@ func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...inte
if readTimeout != 0 {
deadline = time.Now().Add(readTimeout)
}
c.conn.SetReadDeadline(deadline)
if err := c.conn.SetReadDeadline(deadline); err != nil {
return nil, c.fatal(err)
}
if cmd == "" {
reply := make([]interface{}, pending)

View File

@ -101,7 +101,7 @@
//
// Connections support one concurrent caller to the Receive method and one
// concurrent caller to the Send and Flush methods. No other concurrency is
// supported including concurrent calls to the Do method.
// supported including concurrent calls to the Do and Close methods.
//
// For full concurrent access to Redis, use the thread-safe Pool to get, use
// and release a connection from within a goroutine. Connections returned from
@ -174,4 +174,4 @@
// non-recoverable error such as a network error or protocol parsing error. If
// Err() returns a non-nil value, then the connection is not usable and should
// be closed.
package redis // import "github.com/gomodule/redigo/redis"
package redis

View File

@ -1,27 +0,0 @@
// +build !go1.7
package redis
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -16,6 +16,7 @@ package redis
import (
"bytes"
"context"
"fmt"
"log"
"time"
@ -30,20 +31,29 @@ func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix}
return &loggingConn{conn, logger, prefix, nil}
}
//NewLoggingConnFilter returns a logging wrapper around a connection and a filter function.
func NewLoggingConnFilter(conn Conn, logger *log.Logger, prefix string, skip func(cmdName string) bool) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix, skip}
}
type loggingConn struct {
Conn
logger *log.Logger
prefix string
skip func(cmdName string) bool
}
func (c *loggingConn) Close() error {
err := c.Conn.Close()
var buf bytes.Buffer
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
c.logger.Output(2, buf.String())
c.logger.Output(2, buf.String()) // nolint: errcheck
return err
}
@ -85,6 +95,9 @@ func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
}
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
if c.skip != nil && c.skip(commandName) {
return
}
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
if method != "Receive" {
@ -100,7 +113,7 @@ func (c *loggingConn) print(method, commandName string, args []interface{}, repl
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%v)", err)
c.logger.Output(3, buf.String())
c.logger.Output(3, buf.String()) // nolint: errcheck
}
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
@ -109,6 +122,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err
}
func (c *loggingConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoContext(c.Conn, ctx, commandName, args...)
c.print("DoContext", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
@ -127,6 +146,12 @@ func (c *loggingConn) Receive() (interface{}, error) {
return reply, err
}
func (c *loggingConn) ReceiveContext(ctx context.Context) (interface{}, error) {
reply, err := ReceiveContext(c.Conn, ctx)
c.print("ReceiveContext", "", nil, reply, err)
return reply, err
}
func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)

View File

@ -16,16 +16,14 @@ package redis
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"errors"
"io"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gomodule/redigo/internal"
)
var (
@ -41,7 +39,6 @@ var nowFunc = time.Now // for testing
var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
var (
errPoolClosed = errors.New("redigo: connection pool closed")
errConnClosed = errors.New("redigo: connection closed")
)
@ -58,6 +55,7 @@ var (
// return &redis.Pool{
// MaxIdle: 3,
// IdleTimeout: 240 * time.Second,
// // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial.
// Dial: func () (redis.Conn, error) { return redis.Dial("tcp", addr) },
// }
// }
@ -127,6 +125,13 @@ type Pool struct {
// (subscribed to pubsub channel, transaction started, ...).
Dial func() (Conn, error)
// DialContext is an application supplied function for creating and configuring a
// connection with the given context.
//
// The connection returned from Dial must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
DialContext func(ctx context.Context) (Conn, error)
// TestOnBorrow is an optional application supplied function for checking
// the health of an idle connection before the connection is used again by
// the application. Argument t is the time that the connection was returned
@ -154,18 +159,19 @@ type Pool struct {
// the pool does not close connections based on age.
MaxConnLifetime time.Duration
chInitialized uint32 // set to 1 when field ch is initialized
mu sync.Mutex // mu protects the following fields
closed bool // set to true when the pool is closed.
active int // the number of open connections in the pool
ch chan struct{} // limits open connections when p.Wait is true
idle idleList // idle connections
mu sync.Mutex // mu protects the following fields
closed bool // set to true when the pool is closed.
active int // the number of open connections in the pool
initOnce sync.Once // the init ch once func
ch chan struct{} // limits open connections when p.Wait is true
idle idleList // idle connections
waitCount int64 // total number of connections waited for.
waitDuration time.Duration // total time waited for new connections.
}
// NewPool creates a new pool.
//
// Deprecated: Initialize the Pool directory as shown in the example.
// Deprecated: Initialize the Pool directly as shown in the example.
func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
return &Pool{Dial: newFn, MaxIdle: maxIdle}
}
@ -176,11 +182,86 @@ func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get() Conn {
pc, err := p.get(nil)
// GetContext returns errorConn in the first argument when an error occurs.
c, _ := p.GetContext(context.Background())
return c
}
// GetContext gets a connection using the provided context.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Any expiration on the context
// will not affect the returned connection.
//
// If the function completes without error, then the application must close the
// returned connection.
func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
// Wait until there is a vacant connection in the pool.
waited, err := p.waitVacantConn(ctx)
if err != nil {
return errorConn{err}
return errorConn{err}, err
}
return &activeConn{p: p, pc: pc}
p.mu.Lock()
if waited > 0 {
p.waitCount++
p.waitDuration += waited
}
// Prune stale connections at the back of the idle list.
if p.IdleTimeout > 0 {
n := p.idle.count
for i := 0; i < n && p.idle.back != nil && p.idle.back.t.Add(p.IdleTimeout).Before(nowFunc()); i++ {
pc := p.idle.back
p.idle.popBack()
p.mu.Unlock()
pc.c.Close()
p.mu.Lock()
p.active--
}
}
// Get idle connection from the front of idle list.
for p.idle.front != nil {
pc := p.idle.front
p.idle.popFront()
p.mu.Unlock()
if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
return &activeConn{p: p, pc: pc}, nil
}
pc.c.Close()
p.mu.Lock()
p.active--
}
// Check for pool closed before dialing a new connection.
if p.closed {
p.mu.Unlock()
err := errors.New("redigo: get on closed pool")
return errorConn{err}, err
}
// Handle limit for p.Wait == false.
if !p.Wait && p.MaxActive > 0 && p.active >= p.MaxActive {
p.mu.Unlock()
return errorConn{ErrPoolExhausted}, ErrPoolExhausted
}
p.active++
p.mu.Unlock()
c, err := p.dial(ctx)
if err != nil {
p.mu.Lock()
p.active--
if p.ch != nil && !p.closed {
p.ch <- struct{}{}
}
p.mu.Unlock()
return errorConn{err}, err
}
return &activeConn{p: p, pc: &poolConn{c: c, created: nowFunc()}}, nil
}
// PoolStats contains pool statistics.
@ -190,14 +271,24 @@ type PoolStats struct {
ActiveCount int
// IdleCount is the number of idle connections in the pool.
IdleCount int
// WaitCount is the total number of connections waited for.
// This value is currently not guaranteed to be 100% accurate.
WaitCount int64
// WaitDuration is the total time blocked waiting for a new connection.
// This value is currently not guaranteed to be 100% accurate.
WaitDuration time.Duration
}
// Stats returns pool's statistics.
func (p *Pool) Stats() PoolStats {
p.mu.Lock()
stats := PoolStats{
ActiveCount: p.active,
IdleCount: p.idle.count,
ActiveCount: p.active,
IdleCount: p.idle.count,
WaitCount: p.waitCount,
WaitDuration: p.waitDuration,
}
p.mu.Unlock()
@ -244,13 +335,7 @@ func (p *Pool) Close() error {
}
func (p *Pool) lazyInit() {
// Fast path.
if atomic.LoadUint32(&p.chInitialized) == 1 {
return
}
// Slow path.
p.mu.Lock()
if p.chInitialized == 0 {
p.initOnce.Do(func() {
p.ch = make(chan struct{}, p.MaxActive)
if p.closed {
close(p.ch)
@ -259,86 +344,59 @@ func (p *Pool) lazyInit() {
p.ch <- struct{}{}
}
}
atomic.StoreUint32(&p.chInitialized, 1)
}
p.mu.Unlock()
})
}
// get prunes stale connections and returns a connection from the idle list or
// creates a new connection.
func (p *Pool) get(ctx interface {
Done() <-chan struct{}
Err() error
}) (*poolConn, error) {
// Handle limit for p.Wait == true.
if p.Wait && p.MaxActive > 0 {
p.lazyInit()
if ctx == nil {
<-p.ch
} else {
select {
case <-p.ch:
case <-ctx.Done():
return nil, ctx.Err()
}
}
// waitVacantConn waits for a vacant connection in pool if waiting
// is enabled and pool size is limited, otherwise returns instantly.
// If ctx expires before that, an error is returned.
//
// If there were no vacant connection in the pool right away it returns the time spent waiting
// for that connection to appear in the pool.
func (p *Pool) waitVacantConn(ctx context.Context) (waited time.Duration, err error) {
if !p.Wait || p.MaxActive <= 0 {
// No wait or no connection limit.
return 0, nil
}
p.mu.Lock()
p.lazyInit()
// Prune stale connections at the back of the idle list.
if p.IdleTimeout > 0 {
n := p.idle.count
for i := 0; i < n && p.idle.back != nil && p.idle.back.t.Add(p.IdleTimeout).Before(nowFunc()); i++ {
pc := p.idle.back
p.idle.popBack()
p.mu.Unlock()
pc.c.Close()
p.mu.Lock()
p.active--
}
// wait indicates if we believe it will block so its not 100% accurate
// however for stats it should be good enough.
wait := len(p.ch) == 0
var start time.Time
if wait {
start = time.Now()
}
// Get idle connection from the front of idle list.
for p.idle.front != nil {
pc := p.idle.front
p.idle.popFront()
p.mu.Unlock()
if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
return pc, nil
}
pc.c.Close()
p.mu.Lock()
p.active--
}
// Check for pool closed before dialing a new connection.
if p.closed {
p.mu.Unlock()
return nil, errors.New("redigo: get on closed pool")
}
// Handle limit for p.Wait == false.
if !p.Wait && p.MaxActive > 0 && p.active >= p.MaxActive {
p.mu.Unlock()
return nil, ErrPoolExhausted
}
p.active++
p.mu.Unlock()
c, err := p.Dial()
if err != nil {
c = nil
p.mu.Lock()
p.active--
if p.ch != nil && !p.closed {
select {
case <-p.ch:
// Additionally check that context hasn't expired while we were waiting,
// because `select` picks a random `case` if several of them are "ready".
select {
case <-ctx.Done():
p.ch <- struct{}{}
return 0, ctx.Err()
default:
}
p.mu.Unlock()
case <-ctx.Done():
return 0, ctx.Err()
}
return &poolConn{c: c, created: nowFunc()}, err
if wait {
return time.Since(start), nil
}
return 0, nil
}
func (p *Pool) dial(ctx context.Context) (Conn, error) {
if p.DialContext != nil {
return p.DialContext(ctx)
}
if p.Dial != nil {
return p.Dial()
}
return nil, errors.New("redigo: must pass Dial or DialContext to pool")
}
func (p *Pool) put(pc *poolConn, forceClose bool) error {
@ -385,48 +443,65 @@ func initSentinel() {
sentinel = p
} else {
h := sha1.New()
io.WriteString(h, "Oops, rand failed. Use time instead.")
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
io.WriteString(h, "Oops, rand failed. Use time instead.") // nolint: errcheck
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) // nolint: errcheck
sentinel = h.Sum(nil)
}
}
func (ac *activeConn) Close() error {
func (ac *activeConn) firstError(errs ...error) error {
for _, err := range errs[:len(errs)-1] {
if err != nil {
return err
}
}
return errs[len(errs)-1]
}
func (ac *activeConn) Close() (err error) {
pc := ac.pc
if pc == nil {
return nil
}
ac.pc = nil
if ac.state&internal.MultiState != 0 {
pc.c.Send("DISCARD")
ac.state &^= (internal.MultiState | internal.WatchState)
} else if ac.state&internal.WatchState != 0 {
pc.c.Send("UNWATCH")
ac.state &^= internal.WatchState
if ac.state&connectionMultiState != 0 {
err = pc.c.Send("DISCARD")
ac.state &^= (connectionMultiState | connectionWatchState)
} else if ac.state&connectionWatchState != 0 {
err = pc.c.Send("UNWATCH")
ac.state &^= connectionWatchState
}
if ac.state&internal.SubscribeState != 0 {
pc.c.Send("UNSUBSCRIBE")
pc.c.Send("PUNSUBSCRIBE")
if ac.state&connectionSubscribeState != 0 {
err = ac.firstError(err,
pc.c.Send("UNSUBSCRIBE"),
pc.c.Send("PUNSUBSCRIBE"),
)
// To detect the end of the message stream, ask the server to echo
// a sentinel value and read until we see that value.
sentinelOnce.Do(initSentinel)
pc.c.Send("ECHO", sentinel)
pc.c.Flush()
err = ac.firstError(err,
pc.c.Send("ECHO", sentinel),
pc.c.Flush(),
)
for {
p, err := pc.c.Receive()
if err != nil {
p, err2 := pc.c.Receive()
if err2 != nil {
err = ac.firstError(err, err2)
break
}
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
ac.state &^= internal.SubscribeState
ac.state &^= connectionSubscribeState
break
}
}
}
pc.c.Do("")
ac.p.put(pc, ac.state != 0 || pc.c.Err() != nil)
return nil
_, err2 := pc.c.Do("")
return ac.firstError(
err,
err2,
ac.p.put(pc, ac.state != 0 || pc.c.Err() != nil),
)
}
func (ac *activeConn) Err() error {
@ -437,12 +512,26 @@ func (ac *activeConn) Err() error {
return pc.c.Err()
}
func (ac *activeConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return cwt.DoContext(ctx, commandName, args...)
}
func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
ci := internal.LookupCommandInfo(commandName)
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return pc.c.Do(commandName, args...)
}
@ -456,7 +545,7 @@ func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, a
if !ok {
return nil, errTimeoutNotSupported
}
ci := internal.LookupCommandInfo(commandName)
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return cwt.DoWithTimeout(timeout, commandName, args...)
}
@ -466,7 +555,7 @@ func (ac *activeConn) Send(commandName string, args ...interface{}) error {
if pc == nil {
return errConnClosed
}
ci := internal.LookupCommandInfo(commandName)
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return pc.c.Send(commandName, args...)
}
@ -487,6 +576,18 @@ func (ac *activeConn) Receive() (reply interface{}, err error) {
return pc.c.Receive()
}
func (ac *activeConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}
func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
@ -502,6 +603,9 @@ func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface
type errorConn struct{ err error }
func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConn) DoContext(context.Context, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
@ -510,6 +614,7 @@ func (ec errorConn) Err() error { ret
func (ec errorConn) Close() error { return nil }
func (ec errorConn) Flush() error { return ec.err }
func (ec errorConn) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveContext(context.Context) (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
type idleList struct {
@ -534,7 +639,6 @@ func (l *idleList) pushFront(pc *poolConn) {
}
l.front = pc
l.count++
return
}
func (l *idleList) popFront() {

View File

@ -1,35 +0,0 @@
// Copyright 2018 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// +build go1.7
package redis
import "context"
// GetContext gets a connection using the provided context.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Any expiration on the context
// will not affect the returned connection.
//
// If the function completes without error, then the application must close the
// returned connection.
func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
pc, err := p.get(ctx)
if err != nil {
return errorConn{err}, err
}
return &activeConn{p: p, pc: pc}, nil
}

View File

@ -60,27 +60,35 @@ func (c PubSubConn) Close() error {
// Subscribe subscribes the connection to the specified channels.
func (c PubSubConn) Subscribe(channel ...interface{}) error {
c.Conn.Send("SUBSCRIBE", channel...)
if err := c.Conn.Send("SUBSCRIBE", channel...); err != nil {
return err
}
return c.Conn.Flush()
}
// PSubscribe subscribes the connection to the given patterns.
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
c.Conn.Send("PSUBSCRIBE", channel...)
if err := c.Conn.Send("PSUBSCRIBE", channel...); err != nil {
return err
}
return c.Conn.Flush()
}
// Unsubscribe unsubscribes the connection from the given channels, or from all
// of them if none is given.
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
c.Conn.Send("UNSUBSCRIBE", channel...)
if err := c.Conn.Send("UNSUBSCRIBE", channel...); err != nil {
return err
}
return c.Conn.Flush()
}
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
// of them if none is given.
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
c.Conn.Send("PUNSUBSCRIBE", channel...)
if err := c.Conn.Send("PUNSUBSCRIBE", channel...); err != nil {
return err
}
return c.Conn.Flush()
}
@ -89,7 +97,9 @@ func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
// The connection must be subscribed to at least one channel or pattern when
// calling this method.
func (c PubSubConn) Ping(data string) error {
c.Conn.Send("PING", data)
if err := c.Conn.Send("PING", data); err != nil {
return err
}
return c.Conn.Flush()
}

View File

@ -15,6 +15,7 @@
package redis
import (
"context"
"errors"
"time"
)
@ -33,6 +34,7 @@ type Conn interface {
Err() error
// Do sends a command to the server and returns the received reply.
// This function will use the timeout which was set when the connection is created
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
@ -82,17 +84,52 @@ type Scanner interface {
type ConnWithTimeout interface {
Conn
// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
// DoWithTimeout sends a command to the server and returns the received reply.
// The timeout overrides the readtimeout set when dialing the connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)
// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
// ReceiveWithTimeout receives a single reply from the Redis server.
// The timeout overrides the readtimeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}
// ConnWithContext is an optional interface that allows the caller to control the command's life with context.
type ConnWithContext interface {
Conn
// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error)
// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
ReceiveContext(ctx context.Context) (reply interface{}, err error)
}
var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
var errContextNotSupported = errors.New("redis: connection does not support ConnWithContext")
// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func DoContext(c Conn, ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.DoContext(ctx, cmd, args...)
}
// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
@ -105,6 +142,20 @@ func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{
return cwt.DoWithTimeout(timeout, cmd, args...)
}
// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func ReceiveContext(c Conn, ctx context.Context) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}
// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
@ -115,3 +166,24 @@ func ReceiveWithTimeout(c Conn, timeout time.Duration) (interface{}, error) {
}
return cwt.ReceiveWithTimeout(timeout)
}
// SlowLog represents a redis SlowLog
type SlowLog struct {
// ID is a unique progressive identifier for every slow log entry.
ID int64
// Time is the unix timestamp at which the logged command was processed.
Time time.Time
// ExecutationTime is the amount of time needed for the command execution.
ExecutionTime time.Duration
// Args is the command name and arguments
Args []string
// ClientAddr is the client IP address (4.0 only).
ClientAddr string
// ClientName is the name set via the CLIENT SETNAME command (4.0 only).
ClientName string
}

View File

@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"strconv"
"time"
)
// ErrNil indicates that a reply value is nil.
@ -55,7 +56,7 @@ func Int(reply interface{}, err error) (int, error) {
}
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// not equal to nil, then Int64 returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
@ -81,14 +82,16 @@ func Int64(reply interface{}, err error) (int64, error) {
return 0, fmt.Errorf("redigo: unexpected type for Int64, got type %T", reply)
}
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
func errNegativeInt(v int64) error {
return fmt.Errorf("redigo: unexpected negative value %v for Uint64", v)
}
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
// Uint64 is a helper that converts a command reply to 64 bit unsigned integer.
// If err is not equal to nil, then Uint64 returns 0, err. Otherwise, Uint64 converts the
// reply to an uint64 as follows:
//
// Reply type Result
// integer reply, nil
// +integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
@ -99,7 +102,7 @@ func Uint64(reply interface{}, err error) (uint64, error) {
switch reply := reply.(type) {
case int64:
if reply < 0 {
return 0, errNegativeInt
return 0, errNegativeInt(reply)
}
return uint64(reply), nil
case []byte:
@ -115,7 +118,7 @@ func Uint64(reply interface{}, err error) (uint64, error) {
// Float64 is a helper that converts a command reply to 64 bit float. If err is
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
// the reply to an int as follows:
// the reply to a float64 as follows:
//
// Reply type Result
// bulk string parsed reply, nil
@ -274,13 +277,16 @@ func sliceHelper(reply interface{}, err error, name string, makeSlice func(int),
func Float64s(reply interface{}, err error) ([]float64, error) {
var result []float64
err = sliceHelper(reply, err, "Float64s", func(n int) { result = make([]float64, n) }, func(i int, v interface{}) error {
p, ok := v.([]byte)
if !ok {
return fmt.Errorf("redigo: unexpected element type for Floats64, got type %T", v)
switch v := v.(type) {
case []byte:
f, err := strconv.ParseFloat(string(v), 64)
result[i] = f
return err
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for Float64s, got type %T", v)
}
f, err := strconv.ParseFloat(string(p), 64)
result[i] = f
return err
})
return result, err
}
@ -299,6 +305,8 @@ func Strings(reply interface{}, err error) ([]string, error) {
case []byte:
result[i] = string(v)
return nil
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for Strings, got type %T", v)
}
@ -313,12 +321,15 @@ func Strings(reply interface{}, err error) ([]string, error) {
func ByteSlices(reply interface{}, err error) ([][]byte, error) {
var result [][]byte
err = sliceHelper(reply, err, "ByteSlices", func(n int) { result = make([][]byte, n) }, func(i int, v interface{}) error {
p, ok := v.([]byte)
if !ok {
switch v := v.(type) {
case []byte:
result[i] = v
return nil
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for ByteSlices, got type %T", v)
}
result[i] = p
return nil
})
return result, err
}
@ -338,6 +349,8 @@ func Int64s(reply interface{}, err error) ([]int64, error) {
n, err := strconv.ParseInt(string(v), 10, 64)
result[i] = n
return err
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for Int64s, got type %T", v)
}
@ -345,7 +358,7 @@ func Int64s(reply interface{}, err error) ([]int64, error) {
return result, err
}
// Ints is a helper that converts an array command reply to a []in.
// Ints is a helper that converts an array command reply to a []int.
// If err is not equal to nil, then Ints returns nil, err. Nil array
// items are stay nil. Ints returns an error if an array item is not a
// bulk string or nil.
@ -364,6 +377,8 @@ func Ints(reply interface{}, err error) ([]int, error) {
n, err := strconv.Atoi(string(v))
result[i] = n
return err
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for Ints, got type %T", v)
}
@ -379,16 +394,23 @@ func StringMap(result interface{}, err error) (map[string]string, error) {
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: StringMap expects even number of values result")
return nil, fmt.Errorf("redigo: StringMap expects even number of values result, got %d", len(values))
}
m := make(map[string]string, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, okKey := values[i].([]byte)
value, okValue := values[i+1].([]byte)
if !okKey || !okValue {
return nil, errors.New("redigo: StringMap key not a bulk string value")
key, ok := values[i].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: StringMap key[%d] not a bulk string value, got %T", i, values[i])
}
value, ok := values[i+1].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: StringMap value[%d] not a bulk string value, got %T", i+1, values[i+1])
}
m[string(key)] = string(value)
}
return m, nil
@ -402,19 +424,23 @@ func IntMap(result interface{}, err error) (map[string]int, error) {
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: IntMap expects even number of values result")
return nil, fmt.Errorf("redigo: IntMap expects even number of values result, got %d", len(values))
}
m := make(map[string]int, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, errors.New("redigo: IntMap key not a bulk string value")
return nil, fmt.Errorf("redigo: IntMap key[%d] not a bulk string value, got %T", i, values[i])
}
value, err := Int(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
@ -428,19 +454,23 @@ func Int64Map(result interface{}, err error) (map[string]int64, error) {
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, errors.New("redigo: Int64Map expects even number of values result")
return nil, fmt.Errorf("redigo: Int64Map expects even number of values result, got %d", len(values))
}
m := make(map[string]int64, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, errors.New("redigo: Int64Map key not a bulk string value")
return nil, fmt.Errorf("redigo: Int64Map key[%d] not a bulk string value, got %T", i, values[i])
}
value, err := Int64(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
@ -458,22 +488,137 @@ func Positions(result interface{}, err error) ([]*[2]float64, error) {
if values[i] == nil {
continue
}
p, ok := values[i].([]interface{})
if !ok {
return nil, fmt.Errorf("redigo: unexpected element type for interface slice, got type %T", values[i])
}
if len(p) != 2 {
return nil, fmt.Errorf("redigo: unexpected number of values for a member position, got %d", len(p))
}
lat, err := Float64(p[0], nil)
if err != nil {
return nil, err
}
long, err := Float64(p[1], nil)
if err != nil {
return nil, err
}
positions[i] = &[2]float64{lat, long}
}
return positions, nil
}
// Uint64s is a helper that converts an array command reply to a []uint64.
// If err is not equal to nil, then Uint64s returns nil, err. Nil array
// items are stay nil. Uint64s returns an error if an array item is not a
// bulk string or nil.
func Uint64s(reply interface{}, err error) ([]uint64, error) {
var result []uint64
err = sliceHelper(reply, err, "Uint64s", func(n int) { result = make([]uint64, n) }, func(i int, v interface{}) error {
switch v := v.(type) {
case uint64:
result[i] = v
return nil
case []byte:
n, err := strconv.ParseUint(string(v), 10, 64)
result[i] = n
return err
case Error:
return v
default:
return fmt.Errorf("redigo: unexpected element type for Uint64s, got type %T", v)
}
})
return result, err
}
// Uint64Map is a helper that converts an array of strings (alternating key, value)
// into a map[string]uint64. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func Uint64Map(result interface{}, err error) (map[string]uint64, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, fmt.Errorf("redigo: Uint64Map expects even number of values result, got %d", len(values))
}
m := make(map[string]uint64, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, fmt.Errorf("redigo: Uint64Map key[%d] not a bulk string value, got %T", i, values[i])
}
value, err := Uint64(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}
// SlowLogs is a helper that parse the SLOWLOG GET command output and
// return the array of SlowLog
func SlowLogs(result interface{}, err error) ([]SlowLog, error) {
rawLogs, err := Values(result, err)
if err != nil {
return nil, err
}
logs := make([]SlowLog, len(rawLogs))
for i, e := range rawLogs {
rawLog, ok := e.([]interface{})
if !ok {
return nil, fmt.Errorf("redigo: slowlog element is not an array, got %T", e)
}
var log SlowLog
if len(rawLog) < 4 {
return nil, fmt.Errorf("redigo: slowlog element has %d elements, expected at least 4", len(rawLog))
}
log.ID, ok = rawLog[0].(int64)
if !ok {
return nil, fmt.Errorf("redigo: slowlog element[0] not an int64, got %T", rawLog[0])
}
timestamp, ok := rawLog[1].(int64)
if !ok {
return nil, fmt.Errorf("redigo: slowlog element[1] not an int64, got %T", rawLog[0])
}
log.Time = time.Unix(timestamp, 0)
duration, ok := rawLog[2].(int64)
if !ok {
return nil, fmt.Errorf("redigo: slowlog element[2] not an int64, got %T", rawLog[0])
}
log.ExecutionTime = time.Duration(duration) * time.Microsecond
log.Args, err = Strings(rawLog[3], nil)
if err != nil {
return nil, fmt.Errorf("redigo: slowlog element[3] is not array of strings: %w", err)
}
if len(rawLog) >= 6 {
log.ClientAddr, err = String(rawLog[4], nil)
if err != nil {
return nil, fmt.Errorf("redigo: slowlog element[4] is not a string: %w", err)
}
log.ClientName, err = String(rawLog[5], nil)
if err != nil {
return nil, fmt.Errorf("redigo: slowlog element[5] is not a string: %w", err)
}
}
logs[i] = log
}
return logs, nil
}

View File

@ -23,6 +23,10 @@ import (
"sync"
)
var (
scannerType = reflect.TypeOf((*Scanner)(nil)).Elem()
)
func ensureLen(d reflect.Value, n int) {
if n > d.Cap() {
d.Set(reflect.MakeSlice(d.Type(), n, n))
@ -44,44 +48,105 @@ func cannotConvert(d reflect.Value, s interface{}) error {
sname = "Redis bulk string"
case []interface{}:
sname = "Redis array"
case nil:
sname = "Redis nil"
default:
sname = reflect.TypeOf(s).String()
}
return fmt.Errorf("cannot convert from %s to %s", sname, d.Type())
}
func convertAssignBulkString(d reflect.Value, s []byte) (err error) {
func convertAssignNil(d reflect.Value) (err error) {
switch d.Type().Kind() {
case reflect.Slice, reflect.Interface:
d.Set(reflect.Zero(d.Type()))
default:
err = cannotConvert(d, nil)
}
return err
}
func convertAssignError(d reflect.Value, s Error) (err error) {
if d.Kind() == reflect.String {
d.SetString(string(s))
} else if d.Kind() == reflect.Slice && d.Type().Elem().Kind() == reflect.Uint8 {
d.SetBytes([]byte(s))
} else {
err = cannotConvert(d, s)
}
return
}
func convertAssignString(d reflect.Value, s string) (err error) {
switch d.Type().Kind() {
case reflect.Float32, reflect.Float64:
var x float64
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
x, err = strconv.ParseFloat(s, d.Type().Bits())
d.SetFloat(x)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var x int64
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
x, err = strconv.ParseInt(s, 10, d.Type().Bits())
d.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var x uint64
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
x, err = strconv.ParseUint(s, 10, d.Type().Bits())
d.SetUint(x)
case reflect.Bool:
var x bool
x, err = strconv.ParseBool(string(s))
x, err = strconv.ParseBool(s)
d.SetBool(x)
case reflect.String:
d.SetString(string(s))
d.SetString(s)
case reflect.Slice:
if d.Type().Elem().Kind() != reflect.Uint8 {
err = cannotConvert(d, s)
if d.Type().Elem().Kind() == reflect.Uint8 {
d.SetBytes([]byte(s))
} else {
d.SetBytes(s)
err = cannotConvert(d, s)
}
case reflect.Ptr:
err = convertAssignString(d.Elem(), s)
default:
err = cannotConvert(d, s)
}
return
}
func convertAssignBulkString(d reflect.Value, s []byte) (err error) {
switch d.Type().Kind() {
case reflect.Slice:
// Handle []byte destination here to avoid unnecessary
// []byte -> string -> []byte converion.
if d.Type().Elem().Kind() == reflect.Uint8 {
d.SetBytes(s)
} else {
err = cannotConvert(d, s)
}
case reflect.Ptr:
if d.CanInterface() && d.CanSet() {
if s == nil {
if d.IsNil() {
return nil
}
d.Set(reflect.Zero(d.Type()))
return nil
}
if d.IsNil() {
d.Set(reflect.New(d.Type().Elem()))
}
if sc, ok := d.Interface().(Scanner); ok {
return sc.RedisScan(s)
}
}
err = convertAssignString(d, string(s))
default:
err = convertAssignString(d, string(s))
}
return err
}
func convertAssignInt(d reflect.Value, s int64) (err error) {
switch d.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -130,10 +195,16 @@ func convertAssignValue(d reflect.Value, s interface{}) (err error) {
}
switch s := s.(type) {
case nil:
err = convertAssignNil(d)
case []byte:
err = convertAssignBulkString(d, s)
case int64:
err = convertAssignInt(d, s)
case string:
err = convertAssignString(d, s)
case Error:
err = convertAssignError(d, s)
default:
err = cannotConvert(d, s)
}
@ -285,34 +356,49 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
}
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
LOOP:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
switch {
case f.PkgPath != "" && !f.Anonymous:
// Ignore unexported fields.
case f.Anonymous:
// TODO: Handle pointers. Requires change to decoder and
// protection against infinite recursion.
if f.Type.Kind() == reflect.Struct {
switch f.Type.Kind() {
case reflect.Struct:
compileStructSpec(f.Type, depth, append(index, i), ss)
case reflect.Ptr:
// TODO(steve): Protect against infinite recursion.
if f.Type.Elem().Kind() == reflect.Struct {
compileStructSpec(f.Type.Elem(), depth, append(index, i), ss)
}
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")
p := strings.Split(tag, ",")
if len(p) > 0 {
if p[0] == "-" {
continue
var (
p string
)
first := true
for len(tag) > 0 {
i := strings.IndexByte(tag, ',')
if i < 0 {
p, tag = tag, ""
} else {
p, tag = tag[:i], tag[i+1:]
}
if len(p[0]) > 0 {
fs.name = p[0]
if p == "-" {
continue LOOP
}
for _, s := range p[1:] {
switch s {
if first && len(p) > 0 {
fs.name = p
first = false
} else {
switch p {
case "omitempty":
fs.omitEmpty = true
default:
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", s, t.Name()))
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", p, t.Name()))
}
}
}
@ -345,9 +431,8 @@ func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *st
}
var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
defaultFieldSpec = &fieldSpec{}
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
)
func structSpecForType(t reflect.Type) *structSpec {
@ -429,9 +514,13 @@ var (
errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct")
)
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
// slice must be integer, float, boolean, string, struct or pointer to struct
// values.
// ScanSlice scans src to the slice pointed to by dest.
//
// If the target is a slice of types which implement Scanner then the custom
// RedisScan method is used otherwise the following rules apply:
//
// The elements in the dest slice must be integer, float, boolean, string, struct
// or pointer to struct values.
//
// Struct fields must be integer, float, boolean or string values. All struct
// fields are used unless a subset is specified using fieldNames.
@ -447,12 +536,13 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error
isPtr := false
t := d.Type().Elem()
st := t
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
if t.Kind() != reflect.Struct || st.Implements(scannerType) {
ensureLen(d, len(src))
for i, s := range src {
if s == nil {
@ -579,7 +669,15 @@ func flattenStruct(args Args, v reflect.Value) Args {
continue
}
}
args = append(args, fs.name, fv.Interface())
if arg, ok := fv.Interface().(Argument); ok {
args = append(args, fs.name, arg.RedisArg())
} else if fv.Kind() == reflect.Ptr {
if !fv.IsNil() {
args = append(args, fs.name, fv.Elem().Interface())
}
} else {
args = append(args, fs.name, fv.Interface())
}
}
return args
}

View File

@ -15,6 +15,7 @@
package redis
import (
"context"
"crypto/sha1"
"encoding/hex"
"io"
@ -36,7 +37,7 @@ type Script struct {
// SendHash methods.
func NewScript(keyCount int, src string) *Script {
h := sha1.New()
io.WriteString(h, src)
io.WriteString(h, src) // nolint: errcheck
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
}
@ -60,6 +61,18 @@ func (s *Script) Hash() string {
return s.hash
}
func (s *Script) DoContext(ctx context.Context, c Conn, keysAndArgs ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
v, err := cwt.DoContext(ctx, "EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = cwt.DoContext(ctx, "EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}
// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus

View File

@ -361,9 +361,8 @@ github.com/golang/protobuf/ptypes/any
github.com/golang/protobuf/ptypes/duration
github.com/golang/protobuf/ptypes/timestamp
github.com/golang/protobuf/ptypes/wrappers
# github.com/gomodule/redigo v2.0.0+incompatible
## explicit
github.com/gomodule/redigo/internal
# github.com/gomodule/redigo v2.0.0+incompatible => github.com/gomodule/redigo v1.8.8
## explicit; go 1.16
github.com/gomodule/redigo/redis
# github.com/google/certificate-transparency-go v1.0.21
## explicit
@ -1043,4 +1042,5 @@ sigs.k8s.io/yaml
# github.com/Azure/go-autorest => github.com/Azure/go-autorest v14.2.0+incompatible
# github.com/docker/distribution => github.com/distribution/distribution v2.8.1+incompatible
# github.com/goharbor/harbor => ../
# github.com/gomodule/redigo => github.com/gomodule/redigo v1.8.8
# google.golang.org/api => google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff