diff --git a/api/_responses/errors.go b/api/_responses/errors.go index 1cee502..c075815 100644 --- a/api/_responses/errors.go +++ b/api/_responses/errors.go @@ -55,3 +55,11 @@ func BadRequest(message string) *ErrorResponse { func QuotaExceeded() *ErrorResponse { return &ErrorResponse{common.ErrCodeForbidden, "Quota Exceeded", common.ErrCodeQuotaExceeded} } + +func CannotOverwriteMedia() *ErrorResponse { + return &ErrorResponse{common.ErrCodeCannotOverwriteMedia, "Cannot overwrite media", common.ErrCodeCannotOverwriteMedia} +} + +func NotYetUploaded() *ErrorResponse { + return &ErrorResponse{common.ErrCodeNotYetUploaded, "Media not yet uploaded", common.ErrCodeNotYetUploaded} +} diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index 63fc7d0..a5843aa 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -166,7 +166,7 @@ beforeParseDownload: case common.ErrCodeUnknownToken: proposedStatusCode = http.StatusUnauthorized break - case common.ErrCodeNotFound: + case common.ErrCodeNotYetUploaded, common.ErrCodeNotFound: proposedStatusCode = http.StatusNotFound break case common.ErrCodeMediaTooLarge: @@ -181,6 +181,9 @@ beforeParseDownload: case common.ErrCodeForbidden: proposedStatusCode = http.StatusForbidden break + case common.ErrCodeCannotOverwriteMedia: + proposedStatusCode = http.StatusConflict + break default: // Treat as unknown (a generic server error) proposedStatusCode = http.StatusInternalServerError break diff --git a/api/r0/download.go b/api/r0/download.go index 83b9232..8e0dd1a 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -46,6 +46,21 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. downloadRemote = parsedFlag } + var asyncWaitMs *int = nil + if rctx.Config.Features.MSC2246Async.Enabled { + // request default wait time if feature enabled + var parsedInt int = -1 + maxStallMs := r.URL.Query().Get("fi.mau.msc2246.max_stall_ms") + if maxStallMs != "" { + var err error + parsedInt, err = strconv.Atoi(maxStallMs) + if err != nil { + return _responses.InternalServerError("fi.mau.msc2246.max_stall_ms does not appear to be a number") + } + } + asyncWaitMs = &parsedInt + } + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, @@ -58,7 +73,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. return _responses.MediaBlocked() } - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, rctx) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, asyncWaitMs, rctx) if err != nil { if err == common.ErrMediaNotFound { return _responses.NotFoundError() @@ -66,6 +81,8 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. return _responses.RequestTooLarge() } else if err == common.ErrMediaQuarantined { return _responses.NotFoundError() // We lie for security + } else if err == common.ErrNotYetUploaded { + return _responses.NotYetUploaded() } rctx.Log.Error("Unexpected error locating media: " + err.Error()) sentry.CaptureException(err) diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 8669431..93af427 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -34,6 +34,21 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta downloadRemote = parsedFlag } + var asyncWaitMs *int = nil + if rctx.Config.Features.MSC2246Async.Enabled { + // request default wait time if feature enabled + var parsedInt int = -1 + maxStallMs := r.URL.Query().Get("fi.mau.msc2246.max_stall_ms") + if maxStallMs != "" { + var err error + parsedInt, err = strconv.Atoi(maxStallMs) + if err != nil { + return _responses.InternalServerError("fi.mau.msc2246.max_stall_ms does not appear to be a number") + } + } + asyncWaitMs = &parsedInt + } + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, @@ -97,12 +112,14 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta return _responses.BadRequest("Width and height must be greater than zero") } - streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, rctx) + streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, asyncWaitMs, rctx) if err != nil { if err == common.ErrMediaNotFound { return _responses.NotFoundError() } else if err == common.ErrMediaTooLarge { return _responses.RequestTooLarge() + } else if err == common.ErrNotYetUploaded { + return _responses.NotYetUploaded() } rctx.Log.Error("Unexpected error locating media: " + err.Error()) sentry.CaptureException(err) diff --git a/api/r0/upload.go b/api/r0/upload.go index 245b15e..a933713 100644 --- a/api/r0/upload.go +++ b/api/r0/upload.go @@ -4,11 +4,13 @@ import ( "github.com/getsentry/sentry-go" "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/_routers" "github.com/turt2live/matrix-media-repo/util/stream_util" "io" "net/http" "path/filepath" + "time" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" @@ -16,6 +18,7 @@ import ( "github.com/turt2live/matrix-media-repo/controllers/info_controller" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" "github.com/turt2live/matrix-media-repo/quota" + "github.com/turt2live/matrix-media-repo/util" ) type MediaUploadedResponse struct { @@ -23,14 +26,51 @@ type MediaUploadedResponse struct { Blurhash string `json:"xyz.amorgan.blurhash,omitempty"` } +type MediaCreatedResponse struct { + ContentUri string `json:"content_uri"` + UnusedExpiresAt int64 `json:"unused_expires_at"` +} + +func CreateMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + media, _, err := upload_controller.CreateMedia(r.Host, rctx) + if err != nil { + rctx.Log.Error("Unexpected error creating media reference: " + err.Error()) + return _responses.InternalServerError("Unexpected Error") + } + + if err = upload_controller.PersistMedia(media, user.UserId, rctx); err != nil { + rctx.Log.Error("Unexpected error persisting media reference: " + err.Error()) + return _responses.InternalServerError("Unexpected Error") + } + + return &MediaCreatedResponse{ + ContentUri: media.MxcUri(), + UnusedExpiresAt: time.Now().Unix() + int64(rctx.Config.Features.MSC2246Async.AsyncUploadExpirySecs), + } +} + func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + var server = "" + var mediaId = "" + filename := filepath.Base(r.URL.Query().Get("filename")) defer stream_util.DumpAndCloseStream(r.Body) + if rctx.Config.Features.MSC2246Async.Enabled { + server = _routers.GetParam("server", r) + mediaId = _routers.GetParam("mediaId", r) + } + rctx = rctx.LogWithFields(logrus.Fields{ + "server": server, + "mediaId": mediaId, "filename": filename, }) + if server != "" && (!util.IsServerOurs(server) || server != r.Host) { + return _responses.NotFoundError() + } + contentType := r.Header.Get("Content-Type") if contentType == "" { contentType = "application/octet-stream" // binary @@ -60,13 +100,17 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length")) - media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, rctx) + media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, mediaId, rctx) if err != nil { io.Copy(io.Discard, r.Body) // Ditch the entire request if err == common.ErrMediaQuarantined { return _responses.BadRequest("This file is not permitted on this server") - } + } else if err == common.ErrCannotOverwriteMedia { + return _responses.CannotOverwriteMedia() + } else if err == common.ErrMediaNotFound { + return _responses.NotFoundError() + } rctx.Log.Error("Unexpected error storing media: " + err.Error()) sentry.CaptureException(err) diff --git a/api/routes.go b/api/routes.go index ff6378d..4e1c7e3 100644 --- a/api/routes.go +++ b/api/routes.go @@ -13,6 +13,7 @@ import ( "github.com/turt2live/matrix-media-repo/api/custom" "github.com/turt2live/matrix-media-repo/api/r0" "github.com/turt2live/matrix-media-repo/api/unstable" + "github.com/turt2live/matrix-media-repo/common/config" ) const PrefixMedia = "/_matrix/media" @@ -29,7 +30,8 @@ func buildRoutes() http.Handler { } // Standard (spec) features - register([]string{"POST"}, PrefixMedia, "upload", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMedia), "upload", false, counter)) + uploadRoute := makeRoute(_routers.RequireAccessToken(r0.UploadMedia), "upload", false, counter) + register([]string{"POST"}, PrefixMedia, "upload", false, router, uploadRoute) downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", false, counter) register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", false, router, downloadRoute) register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", false, router, downloadRoute) @@ -52,6 +54,12 @@ func buildRoutes() http.Handler { router.Handler("GET", "/healthz", healthzRoute) router.Handler("HEAD", "/healthz", healthzRoute) + if config.Get().Features.MSC2246Async.Enabled { + logrus.Info("Asynchronous uploads (MSC2246) enabled") + register([]string{"POST"}, PrefixMedia, "create", true, router, makeRoute(_routers.RequireAccessToken(r0.CreateMedia), "create", false, counter)) + register([]string{"PUT"}, PrefixMedia, "upload/{server:[a-zA-Z0-9.:\\-_]+}/{mediaId:[^/]+}", true, router, uploadRoute) + } + // All admin routes are unstable only purgeRemoteRoute := makeRoute(_routers.RequireRepoAdmin(custom.PurgeRemoteMedia), "purge_remote_media", false, counter) register([]string{"POST"}, PrefixMedia, "admin/purge_remote", true, router, purgeRemoteRoute) @@ -106,7 +114,7 @@ func makeRoute(generator _routers.GeneratorFn, name string, ignoreHost bool, cou )) } -var versions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media"} +var versions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media", "unstable/fi.mamu.msc2246"} func register(methods []string, prefix string, postfix string, unstableOnly bool, router *httprouter.Router, handler http.Handler) { for _, method := range methods { diff --git a/api/unstable/info.go b/api/unstable/info.go index 8ea40d5..3e5edf3 100644 --- a/api/unstable/info.go +++ b/api/unstable/info.go @@ -79,7 +79,7 @@ func MediaInfo(r *http.Request, rctx rcontext.RequestContext, user _apimeta.User return _responses.MediaBlocked() } - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, nil, rctx) if err != nil { if err == common.ErrMediaNotFound { return _responses.NotFoundError() diff --git a/api/unstable/local_copy.go b/api/unstable/local_copy.go index fa5a1f0..98a21f2 100644 --- a/api/unstable/local_copy.go +++ b/api/unstable/local_copy.go @@ -50,7 +50,7 @@ func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user _apimeta.User // TODO: There's a lot of room for improvement here. Instead of re-uploading media, we should just update the DB. - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, nil, rctx) if err != nil { if err == common.ErrMediaNotFound { return _responses.NotFoundError() @@ -70,7 +70,7 @@ func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user _apimeta.User return &r0.MediaUploadedResponse{ContentUri: streamedMedia.KnownMedia.MxcUri()} } - newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, rctx) + newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, "", rctx) if err != nil { rctx.Log.Error("Unexpected error storing media: " + err.Error()) sentry.CaptureException(err) diff --git a/common/config/conf_min_shared.go b/common/config/conf_min_shared.go index 09eb821..dec051e 100644 --- a/common/config/conf_min_shared.go +++ b/common/config/conf_min_shared.go @@ -53,6 +53,11 @@ func NewDefaultMinimumRepoConfig() MinimumRepoConfig { YComponents: 3, Punch: 1, }, + MSC2246Async: MSC2246Config{ + Enabled: false, + AsyncUploadExpirySecs: 60, + AsyncDownloadDefaultWaitSecs: 20, + }, }, AccessTokens: AccessTokenConfig{ MaxCacheTimeSeconds: 0, diff --git a/common/config/models_domain.go b/common/config/models_domain.go index 0c3fc55..6060763 100644 --- a/common/config/models_domain.go +++ b/common/config/models_domain.go @@ -88,6 +88,7 @@ type TimeoutsConfig struct { type FeatureConfig struct { MSC2448Blurhash MSC2448Config `yaml:"MSC2448"` + MSC2246Async MSC2246Config `yaml:"MSC2246"` Redis RedisConfig `yaml:"redis"` } @@ -102,6 +103,12 @@ type MSC2448Config struct { Punch int `yaml:"punch"` } +type MSC2246Config struct { + Enabled bool `yaml:"enabled"` + AsyncUploadExpirySecs int `yaml:"asyncUploadExpirySecs"` + AsyncDownloadDefaultWaitSecs int `yaml:"asyncDownloadDefaultWaitSecs"` +} + type AccessTokenConfig struct { MaxCacheTimeSeconds int `yaml:"maxCacheTimeSeconds"` UseAppservices bool `yaml:"useLocalAppserviceConfig"` diff --git a/common/errorcodes.go b/common/errorcodes.go index 599b32c..b9a3274 100644 --- a/common/errorcodes.go +++ b/common/errorcodes.go @@ -16,3 +16,5 @@ const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED" const ErrCodeUnknown = "M_UNKNOWN" const ErrCodeForbidden = "M_FORBIDDEN" const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED" +const ErrCodeCannotOverwriteMedia = "FI.MAU.MSC2246_CANNOT_OVEWRITE_MEDIA" +const ErrCodeNotYetUploaded = "FI.MAU.MSC2246_NOT_YET_UPLOADED" diff --git a/common/errors.go b/common/errors.go index 5fc0e69..09f5310 100644 --- a/common/errors.go +++ b/common/errors.go @@ -10,3 +10,5 @@ var ErrInvalidHost = errors.New("invalid host") var ErrHostNotFound = errors.New("host not found") var ErrHostBlacklisted = errors.New("host not allowed") var ErrMediaQuarantined = errors.New("media quarantined") +var ErrCannotOverwriteMedia = errors.New("cannot overwrite media") +var ErrNotYetUploaded = errors.New("not yet uploaded") diff --git a/config.sample.yaml b/config.sample.yaml index e3570bb..0b003ec 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -542,6 +542,20 @@ featureSupport: # make the effect more subtle, larger values make it stronger. punch: 1 + # MSC2246 - Asynchronous uploads + MSC2246: + # Whether or not this MSC is enabled for use in the media repo + enabled: false + + # The number of seconds an asynchronous upload is valid to be started after requesting a media + # id. After expiring the upload endpoint will return an error for the client. + asyncUploadExpirySecs: 60 + + # The number of seconds a download request for an asynchronous upload will stall before + # returning an error. This affects clients that do not support async uploads by making them + # wait by default. Setting to zero will disable this behavior unless the client requests it. + asyncDownloadDefaultWaitSecs: 20 + # Support for redis as a cache mechanism # # Note: Enabling Redis support will mean that the existing cache mechanism will do nothing. diff --git a/controllers/download_controller/download_controller.go b/controllers/download_controller/download_controller.go index 5519d36..437ff1a 100644 --- a/controllers/download_controller/download_controller.go +++ b/controllers/download_controller/download_controller.go @@ -26,14 +26,14 @@ import ( var localCache = cache.New(30*time.Second, 60*time.Second) -func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { +func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia) v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { var media *types.Media var minMedia *types.MinimalMedia var err error if blockForMedia { - media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx) + media, err = FindMediaRecord(origin, mediaId, downloadRemote, asyncWaitMs, ctx) if media != nil { minMedia = &types.MinimalMedia{ Origin: media.Origin, @@ -46,7 +46,7 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia } } } else { - minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx) + minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, asyncWaitMs, ctx) if minMedia != nil { media = minMedia.KnownMedia } @@ -159,7 +159,54 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia return value, err } -func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { +func waitForUpload(media *types.Media, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.Media, error) { + if media == nil { + return nil, errors.New("waited for nil media") + } + + if util.IsServerOurs(media.Origin) && media.SizeBytes == 0 { + // we're not allowed to wait by requester + if asyncWaitMs == nil { + return nil, common.ErrMediaNotFound + } + + waitMs := *asyncWaitMs + + // max wait one minute + if waitMs > 60_000 { + waitMs = 60_000 + } + + // use default wait if negative + if waitMs < 0 { + waitMs = ctx.Config.Features.MSC2246Async.AsyncDownloadDefaultWaitSecs * 1000 + + // if the default is zero and client didn't request any then waiting is disabled + if waitMs == 0 { + return nil, common.ErrMediaNotFound + } + } + + // if the upload did not complete in 6 hours, consider it never will + if util.NowMillis()-media.CreationTs > 3600*6*1000 { + ctx.Log.Info("Tried to download expired asynchronous upload") + return nil, common.ErrMediaNotFound + } + + ctx.Log.Info("Asynchronous upload not complete, waiting") + if ok := util.WaitForUpload(media.Origin, media.MediaId, time.Millisecond*time.Duration(waitMs)); !ok { + return nil, common.ErrNotYetUploaded + } + + // fetch the entry from database again after we're notified it should be complete + db := storage.GetDatabase().GetMediaStore(ctx) + return db.Get(media.Origin, media.MediaId) + } + + return media, nil +} + +func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { db := storage.GetDatabase().GetMediaStore(ctx) var media *types.Media @@ -217,7 +264,10 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, KnownMedia: nil, // unknown }, nil } else { - media = dbMedia + media, err = waitForUpload(dbMedia, asyncWaitMs, ctx) + if err != nil { + return nil, err + } } } @@ -252,7 +302,7 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, }, nil } -func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx rcontext.RequestContext) (*types.Media, error) { +func FindMediaRecord(origin string, mediaId string, downloadRemote bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.Media, error) { cacheKey := origin + "/" + mediaId v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) { db := storage.GetDatabase().GetMediaStore(ctx) @@ -290,7 +340,10 @@ func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx rco } media = result.media } else { - media = dbMedia + media, err = waitForUpload(dbMedia, asyncWaitMs, ctx) + if err != nil { + return nil, err + } } } diff --git a/controllers/info_controller/info_controller.go b/controllers/info_controller/info_controller.go index 98eb096..cb6a0dc 100644 --- a/controllers/info_controller/info_controller.go +++ b/controllers/info_controller/info_controller.go @@ -27,7 +27,7 @@ func GetOrCalculateBlurhash(media *types.Media, rctx rcontext.RequestContext) (s } rctx.Log.Info("Getting minimal media record to calculate blurhash") - minMedia, err := download_controller.FindMinimalMediaRecord(media.Origin, media.MediaId, true, rctx) + minMedia, err := download_controller.FindMinimalMediaRecord(media.Origin, media.MediaId, true, nil, rctx) if err != nil { return "", err } diff --git a/controllers/maintenance_controller/maintainance_controller.go b/controllers/maintenance_controller/maintainance_controller.go index 1df60a3..dfdd1f8 100644 --- a/controllers/maintenance_controller/maintainance_controller.go +++ b/controllers/maintenance_controller/maintainance_controller.go @@ -387,7 +387,7 @@ func PurgeDomainMedia(serverName string, beforeTs int64, ctx rcontext.RequestCon } func PurgeMedia(origin string, mediaId string, ctx rcontext.RequestContext) error { - media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx) + media, err := download_controller.FindMediaRecord(origin, mediaId, false, nil, ctx) if err != nil { return err } diff --git a/controllers/preview_controller/preview_resource_handler.go b/controllers/preview_controller/preview_resource_handler.go index 3230caa..20556c8 100644 --- a/controllers/preview_controller/preview_resource_handler.go +++ b/controllers/preview_controller/preview_resource_handler.go @@ -134,7 +134,7 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) (resp *urlPreviewRe contentLength := upload_controller.EstimateContentLength(preview.Image.ContentLength, preview.Image.ContentLengthHeader) // UploadMedia will close the read stream for the thumbnail and dedupe the image - media, err := upload_controller.UploadMedia(preview.Image.Data, contentLength, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, ctx) + media, err := upload_controller.UploadMedia(preview.Image.Data, contentLength, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, "", ctx) if err != nil { ctx.Log.Warn("Non-fatal error storing preview thumbnail: " + err.Error()) sentry.CaptureException(err) diff --git a/controllers/thumbnail_controller/thumbnail_controller.go b/controllers/thumbnail_controller/thumbnail_controller.go index 12d9cf9..0c21bd2 100644 --- a/controllers/thumbnail_controller/thumbnail_controller.go +++ b/controllers/thumbnail_controller/thumbnail_controller.go @@ -28,8 +28,8 @@ import ( var localCache = cache.New(30*time.Second, 60*time.Second) -func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, ctx rcontext.RequestContext) (*types.StreamedThumbnail, error) { - media, err := download_controller.FindMediaRecord(origin, mediaId, downloadRemote, ctx) +func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.StreamedThumbnail, error) { + media, err := download_controller.FindMediaRecord(origin, mediaId, downloadRemote, asyncWaitMs, ctx) if err != nil { return nil, err } diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index 99c8c58..0d6d54b 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -1,12 +1,12 @@ package upload_controller import ( + "database/sql" "io" "strconv" "time" "github.com/getsentry/sentry-go" - "github.com/turt2live/matrix-media-repo/util/ids" "github.com/turt2live/matrix-media-repo/util/stream_util" "github.com/patrickmn/go-cache" @@ -92,7 +92,63 @@ func EstimateContentLength(contentLength int64, contentLengthHeader string) int6 return -1 // unknown } -func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string, filename string, userId string, origin string, ctx rcontext.RequestContext) (*types.Media, error) { +func CreateMedia(origin string, ctx rcontext.RequestContext) (*types.Media, *datastore.DatastoreRef, error) { + metadataDb := storage.GetDatabase().GetMetadataStore(ctx) + + mediaTaken := true + var mediaId string + var err error + attempts := 0 + for mediaTaken { + attempts += 1 + if attempts > 10 { + return nil, nil, errors.New("failed to generate a media ID after 10 rounds") + } + + mediaId, err = util.GenerateRandomString(64) + if err != nil { + return nil, nil, err + } + mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10)) + if err != nil { + return nil, nil, err + } + + // Because we use the current time in the media ID, we don't need to worry about + // collisions from the database. + if _, present := recentMediaIds.Get(mediaId); present { + mediaTaken = true + continue + } + + mediaTaken, err = metadataDb.IsReserved(origin, mediaId) + if err != nil { + return nil, nil, err + } + } + + _ = recentMediaIds.Add(mediaId, true, cache.DefaultExpiration) + + ds, err := datastore.PickDatastore(common.KindLocalMedia, ctx) + if err != nil { + return nil, nil, err + } + + return &types.Media{MediaId: mediaId, Origin: origin, DatastoreId: ds.DatastoreId}, ds, nil +} + +func PersistMedia(media *types.Media, userId string, ctx rcontext.RequestContext) error { + db := storage.GetDatabase().GetMediaStore(ctx) + + ctx.Log.Info("Persisting async media record") + + media.UserId = userId + media.CreationTs = util.NowMillis() + + return db.Insert(media) +} + +func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string, filename string, userId string, origin string, asyncMediaId string, ctx rcontext.RequestContext) (*types.Media, error) { defer stream_util.DumpAndCloseStream(contents) var data io.ReadCloser @@ -107,46 +163,61 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string return nil, err } - metadataDb := storage.GetDatabase().GetMetadataStore(ctx) - - mediaTaken := true var mediaId string - attempts := 0 - for mediaTaken { - attempts += 1 - if attempts > 10 { - return nil, errors.New("failed to generate a media ID after 10 rounds") + if asyncMediaId == "" { + media, _, err := CreateMedia(origin, ctx) + if err != nil { + return nil, err } - mediaId, err = ids.NewUniqueId() + mediaId = media.MediaId + } else { + db := storage.GetDatabase().GetMediaStore(ctx) + + media, err := db.Get(origin, asyncMediaId) if err != nil { return nil, err } - // Because we use the current time in the media ID, we don't need to worry about - // collisions from the database. - if _, present := recentMediaIds.Get(mediaId); present { - mediaTaken = true - continue + if media == nil { + return nil, common.ErrMediaNotFound } - mediaTaken, err = metadataDb.IsReserved(origin, mediaId) + if media.UserId != userId { + return nil, common.ErrMediaNotFound + } + + if media.SizeBytes > 0 { + return nil, common.ErrCannotOverwriteMedia + } + + if util.NowMillis()-media.CreationTs > int64(ctx.Config.Features.MSC2246Async.AsyncUploadExpirySecs*1000) { + return nil, common.ErrMediaNotFound + } + + mediaId = asyncMediaId + _, err = datastore.LocateDatastore(ctx, media.DatastoreId) if err != nil { return nil, err } } - _ = recentMediaIds.Add(mediaId, true, cache.DefaultExpiration) - - m, err := StoreDirect(nil, util_byte_seeker.NewByteSeeker(dataBytes), contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx, true) + m, err := StoreDirect(nil, util_byte_seeker.NewByteSeeker(dataBytes), contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx, asyncMediaId == "") if err != nil { return m, err } if m != nil { - err = internal_cache.Get().UploadMedia(m.Sha256Hash, util_byte_seeker.NewByteSeeker(dataBytes), ctx) - if err != nil { + util.NotifyUpload(origin, mediaId) + + cache := internal_cache.Get() + if err := cache.UploadMedia(m.Sha256Hash, util_byte_seeker.NewByteSeeker(dataBytes), ctx); err != nil { ctx.Log.Warn("Unexpected error trying to cache media: " + err.Error()) } + if asyncMediaId != "" { + if err := cache.NotifyUpload(origin, mediaId, ctx); err != nil { + ctx.Log.Warn("Unexpected error trying to notify cache about media: " + err.Error()) + } + } } return m, err } @@ -171,8 +242,7 @@ func checkSpam(contents []byte, filename string, contentType string, userId stri return nil } -func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize int64, contentType string, filename string, userId string, origin string, mediaId string, kind string, ctx rcontext.RequestContext, filterUserDuplicates bool) (*types.Media, error) { - var err error +func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize int64, contentType string, filename string, userId string, origin string, mediaId string, kind string, ctx rcontext.RequestContext, filterUserDuplicates bool) (ret *types.Media, err error) { var ds *datastore.DatastoreRef var info *types.ObjectInfo var contentBytes []byte @@ -207,15 +277,16 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in return nil, err } } + defer func() { + // always delete temp object if we return an error + if err != nil { + ds.DeleteObject(info.Location) + } + }() db := storage.GetDatabase().GetMediaStore(ctx) records, err := db.GetByHash(info.Sha256Hash) if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } return nil, err } @@ -248,22 +319,12 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in err = checkSpam(contentBytes, filename, contentType, userId, origin, mediaId) if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } return nil, err } // We'll use the location from the first record record := records[0] if record.Quarantined { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } ctx.Log.Warn("User attempted to upload quarantined content - rejecting") return nil, common.ErrMediaQuarantined } @@ -282,21 +343,37 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in } } - media := record - media.Origin = origin - media.MediaId = mediaId - media.UserId = userId - media.UploadName = filename - media.ContentType = contentType - media.CreationTs = util.NowMillis() + // Check if we have reserved the metadata already + media, err := db.Get(origin, mediaId) + if err == sql.ErrNoRows { + media = record + media.Origin = origin + media.MediaId = mediaId + media.UserId = userId + media.UploadName = filename + media.ContentType = contentType + media.CreationTs = util.NowMillis() + + if err = db.Insert(media); err != nil { + return nil, err + } + } else if err == nil { + // last minute check if the file was already uploaded + if media.SizeBytes > 0 { + return nil, common.ErrCannotOverwriteMedia + } - err = db.Insert(media) - if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) + media.UploadName = filename + media.ContentType = contentType + media.Sha256Hash = info.Sha256Hash + media.SizeBytes = info.SizeBytes + media.DatastoreId = ds.DatastoreId + media.Location = info.Location + + if err = db.Update(media); err != nil { + return nil, err } + } else { return nil, err } @@ -305,11 +382,6 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in if media.DatastoreId != ds.DatastoreId && media.Location != info.Location { ds2, err := datastore.LocateDatastore(ctx, media.DatastoreId) if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } return nil, err } if !ds2.ObjectExists(media.Location) { @@ -344,46 +416,54 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in // The media doesn't already exist - save it as new if info.SizeBytes <= 0 { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } return nil, errors.New("file has no contents") } err = checkSpam(contentBytes, filename, contentType, userId, origin, mediaId) if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) - } return nil, err } - ctx.Log.Info("Persisting new media record") - - media := &types.Media{ - Origin: origin, - MediaId: mediaId, - UploadName: filename, - ContentType: contentType, - UserId: userId, - Sha256Hash: info.Sha256Hash, - SizeBytes: info.SizeBytes, - DatastoreId: ds.DatastoreId, - Location: info.Location, - CreationTs: util.NowMillis(), - } + // Check if we have reserved the metadata already, validate uploader + media, err := db.Get(origin, mediaId) + if err == sql.ErrNoRows { + ctx.Log.Info("Persisting new media record") + + media = &types.Media{ + Origin: origin, + MediaId: mediaId, + UploadName: filename, + ContentType: contentType, + UserId: userId, + Sha256Hash: info.Sha256Hash, + SizeBytes: info.SizeBytes, + DatastoreId: ds.DatastoreId, + Location: info.Location, + CreationTs: util.NowMillis(), + } - err = db.Insert(media) - if err != nil { - err2 := ds.DeleteObject(info.Location) // delete temp object - if err2 != nil { - ctx.Log.Warn("Error deleting temporary upload", err2) - sentry.CaptureException(err2) + if err = db.Insert(media); err != nil { + return nil, err } + } else if err == nil { + ctx.Log.Info("Updating existing media record") + + // last minute check if the file was already uploaded + if media.SizeBytes > 0 { + return nil, common.ErrCannotOverwriteMedia + } + + media.UploadName = filename + media.ContentType = contentType + media.Sha256Hash = info.Sha256Hash + media.SizeBytes = info.SizeBytes + media.DatastoreId = ds.DatastoreId + media.Location = info.Location + + if err = db.Update(media); err != nil { + return nil, err + } + } else { return nil, err } diff --git a/internal_cache/cache.go b/internal_cache/cache.go index b871a07..0b0a5af 100644 --- a/internal_cache/cache.go +++ b/internal_cache/cache.go @@ -18,4 +18,5 @@ type ContentCache interface { MarkDownload(fileHash string) GetMedia(sha256hash string, contents FetchFunction, ctx rcontext.RequestContext) (*CachedContent, error) UploadMedia(sha256hash string, content io.ReadCloser, ctx rcontext.RequestContext) error + NotifyUpload(origin string, mediaId string, ctx rcontext.RequestContext) error } diff --git a/internal_cache/noop.go b/internal_cache/noop.go index 8a4f936..9d4a264 100644 --- a/internal_cache/noop.go +++ b/internal_cache/noop.go @@ -35,3 +35,8 @@ func (n *NoopCache) UploadMedia(sha256hash string, content io.ReadCloser, ctx rc // do nothing return nil } + +func (n *NoopCache) NotifyUpload(origin string, mediaId string, ctx rcontext.RequestContext) error { + // do nothing + return nil +} diff --git a/internal_cache/redis.go b/internal_cache/redis.go index c83a863..9e3199a 100644 --- a/internal_cache/redis.go +++ b/internal_cache/redis.go @@ -66,3 +66,7 @@ func (c *RedisCache) UploadMedia(sha256hash string, content io.ReadCloser, ctx r defer content.Close() return c.redis.SetStream(ctx, sha256hash, content) } + +func (c *RedisCache) NotifyUpload(origin string, mediaId string, ctx rcontext.RequestContext) error { + return c.redis.NotifyUpload(ctx, origin, mediaId) +} diff --git a/redis_cache/redis.go b/redis_cache/redis.go index dadbb87..8a76e47 100644 --- a/redis_cache/redis.go +++ b/redis_cache/redis.go @@ -11,6 +11,8 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/types" + "github.com/turt2live/matrix-media-repo/util" ) var ErrCacheMiss = errors.New("missed cache") @@ -31,14 +33,45 @@ func NewCache(conf config.RedisConfig) *RedisCache { DB: conf.DbNum, }) + ctx := context.Background() + logrus.Info("Contacting Redis shards...") - _ = ring.ForEachShard(context.Background(), func(ctx context.Context, client *redis.Client) error { + _ = ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { logrus.Infof("Pinging %s", client.String()) r, err := client.Ping(ctx).Result() if err != nil { return err } logrus.Infof("%s replied with: %s", client.String(), r) + + psub := client.Subscribe(ctx, "upload") + go func() { + logrus.Infof("Client %s going to subscribe to uploads", client.String()) + for { + for { + msg, err := psub.ReceiveMessage(ctx) + if err != nil { + break + } + + ref := types.MediaRef{} + if err := ref.UnmarshalBinary([]byte(msg.Payload)); err != nil { + logrus.Warn("Failed to unmarshal published upload, ignoring") + continue + } + + logrus.Infof("Client %s notified about %s/%s being uploaded", client.String(), ref.Origin, ref.MediaId) + util.NotifyUpload(ref.Origin, ref.MediaId) + } + + if ctx.Done() != nil { + return + } + + time.Sleep(time.Second * 1) + } + }() + return nil }) @@ -96,3 +129,21 @@ func (c *RedisCache) GetBytes(ctx rcontext.RequestContext, key string) ([]byte, b, err := r.Bytes() return b, err } + +func (c *RedisCache) NotifyUpload(ctx rcontext.RequestContext, origin string, mediaId string) error { + if c.ring.PoolStats().TotalConns == 0 { + return ErrCacheDown + } + r := c.ring.Publish(ctx, "upload", types.MediaRef{Origin: origin, MediaId: mediaId}) + if r.Err() != nil { + if r.Err() == redis.Nil { + return ErrCacheMiss + } + if c.ring.PoolStats().TotalConns == 0 { + ctx.Log.Error(r.Err()) + return ErrCacheDown + } + return r.Err() + } + return nil +} diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index 958f1b0..697c58d 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -16,6 +16,7 @@ import ( const selectMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 and media_id = $2;" const selectMediaByHash = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE sha256_hash = $1;" const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);" +const updateMedia = "UPDATE media SET upload_name = $3, content_type = $4, sha256_hash = $5, size_bytes = $6, datastore_id = $7, location = $8 WHERE origin = $1 AND media_id = $2;" const selectOldMedia = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, quarantined FROM media AS m WHERE m.origin <> ANY($1) AND m.creation_ts < $2 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;" const selectOrigins = "SELECT DISTINCT origin FROM media;" const deleteMedia = "DELETE FROM media WHERE origin = $1 AND media_id = $2;" @@ -47,6 +48,7 @@ type mediaStoreStatements struct { selectMedia *sql.Stmt selectMediaByHash *sql.Stmt insertMedia *sql.Stmt + updateMedia *sql.Stmt selectOldMedia *sql.Stmt selectOrigins *sql.Stmt deleteMedia *sql.Stmt @@ -97,6 +99,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { if store.stmts.insertMedia, err = store.sqlDb.Prepare(insertMedia); err != nil { return nil, err } + if store.stmts.updateMedia, err = store.sqlDb.Prepare(updateMedia); err != nil { + return nil, err + } if store.stmts.selectOldMedia, err = store.sqlDb.Prepare(selectOldMedia); err != nil { return nil, err } @@ -190,6 +195,22 @@ func (s *MediaStore) Insert(media *types.Media) error { return err } +func (s *MediaStore) Update(media *types.Media) error { + _, err := s.statements.updateMedia.ExecContext( + s.ctx, + media.Origin, + media.MediaId, + media.UploadName, + media.ContentType, + media.Sha256Hash, + media.SizeBytes, + media.DatastoreId, + media.Location, + ) + + return err +} + func (s *MediaStore) GetByHash(hash string) ([]*types.Media, error) { rows, err := s.statements.selectMediaByHash.QueryContext(s.ctx, hash) if err != nil { diff --git a/types/media.go b/types/media.go index 873da95..9172ae2 100644 --- a/types/media.go +++ b/types/media.go @@ -1,6 +1,22 @@ package types -import "io" +import ( + "encoding/json" + "io" +) + +type MediaRef struct { + Origin string + MediaId string +} + +func (ref MediaRef) MarshalBinary() ([]byte, error) { + return json.Marshal(ref) +} + +func (ref *MediaRef) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, ref) +} type Media struct { Origin string diff --git a/util/upload_notifier.go b/util/upload_notifier.go new file mode 100644 index 0000000..84a9bd4 --- /dev/null +++ b/util/upload_notifier.go @@ -0,0 +1,61 @@ +package util + +import ( + "sync" + "time" +) + +type mediaSet map[chan struct{}]struct{} + +var waiterLock = &sync.Mutex{} +var waiters = map[string]mediaSet{} + +func WaitForUpload(origin string, mediaId string, timeout time.Duration) bool { + key := origin + mediaId + ch := make(chan struct{}, 1) + + waiterLock.Lock() + var set mediaSet + var ok bool + if set, ok = waiters[key]; !ok { + set = make(mediaSet) + waiters[key] = set + } + set[ch] = struct{}{} + waiterLock.Unlock() + + defer func() { + waiterLock.Lock() + + delete(set, ch) + close(ch) + + if len(set) == 0 { + delete(waiters, key) + } + + waiterLock.Unlock() + }() + + select { + case <-ch: + return true + case <-time.After(timeout): + return false + } +} + +func NotifyUpload(origin string, mediaId string) { + waiterLock.Lock() + defer waiterLock.Unlock() + + set := waiters[origin+mediaId] + + if set == nil { + return + } + + for channel := range set { + channel <- struct{}{} + } +}