diff --git a/internal/mux/response.go b/internal/mux/response.go index 6b1ce906..6a10df81 100644 --- a/internal/mux/response.go +++ b/internal/mux/response.go @@ -15,6 +15,12 @@ type ResponseWriter interface { Size() int64 } +func NewResponseWriter(w http.ResponseWriter) ResponseWriter { //nolint:ireturn // Return a pointer + return &response{ + ResponseWriter: w, + } +} + var ( _ http.ResponseWriter = &response{} _ http.Flusher = &response{} diff --git a/pkg/oci/containerd.go b/pkg/oci/containerd.go index f4cf27f5..9ab226f3 100644 --- a/pkg/oci/containerd.go +++ b/pkg/oci/containerd.go @@ -247,7 +247,7 @@ func (c *Containerd) GetManifest(ctx context.Context, dgst digest.Digest) ([]byt return b, mt, nil } -func (c *Containerd) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadCloser, error) { +func (c *Containerd) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadSeekCloser, error) { if c.contentPath != "" { path := filepath.Join(c.contentPath, "blobs", dgst.Algorithm().String(), dgst.Encoded()) file, err := os.Open(path) @@ -270,13 +270,11 @@ func (c *Containerd) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadCl if err != nil { return nil, err } - return struct { - io.Reader - io.Closer - }{ - Reader: content.NewReader(ra), - Closer: ra, - }, nil + rs, err := newHTTPReadSeeker(logr.FromContextOrDiscard(ctx), ra) + if err != nil { + return nil, err + } + return rs, nil } func getEventImage(e typeurl.Any) (string, EventType, error) { diff --git a/pkg/oci/httpreadseeker.go b/pkg/oci/httpreadseeker.go new file mode 100644 index 00000000..e4ae378d --- /dev/null +++ b/pkg/oci/httpreadseeker.go @@ -0,0 +1,124 @@ +package oci + +import ( + "io" + + "github.com/containerd/containerd/content" + "github.com/containerd/errdefs" + "github.com/go-logr/logr" + "github.com/pkg/errors" +) + +const maxRetry = 3 + +type httpReadSeeker struct { + rc content.ReaderAt + log logr.Logger + size int64 + offset int64 + errsWithNoProgress int + closed bool +} + +var _ io.ReadSeekCloser = (*httpReadSeeker)(nil) + +func newHTTPReadSeeker(log logr.Logger, reader content.ReaderAt) (io.ReadSeekCloser, error) { + if reader == nil { + return nil, errors.Errorf("httpReadSeeker: reader cannot be nil") + } + return &httpReadSeeker{ + log: log, + rc: reader, + size: reader.Size(), + }, nil +} + +func (hrs *httpReadSeeker) Read(p []byte) (n int, err error) { + if hrs.closed { + return 0, io.EOF + } + + n, err = hrs.rc.ReadAt(p, hrs.offset) + hrs.offset += int64(n) + if n > 0 || err == nil { + hrs.errsWithNoProgress = 0 + } + if err == io.ErrUnexpectedEOF { + // connection closed unexpectedly. try reconnecting. + if n == 0 { + hrs.errsWithNoProgress++ + if hrs.errsWithNoProgress > maxRetry { + return // too many retries for this offset with no progress + } + } + if hrs.rc != nil { + if clsErr := hrs.rc.Close(); clsErr != nil { + hrs.log.Error(clsErr, "httpReadSeeker: failed to close ReadCloser") + } + hrs.rc = nil + } + + } else if err == io.EOF { + // The CRI's imagePullProgressTimeout relies on responseBody.Close to + // update the process monitor's status. If the err is io.EOF, close + // the connection since there is no more available data. + if hrs.rc != nil { + if clsErr := hrs.rc.Close(); clsErr != nil { + hrs.log.Error(clsErr, "httpReadSeeker: failed to close ReadCloser after io.EOF") + } + hrs.rc = nil + } + } + return +} + +func (hrs *httpReadSeeker) Close() error { + if hrs.closed { + return nil + } + hrs.closed = true + if hrs.rc != nil { + return hrs.rc.Close() + } + + return nil +} + +func (hrs *httpReadSeeker) Seek(offset int64, whence int) (int64, error) { + if hrs.closed { + return 0, errors.Errorf("Fetcher.Seek: closed: %v", errdefs.ErrUnavailable) + } + + abs := hrs.offset + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs += offset + case io.SeekEnd: + if hrs.size == -1 { + return 0, errors.Errorf("Fetcher.Seek: unknown size, cannot seek from end: %v", errdefs.ErrUnavailable) + } + abs = hrs.size + offset + default: + return 0, errors.Errorf("Fetcher.Seek: invalid whence: %v", errdefs.ErrInvalidArgument) + } + + if abs < 0 { + return 0, errors.Errorf("Fetcher.Seek: negative offset: %v", errdefs.ErrInvalidArgument) + } + + if abs != hrs.offset { + if hrs.rc != nil { + if err := hrs.rc.Close(); err != nil { + hrs.log.Error(err, "Fetcher.Seek: failed to close ReadCloser") + } + + hrs.rc = nil + } + + hrs.offset = abs + } + + return hrs.offset, nil +} diff --git a/pkg/oci/memory.go b/pkg/oci/memory.go index 7bfd4dcc..43ab6e54 100644 --- a/pkg/oci/memory.go +++ b/pkg/oci/memory.go @@ -84,7 +84,7 @@ func (m *Memory) GetManifest(ctx context.Context, dgst digest.Digest) ([]byte, s return b, mt, nil } -func (m *Memory) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadCloser, error) { +func (m *Memory) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadSeekCloser, error) { m.mx.RLock() defer m.mx.RUnlock() @@ -92,8 +92,14 @@ func (m *Memory) GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadCloser if !ok { return nil, errors.Join(ErrNotFound, fmt.Errorf("blob with digest %s not found", dgst)) } - rc := io.NopCloser(bytes.NewBuffer(b)) - return rc, nil + rc := bytes.NewReader(b) + return struct { + io.ReadSeeker + io.Closer + }{ + ReadSeeker: rc, + Closer: io.NopCloser(rc), + }, nil } func (m *Memory) AddImage(img Image) { diff --git a/pkg/oci/oci.go b/pkg/oci/oci.go index ab702e15..70767834 100644 --- a/pkg/oci/oci.go +++ b/pkg/oci/oci.go @@ -44,7 +44,7 @@ type Client interface { // GetBlob returns a stream of the blob content for the given digest. // Will return ErrNotFound if the digest cannot be found. - GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadCloser, error) + GetBlob(ctx context.Context, dgst digest.Digest) (io.ReadSeekCloser, error) } type UnknownDocument struct { diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 256b51ab..95c30752 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/http" "net/http/httputil" @@ -345,7 +344,7 @@ func (r *Registry) handleBlob(rw mux.ResponseWriter, req *http.Request, ref refe if req.Method == http.MethodHead { return } - var w io.Writer = rw + var w mux.ResponseWriter = rw if r.throttler != nil { w = r.throttler.Writer(rw) } @@ -355,11 +354,8 @@ func (r *Registry) handleBlob(rw mux.ResponseWriter, req *http.Request, ref refe return } defer rc.Close() - _, err = io.Copy(w, rc) - if err != nil { - r.log.Error(err, "error occurred when copying blob") - return - } + + http.ServeContent(w, req, ref.dgst.String(), time.Time{}, rc) } func (r *Registry) isExternalRequest(req *http.Request) bool { diff --git a/pkg/throttle/throttle.go b/pkg/throttle/throttle.go index 9ab6d41b..e4af63b9 100644 --- a/pkg/throttle/throttle.go +++ b/pkg/throttle/throttle.go @@ -2,9 +2,9 @@ package throttle import ( "fmt" - "io" "time" + "github.com/spegel-org/spegel/internal/mux" "golang.org/x/time/rate" ) @@ -22,20 +22,20 @@ func NewThrottler(br Byterate) *Throttler { } } -func (t *Throttler) Writer(w io.Writer) io.Writer { +func (t *Throttler) Writer(w mux.ResponseWriter) mux.ResponseWriter { //nolint:ireturn // Retrun is a pointer return &writer{ - limiter: t.limiter, - writer: w, + limiter: t.limiter, + ResponseWriter: w, } } type writer struct { + mux.ResponseWriter limiter *rate.Limiter - writer io.Writer } func (w *writer) Write(p []byte) (int, error) { - n, err := w.writer.Write(p) + n, err := w.ResponseWriter.Write(p) if err != nil { return 0, err } diff --git a/pkg/throttle/throttle_test.go b/pkg/throttle/throttle_test.go index 5aaf4d97..4e35a0e4 100644 --- a/pkg/throttle/throttle_test.go +++ b/pkg/throttle/throttle_test.go @@ -1,10 +1,11 @@ package throttle import ( - "bytes" + "net/http/httptest" "testing" "time" + "github.com/spegel-org/spegel/internal/mux" "github.com/stretchr/testify/require" ) @@ -13,7 +14,7 @@ func TestThrottler(t *testing.T) { br := 500 * Bps throttler := NewThrottler(br) - w := throttler.Writer(bytes.NewBuffer([]byte{})) + w := throttler.Writer(mux.NewResponseWriter(httptest.NewRecorder())) chunkSize := 100 start := time.Now() for i := 0; i < 10; i++ {