package cache import ( "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path" "strings" "sync" "golang.org/x/sync/errgroup" ) var UploadConcurrency = 4 var UploadChunkSize = 32 * 1024 * 1024 type Client struct { base string http *http.Client } type auth struct { transport http.RoundTripper token string } func (t *auth) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token)) return t.transport.RoundTrip(req) } func New(token, url string) *Client { t := &auth{transport: &retry{transport: &http.Transport{}}, token: token} return &Client{ base: url, http: &http.Client{Transport: t}, } } func (c *Client) url(p string) string { return path.Join(c.base, "_apis/artifactcache", p) } func (c *Client) version(k string) string { h := sha256.New() h.Write([]byte("|go-actionscache-1.0")) return hex.EncodeToString(h.Sum(nil)) } type ApiError struct { Message string `json:"message"` TypeName string `json:"typeName"` TypeKey string `json:"typeKey"` ErrorCode int `json:"errorCode"` } func (e ApiError) Error() string { return e.Message } func (e ApiError) Is(err error) bool { if err == os.ErrExist { if strings.Contains(e.TypeKey, "AlreadyExists") { return true } } return false } func checkApiError(res *http.Response) error { if res.StatusCode >= 200 && res.StatusCode < 300 { return nil } dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) var details ApiError if err := dec.Decode(&details); err != nil { return err } if details.Message != "" { return details } else { return fmt.Errorf("unknown error %s", res.Status) } } func (c *Client) Load(ctx context.Context, keys ...string) (*Entry, error) { u, err := url.Parse(c.url("cache")) if err != nil { return nil, err } q := u.Query() q.Set("keys", strings.Join(keys, ",")) q.Set("version", c.version(keys[0])) u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return nil, err } req.Header.Add("Accept", "application/json;api-version=6.0-preview.1") res, err := c.http.Do(req) if err != nil { return nil, err } defer res.Body.Close() err = checkApiError(res) if err != nil { return nil, err } dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) var ce Entry if err = dec.Decode(&ce); err != nil { return nil, err } ce.http = c.http return &ce, nil } func (c *Client) Save(ctx context.Context, key string, b Blob) error { id, err := c.reserve(ctx, key) if err != nil { return err } err = c.upload(ctx, id, b) if err != nil { return err } return c.commit(ctx, id, b.Size()) } type ReserveCacheReq struct { Key string `json:"key"` Version string `json:"version"` } type ReserveCacheRes struct { CacheID int `json:"cacheID"` } func (c *Client) reserve(ctx context.Context, key string) (int, error) { payload := ReserveCacheReq{Key: key, Version: c.version(key)} buf := new(bytes.Buffer) if err := json.NewEncoder(buf).Encode(payload); err != nil { return 0, err } url := c.url("caches") req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf) if err != nil { return 0, err } req.Header.Add("Content-Type", "application/json") res, err := c.http.Do(req) if err != nil { return 0, err } defer res.Body.Close() err = checkApiError(res) if err != nil { return 0, err } dec := json.NewDecoder(io.LimitReader(res.Body, 32*1024)) var cr ReserveCacheRes if err = dec.Decode(&cr); err != nil { return 0, err } if cr.CacheID == 0 { return 0, fmt.Errorf("invalid response (cache id is 0)") } return cr.CacheID, nil } type CommitCacheReq struct { Size int64 `json:"size"` } func (c *Client) commit(ctx context.Context, id int, size int64) error { payload := CommitCacheReq{Size: size} buf := new(bytes.Buffer) if err := json.NewEncoder(buf).Encode(payload); err != nil { return err } url := c.url(fmt.Sprintf("caches/%d", id)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf) if err != nil { return err } req.Header.Add("Content-Type", "application/json") res, err := c.http.Do(req) if err != nil { return err } defer res.Body.Close() err = checkApiError(res) if err != nil { return err } return nil } func (c *Client) upload(ctx context.Context, id int, b Blob) error { var mu sync.Mutex grp, ctx := errgroup.WithContext(ctx) offset := int64(0) for i := 0; i < UploadConcurrency; i++ { grp.Go(func() error { for { mu.Lock() start := offset if start >= b.Size() { mu.Unlock() return nil } end := start + int64(UploadChunkSize) if end > b.Size() { end = b.Size() } offset = end mu.Unlock() if err := c.create(ctx, id, b, start, end-start); err != nil { return err } } }) } return grp.Wait() } func (c *Client) create(ctx context.Context, id int, ra io.ReaderAt, off, n int64) error { url := c.url(fmt.Sprintf("caches/%d", id)) req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, io.NewSectionReader(ra, off, n)) if err != nil { return err } req.Header.Add("Content-Type", "application/octet-stream") req.Header.Add("Content-Range", fmt.Sprintf("bytes %d-%d/*", off, off+n-1)) res, err := c.http.Do(req) if err != nil { return err } defer res.Body.Close() err = checkApiError(res) if err != nil { return err } return nil } type Entry struct { Key string `json:"cacheKey"` Scope string `json:"scope"` URL string `json:"archiveLocation"` http *http.Client } // Download returns a ReaderAtCloser for pulling the data. Concurrent reads are not allowed func (ce *Entry) Download(ctx context.Context) ReaderAtCloser { return NewReaderAtCloser(func(offset int64) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, "GET", ce.URL, nil) if err != nil { return nil, err } if offset != 0 { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", offset)) } client := ce.http if client == nil { client = http.DefaultClient } res, err := client.Do(req) if err != nil { return nil, err } if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode == http.StatusRequestedRangeNotSatisfiable { return nil, fmt.Errorf("invalid status response %v for %s, range: %v", res.Status, ce.URL, req.Header.Get("Range")) } return nil, fmt.Errorf("invalid status response %v for %s", res.Status, ce.URL) } if offset != 0 { cr := res.Header.Get("content-range") if !strings.HasPrefix(cr, fmt.Sprintf("bytes %d-", offset)) { res.Body.Close() return nil, fmt.Errorf("unhandled content range in response: %v", cr) } } return res.Body, nil }) } func (ce *Entry) WriteTo(ctx context.Context, w io.Writer) error { rac := ce.Download(ctx) if _, err := io.Copy(w, &rc{ReaderAt: rac}); err != nil { return err } return rac.Close() }