diff --git a/.github/workflows/verify.yml b/.github/workflows/verify.yml index a3eb74b3..8a82fcc3 100644 --- a/.github/workflows/verify.yml +++ b/.github/workflows/verify.yml @@ -29,4 +29,3 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: v1.53 - args: --timeout=5m diff --git a/.golangci.yml b/.golangci.yml index 34882139..44cf86a0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,13 @@ run: - skip-dirs: + timeout: "5m" + # will not run golangci-lint against *_test.go + tests: false +issues: + exclude-dirs: - examples/*.go + exclude-rules: + # excluding error checks from all the .go files + - path: ./*.go + linters: + - errcheck + diff --git a/LICENSE b/LICENSE index bb9d80bc..8692af65 100644 --- a/LICENSE +++ b/LICENSE @@ -24,4 +24,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md index 1fd5e9c4..525a62a9 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,11 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket ### Documentation * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) -* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) -* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) -* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) -* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) -* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool) +* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch) +* [Write buffer pool example](https://github.com/gorilla/websocket/tree/main/examples/bufferpool) ### Status @@ -33,4 +33,4 @@ package API is stable. The Gorilla WebSocket package passes the server tests in the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn -subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). +subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn). diff --git a/client.go b/client.go index 815b0ca5..7023e117 100644 --- a/client.go +++ b/client.go @@ -11,8 +11,6 @@ import ( "errors" "fmt" "io" - "log" - "net" "net/http" "net/http/httptrace" @@ -228,7 +226,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || - //#nosec G101 (CWE-798): Potential HTTP request smuggling via parameter pollution k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) @@ -294,9 +291,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } err = c.SetDeadline(deadline) if err != nil { - if err := c.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } + c.Close() return nil, err } return c, nil @@ -336,9 +331,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h defer func() { if netConn != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } + netConn.Close() } }() @@ -399,7 +392,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } } - if resp.StatusCode != 101 || + if resp.StatusCode != http.StatusSwitchingProtocols || !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { @@ -429,9 +422,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp.Body = io.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - if err := netConn.SetDeadline(time.Time{}); err != nil { - return nil, nil, err - } + netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil } diff --git a/client_server_test.go b/client_server_test.go index 6095c09e..99ab3b0e 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -49,6 +49,7 @@ type cstHandler struct{ *testing.T } type cstServer struct { *httptest.Server URL string + t *testing.T } const ( @@ -92,6 +93,7 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { + t.Logf("Upgrade: %v", err) return } defer ws.Close() @@ -103,16 +105,20 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } op, rd, err := ws.NextReader() if err != nil { + t.Logf("NextReader: %v", err) return } wr, err := ws.NextWriter(op) if err != nil { + t.Logf("NextWriter: %v", err) return } if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) return } if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) return } } @@ -142,6 +148,7 @@ func sendRecv(t *testing.T, ws *Conn) { } func TestProxyDial(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -180,6 +187,7 @@ func TestProxyDial(t *testing.T) { } func TestProxyAuthorizationDial(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -220,6 +228,7 @@ func TestProxyAuthorizationDial(t *testing.T) { } func TestDial(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -232,6 +241,7 @@ func TestDial(t *testing.T) { } func TestDialCookieJar(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -294,6 +304,7 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { } func TestDialTLS(t *testing.T) { + t.Parallel() s := newTLSServer(t) defer s.Close() @@ -308,6 +319,7 @@ func TestDialTLS(t *testing.T) { } func TestDialTimeout(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -364,6 +376,7 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func TestHandshakeTimeout(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -380,6 +393,7 @@ func TestHandshakeTimeout(t *testing.T) { } func TestHandshakeTimeoutInContext(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -401,6 +415,7 @@ func TestHandshakeTimeoutInContext(t *testing.T) { } func TestDialBadScheme(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -412,6 +427,7 @@ func TestDialBadScheme(t *testing.T) { } func TestDialBadOrigin(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -429,6 +445,7 @@ func TestDialBadOrigin(t *testing.T) { } func TestDialBadHeader(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -448,6 +465,7 @@ func TestDialBadHeader(t *testing.T) { } func TestBadMethod(t *testing.T) { + t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ws, err := cstUpgrader.Upgrade(w, r, nil) if err == nil { @@ -476,6 +494,7 @@ func TestBadMethod(t *testing.T) { } func TestDialExtraTokensInRespHeaders(t *testing.T) { + t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { challengeKey := r.Header.Get("Sec-Websocket-Key") w.Header().Set("Upgrade", "foo, websocket") @@ -493,6 +512,7 @@ func TestDialExtraTokensInRespHeaders(t *testing.T) { } func TestHandshake(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -519,14 +539,13 @@ func TestHandshake(t *testing.T) { } func TestRespOnBadHandshake(t *testing.T) { + t.Parallel() const expectedStatus = http.StatusGone const expectedBody = "This is the response body." s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - if _, err := io.WriteString(w, expectedBody); err != nil { - t.Fatalf("WriteString: %v", err) - } + io.WriteString(w, expectedBody) })) defer s.Close() @@ -559,11 +578,13 @@ type testLogWriter struct { } func (w testLogWriter) Write(p []byte) (int, error) { + w.t.Logf("%s", p) return len(p), nil } // TestHost tests handling of host names and confirms that it matches net/http. func TestHost(t *testing.T) { + t.Parallel() upgrader := Upgrader{} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -745,6 +766,7 @@ func TestHost(t *testing.T) { } func TestDialCompression(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -759,6 +781,7 @@ func TestDialCompression(t *testing.T) { } func TestSocksProxyDial(t *testing.T) { + t.Parallel() s := newServer(t) defer s.Close() @@ -775,10 +798,7 @@ func TestSocksProxyDial(t *testing.T) { } defer c1.Close() - if err := c1.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { - t.Errorf("set deadline failed: %v", err) - return - } + c1.SetDeadline(time.Now().Add(30 * time.Second)) buf := make([]byte, 32) if _, err := io.ReadFull(c1, buf[:3]); err != nil { @@ -817,15 +837,10 @@ func TestSocksProxyDial(t *testing.T) { defer c2.Close() done := make(chan struct{}) go func() { - if _, err := io.Copy(c1, c2); err != nil { - t.Errorf("copy failed: %v", err) - } + io.Copy(c1, c2) close(done) }() - if _, err := io.Copy(c2, c1); err != nil { - t.Errorf("copy failed: %v", err) - return - } + io.Copy(c2, c1) <-done }() @@ -846,6 +861,7 @@ func TestSocksProxyDial(t *testing.T) { } func TestTracingDialWithContext(t *testing.T) { + t.Parallel() var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool trace := &httptrace.ClientTrace{ @@ -905,6 +921,7 @@ func TestTracingDialWithContext(t *testing.T) { } func TestEmptyTracingDialWithContext(t *testing.T) { + t.Parallel() trace := &httptrace.ClientTrace{} ctx := httptrace.WithClientTrace(context.Background(), trace) @@ -926,6 +943,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) { // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext func TestNetDialConnect(t *testing.T) { + t.Parallel() upgrader := Upgrader{} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1101,6 +1119,7 @@ func TestNetDialConnect(t *testing.T) { } } func TestNextProtos(t *testing.T) { + t.Parallel() ts := httptest.NewUnstartedServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), ) diff --git a/client_test.go b/client_test.go index 5aa27b37..982b5466 100644 --- a/client_test.go +++ b/client_test.go @@ -20,6 +20,7 @@ var hostPortNoPortTests = []struct { } func TestHostPortNoPort(t *testing.T) { + t.Parallel() for _, tt := range hostPortNoPortTests { hostPort, hostNoPort := hostPortNoPort(tt.u) if hostPort != tt.hostPort { diff --git a/compression.go b/compression.go index 9fed0ef5..813ffb1e 100644 --- a/compression.go +++ b/compression.go @@ -8,7 +8,6 @@ import ( "compress/flate" "errors" "io" - "log" "strings" "sync" ) @@ -34,9 +33,7 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil { - panic(err) - } + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) return &flateReadWrapper{fr} } @@ -135,9 +132,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after // this final read. - if err := r.Close(); err != nil { - log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err) - } + r.Close() } return n, err } diff --git a/compression_test.go b/compression_test.go index 80cbc4eb..88f7c88a 100644 --- a/compression_test.go +++ b/compression_test.go @@ -12,6 +12,7 @@ type nopCloser struct{ io.Writer } func (nopCloser) Close() error { return nil } func TestTruncWriter(t *testing.T) { + t.Parallel() const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" for n := 1; n <= 10; n++ { var b bytes.Buffer @@ -22,9 +23,7 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - if _, err := w.Write(p[:m]); err != nil { - t.Fatal(err) - } + w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -48,9 +47,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil { - b.Fatal(err) - } + c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } @@ -63,14 +60,13 @@ func BenchmarkWriteWithCompression(b *testing.B) { c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil { - b.Fatal(err) - } + c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } func TestValidCompressionLevel(t *testing.T) { + t.Parallel() c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { diff --git a/conn.go b/conn.go index 221e6cf7..49399b12 100644 --- a/conn.go +++ b/conn.go @@ -10,7 +10,6 @@ import ( "encoding/binary" "errors" "io" - "log" "net" "strconv" "strings" @@ -193,13 +192,6 @@ func newMaskKey() [4]byte { return k } -func hideTempErr(err error) error { - if e, ok := err.(net.Error); ok { - err = &netError{msg: e.Error(), timeout: e.Timeout()} - } - return err -} - func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } @@ -365,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods func (c *Conn) writeFatal(err error) error { - err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { c.writeErr = err @@ -379,9 +370,7 @@ func (c *Conn) read(n int) ([]byte, error) { if err == io.EOF { err = errUnexpectedEOF } - if _, err := c.br.Discard(len(p)); err != nil { - return p, err - } + c.br.Discard(len(p)) return p, err } @@ -396,9 +385,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return err } - if err := c.conn.SetWriteDeadline(deadline); err != nil { - return c.writeFatal(err) - } + c.conn.SetWriteDeadline(deadline) if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { @@ -408,7 +395,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return c.writeFatal(err) } if frameType == CloseMessage { - _ = c.writeFatal(ErrCloseSent) + c.writeFatal(ErrCloseSent) } return nil } @@ -447,21 +434,27 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := 1000 * time.Hour - if !deadline.IsZero() { - d = time.Until(deadline) + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) if d < 0 { return errWriteTimeout } + select { + case <-c.mu: + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + } } - timer := time.NewTimer(d) - select { - case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout - } defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() @@ -471,15 +464,13 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - if err := c.conn.SetWriteDeadline(deadline); err != nil { - return c.writeFatal(err) - } + c.conn.SetWriteDeadline(deadline) _, err = c.conn.Write(buf) if err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - _ = c.writeFatal(ErrCloseSent) + c.writeFatal(ErrCloseSent) } return err } @@ -490,9 +481,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. if c.writer != nil { - if err := c.writer.Close(); err != nil { - log.Printf("websocket: discarding writer close error: %v", err) - } + c.writer.Close() c.writer = nil } @@ -645,7 +634,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if final { - _ = w.endMessage(errWriteClosed) + w.endMessage(errWriteClosed) return nil } @@ -832,9 +821,7 @@ func (c *Conn) advanceFrame() (int, error) { rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - if err := c.setReadRemaining(int64(p[1] & 0x7f)); err != nil { - return noFrame, err - } + c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false if rsv1 { @@ -939,9 +926,7 @@ func (c *Conn) advanceFrame() (int, error) { } if c.readLimit > 0 && c.readLength > c.readLimit { - if err := c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)); err != nil { - return noFrame, err - } + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } @@ -953,9 +938,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - if err := c.setReadRemaining(0); err != nil { - return noFrame, err - } + c.setReadRemaining(0) if err != nil { return noFrame, err } @@ -1002,9 +985,7 @@ func (c *Conn) handleProtocolError(message string) error { if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } - if err := c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)); err != nil { - return err - } + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -1021,9 +1002,7 @@ func (c *Conn) handleProtocolError(message string) error { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { // Close previous reader, only relevant for decompression. if c.reader != nil { - if err := c.reader.Close(); err != nil { - log.Printf("websocket: discarding reader close error: %v", err) - } + c.reader.Close() c.reader = nil } @@ -1033,7 +1012,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { - c.readErr = hideTempErr(err) + c.readErr = err break } @@ -1073,15 +1052,13 @@ func (r *messageReader) Read(b []byte) (int, error) { b = b[:c.readRemaining] } n, err := c.br.Read(b) - c.readErr = hideTempErr(err) + c.readErr = err if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } rem := c.readRemaining rem -= int64(n) - if err := c.setReadRemaining(rem); err != nil { - return 0, err - } + c.setReadRemaining(rem) if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1096,7 +1073,7 @@ func (r *messageReader) Read(b []byte) (int, error) { frameType, err := c.advanceFrame() switch { case err != nil: - c.readErr = hideTempErr(err) + c.readErr = err case frameType == TextMessage || frameType == BinaryMessage: c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } @@ -1163,9 +1140,7 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil { - return err - } + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } @@ -1265,3 +1240,15 @@ func FormatCloseMessage(closeCode int, text string) []byte { copy(buf[2:], text) return buf } + +var messageTypes = map[int]string{ + TextMessage: "TextMessage", + BinaryMessage: "BinaryMessage", + CloseMessage: "CloseMessage", + PingMessage: "PingMessage", + PongMessage: "PongMessage", +} + +func FormatMessageType(mt int) string { + return messageTypes[mt] +} diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index f63f62ff..d8a6492d 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -69,13 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) { select { case msg := <-c.msgCh: if msg.prepared != nil { - if err := c.conn.WritePreparedMessage(msg.prepared); err != nil { - panic(err) - } + c.conn.WritePreparedMessage(msg.prepared) } else { - if err := c.conn.WriteMessage(TextMessage, msg.payload); err != nil { - panic(err) - } + c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { diff --git a/conn_test.go b/conn_test.go index 2b823dd4..564f9645 100644 --- a/conn_test.go +++ b/conn_test.go @@ -54,6 +54,7 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { } func TestFraming(t *testing.T) { + t.Parallel() frameSizes := []int{ 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, // 65536, 65537 @@ -123,6 +124,7 @@ func TestFraming(t *testing.T) { continue } + t.Logf("frame size: %d", n) rbuf, err := io.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) @@ -147,7 +149,49 @@ func TestFraming(t *testing.T) { } } +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + +func TestConcurrencyWriteControl(t *testing.T) { + const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + func TestControl(t *testing.T) { + t.Parallel() const message = "this is a ping/pong messsage" for _, isServer := range []bool{true, false} { for _, isWriteControl := range []bool{true, false} { @@ -156,10 +200,7 @@ func TestControl(t *testing.T) { wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { - t.Errorf("%s: wc.WriteControl() returned %v", name, err) - continue - } + wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -176,9 +217,7 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - if _, _, err := rc.NextReader(); err != nil { - continue - } + rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -204,6 +243,7 @@ func (p *simpleBufferPool) Put(v interface{}) { } func TestWriteBufferPool(t *testing.T) { + t.Parallel() const message = "Now is the time for all good people to come to the aid of the party." var buf bytes.Buffer @@ -282,6 +322,7 @@ func TestWriteBufferPool(t *testing.T) { // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. func TestWriteBufferPoolSync(t *testing.T) { + t.Parallel() var buf bytes.Buffer var pool sync.Pool wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) @@ -310,6 +351,7 @@ func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error // TestWriteBufferPoolError ensures that buffer is returned to pool after error // on write. func TestWriteBufferPoolError(t *testing.T) { + t.Parallel() // Part 1: Test NextWriter/Write/Close @@ -353,6 +395,7 @@ func TestWriteBufferPoolError(t *testing.T) { } func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { + t.Parallel() const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} @@ -362,12 +405,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil { - t.Fatalf("w.Write() returned %v", err) - } - if err := wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)); err != nil { - t.Fatalf("wc.WriteControl() returned %v", err) - } + w.Write(make([]byte, bufSize+bufSize/2)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -385,6 +424,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { } func TestEOFWithinFrame(t *testing.T) { + t.Parallel() const bufSize = 64 for n := 0; ; n++ { @@ -393,9 +433,7 @@ func TestEOFWithinFrame(t *testing.T) { rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - if _, err := w.Write(make([]byte, bufSize)); err != nil { - t.Fatalf("%d: w.Write() returned %v", n, err) - } + w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { @@ -422,6 +460,7 @@ func TestEOFWithinFrame(t *testing.T) { } func TestEOFBeforeFinalFrame(t *testing.T) { + t.Parallel() const bufSize = 512 var b1, b2 bytes.Buffer @@ -429,9 +468,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil { - t.Fatalf("w.Write() returned %v", err) - } + w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -448,11 +485,10 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } func TestWriteAfterMessageWriterClose(t *testing.T) { + t.Parallel() wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - if _, err := io.WriteString(w, "hello"); err != nil { - t.Fatalf("unexpected error writing, %v", err) - } + io.WriteString(w, "hello") if err := w.Close(); err != nil { t.Fatalf("unxpected error closing message writer, %v", err) } @@ -462,9 +498,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - if _, err := io.WriteString(w, "hello"); err != nil { - t.Fatalf("unexpected error writing after getting new writer, %v", err) - } + io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -477,7 +511,29 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } } +func TestWriteHandlerDoesNotReturnErrCloseSent(t *testing.T) { + t.Parallel() + var b1, b2 bytes.Buffer + + client := newTestConn(&b2, &b1, false) + server := newTestConn(&b1, &b2, true) + + msg := FormatCloseMessage(CloseNormalClosure, "") + if err := client.WriteMessage(CloseMessage, msg); err != nil { + t.Fatalf("unexpected error when writing close message, %v", err) + } + + if _, _, err := server.NextReader(); !IsCloseError(err, 1000) { + t.Fatalf("server expects a close message, %v returned", err) + } + + if _, _, err := client.NextReader(); !IsCloseError(err, 1000) { + t.Fatalf("client expects a close message, %v returned", err) + } +} + func TestReadLimit(t *testing.T) { + t.Parallel() t.Run("Test ReadLimit is enforced", func(t *testing.T) { const readLimit = 512 message := make([]byte, readLimit+1) @@ -489,21 +545,13 @@ func TestReadLimit(t *testing.T) { // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) - if _, err := w.Write(message[:readLimit-1]); err != nil { - t.Fatalf("w.WriteMessage() returned %v", err) - } - if err := wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)); err != nil { - t.Fatalf("wc.WriteControl() returned %v", err) - } - if _, err := w.Write(message[:1]); err != nil { - t.Fatalf("w.Write() returned %v", err) - } + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) w.Close() // Send message larger than the limit. - if err := wc.WriteMessage(BinaryMessage, message[:readLimit+1]); err != nil { - t.Fatalf("wc.WriteMessage() returned %v", err) - } + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -576,6 +624,7 @@ func TestReadLimit(t *testing.T) { } func TestAddrs(t *testing.T) { + t.Parallel() c := newTestConn(nil, nil, true) if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) @@ -586,6 +635,7 @@ func TestAddrs(t *testing.T) { } func TestDeprecatedUnderlyingConn(t *testing.T) { + t.Parallel() var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} c := newConn(fc, true, 1024, 1024, nil, nil, nil) @@ -596,6 +646,7 @@ func TestDeprecatedUnderlyingConn(t *testing.T) { } func TestNetConn(t *testing.T) { + t.Parallel() var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} c := newConn(fc, true, 1024, 1024, nil, nil, nil) @@ -606,6 +657,7 @@ func TestNetConn(t *testing.T) { } func TestBufioReadBytes(t *testing.T) { + t.Parallel() // Test calling bufio.ReadBytes for value longer than read buffer size. m := make([]byte, 512) @@ -616,9 +668,7 @@ func TestBufioReadBytes(t *testing.T) { rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - if _, err := w.Write(m); err != nil { - t.Fatalf("w.Write() returned %v", err) - } + w.Write(m) w.Close() op, r, err := rc.NextReader() @@ -648,6 +698,7 @@ var closeErrorTests = []struct { } func TestCloseError(t *testing.T) { + t.Parallel() for _, tt := range closeErrorTests { ok := IsCloseError(tt.err, tt.codes...) if ok != tt.ok { @@ -668,6 +719,7 @@ var unexpectedCloseErrorTests = []struct { } func TestUnexpectedCloseErrors(t *testing.T) { + t.Parallel() for _, tt := range unexpectedCloseErrorTests { ok := IsUnexpectedCloseError(tt.err, tt.codes...) if ok != tt.ok { @@ -689,12 +741,11 @@ func (w blockingWriter) Write(p []byte) (int, error) { } func TestConcurrentWritePanic(t *testing.T) { + t.Parallel() w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { - if err := c.WriteMessage(TextMessage, []byte{}); err != nil { - t.Error(err) - } + c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. @@ -707,9 +758,7 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - if err := c.WriteMessage(TextMessage, []byte{}); err != nil { - t.Error(err) - } + c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } @@ -720,6 +769,7 @@ func (r failingReader) Read(p []byte) (int, error) { } func TestFailedConnectionReadPanic(t *testing.T) { + t.Parallel() c := newTestConn(failingReader{}, nil, false) defer func() { @@ -729,7 +779,46 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - _, _, _ = c.ReadMessage() + c.ReadMessage() } t.Fatal("should not get here") } + +func TestFormatMessageType(t *testing.T) { + str := FormatMessageType(TextMessage) + if str != messageTypes[TextMessage] { + t.Error("failed to format message type") + } + + str = FormatMessageType(CloseMessage) + if str != messageTypes[CloseMessage] { + t.Error("failed to format message type") + } + + str = FormatMessageType(123) + if str != messageTypes[123] { + t.Error("failed to format message type") + } +} + +type fakeNetClosedReader struct { +} + +func (r fakeNetClosedReader) Read([]byte) (int, error) { + return 0, net.ErrClosed +} + +func TestConnectionClosed(t *testing.T) { + var b1, b2 bytes.Buffer + + client := newTestConn(fakeNetClosedReader{}, &b1, false) + server := newTestConn(fakeNetClosedReader{}, &b2, true) + + if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("server expects a net.ErrClosed error, %v returned", err) + } + + if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("client expects a net.ErrClosed error, %v returned", err) + } +} diff --git a/example_test.go b/example_test.go index cd1883b8..cb3a5eb0 100644 --- a/example_test.go +++ b/example_test.go @@ -42,4 +42,4 @@ func processMessage(mt int, p []byte) {} // TestX prevents godoc from showing this entire file in the example. Remove // this function when a second example is added. -func TestX(t *testing.T) {} +func TestX(t *testing.T) { t.Parallel() } diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index 1cd273f6..2d6d36f8 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) { } // echoReadAll echoes messages from the client by reading the entire message -// with ioutil.ReadAll. +// with io.ReadAll. func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -178,11 +178,7 @@ func main() { http.HandleFunc("/r", echoReadAllWriter) http.HandleFunc("/m", echoReadAllWriteMessage) http.HandleFunc("/p", echoReadAllWritePreparedMessage) - server := &http.Server{ - Addr: *addr, - ReadHeaderTimeout: 3 * time.Second, - } - err := server.ListenAndServe() + err := http.ListenAndServe(*addr, nil) if err != nil { log.Fatal("ListenAndServe: ", err) } diff --git a/examples/chat/README.md b/examples/chat/README.md index 7baf3e32..33fcea71 100644 --- a/examples/chat/README.md +++ b/examples/chat/README.md @@ -38,7 +38,7 @@ sends them to the hub. ### Hub The code for the `Hub` type is in -[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). +[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go). The application's `main` function starts the hub's `run` method as a goroutine. Clients send requests to the hub using the `register`, `unregister` and `broadcast` channels. @@ -57,7 +57,7 @@ unregisters the client and closes the websocket. ### Client -The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). +The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go). The `serveWs` function is registered by the application's `main` function as an HTTP handler. The handler upgrades the HTTP connection to the WebSocket @@ -85,7 +85,7 @@ network. ## Frontend -The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). +The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html). On document load, the script checks for websocket functionality in the browser. If websocket functionality is available, then the script opens a connection to diff --git a/examples/chat/main.go b/examples/chat/main.go index 591cb896..474709f6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -8,7 +8,6 @@ import ( "flag" "log" "net/http" - "time" ) var addr = flag.String("addr", ":8080", "http service address") @@ -34,11 +33,7 @@ func main() { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { serveWs(hub, w, r) }) - server := &http.Server{ - Addr: *addr, - ReadHeaderTimeout: 3 * time.Second, - } - err := server.ListenAndServe() + err := http.ListenAndServe(*addr, nil) if err != nil { log.Fatal("ListenAndServe: ", err) } diff --git a/examples/command/main.go b/examples/command/main.go index 11b52e50..0e3cd66c 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -57,8 +57,6 @@ func pumpStdin(ws *websocket.Conn, w io.Writer) { } func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { - defer func() { - }() s := bufio.NewScanner(r) for s.Scan() { ws.SetWriteDeadline(time.Now().Add(writeWait)) @@ -189,9 +187,5 @@ func main() { } http.HandleFunc("/", serveHome) http.HandleFunc("/ws", serveWs) - server := &http.Server{ - Addr: *addr, - ReadHeaderTimeout: 3 * time.Second, - } - log.Fatal(server.ListenAndServe()) + log.Fatal(http.ListenAndServe(*addr, nil)) } diff --git a/examples/echo/client.go b/examples/echo/client.go index 7d870bdf..d53b81c8 100644 --- a/examples/echo/client.go +++ b/examples/echo/client.go @@ -41,12 +41,12 @@ func main() { go func() { defer close(done) for { - _, message, err := c.ReadMessage() + mt, message, err := c.ReadMessage() if err != nil { log.Println("read:", err) return } - log.Printf("recv: %s", message) + log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt)) } }() diff --git a/examples/echo/server.go b/examples/echo/server.go index f9a0b7b5..9804e6b2 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -33,7 +33,8 @@ func echo(w http.ResponseWriter, r *http.Request) { log.Println("read:", err) break } - log.Printf("recv: %s", message) + + log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt)) err = c.WriteMessage(mt, message) if err != nil { log.Println("write:", err) diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index ddb613cf..a3c38d8b 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -7,7 +7,6 @@ package main import ( "flag" "html/template" - "io/ioutil" "log" "net/http" "os" @@ -50,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { if !fi.ModTime().After(lastMod) { return nil, lastMod, nil } - p, err := ioutil.ReadFile(filepath.Clean(filename)) + p, err := os.ReadFile(filepath.Clean(filename)) if err != nil { return nil, fi.ModTime(), err } @@ -164,11 +163,7 @@ func main() { filename = flag.Args()[0] http.HandleFunc("/", serveHome) http.HandleFunc("/ws", serveWs) - server := &http.Server{ - Addr: *addr, - ReadHeaderTimeout: 3 * time.Second, - } - if err := server.ListenAndServe(); err != nil { + if err := http.ListenAndServe(*addr, nil); err != nil { log.Fatal(err) } } diff --git a/go.mod b/go.mod index 4f905673..2d73fad6 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/gorilla/websocket go 1.20 -require golang.org/x/net v0.17.0 +require golang.org/x/net v0.23.0 diff --git a/go.sum b/go.sum index 0d6f4548..199e74ae 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= diff --git a/join_test.go b/join_test.go index f06601da..fcf0980c 100644 --- a/join_test.go +++ b/join_test.go @@ -12,6 +12,7 @@ import ( ) func TestJoinMessages(t *testing.T) { + t.Parallel() messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { for _, term := range []string{"", ","} { @@ -19,9 +20,7 @@ func TestJoinMessages(t *testing.T) { wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { - if err := wc.WriteMessage(BinaryMessage, []byte(m)); err != nil { - t.Fatalf("write %q: %v", m, err) - } + wc.WriteMessage(BinaryMessage, []byte(m)) } var result bytes.Buffer diff --git a/json_test.go b/json_test.go index e4c4bdfe..3e954352 100644 --- a/json_test.go +++ b/json_test.go @@ -13,6 +13,7 @@ import ( ) func TestJSON(t *testing.T) { + t.Parallel() var buf bytes.Buffer wc := newTestConn(nil, &buf, true) rc := newTestConn(&buf, nil, false) @@ -38,6 +39,7 @@ func TestJSON(t *testing.T) { } func TestPartialJSONRead(t *testing.T) { + t.Parallel() var buf0, buf1 bytes.Buffer wc := newTestConn(nil, &buf0, true) rc := newTestConn(&buf0, &buf1, false) @@ -91,6 +93,7 @@ func TestPartialJSONRead(t *testing.T) { } func TestDeprecatedJSON(t *testing.T) { + t.Parallel() var buf bytes.Buffer wc := newTestConn(nil, &buf, true) rc := newTestConn(&buf, nil, false) diff --git a/mask.go b/mask.go index 67d0968b..d0742bf2 100644 --- a/mask.go +++ b/mask.go @@ -9,7 +9,6 @@ package websocket import "unsafe" -// #nosec G103 -- (CWE-242) Has been audited const wordSize = int(unsafe.Sizeof(uintptr(0))) func maskBytes(key [4]byte, pos int, b []byte) int { @@ -23,7 +22,6 @@ func maskBytes(key [4]byte, pos int, b []byte) int { } // Mask one byte at a time to word boundary. - //#nosec G103 -- (CWE-242) Has been audited if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { n = wordSize - n for i := range b[:n] { @@ -38,13 +36,11 @@ func maskBytes(key [4]byte, pos int, b []byte) int { for i := range k { k[i] = key[(pos+i)&3] } - //#nosec G103 -- (CWE-242) Has been audited kw := *(*uintptr)(unsafe.Pointer(&k)) // Mask one word at a time. n := (len(b) / wordSize) * wordSize for i := 0; i < n; i += wordSize { - //#nosec G103 -- (CWE-242) Has been audited *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw } diff --git a/mask_test.go b/mask_test.go index 6389f436..718a6b4e 100644 --- a/mask_test.go +++ b/mask_test.go @@ -29,6 +29,7 @@ func notzero(b []byte) int { } func TestMaskBytes(t *testing.T) { + t.Parallel() key := [4]byte{1, 2, 3, 4} for size := 1; size <= 1024; size++ { for align := 0; align < wordSize; align++ { diff --git a/prepared_test.go b/prepared_test.go index dc77ef0d..536d58d9 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -45,10 +45,7 @@ func TestPreparedMessage(t *testing.T) { if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } - - if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { - t.Fatal(err) - } + c.SetCompressionLevel(tt.compressionLevel) // Seed random number generator for consistent frame mask. testRand.Seed(1234) @@ -76,7 +73,7 @@ func TestPreparedMessage(t *testing.T) { got := buf.String() if got != want { - t.Errorf("write message != prepared message, got %#v, want %#v", got, want) + t.Errorf("write message != prepared message for %+v", tt) } } } diff --git a/proxy.go b/proxy.go index 80f55d1e..3c570c26 100644 --- a/proxy.go +++ b/proxy.go @@ -8,7 +8,6 @@ import ( "bufio" "encoding/base64" "errors" - "log" "net" "net/http" "net/url" @@ -58,9 +57,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } if err := connectReq.Write(conn); err != nil { - if err := conn.Close(); err != nil { - log.Printf("httpProxyDialer: failed to close connection: %v", err) - } + conn.Close() return nil, err } @@ -69,16 +66,12 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) br := bufio.NewReader(conn) resp, err := http.ReadResponse(br, connectReq) if err != nil { - if err := conn.Close(); err != nil { - log.Printf("httpProxyDialer: failed to close connection: %v", err) - } + conn.Close() return nil, err } - if resp.StatusCode != 200 { - if err := conn.Close(); err != nil { - log.Printf("httpProxyDialer: failed to close connection: %v", err) - } + if resp.StatusCode != http.StatusOK { + conn.Close() f := strings.SplitN(resp.Status, " ", 2) return nil, errors.New(f[1]) } diff --git a/server.go b/server.go index 1e720e1d..fda75ff0 100644 --- a/server.go +++ b/server.go @@ -8,7 +8,6 @@ import ( "bufio" "errors" "io" - "log" "net/http" "net/url" "strings" @@ -34,6 +33,7 @@ type Upgrader struct { // size is zero, then buffers allocated by the HTTP server are used. The // I/O buffer sizes do not limit the size of the messages that can be sent // or received. + // The default value is 4096 bytes, 4kb. ReadBufferSize, WriteBufferSize int // WriteBufferPool is a pool of buffers for write operations. If the value @@ -102,8 +102,8 @@ func checkSameOrigin(r *http.Request) bool { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(r) - for _, serverProtocol := range u.Subprotocols { - for _, clientProtocol := range clientProtocols { + for _, clientProtocol := range clientProtocols { + for _, serverProtocol := range u.Subprotocols { if clientProtocol == serverProtocol { return clientProtocol } @@ -173,20 +173,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } } - h, ok := w.(http.Hijacker) - if !ok { - return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") - } - var brw *bufio.ReadWriter - netConn, brw, err := h.Hijack() + netConn, brw, err := http.NewResponseController(w).Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } if brw.Reader.Buffered() > 0 { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } + netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } @@ -251,34 +244,17 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, "\r\n"...) // Clear deadlines set by HTTP server. - if err := netConn.SetDeadline(time.Time{}); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } + netConn.SetDeadline(time.Time{}) if u.HandshakeTimeout > 0 { - if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } if _, err = netConn.Write(p); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } + netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { - if err := netConn.SetWriteDeadline(time.Time{}); err != nil { - if err := netConn.Close(); err != nil { - log.Printf("websocket: failed to close network connection: %v", err) - } - return nil, err - } + netConn.SetWriteDeadline(time.Time{}) } return c, nil @@ -376,12 +352,8 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { // bufio.Writer's underlying writer. var wh writeHook bw.Reset(&wh) - if err := bw.WriteByte(0); err != nil { - panic(err) - } - if err := bw.Flush(); err != nil { - log.Printf("websocket: bufioWriterBuffer: Flush: %v", err) - } + bw.WriteByte(0) + bw.Flush() bw.Reset(originalWriter) diff --git a/server_test.go b/server_test.go index 5804be13..d7eb8806 100644 --- a/server_test.go +++ b/server_test.go @@ -7,8 +7,10 @@ package websocket import ( "bufio" "bytes" + "errors" "net" "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -27,6 +29,7 @@ var subprotocolTests = []struct { } func TestSubprotocols(t *testing.T) { + t.Parallel() for _, st := range subprotocolTests { r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} protocols := Subprotocols(&r) @@ -46,6 +49,7 @@ var isWebSocketUpgradeTests = []struct { } func TestIsWebSocketUpgrade(t *testing.T) { + t.Parallel() for _, tt := range isWebSocketUpgradeTests { ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) if tt.ok != ok { @@ -54,6 +58,37 @@ func TestIsWebSocketUpgrade(t *testing.T) { } } +func TestSubProtocolSelection(t *testing.T) { + t.Parallel() + upgrader := Upgrader{ + Subprotocols: []string{"foo", "bar", "baz"}, + } + + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} + s := upgrader.selectSubprotocol(&r, nil) + if s != "foo" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "bar" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "baz" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") + } +} + var checkSameOriginTests = []struct { ok bool r *http.Request @@ -64,6 +99,7 @@ var checkSameOriginTests = []struct { } func TestCheckSameOrigin(t *testing.T) { + t.Parallel() for _, tt := range checkSameOriginTests { ok := checkSameOrigin(tt.r) if tt.ok != ok { @@ -90,6 +126,7 @@ var bufioReuseTests = []struct { } func TestBufioReuse(t *testing.T) { + t.Parallel() for i, tt := range bufioReuseTests { br := bufio.NewReaderSize(strings.NewReader(""), tt.n) bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) @@ -117,3 +154,23 @@ func TestBufioReuse(t *testing.T) { } } } + +func TestHijack_NotSupported(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-Websocket-Version", "13") + + recorder := httptest.NewRecorder() + + upgrader := Upgrader{} + _, err := upgrader.Upgrade(recorder, req, nil) + + if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { + t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) + t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) + } +} diff --git a/tls_handshake_116.go b/tls_handshake_116.go new file mode 100644 index 00000000..e1b2b44f --- /dev/null +++ b/tls_handshake_116.go @@ -0,0 +1,21 @@ +//go:build !go1.17 +// +build !go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/util.go b/util.go index 9b1a629b..31a5dee6 100644 --- a/util.go +++ b/util.go @@ -6,7 +6,7 @@ package websocket import ( "crypto/rand" - "crypto/sha1" //#nosec G505 -- (CWE-327) https://datatracker.ietf.org/doc/html/rfc6455#page-54 + "crypto/sha1" "encoding/base64" "io" "net/http" @@ -17,7 +17,7 @@ import ( var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func computeAcceptKey(challengeKey string) string { - h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 + h := sha1.New() h.Write([]byte(challengeKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) diff --git a/util_test.go b/util_test.go index f14d69a1..70621d4c 100644 --- a/util_test.go +++ b/util_test.go @@ -21,6 +21,7 @@ var equalASCIIFoldTests = []struct { } func TestEqualASCIIFold(t *testing.T) { + t.Parallel() for _, tt := range equalASCIIFoldTests { eq := equalASCIIFold(tt.s, tt.t) if eq != tt.eq { @@ -44,6 +45,7 @@ var tokenListContainsValueTests = []struct { } func TestTokenListContainsValue(t *testing.T) { + t.Parallel() for _, tt := range tokenListContainsValueTests { h := http.Header{"Upgrade": {tt.value}} ok := tokenListContainsValue(h, "Upgrade", "websocket") @@ -64,6 +66,7 @@ var isValidChallengeKeyTests = []struct { } func TestIsValidChallengeKey(t *testing.T) { + t.Parallel() for _, tt := range isValidChallengeKeyTests { ok := isValidChallengeKey(tt.key) if ok != tt.ok { @@ -105,6 +108,7 @@ var parseExtensionTests = []struct { } func TestParseExtensions(t *testing.T) { + t.Parallel() for _, tt := range parseExtensionTests { h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} extensions := parseExtensions(h) diff --git a/vendor/modules.txt b/vendor/modules.txt index 080a92fe..ecf9c031 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# golang.org/x/net v0.17.0 -## explicit; go 1.17 +# golang.org/x/net v0.23.0 +## explicit; go 1.18 golang.org/x/net/internal/socks golang.org/x/net/proxy diff --git a/x_net_proxy.go b/x_net_proxy.go new file mode 100644 index 00000000..2e668f6b --- /dev/null +++ b/x_net_proxy.go @@ -0,0 +1,473 @@ +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy + +// Package proxy provides support for a variety of protocols to proxy network +// data. +// + +package websocket + +import ( + "errors" + "io" + "net" + "net/url" + "os" + "strconv" + "strings" + "sync" +) + +type proxy_direct struct{} + +// Direct is a direct proxy: one that makes network connections directly. +var proxy_Direct = proxy_direct{} + +func (proxy_direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type proxy_PerHost struct { + def, bypass proxy_Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { + return &proxy_PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *proxy_PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *proxy_PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *proxy_PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *proxy_PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} + +// A Dialer is a means to establish a connection. +type proxy_Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type proxy_Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy related variables in +// the environment. +func proxy_FromEnvironment() proxy_Dialer { + allProxy := proxy_allProxyEnv.Get() + if len(allProxy) == 0 { + return proxy_Direct + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return proxy_Direct + } + proxy, err := proxy_FromURL(proxyURL, proxy_Direct) + if err != nil { + return proxy_Direct + } + + noProxy := proxy_noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := proxy_NewPerHost(proxy, proxy_Direct) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { + if proxy_proxySchemes == nil { + proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) + } + proxy_proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { + var auth *proxy_Auth + if u.User != nil { + auth = new(proxy_Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5": + return proxy_SOCKS5("tcp", u.Host, auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxy_proxySchemes != nil { + if f, ok := proxy_proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + proxy_allProxyEnv = &proxy_envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + proxy_noProxyEnv = &proxy_envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type proxy_envOnce struct { + names []string + once sync.Once + val string +} + +func (e *proxy_envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *proxy_envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928 and RFC 1929. +func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { + s := &proxy_socks5{ + network: network, + addr: addr, + forward: forward, + } + if auth != nil { + s.user = auth.User + s.password = auth.Password + } + + return s, nil +} + +type proxy_socks5 struct { + user, password string + network, addr string + forward proxy_Dialer +} + +const proxy_socks5Version = 5 + +const ( + proxy_socks5AuthNone = 0 + proxy_socks5AuthPassword = 2 +) + +const proxy_socks5Connect = 1 + +const ( + proxy_socks5IP4 = 1 + proxy_socks5Domain = 3 + proxy_socks5IP6 = 4 +) + +var proxy_socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Dial connects to the address addr on the given network via the SOCKS5 proxy. +func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) + } + + conn, err := s.forward.Dial(s.network, s.addr) + if err != nil { + return nil, err + } + if err := s.connect(conn, addr); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *proxy_socks5) connect(conn net.Conn, target string) error { + host, portStr, err := net.SplitHostPort(target) + if err != nil { + return err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return errors.New("proxy: port number out of range: " + portStr) + } + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, proxy_socks5Version) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) + } else { + buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) + } + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + if buf[0] != 5 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + // See RFC 1929 + if buf[1] == proxy_socks5AuthPassword { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if buf[1] != 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + + buf = buf[:0] + buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) + + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, proxy_socks5IP4) + ip = ip4 + } else { + buf = append(buf, proxy_socks5IP6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return errors.New("proxy: destination host name too long: " + host) + } + buf = append(buf, proxy_socks5Domain) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:4]); err != nil { + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + failure := "unknown error" + if int(buf[1]) < len(proxy_socks5Errors) { + failure = proxy_socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case proxy_socks5IP4: + bytesToDiscard = net.IPv4len + case proxy_socks5IP6: + bytesToDiscard = net.IPv6len + case proxy_socks5Domain: + _, err := io.ReadFull(conn, buf[:1]) + if err != nil { + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + if _, err := io.ReadFull(conn, buf); err != nil { + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + // Also need to discard the port number + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + return nil +}