Handle cancelled proxy downloads using context

This commit is contained in:
Evan Fiordeliso 2025-11-10 16:15:59 -05:00
parent c281da0eb0
commit 0e35f73f06
2 changed files with 62 additions and 14 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"net/url"
"strconv"
@ -14,6 +15,27 @@ import (
"go.fifitido.net/ytdl-web/pkg/ytdl/metadata"
)
func init() {
mime.AddExtensionType(".mp4", "video/mp4")
mime.AddExtensionType(".m4v", "video/x-m4v")
mime.AddExtensionType(".mkv", "video/x-matroska")
mime.AddExtensionType(".webm", "video/webm")
mime.AddExtensionType(".mov", "video/quicktime")
mime.AddExtensionType(".avi", "video/x-msvideo")
mime.AddExtensionType(".wmv", "video/x-ms-wmv")
mime.AddExtensionType(".mpg", "video/mpeg")
mime.AddExtensionType(".flv", "video/x-flv")
mime.AddExtensionType(".3gp", "video/3gpp")
mime.AddExtensionType(".m3u8", "application/x-mpegURL")
mime.AddExtensionType(".ts", "video/mp2t")
mime.AddExtensionType(".m4a", "audio/mp4")
mime.AddExtensionType(".mp3", "audio/mpeg")
mime.AddExtensionType(".aac", "audio/aac")
mime.AddExtensionType(".ogg", "audio/ogg")
mime.AddExtensionType(".wav", "audio/wav")
mime.AddExtensionType(".opus", "audio/opus")
}
func getUrlParam(r *http.Request) (string, error) {
urlRaw := r.URL.Query().Get("url")
if urlRaw == "" {
@ -38,7 +60,7 @@ func download(w http.ResponseWriter, r *http.Request) {
return
}
meta, err := ytdl.GetMetadata(videoUrl)
meta, err := ytdl.GetMetadata(r.Context(), videoUrl)
if err != nil {
views.Render(w, r, views.Home(&views.Error{Message: "Could not find a video at that url", RetryUrl: &videoUrl}))
return
@ -48,6 +70,12 @@ func download(w http.ResponseWriter, r *http.Request) {
}
func proxyDownload(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusBadRequest)
return
}
ytdl := ytdl.Default()
videoUrl, err := getUrlParam(r)
if err != nil {
@ -62,7 +90,7 @@ func proxyDownload(w http.ResponseWriter, r *http.Request) {
return
}
meta, err := ytdl.GetMetadata(videoUrl)
meta, err := ytdl.GetMetadata(r.Context(), videoUrl)
if err != nil {
slog.Error("Failed to get metadata", slog.String("error", err.Error()))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -103,6 +131,8 @@ func proxyDownload(w http.ResponseWriter, r *http.Request) {
}
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.%s\"", meta.ID, format.Ext))
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", mime.TypeByExtension("."+format.Ext))
if format.Filesize != nil {
w.Header().Set("Content-Length", fmt.Sprint(*format.Filesize))
}
@ -115,14 +145,14 @@ func proxyDownload(w http.ResponseWriter, r *http.Request) {
defer write.Close()
go func() {
_, err := io.Copy(w, read)
if err != nil {
slog.Error("Failed to copy", slog.String("error", err.Error()))
}
defer write.Close()
ytdl.Download(r.Context(), write, videoUrl, format.FormatID, index)
}()
if err := ytdl.Download(write, videoUrl, format.FormatID, index); err != nil {
slog.Error("Failed to download", slog.String("error", err.Error()))
if _, err := io.Copy(w, read); err != nil {
slog.Error("Failed to copy", slog.String("error", err.Error()))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
flusher.Flush()
}

View File

@ -2,6 +2,7 @@ package ytdl
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -16,8 +17,8 @@ import (
)
type Ytdl interface {
GetMetadata(url string) (*metadata.Metadata, error)
Download(w io.Writer, url, format string, index int) error
GetMetadata(ctx context.Context, url string) (*metadata.Metadata, error)
Download(ctx context.Context, w io.Writer, url, format string, index int) error
Version() string
}
@ -100,7 +101,7 @@ func (y *ytdlImpl) Version() string {
}
// GetMetadata implements Ytdl
func (y *ytdlImpl) GetMetadata(url string) (*metadata.Metadata, error) {
func (y *ytdlImpl) GetMetadata(ctx context.Context, url string) (*metadata.Metadata, error) {
meta, err := y.cache.Get(url)
if err == nil {
return meta, nil
@ -115,7 +116,7 @@ func (y *ytdlImpl) GetMetadata(url string) (*metadata.Metadata, error) {
fmt.Printf("ytdlp args: %#v\n", args)
cmd := exec.Command(y.cfg.BinaryPath, args...)
cmd := exec.CommandContext(ctx, y.cfg.BinaryPath, args...)
out, err := cmd.Output()
if err != nil {
@ -150,7 +151,7 @@ func (y *ytdlImpl) GetMetadata(url string) (*metadata.Metadata, error) {
}
// Download implements Ytdl
func (y *ytdlImpl) Download(w io.Writer, url, format string, index int) error {
func (y *ytdlImpl) Download(ctx context.Context, w io.Writer, url, format string, index int) error {
args := []string{
url,
"--format", format,
@ -168,7 +169,7 @@ func (y *ytdlImpl) Download(w io.Writer, url, format string, index int) error {
args = append(args, "--load-info-json", "-")
}
cmd := exec.Command(y.cfg.BinaryPath, args...)
cmd := exec.CommandContext(ctx, y.cfg.BinaryPath, args...)
cmd.Stdout = w
if err == nil {
@ -181,6 +182,23 @@ func (y *ytdlImpl) Download(w io.Writer, url, format string, index int) error {
}
if err := cmd.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
if exitErr.ExitCode() == -1 {
// Handle the case where the process was terminated by a signal
return nil
}
attrs := []any{
slog.Int("code", exitErr.ExitCode()),
slog.String("stderr", string(exitErr.Stderr)),
slog.String("error", exitErr.Error()),
}
y.logger.Error("failed to download", attrs...)
return err
}
y.logger.Error("failed to download", slog.String("url", url), slog.String("error", err.Error()))
return err
}