From 3d43ff47587dddbf74a1877b1e9cdbac88083427 Mon Sep 17 00:00:00 2001 From: Louis Seubert Date: Mon, 22 Jul 2024 21:13:54 +0200 Subject: [PATCH] feat: add cache helpers --- cache.go | 9 ++ cache/blob.go | 56 +++++++++ cache/cache.go | 329 ++++++++++++++++++++++++++++++++++++++++++++++++ cache/reader.go | 89 +++++++++++++ cache/retry.go | 42 +++++++ cmd/main.go | 4 +- go.mod | 2 + go.sum | 2 + 8 files changed, 532 insertions(+), 1 deletion(-) create mode 100644 cache.go create mode 100644 cache/blob.go create mode 100644 cache/cache.go create mode 100644 cache/reader.go create mode 100644 cache/retry.go create mode 100644 go.sum diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..6fca1f5 --- /dev/null +++ b/cache.go @@ -0,0 +1,9 @@ +package sdk + +import "git.geekeey.de/actions/sdk/cache" + +func (c *Action) Cache() *cache.Client { + c.env("ACTIONS_CACHE_URL") + c.env("ACTIONS_RUNTIME_TOKEN") + return cache.New("", "") +} diff --git a/cache/blob.go b/cache/blob.go new file mode 100644 index 0000000..d2cbca3 --- /dev/null +++ b/cache/blob.go @@ -0,0 +1,56 @@ +package cache + +import ( + "bytes" + "io" + "os" +) + +type Blob interface { + io.ReaderAt + io.Closer + Size() int64 +} + +type byteBlob struct { + buf *bytes.Reader +} + +func NewByteBlob(b []byte) Blob { + return &byteBlob{buf: bytes.NewReader(b)} +} + +func (blob *byteBlob) ReadAt(p []byte, off int64) (n int, err error) { + return blob.buf.ReadAt(p, off) +} + +func (blob *byteBlob) Size() int64 { + return blob.buf.Size() +} + +func (blob *byteBlob) Close() error { + return nil +} + +type fileBlob struct { + buf *os.File +} + +func NewFileBlob(f *os.File) Blob { + return &fileBlob{buf: f} +} + +func (blob *fileBlob) ReadAt(p []byte, off int64) (n int, err error) { + return blob.buf.ReadAt(p, off) +} + +func (blob *fileBlob) Size() int64 { + if i, err := blob.buf.Stat(); err != nil { + return i.Size() + } + return 0 +} + +func (blob *fileBlob) Close() error { + return nil +} diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..bc7ad6b --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,329 @@ +package cache + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "strings" + "sync" + + "golang.org/x/sync/errgroup" +) + +var UploadConcurrency = 4 +var UploadChunkSize = 32 * 1024 * 1024 + +type Client struct { + base string + http *http.Client +} + +type auth struct { + transport http.RoundTripper + token string +} + +func (t *auth) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token)) + return t.transport.RoundTrip(req) +} + +func New(token, url string) *Client { + t := &auth{transport: &retry{transport: &http.Transport{}}, token: token} + return &Client{ + base: url, + http: &http.Client{Transport: t}, + } +} + +func (c *Client) url(p string) string { + return path.Join(c.base, "_apis/artifactcache", p) +} + +func (c *Client) version(k string) string { + h := sha256.New() + h.Write([]byte("|go-actionscache-1.0")) + return hex.EncodeToString(h.Sum(nil)) +} + +type ApiError struct { + Message string `json:"message"` + TypeName string `json:"typeName"` + TypeKey string `json:"typeKey"` + ErrorCode int `json:"errorCode"` +} + +func (e ApiError) Error() string { + return e.Message +} + +func (e ApiError) Is(err error) bool { + if err == os.ErrExist { + if strings.Contains(e.TypeKey, "AlreadyExists") { + return true + } + } + return false +} + +func checkApiError(res *http.Response) error { + if res.StatusCode >= 200 && res.StatusCode < 300 { + return nil + } + dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) + + var details ApiError + if err := dec.Decode(&details); err != nil { + return err + } + + if details.Message != "" { + return details + } else { + return fmt.Errorf("unknown error %s", res.Status) + } +} + +func (c *Client) Load(ctx context.Context, keys ...string) (*Entry, error) { + u, err := url.Parse(c.url("cache")) + if err != nil { + return nil, err + } + q := u.Query() + q.Set("keys", strings.Join(keys, ",")) + q.Set("version", c.version(keys[0])) + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/json;api-version=6.0-preview.1") + + res, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + err = checkApiError(res) + if err != nil { + return nil, err + } + + dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) + + var ce Entry + if err = dec.Decode(&ce); err != nil { + return nil, err + } + + ce.http = c.http + return &ce, nil +} + +func (c *Client) Save(ctx context.Context, key string, b Blob) error { + id, err := c.reserve(ctx, key) + if err != nil { + return err + } + err = c.upload(ctx, id, b) + if err != nil { + return err + } + return c.commit(ctx, id, b.Size()) +} + +type ReserveCacheReq struct { + Key string `json:"key"` + Version string `json:"version"` +} + +type ReserveCacheRes struct { + CacheID int `json:"cacheID"` +} + +func (c *Client) reserve(ctx context.Context, key string) (int, error) { + payload := ReserveCacheReq{Key: key, Version: c.version(key)} + + buf := new(bytes.Buffer) + if err := json.NewEncoder(buf).Encode(payload); err != nil { + return 0, err + } + + url := c.url("caches") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf) + if err != nil { + return 0, err + } + req.Header.Add("Content-Type", "application/json") + + res, err := c.http.Do(req) + if err != nil { + return 0, err + } + defer res.Body.Close() + + err = checkApiError(res) + if err != nil { + return 0, err + } + + dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) + + var cr ReserveCacheRes + if err = dec.Decode(&cr); err != nil { + return 0, err + } + + if cr.CacheID == 0 { + return 0, fmt.Errorf("invalid response (cache id is 0)") + } + return cr.CacheID, nil +} + +type CommitCacheReq struct { + Size int64 `json:"size"` +} + +func (c *Client) commit(ctx context.Context, id int, size int64) error { + payload := CommitCacheReq{Size: size} + + buf := new(bytes.Buffer) + if err := json.NewEncoder(buf).Encode(payload); err != nil { + return err + } + + url := c.url(fmt.Sprintf("caches/%d", id)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf) + if err != nil { + return err + } + req.Header.Add("Content-Type", "application/json") + + res, err := c.http.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + err = checkApiError(res) + if err != nil { + return err + } + + return nil +} + +func (c *Client) upload(ctx context.Context, id int, b Blob) error { + var mu sync.Mutex + grp, ctx := errgroup.WithContext(ctx) + offset := int64(0) + for i := 0; i < UploadConcurrency; i++ { + grp.Go(func() error { + for { + mu.Lock() + start := offset + if start >= b.Size() { + mu.Unlock() + return nil + } + end := start + int64(UploadChunkSize) + if end > b.Size() { + end = b.Size() + } + offset = end + mu.Unlock() + + if err := c.create(ctx, id, b, start, end-start); err != nil { + return err + } + } + }) + } + return grp.Wait() +} + +func (c *Client) create(ctx context.Context, id int, ra io.ReaderAt, off, n int64) error { + url := c.url(fmt.Sprintf("caches/%d", id)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, io.NewSectionReader(ra, off, n)) + if err != nil { + return err + } + req.Header.Add("Content-Type", "application/octet-stream") + req.Header.Add("Content-Range", fmt.Sprintf("bytes %d-%d/*", off, off+n-1)) + + res, err := c.http.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + err = checkApiError(res) + if err != nil { + return err + } + + return nil +} + +type Entry struct { + Key string `json:"cacheKey"` + Scope string `json:"scope"` + URL string `json:"archiveLocation"` + + http *http.Client +} + +// Download returns a ReaderAtCloser for pulling the data. Concurrent reads are not allowed +func (ce *Entry) Download(ctx context.Context) ReaderAtCloser { + return NewReaderAtCloser(func(offset int64) (io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, "GET", ce.URL, nil) + if err != nil { + return nil, err + } + if offset != 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", offset)) + } + client := ce.http + if client == nil { + client = http.DefaultClient + } + + res, err := client.Do(req) + if err != nil { + return nil, err + } + + if res.StatusCode < 200 || res.StatusCode >= 300 { + if res.StatusCode == http.StatusRequestedRangeNotSatisfiable { + return nil, fmt.Errorf("invalid status response %v for %s, range: %v", res.Status, ce.URL, req.Header.Get("Range")) + } + return nil, fmt.Errorf("invalid status response %v for %s", res.Status, ce.URL) + } + if offset != 0 { + cr := res.Header.Get("content-range") + if !strings.HasPrefix(cr, fmt.Sprintf("bytes %d-", offset)) { + res.Body.Close() + return nil, fmt.Errorf("unhandled content range in response: %v", cr) + } + } + return res.Body, nil + }) +} + +func (ce *Entry) WriteTo(ctx context.Context, w io.Writer) error { + rac := ce.Download(ctx) + if _, err := io.Copy(w, &rc{ReaderAt: rac}); err != nil { + return err + } + return rac.Close() +} diff --git a/cache/reader.go b/cache/reader.go new file mode 100644 index 0000000..9225a76 --- /dev/null +++ b/cache/reader.go @@ -0,0 +1,89 @@ +package cache + +import ( + "io" +) + +type ReaderAtCloser interface { + io.ReaderAt + io.Closer +} + +type readerAtCloser struct { + offset int64 + rc io.ReadCloser + ra io.ReaderAt + open func(offset int64) (io.ReadCloser, error) + closed bool +} + +func NewReaderAtCloser(open func(offset int64) (io.ReadCloser, error)) ReaderAtCloser { + return &readerAtCloser{ + open: open, + } +} + +func (hrs *readerAtCloser) ReadAt(p []byte, off int64) (n int, err error) { + if hrs.closed { + return 0, io.EOF + } + + if hrs.ra != nil { + return hrs.ra.ReadAt(p, off) + } + + if hrs.rc == nil || off != hrs.offset { + if hrs.rc != nil { + hrs.rc.Close() + hrs.rc = nil + } + rc, err := hrs.open(off) + if err != nil { + return 0, err + } + hrs.rc = rc + } + if ra, ok := hrs.rc.(io.ReaderAt); ok { + hrs.ra = ra + n, err = ra.ReadAt(p, off) + } else { + for { + var nn int + nn, err = hrs.rc.Read(p) + n += nn + p = p[nn:] + if nn == len(p) || err != nil { + break + } + } + } + + hrs.offset += int64(n) + return +} + +func (hrs *readerAtCloser) Close() error { + if hrs.closed { + return nil + } + hrs.closed = true + if hrs.rc != nil { + return hrs.rc.Close() + } + + return nil +} + +type rc struct { + io.ReaderAt + offset int +} + +func (r *rc) Read(b []byte) (int, error) { + n, err := r.ReadAt(b, int64(r.offset)) + r.offset += n + if n > 0 && err == io.EOF { + err = nil + } + return n, err +} diff --git a/cache/retry.go b/cache/retry.go new file mode 100644 index 0000000..ce93305 --- /dev/null +++ b/cache/retry.go @@ -0,0 +1,42 @@ +package cache + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +type retry struct { + transport http.RoundTripper + retry int +} + +func (t *retry) RoundTrip(req *http.Request) (*http.Response, error) { + var body []byte + if req.Body != nil { + body, _ = io.ReadAll(req.Body) + } + + for count := 0; count < t.retry; count++ { + req.Body = io.NopCloser(bytes.NewBuffer(body)) + res, err := t.transport.RoundTrip(req) + if err != nil { + return nil, err + } + if t.check(res) { + if res.Body != nil { + io.Copy(io.Discard, res.Body) + res.Body.Close() + } + continue + } + return res, err + } + + return nil, fmt.Errorf("too many retries") +} + +func (t *retry) check(res *http.Response) bool { + return res.StatusCode > 399 +} diff --git a/cmd/main.go b/cmd/main.go index 2a9c54b..ebb54e0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,6 +1,8 @@ package main -import "git.geekeey.de/actions/sdk" +import ( + "git.geekeey.de/actions/sdk" +) func main() { a := sdk.New() diff --git a/go.mod b/go.mod index 5f045b8..1348d41 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module git.geekeey.de/actions/sdk go 1.22.5 + +require golang.org/x/sync v0.7.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e8ef4a3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=