download_manager_mtproto.cpp 29 KB


  1. /*
  2. This file is part of Telegram Desktop,
  3. the official desktop application for the Telegram messaging service.
  4. For license and copyright information please follow this link:
  5. https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
  6. */
  7. #include "storage/download_manager_mtproto.h"
  8. #include "mtproto/facade.h"
  9. #include "mtproto/mtproto_auth_key.h"
  10. #include "mtproto/mtproto_response.h"
  11. #include "main/main_session.h"
  12. #include "data/data_session.h"
  13. #include "data/data_document.h"
  14. #include "apiwrap.h"
  15. #include "base/openssl_help.h"
  16. namespace Storage {
  17. namespace {
  18. constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
  19. constexpr auto kStartWaitedInSession = 4 * kDownloadPartSize;
  20. constexpr auto kMaxWaitedInSession = 16 * kDownloadPartSize;
  21. constexpr auto kStartSessionsCount = 1;
  22. constexpr auto kMaxSessionsCount = 8;
  23. constexpr auto kMaxTrackedSessionRemoves = 64;
  24. constexpr auto kRetryAddSessionTimeout = 8 * crl::time(1000);
  25. constexpr auto kRetryAddSessionSuccesses = 3;
  26. constexpr auto kMaxTrackedSuccesses = kRetryAddSessionSuccesses
  27. * kMaxTrackedSessionRemoves;
  28. constexpr auto kRemoveSessionAfterTimeouts = 4;
  29. constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
  30. constexpr auto kBadRequestDurationThreshold = 8 * crl::time(1000);
  31. // Each (session remove by timeouts) we wait for time:
  32. // kRetryAddSessionTimeout * max(removesCount, kMaxTrackedSessionRemoves)
  33. // and for successes in all remaining sessions:
  34. // kRetryAddSessionSuccesses * max(removesCount, kMaxTrackedSessionRemoves)
  35. } // namespace
  36. void DownloadManagerMtproto::Queue::enqueue(
  37. not_null<Task*> task,
  38. int priority) {
  39. const auto position = ranges::find_if(_tasks, [&](const Enqueued &task) {
  40. return task.priority <= priority;
  41. }) - begin(_tasks);
  42. const auto now = ranges::find(_tasks, task, &Enqueued::task);
  43. const auto i = [&] {
  44. if (now != end(_tasks)) {
  45. (now->priority = priority);
  46. return now;
  47. }
  48. _tasks.push_back({ task, priority });
  49. return end(_tasks) - 1;
  50. }();
  51. const auto j = begin(_tasks) + position;
  52. if (j < i) {
  53. std::rotate(j, i, i + 1);
  54. } else if (j > i + 1) {
  55. std::rotate(i, i + 1, j);
  56. }
  57. }
  58. void DownloadManagerMtproto::Queue::remove(not_null<Task*> task) {
  59. _tasks.erase(ranges::remove(_tasks, task, &Enqueued::task), end(_tasks));
  60. }
  61. void DownloadManagerMtproto::Queue::resetGeneration() {
  62. const auto from = ranges::find(_tasks, 0, &Enqueued::priority);
  63. for (auto &task : ranges::make_subrange(from, end(_tasks))) {
  64. if (task.priority) {
  65. Assert(task.priority == -1);
  66. break;
  67. }
  68. task.priority = -1;
  69. }
  70. }
  71. bool DownloadManagerMtproto::Queue::empty() const {
  72. return _tasks.empty();
  73. }
  74. auto DownloadManagerMtproto::Queue::nextTask(bool onlyHighestPriority) const
  75. -> Task* {
  76. if (_tasks.empty()) {
  77. return nullptr;
  78. }
  79. const auto highestPriority = _tasks.front().priority;
  80. const auto notHighestPriority = [&](const Enqueued &enqueued) {
  81. return (enqueued.priority != highestPriority);
  82. };
  83. const auto till = (onlyHighestPriority && highestPriority > 0)
  84. ? ranges::find_if(_tasks, notHighestPriority)
  85. : end(_tasks);
  86. const auto readyToRequest = [&](const Enqueued &enqueued) {
  87. return enqueued.task->readyToRequest();
  88. };
  89. const auto first = ranges::find_if(
  90. ranges::make_subrange(begin(_tasks), till),
  91. readyToRequest);
  92. return (first != till) ? first->task.get() : nullptr;
  93. }
  94. void DownloadManagerMtproto::Queue::removeSession(int index) {
  95. for (const auto &enqueued : _tasks) {
  96. enqueued.task->removeSession(index);
  97. }
  98. }
  99. DownloadManagerMtproto::DcSessionBalanceData::DcSessionBalanceData()
  100. : maxWaitedAmount(kStartWaitedInSession) {
  101. }
  102. DownloadManagerMtproto::DcBalanceData::DcBalanceData()
  103. : sessions(kStartSessionsCount) {
  104. }
  105. DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
  106. : _api(api)
  107. , _resetGenerationTimer([=] { resetGeneration(); })
  108. , _killSessionsTimer([=] { killSessions(); }) {
  109. _api->instance().restartsByTimeout(
  110. ) | rpl::filter([](MTP::ShiftedDcId shiftedDcId) {
  111. return MTP::isDownloadDcId(shiftedDcId);
  112. }) | rpl::start_with_next([=](MTP::ShiftedDcId shiftedDcId) {
  113. sessionTimedOut(
  114. MTP::BareDcId(shiftedDcId),
  115. MTP::GetDcIdShift(shiftedDcId));
  116. }, _lifetime);
  117. }
  118. DownloadManagerMtproto::~DownloadManagerMtproto() {
  119. killSessions();
  120. }
  121. void DownloadManagerMtproto::enqueue(not_null<Task*> task, int priority) {
  122. const auto dcId = task->dcId();
  123. auto &queue = _queues[dcId];
  124. queue.enqueue(task, priority);
  125. if (!_resetGenerationTimer.isActive()) {
  126. _resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
  127. }
  128. checkSendNext(dcId, queue);
  129. }
  130. void DownloadManagerMtproto::remove(not_null<Task*> task) {
  131. const auto dcId = task->dcId();
  132. auto &queue = _queues[dcId];
  133. queue.remove(task);
  134. checkSendNext(dcId, queue);
  135. }
  136. void DownloadManagerMtproto::resetGeneration() {
  137. _resetGenerationTimer.cancel();
  138. for (auto &[dcId, queue] : _queues) {
  139. queue.resetGeneration();
  140. }
  141. }
  142. void DownloadManagerMtproto::checkSendNext() {
  143. for (auto &[dcId, queue] : _queues) {
  144. if (queue.empty()) {
  145. continue;
  146. }
  147. checkSendNext(dcId, queue);
  148. }
  149. }
  150. void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
  151. while (trySendNextPart(dcId, queue)) {
  152. }
  153. }
  154. void DownloadManagerMtproto::checkSendNextAfterSuccess(MTP::DcId dcId) {
  155. checkSendNext(dcId, _queues[dcId]);
  156. }
  157. bool DownloadManagerMtproto::trySendNextPart(MTP::DcId dcId, Queue &queue) {
  158. auto &balanceData = _balanceData[dcId];
  159. const auto &sessions = balanceData.sessions;
  160. const auto bestIndex = [&] {
  161. const auto proj = [](const DcSessionBalanceData &data) {
  162. return (data.requested < data.maxWaitedAmount)
  163. ? data.requested
  164. : kMaxWaitedInSession;
  165. };
  166. const auto j = ranges::min_element(sessions, ranges::less(), proj);
  167. return (j->requested + kDownloadPartSize <= j->maxWaitedAmount)
  168. ? (j - begin(sessions))
  169. : -1;
  170. }();
  171. if (bestIndex < 0) {
  172. return false;
  173. }
  174. const auto onlyHighestPriority = (balanceData.totalRequested > 0);
  175. if (const auto task = queue.nextTask(onlyHighestPriority)) {
  176. task->loadPart(bestIndex);
  177. return true;
  178. }
  179. return false;
  180. }
  181. int DownloadManagerMtproto::changeRequestedAmount(
  182. MTP::DcId dcId,
  183. int index,
  184. int delta) {
  185. const auto i = _balanceData.find(dcId);
  186. Assert(i != _balanceData.end());
  187. Assert(index < i->second.sessions.size());
  188. const auto result = (i->second.sessions[index].requested += delta);
  189. i->second.totalRequested += delta;
  190. const auto findNonEmptySession = [](const DcBalanceData &data) {
  191. using namespace rpl::mappers;
  192. return ranges::find_if(
  193. data.sessions,
  194. _1 > 0,
  195. &DcSessionBalanceData::requested);
  196. };
  197. if (delta > 0) {
  198. killSessionsCancel(dcId);
  199. } else if (findNonEmptySession(i->second) == end(i->second.sessions)) {
  200. killSessionsSchedule(dcId);
  201. }
  202. return result;
  203. }
  204. void DownloadManagerMtproto::requestSucceeded(
  205. MTP::DcId dcId,
  206. int index,
  207. int amountAtRequestStart,
  208. crl::time timeAtRequestStart) {
  209. using namespace rpl::mappers;
  210. const auto i = _balanceData.find(dcId);
  211. Assert(i != end(_balanceData));
  212. auto &dc = i->second;
  213. Assert(index < dc.sessions.size());
  214. auto &data = dc.sessions[index];
  215. const auto overloaded = (timeAtRequestStart <= dc.lastSessionRemove)
  216. || (amountAtRequestStart > data.maxWaitedAmount);
  217. const auto parts = amountAtRequestStart / kDownloadPartSize;
  218. const auto duration = (crl::now() - timeAtRequestStart);
  219. DEBUG_LOG(("Download (%1,%2) request done, duration: %3, parts: %4%5"
  220. ).arg(dcId
  221. ).arg(index
  222. ).arg(duration
  223. ).arg(parts
  224. ).arg(overloaded ? " (overloaded)" : ""));
  225. if (overloaded) {
  226. return;
  227. }
  228. if (duration >= kBadRequestDurationThreshold) {
  229. DEBUG_LOG(("Duration too large, signaling time out."));
  230. crl::on_main(this, [=] {
  231. sessionTimedOut(dcId, index);
  232. });
  233. return;
  234. }
  235. if (amountAtRequestStart == data.maxWaitedAmount
  236. && data.maxWaitedAmount < kMaxWaitedInSession) {
  237. data.maxWaitedAmount = std::min(
  238. data.maxWaitedAmount + kDownloadPartSize,
  239. kMaxWaitedInSession);
  240. DEBUG_LOG(("Download (%1,%2) increased max waited amount %3."
  241. ).arg(dcId
  242. ).arg(index
  243. ).arg(data.maxWaitedAmount));
  244. }
  245. data.successes = std::min(data.successes + 1, kMaxTrackedSuccesses);
  246. const auto notEnough = ranges::any_of(
  247. dc.sessions,
  248. _1 < (dc.sessionRemoveTimes + 1) * kRetryAddSessionSuccesses,
  249. &DcSessionBalanceData::successes);
  250. if (notEnough) {
  251. return;
  252. }
  253. for (auto &session : dc.sessions) {
  254. session.successes = 0;
  255. }
  256. if (dc.timeouts > 0) {
  257. --dc.timeouts;
  258. return;
  259. } else if (dc.sessions.size() == kMaxSessionsCount) {
  260. return;
  261. }
  262. const auto now = crl::now();
  263. const auto delay = (dc.sessionRemoveTimes + 1) * kRetryAddSessionTimeout;
  264. if (dc.lastSessionRemove && now < dc.lastSessionRemove + delay) {
  265. return;
  266. }
  267. dc.sessions.emplace_back();
  268. DEBUG_LOG(("Download (%1,%2) adding, now sessions: %3"
  269. ).arg(dcId
  270. ).arg(dc.sessions.size() - 1
  271. ).arg(dc.sessions.size()));
  272. }
  273. int DownloadManagerMtproto::chooseSessionIndex(MTP::DcId dcId) const {
  274. const auto i = _balanceData.find(dcId);
  275. Assert(i != end(_balanceData));
  276. const auto &sessions = i->second.sessions;
  277. const auto j = ranges::min_element(
  278. sessions,
  279. ranges::less(),
  280. &DcSessionBalanceData::requested);
  281. return (j - begin(sessions));
  282. }
  283. void DownloadManagerMtproto::sessionTimedOut(MTP::DcId dcId, int index) {
  284. const auto i = _balanceData.find(dcId);
  285. if (i == end(_balanceData)) {
  286. return;
  287. }
  288. auto &dc = i->second;
  289. if (index >= dc.sessions.size()) {
  290. return;
  291. }
  292. DEBUG_LOG(("Download (%1,%2) session timed-out.").arg(dcId).arg(index));
  293. for (auto &session : dc.sessions) {
  294. session.successes = 0;
  295. }
  296. if (dc.sessions.size() == kStartSessionsCount
  297. || ++dc.timeouts < kRemoveSessionAfterTimeouts) {
  298. return;
  299. }
  300. dc.timeouts = 0;
  301. removeSession(dcId);
  302. }
  303. void DownloadManagerMtproto::removeSession(MTP::DcId dcId) {
  304. auto &dc = _balanceData[dcId];
  305. Assert(dc.sessions.size() > kStartSessionsCount);
  306. const auto index = int(dc.sessions.size() - 1);
  307. DEBUG_LOG(("Download (%1,%2) removing, now sessions: %3"
  308. ).arg(dcId
  309. ).arg(index
  310. ).arg(index));
  311. auto &queue = _queues[dcId];
  312. if (dc.sessionRemoveIndex == index) {
  313. dc.sessionRemoveTimes = std::min(
  314. dc.sessionRemoveTimes + 1,
  315. kMaxTrackedSessionRemoves);
  316. } else {
  317. dc.sessionRemoveIndex = index;
  318. dc.sessionRemoveTimes = 1;
  319. }
  320. auto &session = dc.sessions.back();
  321. // Make sure we don't send anything to that session while redirecting.
  322. session.requested += kMaxWaitedInSession * kMaxSessionsCount;
  323. queue.removeSession(index);
  324. Assert(session.requested == kMaxWaitedInSession * kMaxSessionsCount);
  325. dc.sessions.pop_back();
  326. api().instance().killSession(MTP::downloadDcId(dcId, index));
  327. dc.lastSessionRemove = crl::now();
  328. }
  329. void DownloadManagerMtproto::killSessionsSchedule(MTP::DcId dcId) {
  330. if (!_killSessionsWhen.contains(dcId)) {
  331. _killSessionsWhen.emplace(dcId, crl::now() + kKillSessionTimeout);
  332. }
  333. if (!_killSessionsTimer.isActive()) {
  334. _killSessionsTimer.callOnce(kKillSessionTimeout + 5);
  335. }
  336. }
  337. void DownloadManagerMtproto::killSessionsCancel(MTP::DcId dcId) {
  338. _killSessionsWhen.erase(dcId);
  339. if (_killSessionsWhen.empty()) {
  340. _killSessionsTimer.cancel();
  341. }
  342. }
  343. void DownloadManagerMtproto::killSessions() {
  344. const auto now = crl::now();
  345. auto left = kKillSessionTimeout;
  346. for (auto i = begin(_killSessionsWhen); i != end(_killSessionsWhen); ) {
  347. if (i->second <= now) {
  348. killSessions(i->first);
  349. i = _killSessionsWhen.erase(i);
  350. } else {
  351. if (i->second - now < left) {
  352. left = i->second - now;
  353. }
  354. ++i;
  355. }
  356. }
  357. if (!_killSessionsWhen.empty()) {
  358. _killSessionsTimer.callOnce(left);
  359. }
  360. }
  361. void DownloadManagerMtproto::killSessions(MTP::DcId dcId) {
  362. const auto i = _balanceData.find(dcId);
  363. if (i != end(_balanceData)) {
  364. auto &dc = i->second;
  365. Assert(dc.totalRequested == 0);
  366. auto sessions = base::take(dc.sessions);
  367. dc = DcBalanceData();
  368. for (auto j = 0; j != int(sessions.size()); ++j) {
  369. Assert(sessions[j].requested == 0);
  370. sessions[j] = DcSessionBalanceData();
  371. api().instance().stopSession(MTP::downloadDcId(dcId, j));
  372. }
  373. dc.sessions = base::take(sessions);
  374. }
  375. }
  376. DownloadMtprotoTask::DownloadMtprotoTask(
  377. not_null<DownloadManagerMtproto*> owner,
  378. const StorageFileLocation &location,
  379. Data::FileOrigin origin)
  380. : _owner(owner)
  381. , _dcId(location.dcId())
  382. , _location({ location })
  383. , _origin(origin) {
  384. }
  385. DownloadMtprotoTask::DownloadMtprotoTask(
  386. not_null<DownloadManagerMtproto*> owner,
  387. MTP::DcId dcId,
  388. const Location &location)
  389. : _owner(owner)
  390. , _dcId(dcId)
  391. , _location(location) {
  392. }
  393. DownloadMtprotoTask::~DownloadMtprotoTask() {
  394. cancelAllRequests();
  395. _owner->remove(this);
  396. }
  397. MTP::DcId DownloadMtprotoTask::dcId() const {
  398. return _dcId;
  399. }
  400. Data::FileOrigin DownloadMtprotoTask::fileOrigin() const {
  401. return _origin;
  402. }
  403. uint64 DownloadMtprotoTask::objectId() const {
  404. if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
  405. return v->objectId();
  406. }
  407. return 0;
  408. }
  409. const DownloadMtprotoTask::Location &DownloadMtprotoTask::location() const {
  410. return _location;
  411. }
  412. void DownloadMtprotoTask::refreshFileReferenceFrom(
  413. const Data::UpdatedFileReferences &updates,
  414. int requestId,
  415. const QByteArray &current) {
  416. if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
  417. v->refreshFileReference(updates);
  418. if (v->fileReference() == current) {
  419. cancelOnFail();
  420. return;
  421. }
  422. } else {
  423. cancelOnFail();
  424. return;
  425. }
  426. if (_sentRequests.contains(requestId)) {
  427. makeRequest(finishSentRequest(
  428. requestId,
  429. FinishRequestReason::Redirect));
  430. }
  431. }
  432. void DownloadMtprotoTask::loadPart(int sessionIndex) {
  433. makeRequest({ takeNextRequestOffset(), sessionIndex });
  434. }
  435. void DownloadMtprotoTask::removeSession(int sessionIndex) {
  436. struct Redirect {
  437. mtpRequestId requestId = 0;
  438. int64 offset = 0;
  439. };
  440. auto redirect = std::vector<Redirect>();
  441. for (const auto &[requestId, requestData] : _sentRequests) {
  442. if (requestData.sessionIndex == sessionIndex) {
  443. redirect.reserve(_sentRequests.size());
  444. redirect.push_back({ requestId, requestData.offset });
  445. }
  446. }
  447. for (auto &[requestData, bytes] : _cdnUncheckedParts) {
  448. if (requestData.sessionIndex == sessionIndex) {
  449. const auto newIndex = _owner->chooseSessionIndex(dcId());
  450. Assert(newIndex < sessionIndex);
  451. requestData.sessionIndex = newIndex;
  452. }
  453. }
  454. for (const auto &[requestId, offset] : redirect) {
  455. const auto needMakeRequest = (requestId != _cdnHashesRequestId);
  456. cancelRequest(requestId);
  457. if (needMakeRequest) {
  458. const auto newIndex = _owner->chooseSessionIndex(dcId());
  459. Assert(newIndex < sessionIndex);
  460. makeRequest({ offset, newIndex });
  461. }
  462. }
  463. }
  464. mtpRequestId DownloadMtprotoTask::sendRequest(
  465. const RequestData &requestData) {
  466. const auto offset = requestData.offset;
  467. const auto limit = Storage::kDownloadPartSize;
  468. const auto shiftedDcId = MTP::downloadDcId(
  469. _cdnDcId ? _cdnDcId : dcId(),
  470. requestData.sessionIndex);
  471. if (_cdnDcId) {
  472. return api().request(MTPupload_GetCdnFile(
  473. MTP_bytes(_cdnToken),
  474. MTP_long(offset),
  475. MTP_int(limit)
  476. )).done([=](const MTPupload_CdnFile &result, mtpRequestId id) {
  477. cdnPartLoaded(result, id);
  478. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  479. cdnPartFailed(error, id);
  480. }).toDC(shiftedDcId).send();
  481. }
  482. return v::match(_location.data, [&](const WebFileLocation &location) {
  483. return api().request(MTPupload_GetWebFile(
  484. MTP_inputWebFileLocation(
  485. MTP_bytes(location.url()),
  486. MTP_long(location.accessHash())),
  487. MTP_int(offset),
  488. MTP_int(limit)
  489. )).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
  490. webPartLoaded(result, id);
  491. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  492. partFailed(error, id);
  493. }).toDC(shiftedDcId).send();
  494. }, [&](const GeoPointLocation &location) {
  495. return api().request(MTPupload_GetWebFile(
  496. MTP_inputWebFileGeoPointLocation(
  497. MTP_inputGeoPoint(
  498. MTP_flags(0),
  499. MTP_double(location.lat),
  500. MTP_double(location.lon),
  501. MTP_int(0)), // accuracy_radius
  502. MTP_long(location.access),
  503. MTP_int(location.width),
  504. MTP_int(location.height),
  505. MTP_int(location.zoom),
  506. MTP_int(location.scale)),
  507. MTP_int(offset),
  508. MTP_int(limit)
  509. )).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
  510. webPartLoaded(result, id);
  511. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  512. partFailed(error, id);
  513. }).toDC(shiftedDcId).send();
  514. }, [&](const AudioAlbumThumbLocation &location) {
  515. using Flag = MTPDinputWebFileAudioAlbumThumbLocation::Flag;
  516. const auto owner = &api().session().data();
  517. return api().request(MTPupload_GetWebFile(
  518. MTP_inputWebFileAudioAlbumThumbLocation(
  519. MTP_flags(Flag::f_document | Flag::f_small),
  520. owner->document(location.documentId)->mtpInput(),
  521. MTPstring(),
  522. MTPstring()),
  523. MTP_int(offset),
  524. MTP_int(limit)
  525. )).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
  526. webPartLoaded(result, id);
  527. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  528. partFailed(error, id);
  529. }).toDC(shiftedDcId).send();
  530. }, [&](const StorageFileLocation &location) {
  531. const auto reference = location.fileReference();
  532. return api().request(MTPupload_GetFile(
  533. MTP_flags(MTPupload_GetFile::Flag::f_cdn_supported),
  534. location.tl(api().session().userId()),
  535. MTP_long(offset),
  536. MTP_int(limit)
  537. )).done([=](const MTPupload_File &result, mtpRequestId id) {
  538. normalPartLoaded(result, id);
  539. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  540. normalPartFailed(reference, error, id);
  541. }).toDC(shiftedDcId).send();
  542. });
  543. }
  544. bool DownloadMtprotoTask::setWebFileSizeHook(int64 size) {
  545. return true;
  546. }
  547. void DownloadMtprotoTask::makeRequest(const RequestData &requestData) {
  548. placeSentRequest(sendRequest(requestData), requestData);
  549. }
  550. void DownloadMtprotoTask::requestMoreCdnFileHashes() {
  551. if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
  552. return;
  553. }
  554. const auto requestData = _cdnUncheckedParts.cbegin()->first;
  555. const auto shiftedDcId = MTP::downloadDcId(
  556. dcId(),
  557. requestData.sessionIndex);
  558. _cdnHashesRequestId = api().request(MTPupload_GetCdnFileHashes(
  559. MTP_bytes(_cdnToken),
  560. MTP_long(requestData.offset)
  561. )).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
  562. getCdnFileHashesDone(result, id);
  563. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  564. cdnPartFailed(error, id);
  565. }).toDC(shiftedDcId).send();
  566. placeSentRequest(_cdnHashesRequestId, requestData);
  567. }
  568. void DownloadMtprotoTask::normalPartLoaded(
  569. const MTPupload_File &result,
  570. mtpRequestId requestId) {
  571. const auto requestData = finishSentRequest(
  572. requestId,
  573. FinishRequestReason::Success);
  574. const auto owner = _owner;
  575. const auto dcId = this->dcId();
  576. result.match([&](const MTPDupload_fileCdnRedirect &data) {
  577. switchToCDN(requestData, data);
  578. }, [&](const MTPDupload_file &data) {
  579. partLoaded(requestData.offset, data.vbytes().v);
  580. });
  581. // 'this' may be deleted at this point.
  582. owner->checkSendNextAfterSuccess(dcId);
  583. }
  584. void DownloadMtprotoTask::webPartLoaded(
  585. const MTPupload_WebFile &result,
  586. mtpRequestId requestId) {
  587. const auto requestData = finishSentRequest(
  588. requestId,
  589. FinishRequestReason::Success);
  590. const auto owner = _owner;
  591. const auto dcId = this->dcId();
  592. result.match([&](const MTPDupload_webFile &data) {
  593. if (setWebFileSizeHook(data.vsize().v)) {
  594. partLoaded(requestData.offset, data.vbytes().v);
  595. }
  596. });
  597. // 'this' may be deleted at this point.
  598. owner->checkSendNextAfterSuccess(dcId);
  599. }
  600. void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
  601. result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
  602. const auto requestData = finishSentRequest(
  603. requestId,
  604. FinishRequestReason::Redirect);
  605. const auto shiftedDcId = MTP::downloadDcId(
  606. dcId(),
  607. requestData.sessionIndex);
  608. const auto requestId = api().request(MTPupload_ReuploadCdnFile(
  609. MTP_bytes(_cdnToken),
  610. data.vrequest_token()
  611. )).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
  612. reuploadDone(result, id);
  613. }).fail([=](const MTP::Error &error, mtpRequestId id) {
  614. cdnPartFailed(error, id);
  615. }).toDC(shiftedDcId).send();
  616. placeSentRequest(requestId, requestData);
  617. }, [&](const MTPDupload_cdnFile &data) {
  618. const auto requestData = finishSentRequest(
  619. requestId,
  620. FinishRequestReason::Success);
  621. const auto owner = _owner;
  622. const auto dcId = this->dcId();
  623. const auto guard = gsl::finally([=] {
  624. // 'this' may be deleted at this point.
  625. owner->checkSendNextAfterSuccess(dcId);
  626. });
  627. auto key = bytes::make_span(_cdnEncryptionKey);
  628. auto iv = bytes::make_span(_cdnEncryptionIV);
  629. Expects(key.size() == MTP::CTRState::KeySize);
  630. Expects(iv.size() == MTP::CTRState::IvecSize);
  631. auto state = MTP::CTRState();
  632. auto ivec = bytes::make_span(state.ivec);
  633. std::copy(iv.begin(), iv.end(), ivec.begin());
  634. auto counterOffset = static_cast<uint32>(requestData.offset >> 4);
  635. state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
  636. state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
  637. state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
  638. state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
  639. auto decryptInPlace = data.vbytes().v;
  640. auto buffer = bytes::make_detached_span(decryptInPlace);
  641. MTP::aesCtrEncrypt(buffer, key.data(), &state);
  642. switch (checkCdnFileHash(requestData.offset, buffer)) {
  643. case CheckCdnHashResult::NoHash: {
  644. _cdnUncheckedParts.emplace(requestData, decryptInPlace);
  645. requestMoreCdnFileHashes();
  646. } return;
  647. case CheckCdnHashResult::Invalid: {
  648. LOG(("API Error: Wrong cdnFileHash for offset %1."
  649. ).arg(requestData.offset));
  650. cancelOnFail();
  651. } return;
  652. case CheckCdnHashResult::Good: {
  653. partLoaded(requestData.offset, decryptInPlace);
  654. } return;
  655. }
  656. Unexpected("Result of checkCdnFileHash()");
  657. });
  658. }
  659. DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
  660. int64 offset,
  661. bytes::const_span buffer) {
  662. const auto cdnFileHashIt = _cdnFileHashes.find(offset);
  663. if (cdnFileHashIt == _cdnFileHashes.cend()) {
  664. return CheckCdnHashResult::NoHash;
  665. }
  666. const auto realHash = openssl::Sha256(buffer);
  667. const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
  668. if (bytes::compare(realHash, receivedHash)) {
  669. return CheckCdnHashResult::Invalid;
  670. }
  671. return CheckCdnHashResult::Good;
  672. }
  673. void DownloadMtprotoTask::reuploadDone(
  674. const MTPVector<MTPFileHash> &result,
  675. mtpRequestId requestId) {
  676. const auto requestData = finishSentRequest(
  677. requestId,
  678. FinishRequestReason::Redirect);
  679. addCdnHashes(result.v);
  680. makeRequest(requestData);
  681. }
  682. void DownloadMtprotoTask::getCdnFileHashesDone(
  683. const MTPVector<MTPFileHash> &result,
  684. mtpRequestId requestId) {
  685. Expects(_cdnHashesRequestId == requestId);
  686. const auto requestData = finishSentRequest(
  687. requestId,
  688. FinishRequestReason::Redirect);
  689. addCdnHashes(result.v);
  690. auto someMoreChecked = false;
  691. for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
  692. const auto uncheckedData = i->first;
  693. const auto uncheckedBytes = bytes::make_span(i->second);
  694. switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
  695. case CheckCdnHashResult::NoHash: {
  696. ++i;
  697. } break;
  698. case CheckCdnHashResult::Invalid: {
  699. LOG(("API Error: Wrong cdnFileHash for offset %1."
  700. ).arg(uncheckedData.offset));
  701. cancelOnFail();
  702. return;
  703. } break;
  704. case CheckCdnHashResult::Good: {
  705. someMoreChecked = true;
  706. const auto goodOffset = uncheckedData.offset;
  707. const auto goodBytes = std::move(i->second);
  708. const auto weak = base::make_weak(this);
  709. i = _cdnUncheckedParts.erase(i);
  710. if (!feedPart(goodOffset, goodBytes) || !weak) {
  711. return;
  712. }
  713. } break;
  714. default: Unexpected("Result of checkCdnFileHash()");
  715. }
  716. }
  717. if (!someMoreChecked) {
  718. LOG(("API Error: "
  719. "Could not find cdnFileHash for offset %1 "
  720. "after getCdnFileHashes request."
  721. ).arg(requestData.offset));
  722. cancelOnFail();
  723. return;
  724. }
  725. requestMoreCdnFileHashes();
  726. }
  727. void DownloadMtprotoTask::placeSentRequest(
  728. mtpRequestId requestId,
  729. const RequestData &requestData) {
  730. if (_sentRequests.empty()) {
  731. subscribeToNonPremiumLimit();
  732. }
  733. const auto amount = _owner->changeRequestedAmount(
  734. dcId(),
  735. requestData.sessionIndex,
  736. Storage::kDownloadPartSize);
  737. const auto &[i, ok1] = _sentRequests.emplace(requestId, requestData);
  738. const auto &[j, ok2] = _requestByOffset.emplace(
  739. requestData.offset,
  740. requestId);
  741. i->second.requestedInSession = amount;
  742. i->second.sent = crl::now();
  743. Ensures(ok1 && ok2);
  744. }
  745. void DownloadMtprotoTask::subscribeToNonPremiumLimit() {
  746. if (_nonPremiumLimitSubscription) {
  747. return;
  748. }
  749. _owner->api().instance().nonPremiumDelayedRequests(
  750. ) | rpl::start_with_next([=](mtpRequestId id) {
  751. if (_sentRequests.contains(id)) {
  752. if (const auto documentId = objectId()) {
  753. const auto type = v::get<StorageFileLocation>(
  754. _location.data).type();
  755. if (type == StorageFileLocation::Type::Document) {
  756. _owner->notifyNonPremiumDelay(documentId);
  757. }
  758. }
  759. }
  760. }, _nonPremiumLimitSubscription);
  761. }
  762. auto DownloadMtprotoTask::finishSentRequest(
  763. mtpRequestId requestId,
  764. FinishRequestReason reason)
  765. -> RequestData {
  766. auto it = _sentRequests.find(requestId);
  767. Assert(it != _sentRequests.cend());
  768. if (_cdnHashesRequestId == requestId) {
  769. _cdnHashesRequestId = 0;
  770. }
  771. const auto result = it->second;
  772. _owner->changeRequestedAmount(
  773. dcId(),
  774. result.sessionIndex,
  775. -Storage::kDownloadPartSize);
  776. _sentRequests.erase(it);
  777. const auto ok = _requestByOffset.remove(result.offset);
  778. if (_sentRequests.empty()) {
  779. _nonPremiumLimitSubscription.destroy();
  780. }
  781. if (reason == FinishRequestReason::Success) {
  782. _owner->requestSucceeded(
  783. dcId(),
  784. result.sessionIndex,
  785. result.requestedInSession,
  786. result.sent);
  787. }
  788. Ensures(ok);
  789. return result;
  790. }
  791. bool DownloadMtprotoTask::haveSentRequests() const {
  792. return !_sentRequests.empty() || !_cdnUncheckedParts.empty();
  793. }
  794. bool DownloadMtprotoTask::haveSentRequestForOffset(int64 offset) const {
  795. return _requestByOffset.contains(offset)
  796. || _cdnUncheckedParts.contains({ offset, 0 });
  797. }
  798. void DownloadMtprotoTask::cancelAllRequests() {
  799. while (!_sentRequests.empty()) {
  800. cancelRequest(_sentRequests.begin()->first);
  801. }
  802. _cdnUncheckedParts.clear();
  803. }
  804. void DownloadMtprotoTask::cancelRequestForOffset(int64 offset) {
  805. const auto i = _requestByOffset.find(offset);
  806. if (i != end(_requestByOffset)) {
  807. cancelRequest(i->second);
  808. }
  809. _cdnUncheckedParts.remove({ offset, 0 });
  810. }
  811. void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
  812. const auto hashes = (_cdnHashesRequestId == requestId);
  813. api().request(requestId).cancel();
  814. [[maybe_unused]] const auto data = finishSentRequest(
  815. requestId,
  816. FinishRequestReason::Cancel);
  817. if (hashes && !_cdnUncheckedParts.empty()) {
  818. crl::on_main(this, [=] {
  819. requestMoreCdnFileHashes();
  820. });
  821. }
  822. }
  823. void DownloadMtprotoTask::addToQueue(int priority) {
  824. _owner->enqueue(this, priority);
  825. }
  826. void DownloadMtprotoTask::removeFromQueue() {
  827. _owner->remove(this);
  828. }
  829. void DownloadMtprotoTask::partLoaded(
  830. int64 offset,
  831. const QByteArray &bytes) {
  832. feedPart(offset, bytes);
  833. }
  834. bool DownloadMtprotoTask::normalPartFailed(
  835. QByteArray fileReference,
  836. const MTP::Error &error,
  837. mtpRequestId requestId) {
  838. if (MTP::IsDefaultHandledError(error)) {
  839. return false;
  840. }
  841. if (error.code() == 400
  842. && error.type().startsWith(u"FILE_REFERENCE_"_q)) {
  843. api().refreshFileReference(
  844. _origin,
  845. this,
  846. requestId,
  847. fileReference);
  848. return true;
  849. }
  850. return partFailed(error, requestId);
  851. }
  852. bool DownloadMtprotoTask::partFailed(
  853. const MTP::Error &error,
  854. mtpRequestId requestId) {
  855. if (MTP::IsDefaultHandledError(error)) {
  856. return false;
  857. }
  858. cancelOnFail();
  859. return true;
  860. }
  861. bool DownloadMtprotoTask::cdnPartFailed(
  862. const MTP::Error &error,
  863. mtpRequestId requestId) {
  864. if (MTP::IsDefaultHandledError(error)) {
  865. return false;
  866. }
  867. if (error.type() == u"FILE_TOKEN_INVALID"_q
  868. || error.type() == u"REQUEST_TOKEN_INVALID"_q) {
  869. const auto requestData = finishSentRequest(
  870. requestId,
  871. FinishRequestReason::Redirect);
  872. changeCDNParams(
  873. requestData,
  874. 0,
  875. QByteArray(),
  876. QByteArray(),
  877. QByteArray(),
  878. QVector<MTPFileHash>());
  879. return true;
  880. }
  881. return partFailed(error, requestId);
  882. }
  883. void DownloadMtprotoTask::switchToCDN(
  884. const RequestData &requestData,
  885. const MTPDupload_fileCdnRedirect &redirect) {
  886. changeCDNParams(
  887. requestData,
  888. redirect.vdc_id().v,
  889. redirect.vfile_token().v,
  890. redirect.vencryption_key().v,
  891. redirect.vencryption_iv().v,
  892. redirect.vfile_hashes().v);
  893. }
  894. void DownloadMtprotoTask::addCdnHashes(
  895. const QVector<MTPFileHash> &hashes) {
  896. for (const auto &hash : hashes) {
  897. hash.match([&](const MTPDfileHash &data) {
  898. _cdnFileHashes.emplace(
  899. data.voffset().v,
  900. CdnFileHash{ data.vlimit().v, data.vhash().v });
  901. });
  902. }
  903. }
  904. void DownloadMtprotoTask::changeCDNParams(
  905. const RequestData &requestData,
  906. MTP::DcId dcId,
  907. const QByteArray &token,
  908. const QByteArray &encryptionKey,
  909. const QByteArray &encryptionIV,
  910. const QVector<MTPFileHash> &hashes) {
  911. if (dcId != 0
  912. && (encryptionKey.size() != MTP::CTRState::KeySize
  913. || encryptionIV.size() != MTP::CTRState::IvecSize)) {
  914. LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params"
  915. ).arg(encryptionKey.size()
  916. ).arg(encryptionIV.size()));
  917. cancelOnFail();
  918. return;
  919. }
  920. auto resendAllRequests = (_cdnDcId != dcId
  921. || _cdnToken != token
  922. || _cdnEncryptionKey != encryptionKey
  923. || _cdnEncryptionIV != encryptionIV);
  924. _cdnDcId = dcId;
  925. _cdnToken = token;
  926. _cdnEncryptionKey = encryptionKey;
  927. _cdnEncryptionIV = encryptionIV;
  928. addCdnHashes(hashes);
  929. if (resendAllRequests && !_sentRequests.empty()) {
  930. auto resendRequests = std::vector<RequestData>();
  931. resendRequests.reserve(_sentRequests.size());
  932. while (!_sentRequests.empty()) {
  933. const auto requestId = _sentRequests.begin()->first;
  934. api().request(requestId).cancel();
  935. resendRequests.push_back(finishSentRequest(
  936. requestId,
  937. FinishRequestReason::Redirect));
  938. }
  939. for (const auto &requestData : resendRequests) {
  940. makeRequest(requestData);
  941. }
  942. }
  943. makeRequest(requestData);
  944. }
  945. } // namespace Storage