Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions go/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ func (ws *WsService) readMsg() {
default:
_, rawMsg, err := ws.Client.ReadMessage()
if err != nil {
ws.Logger.Printf("websocket err: %s", err.Error())
ws.Logger.Printf("websocket err: %s %s", err.Error(), ws.conf.Key)
if e := ws.reconnect(); e != nil {
ws.Logger.Printf("reconnect err:%s", err.Error())
return
}
ws.Logger.Println("reconnect success, continue read message")
ws.Logger.Println("reconnect success, continue read message ", ws.conf.Key)
continue
}

ws.Logger.Println("received message:", string(rawMsg), ws.conf.Key)
var msg UpdateMsg
if err := json.Unmarshal(rawMsg, &msg); err != nil {
continue
Expand Down Expand Up @@ -212,9 +212,7 @@ func (ws *WsService) receiveCallMsg(channel string, msgCh chan *UpdateMsg) {

func (ws *WsService) APIRequest(channel string, payload any, keyVals map[string]any) error {
var err error
ws.loginOnce.Do(func() {
err = ws.login()
})
err = ws.Login()

if err != nil {
return err
Expand All @@ -239,6 +237,14 @@ func (ws *WsService) APIRequest(channel string, payload any, keyVals map[string]
return ws.apiRequest(channel, payload, keyVals)
}

func (ws *WsService) Login() error {
var err error
ws.loginOnce.Do(func() {
err = ws.login()
})
return err
}

func (ws *WsService) login() error {
if ws.conf.Key == "" || ws.conf.Secret == "" {
return newAuthEmptyErr()
Expand Down
65 changes: 27 additions & 38 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type WsService struct {
// ConnConf default URL is spot websocket
type ConnConf struct {
App string
subscribeMsg *sync.Map
subscribeMsg sync.Map
URL string
Key string
Secret string
Expand Down Expand Up @@ -116,15 +116,14 @@ func NewWsService(ctx context.Context, logger *log.Logger, conf *ConnConf) (*WsS
clientMu: new(sync.Mutex),
}

go ws.activePing()
ws.keepAlive()

return ws, nil
}

func getInitConnConf() *ConnConf {
return &ConnConf{
App: "spot",
subscribeMsg: new(sync.Map),
MaxRetryConn: MaxRetryConn,
Key: "",
Secret: "",
Expand Down Expand Up @@ -165,7 +164,6 @@ func NewConnConfFromOption(op *ConfOptions) *ConnConf {
}
return &ConnConf{
App: op.App,
subscribeMsg: new(sync.Map),
MaxRetryConn: op.MaxRetryConn,
Key: op.Key,
Secret: op.Secret,
Expand Down Expand Up @@ -216,6 +214,11 @@ func (ws *WsService) reconnect() error {

ws.status = connected

// should login when reconnect
if err := ws.login(); err != nil {
ws.Logger.Println("reconnect login err:%s", err.Error())
}

// resubscribe after reconnect
ws.conf.subscribeMsg.Range(func(key, value interface{}) bool {
// key is channel, value is []requestHistory
Expand Down Expand Up @@ -313,45 +316,31 @@ func (ws *WsService) GetConnection() *websocket.Conn {
return ws.Client
}

func (ws *WsService) activePing() {
du, err := time.ParseDuration(ws.conf.PingInterval)
if err != nil {
ws.Logger.Printf("failed to parse ping interval: %s, use default ping interval 10s instead", ws.conf.PingInterval)
du, err = time.ParseDuration(DefaultPingInterval)
if err != nil {
du = time.Second * 10
}
}
func (ws *WsService) keepAlive() {
var timeout = 10 * time.Second
ticker := time.NewTicker(timeout)

ticker := time.NewTicker(du)
defer ticker.Stop()

for {
select {
case <-ws.Ctx.Done():
return
case <-ticker.C:
subscribeMap := map[string]int{}
ws.conf.subscribeMsg.Range(func(key, value interface{}) bool {
splits := strings.Split(key.(string), ".")
if len(splits) == 2 {
subscribeMap[splits[0]] = 1
}
return true
})
lastResponse := time.Now()
ws.Client.SetPongHandler(func(msg string) error {
lastResponse = time.Now()
return nil
})

if ws.status != connected {
continue
go func() {
defer ticker.Stop()
for {
ws.mu.Lock()
err := ws.Client.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second))
ws.mu.Unlock()
if err != nil {
ws.Logger.Printf("send ping err:%s", err.Error())
}

for app := range subscribeMap {
channel := app + ".ping"
if err := ws.Subscribe(channel, nil); err != nil {
ws.Logger.Printf("subscribe channel[%s] failed: %v", channel, err)
}
<-ticker.C
if time.Since(lastResponse) > 30*time.Second {
ws.Logger.Printf("ping timeout, should reconnect")
}
}
}
}()
}

var statusString = map[status]string{
Expand Down
20 changes: 15 additions & 5 deletions go/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,24 @@ type UpdateMsg struct {
Message string `json:"message"`
} `json:"errs"`
} `json:"data"`
RequestID string `json:"request_id"`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只有 API 请求的时候才有RequestID 和Ack 字段

Ack bool `json:"ack"`
}

type ResponseHeader struct {
ResponseTime string `json:"response_time"`
Status string `json:"status"`
Channel string `json:"channel"`
Event string `json:"event"`
ClientID string `json:"client_id"`
ResponseTime string `json:"response_time"`
Status string `json:"status"`
Channel string `json:"channel"`
Event string `json:"event"`
ClientID string `json:"client_id"`
ConnID string `json:"conn_id"`
ConnTraceID string `json:"conn_trace_id"`
TraceID string `json:"trace_id"`
XInTime int64 `json:"x_in_time"`
XOutTime int64 `json:"x_out_time"`
XGateRatelimitRequestsRemain int `json:"x_gate_ratelimit_requests_remain"`
XGateRatelimitLimit int `json:"x_gate_ratelimit_limit"`
XGateRatelimitResetTimestamp int64 `json:"x_gate_ratelimit_reset_timestamp"`
}

func (u *UpdateMsg) GetChannel() string {
Expand Down