Refactor cache.HTTPClient

This commit is contained in:
Deluan
2024-05-11 19:36:33 -04:00
parent 955a9b43af
commit ec68d69d56
5 changed files with 20 additions and 20 deletions
+107
View File
@@ -0,0 +1,107 @@
package cache
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/jellydator/ttlcache/v2"
"github.com/navidrome/navidrome/log"
)
const cacheSizeLimit = 100
type HTTPClient struct {
cache *ttlcache.Cache
hc httpDoer
}
type httpDoer interface {
Do(req *http.Request) (*http.Response, error)
}
type requestData struct {
Method string
Header http.Header
URL string
Body *string
}
func NewHTTPClient(wrapped httpDoer, ttl time.Duration) *HTTPClient {
c := &HTTPClient{hc: wrapped}
c.cache = ttlcache.NewCache()
c.cache.SetCacheSizeLimit(cacheSizeLimit)
c.cache.SkipTTLExtensionOnHit(true)
c.cache.SetLoaderFunction(func(key string) (interface{}, time.Duration, error) {
req, err := c.deserializeReq(key)
if err != nil {
return nil, 0, err
}
resp, err := c.hc.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
return c.serializeResponse(resp), ttl, nil
})
c.cache.SetNewItemCallback(func(key string, value interface{}) {
log.Trace("New request cached", "req", key, "resp", value)
})
return c
}
func (c *HTTPClient) Do(req *http.Request) (*http.Response, error) {
key := c.serializeReq(req)
respStr, err := c.cache.Get(key)
if err != nil {
return nil, err
}
return c.deserializeResponse(req, respStr.(string))
}
func (c *HTTPClient) serializeReq(req *http.Request) string {
data := requestData{
Method: req.Method,
Header: req.Header,
URL: req.URL.String(),
}
if req.Body != nil {
bodyData, _ := io.ReadAll(req.Body)
bodyStr := base64.StdEncoding.EncodeToString(bodyData)
data.Body = &bodyStr
}
j, _ := json.Marshal(&data)
return string(j)
}
func (c *HTTPClient) deserializeReq(reqStr string) (*http.Request, error) {
var data requestData
_ = json.Unmarshal([]byte(reqStr), &data)
var body io.Reader
if data.Body != nil {
bodyStr, _ := base64.StdEncoding.DecodeString(*data.Body)
body = strings.NewReader(string(bodyStr))
}
req, err := http.NewRequest(data.Method, data.URL, body)
if err != nil {
return nil, err
}
req.Header = data.Header
return req, nil
}
func (c *HTTPClient) serializeResponse(resp *http.Response) string {
var b = &bytes.Buffer{}
_ = resp.Write(b)
return b.String()
}
func (c *HTTPClient) deserializeResponse(req *http.Request, respStr string) (*http.Response, error) {
r := bufio.NewReader(strings.NewReader(respStr))
return http.ReadResponse(r, req)
}
+93
View File
@@ -0,0 +1,93 @@
package cache
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"time"
"github.com/navidrome/navidrome/consts"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("HTTPClient", func() {
Context("GET", func() {
var chc *HTTPClient
var ts *httptest.Server
var requestsReceived int
var header string
BeforeEach(func() {
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsReceived++
header = r.Header.Get("head")
_, _ = fmt.Fprintf(w, "Hello, %s", r.URL.Query()["name"])
}))
chc = NewHTTPClient(http.DefaultClient, consts.DefaultHttpClientTimeOut)
})
AfterEach(func() {
defer ts.Close()
})
It("caches repeated requests", func() {
r, _ := http.NewRequest("GET", ts.URL+"?name=doe", nil)
resp, err := chc.Do(r)
Expect(err).To(BeNil())
body, err := io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(string(body)).To(Equal("Hello, [doe]"))
Expect(requestsReceived).To(Equal(1))
// Same request
r, _ = http.NewRequest("GET", ts.URL+"?name=doe", nil)
resp, err = chc.Do(r)
Expect(err).To(BeNil())
body, err = io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(string(body)).To(Equal("Hello, [doe]"))
Expect(requestsReceived).To(Equal(1))
// Different request
r, _ = http.NewRequest("GET", ts.URL, nil)
resp, err = chc.Do(r)
Expect(err).To(BeNil())
body, err = io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(string(body)).To(Equal("Hello, []"))
Expect(requestsReceived).To(Equal(2))
// Different again (same as before, but with header)
r, _ = http.NewRequest("GET", ts.URL, nil)
r.Header.Add("head", "this is a header")
resp, err = chc.Do(r)
Expect(err).To(BeNil())
body, err = io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(string(body)).To(Equal("Hello, []"))
Expect(header).To(Equal("this is a header"))
Expect(requestsReceived).To(Equal(3))
})
It("expires responses after TTL", func() {
requestsReceived = 0
chc = NewHTTPClient(http.DefaultClient, 10*time.Millisecond)
r, _ := http.NewRequest("GET", ts.URL+"?name=doe", nil)
_, err := chc.Do(r)
Expect(err).To(BeNil())
Expect(requestsReceived).To(Equal(1))
// Wait more than the TTL
time.Sleep(50 * time.Millisecond)
// Same request
r, _ = http.NewRequest("GET", ts.URL+"?name=doe", nil)
_, err = chc.Do(r)
Expect(err).To(BeNil())
Expect(requestsReceived).To(Equal(2))
})
})
})