Skip to content

Commit dd73357

Browse files
authored
Merge pull request #2624 from pyth-network/twap-multiple-price-feed-ids
feat(twap): update TWAP price feed logic to handle multiple price feeds and improve validation
2 parents 3dcf09e + a0a205b commit dd73357

File tree

7 files changed

+452
-107
lines changed

7 files changed

+452
-107
lines changed

target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

Lines changed: 125 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,37 @@ abstract contract Pyth is
120120
return getTotalFee(totalNumUpdates);
121121
}
122122

123+
function getTwapUpdateFee(
124+
bytes[] calldata updateData
125+
) public view override returns (uint feeAmount) {
126+
uint totalNumUpdates = 0;
127+
// For TWAP updates, updateData is always length 2 (start and end points),
128+
// but each VAA can contain multiple price feeds. We only need to count
129+
// the number of updates in the first VAA since both VAAs will have the
130+
// same number of price feeds.
131+
if (
132+
updateData[0].length > 4 &&
133+
UnsafeCalldataBytesLib.toUint32(updateData[0], 0) ==
134+
ACCUMULATOR_MAGIC
135+
) {
136+
(
137+
uint offset,
138+
UpdateType updateType
139+
) = extractUpdateTypeFromAccumulatorHeader(updateData[0]);
140+
if (updateType != UpdateType.WormholeMerkle) {
141+
revert PythErrors.InvalidUpdateData();
142+
}
143+
totalNumUpdates += parseWormholeMerkleHeaderNumUpdates(
144+
updateData[0],
145+
offset
146+
);
147+
} else {
148+
revert PythErrors.InvalidUpdateData();
149+
}
150+
151+
return getTotalFee(totalNumUpdates);
152+
}
153+
123154
// This is an overwrite of the same method in AbstractPyth.sol
124155
// to be more gas efficient.
125156
function updatePriceFeedsIfNecessary(
@@ -372,18 +403,18 @@ abstract contract Pyth is
372403
);
373404
}
374405

375-
function processSingleTwapUpdate(
406+
function extractTwapPriceInfos(
376407
bytes calldata updateData
377408
)
378409
private
379410
view
380411
returns (
381412
/// @return newOffset The next position in the update data after processing this TWAP update
382-
/// @return twapPriceInfo The extracted time-weighted average price information
383-
/// @return priceId The unique identifier for this price feed
413+
/// @return priceInfos Array of extracted TWAP price information
414+
/// @return priceIds Array of corresponding price feed IDs
384415
uint newOffset,
385-
PythStructs.TwapPriceInfo memory twapPriceInfo,
386-
bytes32 priceId
416+
PythStructs.TwapPriceInfo[] memory twapPriceInfos,
417+
bytes32[] memory priceIds
387418
)
388419
{
389420
UpdateType updateType;
@@ -417,12 +448,22 @@ abstract contract Pyth is
417448
revert PythErrors.InvalidUpdateData();
418449
}
419450

420-
// Extract start TWAP data with robust error checking
421-
(offset, twapPriceInfo, priceId) = extractTwapPriceInfoFromMerkleProof(
422-
digest,
423-
encoded,
424-
offset
425-
);
451+
// Initialize arrays to store all price infos and ids from this update
452+
twapPriceInfos = new PythStructs.TwapPriceInfo[](numUpdates);
453+
priceIds = new bytes32[](numUpdates);
454+
455+
// Extract each TWAP price info from the merkle proof
456+
for (uint i = 0; i < numUpdates; i++) {
457+
PythStructs.TwapPriceInfo memory twapPriceInfo;
458+
bytes32 priceId;
459+
(
460+
offset,
461+
twapPriceInfo,
462+
priceId
463+
) = extractTwapPriceInfoFromMerkleProof(digest, encoded, offset);
464+
twapPriceInfos[i] = twapPriceInfo;
465+
priceIds[i] = priceId;
466+
}
426467

427468
if (offset != encoded.length) {
428469
revert PythErrors.InvalidTwapUpdateData();
@@ -439,71 +480,89 @@ abstract contract Pyth is
439480
override
440481
returns (PythStructs.TwapPriceFeed[] memory twapPriceFeeds)
441482
{
442-
// TWAP requires exactly 2 updates - one for the start point and one for the end point
443-
// to calculate the time-weighted average price between those two points
483+
// TWAP requires exactly 2 updates: one for the start point and one for the end point
444484
if (updateData.length != 2) {
445485
revert PythErrors.InvalidUpdateData();
446486
}
447487

448-
uint requiredFee = getUpdateFee(updateData);
488+
uint requiredFee = getTwapUpdateFee(updateData);
449489
if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
450490

451-
unchecked {
452-
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
453-
for (uint i = 0; i < updateData.length - 1; i++) {
454-
if (
455-
(updateData[i].length > 4 &&
456-
UnsafeCalldataBytesLib.toUint32(updateData[i], 0) ==
457-
ACCUMULATOR_MAGIC) &&
458-
(updateData[i + 1].length > 4 &&
459-
UnsafeCalldataBytesLib.toUint32(updateData[i + 1], 0) ==
460-
ACCUMULATOR_MAGIC)
461-
) {
462-
uint offsetStart;
463-
uint offsetEnd;
464-
bytes32 priceIdStart;
465-
bytes32 priceIdEnd;
466-
PythStructs.TwapPriceInfo memory twapPriceInfoStart;
467-
PythStructs.TwapPriceInfo memory twapPriceInfoEnd;
468-
(
469-
offsetStart,
470-
twapPriceInfoStart,
471-
priceIdStart
472-
) = processSingleTwapUpdate(updateData[i]);
473-
(
474-
offsetEnd,
475-
twapPriceInfoEnd,
476-
priceIdEnd
477-
) = processSingleTwapUpdate(updateData[i + 1]);
478-
479-
if (priceIdStart != priceIdEnd)
480-
revert PythErrors.InvalidTwapUpdateDataSet();
481-
482-
validateTwapPriceInfo(twapPriceInfoStart, twapPriceInfoEnd);
483-
484-
uint k = findIndexOfPriceId(priceIds, priceIdStart);
485-
486-
// If priceFeed[k].id != 0 then it means that there was a valid
487-
// update for priceIds[k] and we don't need to process this one.
488-
if (k == priceIds.length || twapPriceFeeds[k].id != 0) {
489-
continue;
490-
}
491-
492-
twapPriceFeeds[k] = calculateTwap(
493-
priceIdStart,
494-
twapPriceInfoStart,
495-
twapPriceInfoEnd
496-
);
497-
} else {
498-
revert PythErrors.InvalidUpdateData();
499-
}
491+
// Process start update data
492+
PythStructs.TwapPriceInfo[] memory startTwapPriceInfos;
493+
bytes32[] memory startPriceIds;
494+
{
495+
uint offsetStart;
496+
(
497+
offsetStart,
498+
startTwapPriceInfos,
499+
startPriceIds
500+
) = extractTwapPriceInfos(updateData[0]);
501+
}
502+
503+
// Process end update data
504+
PythStructs.TwapPriceInfo[] memory endTwapPriceInfos;
505+
bytes32[] memory endPriceIds;
506+
{
507+
uint offsetEnd;
508+
(offsetEnd, endTwapPriceInfos, endPriceIds) = extractTwapPriceInfos(
509+
updateData[1]
510+
);
511+
}
512+
513+
// Verify that we have the same number of price feeds in start and end updates
514+
if (startPriceIds.length != endPriceIds.length) {
515+
revert PythErrors.InvalidTwapUpdateDataSet();
516+
}
517+
518+
// Hermes always returns price feeds in the same order for start and end updates
519+
// This allows us to assume startPriceIds[i] == endPriceIds[i] for efficiency
520+
for (uint i = 0; i < startPriceIds.length; i++) {
521+
if (startPriceIds[i] != endPriceIds[i]) {
522+
revert PythErrors.InvalidTwapUpdateDataSet();
500523
}
524+
}
501525

502-
for (uint k = 0; k < priceIds.length; k++) {
503-
if (twapPriceFeeds[k].id == 0) {
504-
revert PythErrors.PriceFeedNotFoundWithinRange();
526+
// Initialize the output array
527+
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
528+
529+
// For each requested price ID, find matching start and end data points
530+
for (uint i = 0; i < priceIds.length; i++) {
531+
bytes32 requestedPriceId = priceIds[i];
532+
int startIdx = -1;
533+
534+
// Find the index of this price ID in the startPriceIds array
535+
// (which is the same as the endPriceIds array based on our validation above)
536+
for (uint j = 0; j < startPriceIds.length; j++) {
537+
if (startPriceIds[j] == requestedPriceId) {
538+
startIdx = int(j);
539+
break;
505540
}
506541
}
542+
543+
// If we found the price ID
544+
if (startIdx >= 0) {
545+
uint idx = uint(startIdx);
546+
// Validate the pair of price infos
547+
validateTwapPriceInfo(
548+
startTwapPriceInfos[idx],
549+
endTwapPriceInfos[idx]
550+
);
551+
552+
// Calculate TWAP from these data points
553+
twapPriceFeeds[i] = calculateTwap(
554+
requestedPriceId,
555+
startTwapPriceInfos[idx],
556+
endTwapPriceInfos[idx]
557+
);
558+
}
559+
}
560+
561+
// Ensure all requested price IDs were found
562+
for (uint k = 0; k < priceIds.length; k++) {
563+
if (twapPriceFeeds[k].id == 0) {
564+
revert PythErrors.PriceFeedNotFoundWithinRange();
565+
}
507566
}
508567
}
509568

0 commit comments

Comments
 (0)