diff --git a/connection.go b/connection.go index c70c59d..a6b27a0 100644 --- a/connection.go +++ b/connection.go @@ -490,6 +490,7 @@ func (u *gettyUDPConn) CloseConn(_ int) { type gettyWSConn struct { gettyConn conn *websocket.Conn + lock sync.Mutex } // create websocket connection @@ -608,7 +609,7 @@ func (w *gettyWSConn) Send(pkg interface{}) (int, error) { } w.updateWriteDeadline() - if err = w.conn.WriteMessage(websocket.BinaryMessage, p); err == nil { + if err = w.threadSafeWriteMessage(websocket.BinaryMessage, p); err == nil { w.writeBytes.Add((uint32)(len(p))) w.writePkgNum.Add(1) } @@ -617,18 +618,18 @@ func (w *gettyWSConn) Send(pkg interface{}) (int, error) { func (w *gettyWSConn) writePing() error { w.updateWriteDeadline() - return perrors.WithStack(w.conn.WriteMessage(websocket.PingMessage, []byte{})) + return perrors.WithStack(w.threadSafeWriteMessage(websocket.PingMessage, []byte{})) } func (w *gettyWSConn) writePong(message []byte) error { w.updateWriteDeadline() - return perrors.WithStack(w.conn.WriteMessage(websocket.PongMessage, message)) + return perrors.WithStack(w.threadSafeWriteMessage(websocket.PongMessage, message)) } // close websocket connection func (w *gettyWSConn) CloseConn(waitSec int) { w.updateWriteDeadline() - w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye-bye!!!")) + w.threadSafeWriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye-bye!!!")) conn := w.conn.UnderlyingConn() if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetLinger(waitSec) @@ -637,3 +638,13 @@ func (w *gettyWSConn) CloseConn(waitSec int) { } w.conn.Close() } + +// uses a mutex to ensure that only one thread can send a message at a time, preventing race conditions. +func (w *gettyWSConn) threadSafeWriteMessage(messageType int, data []byte) error { + w.lock.Lock() + defer w.lock.Unlock() + if err := w.conn.WriteMessage(messageType, data); err != nil { + return err + } + return nil +} diff --git a/session.go b/session.go index 587f654..7e704b5 100644 --- a/session.go +++ b/session.go @@ -368,7 +368,7 @@ func (s *session) sessionToken() string { s.name, s.EndPoint().EndPointType(), s.ID(), s.LocalAddr(), s.RemoteAddr()) } -func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, error) { +func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (pkgBytesLenth int, successCount int, err error) { if pkg == nil { return 0, 0, fmt.Errorf("@pkg is nil") } @@ -381,7 +381,8 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, er const size = 64 << 10 rBuf := make([]byte, size) rBuf = rBuf[:runtime.Stack(rBuf, false)] - log.Errorf("[session.WritePkg] panic session %s: err=%s\n%s", s.sessionToken(), r, rBuf) + err = perrors.WithStack(fmt.Errorf("[session.WritePkg] panic session %s: err=%v\n%s", s.sessionToken(), r, rBuf)) + log.Error(err) } }() @@ -407,13 +408,12 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, er if 0 < timeout { s.Connection.SetWriteTimeout(timeout) } - var succssCount int - succssCount, err = s.Connection.Send(pkg) + successCount, err = s.Connection.Send(pkg) if err != nil { log.Warnf("%s, [session.WritePkg] @s.Connection.Write(pkg:%#v) = err:%+v", s.Stat(), pkg, err) - return len(pkgBytes), succssCount, perrors.WithStack(err) + return len(pkgBytes), successCount, perrors.WithStack(err) } - return len(pkgBytes), succssCount, nil + return len(pkgBytes), successCount, nil } // WriteBytes for codecs