Skip to content

Commit

Permalink
Merge pull request #83 from AlexStocks/fix/ws-concurrent-read
Browse files Browse the repository at this point in the history
fix:add read mutex in gettyWSConn(websocket) struct to prevent data race in ReadMessage()
  • Loading branch information
AlexStocks authored Jun 5, 2024
2 parents 87ba8ee + a48ffa3 commit da9fedf
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions transport/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,9 @@ func (u *gettyUDPConn) CloseConn(_ int) {

type gettyWSConn struct {
gettyConn
conn *websocket.Conn
lock sync.Mutex
writeLock sync.Mutex
readLock sync.Mutex
conn *websocket.Conn
}

// create websocket connection
Expand Down Expand Up @@ -569,7 +570,7 @@ func (w *gettyWSConn) handlePong(string) error {
func (w *gettyWSConn) recv() ([]byte, error) {
// Pls do not set read deadline when using ReadMessage. AlexStocks 20180310
// gorilla/websocket/conn.go:NextReader will always fail when got a timeout error.
_, b, e := w.conn.ReadMessage() // the first return value is message type.
_, b, e := w.threadSafeReadMessage() // the first return value is message type.
if e == nil {
w.readBytes.Add((uint32)(len(b)))
} else {
Expand Down Expand Up @@ -643,12 +644,23 @@ func (w *gettyWSConn) CloseConn(waitSec int) {
w.conn.Close()
}

// uses a mutex to ensure that only one thread can send a message at a time, preventing race conditions.
// uses a mutex(writeLock) to ensure that only one thread can send a message at a time, preventing race conditions.
func (w *gettyWSConn) threadSafeWriteMessage(messageType int, data []byte) error {
w.lock.Lock()
defer w.lock.Unlock()
w.writeLock.Lock()
defer w.writeLock.Unlock()
if err := w.conn.WriteMessage(messageType, data); err != nil {
return err
}
return nil
}

// uses a mutex(readLock) to ensure that only one thread can read a message at a time, preventing race conditions.
func (w *gettyWSConn) threadSafeReadMessage() (int, []byte, error) {
w.readLock.Lock()
defer w.readLock.Unlock()
messageType, readBytes, err := w.conn.ReadMessage()
if err != nil {
return messageType, nil, err
}
return messageType, readBytes, nil
}

0 comments on commit da9fedf

Please sign in to comment.