From adf49ce4bbfd4af852b751cc5336ad010b729827 Mon Sep 17 00:00:00 2001 From: Evan Fiordeliso Date: Fri, 8 Mar 2024 13:05:54 -0500 Subject: [PATCH] feat: Add context to GetToken endpoint --- auth/callback.go | 2 +- auth/client.go | 10 ++++++---- auth/token.go | 11 +++++++++-- auth/token_source.go | 3 ++- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/auth/callback.go b/auth/callback.go index 99a7795..93142d1 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -56,7 +56,7 @@ func (c *CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { scope := q.Get("scope") _ = scope - token, err := c.client.GetToken(code) + token, err := c.client.GetToken(r.Context(), code) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/auth/client.go b/auth/client.go index 014adf3..086a18f 100644 --- a/auth/client.go +++ b/auth/client.go @@ -1,5 +1,7 @@ package auth +import "context" + type Client struct { clientId string clientSecret string @@ -21,8 +23,8 @@ func NewClient(clientId string, clientSecret string, redirectUri string) *Client // GetToken exchanges an authorization code for an access token. // // https://dev.twitch.tv/docs/authentication/getting-tokens-oidc/#oidc-authorization-code-grant-flow -func (c *Client) GetToken(code string) (*Token, error) { - return GetToken(&GetTokenParams{ +func (c *Client) GetToken(ctx context.Context, code string) (*Token, error) { + return GetToken(ctx, &GetTokenParams{ ClientId: c.clientId, ClientSecret: c.clientSecret, Code: code, @@ -34,8 +36,8 @@ func (c *Client) GetToken(code string) (*Token, error) { // RefreshToken exchanges a refresh token for an access token. // // https://dev.twitch.tv/docs/authentication/refresh-tokens/ -func (c *Client) RefreshToken(token *Token) (*Token, error) { - return GetToken(&GetTokenParams{ +func (c *Client) RefreshToken(ctx context.Context, token *Token) (*Token, error) { + return GetToken(ctx, &GetTokenParams{ ClientId: c.clientId, ClientSecret: c.clientSecret, Code: token.RefreshToken, diff --git a/auth/token.go b/auth/token.go index 4f47e77..505fa11 100644 --- a/auth/token.go +++ b/auth/token.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "fmt" "net/http" @@ -47,13 +48,19 @@ type GetTokenParams struct { RedirectUri string `url:"redirect_uri"` } -func GetToken(params *GetTokenParams) (*Token, error) { +func GetToken(ctx context.Context, params *GetTokenParams) (*Token, error) { v, err := query.Values(params) if err != nil { return nil, err } - res, err := http.Post(TokenUrl, "application/x-www-form-urlencoded", strings.NewReader(v.Encode())) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenUrl, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + res, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/auth/token_source.go b/auth/token_source.go index 9368d6c..afaaf2a 100644 --- a/auth/token_source.go +++ b/auth/token_source.go @@ -1,6 +1,7 @@ package auth import ( + "context" "sync" "golang.org/x/oauth2" @@ -28,7 +29,7 @@ func (ts *TokenSource) Token() (*oauth2.Token, error) { return ts.token.Underlying(), nil } - token, err := ts.client.RefreshToken(ts.token) + token, err := ts.client.RefreshToken(context.Background(), ts.token) if err != nil { return nil, err }