@@ -533,8 +533,12 @@ impl Processor {
533533 let token_a = Self :: unpack_token_account ( token_a_info, & token_swap. token_program_id ( ) ) ?;
534534 let token_b = Self :: unpack_token_account ( token_b_info, & token_swap. token_program_id ( ) ) ?;
535535 let pool_mint = Self :: unpack_mint ( pool_mint_info, & token_swap. token_program_id ( ) ) ?;
536- let pool_token_amount = to_u128 ( pool_token_amount) ?;
537- let pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
536+ let current_pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
537+ let ( pool_token_amount, pool_mint_supply) = if current_pool_mint_supply > 0 {
538+ ( to_u128 ( pool_token_amount) ?, current_pool_mint_supply)
539+ } else {
540+ ( calculator. new_pool_supply ( ) , calculator. new_pool_supply ( ) )
541+ } ;
538542
539543 let results = calculator
540544 . pool_tokens_to_trading_tokens (
@@ -658,13 +662,15 @@ impl Processor {
658662 )
659663 . ok_or ( SwapError :: ZeroTradingTokens ) ?;
660664 let token_a_amount = to_u64 ( results. token_a_amount ) ?;
665+ let token_a_amount = std:: cmp:: min ( token_a. amount , token_a_amount) ;
661666 if token_a_amount < minimum_token_a_amount {
662667 return Err ( SwapError :: ExceededSlippage . into ( ) ) ;
663668 }
664669 if token_a_amount == 0 && token_a. amount != 0 {
665670 return Err ( SwapError :: ZeroTradingTokens . into ( ) ) ;
666671 }
667672 let token_b_amount = to_u64 ( results. token_b_amount ) ?;
673+ let token_b_amount = std:: cmp:: min ( token_b. amount , token_b_amount) ;
668674 if token_b_amount < minimum_token_b_amount {
669675 return Err ( SwapError :: ExceededSlippage . into ( ) ) ;
670676 }
@@ -693,7 +699,6 @@ impl Processor {
693699 to_u64 ( pool_token_amount) ?,
694700 ) ?;
695701
696- let token_a_amount = std:: cmp:: min ( token_a. amount , token_a_amount) ;
697702 if token_a_amount > 0 {
698703 Self :: token_transfer (
699704 swap_info. key ,
@@ -705,7 +710,6 @@ impl Processor {
705710 token_a_amount,
706711 ) ?;
707712 }
708- let token_b_amount = std:: cmp:: min ( token_b. amount , token_b_amount) ;
709713 if token_b_amount > 0 {
710714 Self :: token_transfer (
711715 swap_info. key ,
@@ -775,19 +779,22 @@ impl Processor {
775779
776780 let pool_mint = Self :: unpack_mint ( pool_mint_info, & token_swap. token_program_id ( ) ) ?;
777781 let pool_mint_supply = to_u128 ( pool_mint. supply ) ?;
778-
779- let pool_token_amount = token_swap
780- . swap_curve ( )
781- . trading_tokens_to_pool_tokens (
782- to_u128 ( source_token_amount) ?,
783- to_u128 ( swap_token_a. amount ) ?,
784- to_u128 ( swap_token_b. amount ) ?,
785- pool_mint_supply,
786- trade_direction,
787- RoundDirection :: Floor ,
788- token_swap. fees ( ) ,
789- )
790- . ok_or ( SwapError :: ZeroTradingTokens ) ?;
782+ let pool_token_amount = if pool_mint_supply > 0 {
783+ token_swap
784+ . swap_curve ( )
785+ . trading_tokens_to_pool_tokens (
786+ to_u128 ( source_token_amount) ?,
787+ to_u128 ( swap_token_a. amount ) ?,
788+ to_u128 ( swap_token_b. amount ) ?,
789+ pool_mint_supply,
790+ trade_direction,
791+ RoundDirection :: Floor ,
792+ token_swap. fees ( ) ,
793+ )
794+ . ok_or ( SwapError :: ZeroTradingTokens ) ?
795+ } else {
796+ token_swap. swap_curve ( ) . calculator . new_pool_supply ( )
797+ } ;
791798
792799 let pool_token_amount = to_u64 ( pool_token_amount) ?;
793800 if pool_token_amount < minimum_pool_token_amount {
@@ -6730,4 +6737,176 @@ mod tests {
67306737 spl_token:: state:: Account :: unpack ( & accounts. token_b_account . data ) . unwrap ( ) ;
67316738 assert_eq ! ( swap_token_b. amount, 0 ) ;
67326739 }
6740+
6741+ #[ test]
6742+ fn test_withdraw_all_constant_price_curve ( ) {
6743+ let trade_fee_numerator = 1 ;
6744+ let trade_fee_denominator = 10 ;
6745+ let owner_trade_fee_numerator = 1 ;
6746+ let owner_trade_fee_denominator = 30 ;
6747+ let owner_withdraw_fee_numerator = 0 ;
6748+ let owner_withdraw_fee_denominator = 30 ;
6749+ let host_fee_numerator = 10 ;
6750+ let host_fee_denominator = 100 ;
6751+
6752+ // initialize "unbalanced", so that withdrawing all will have some issues
6753+ // A: 1_000_000_000
6754+ // B: 2_000_000_000 (1_000 * 2_000_000)
6755+ let swap_token_a_amount = 1_000_000_000 ;
6756+ let swap_token_b_amount = 1_000 ;
6757+ let token_b_price = 2_000_000 ;
6758+ let fees = Fees {
6759+ trade_fee_numerator,
6760+ trade_fee_denominator,
6761+ owner_trade_fee_numerator,
6762+ owner_trade_fee_denominator,
6763+ owner_withdraw_fee_numerator,
6764+ owner_withdraw_fee_denominator,
6765+ host_fee_numerator,
6766+ host_fee_denominator,
6767+ } ;
6768+
6769+ let swap_curve = SwapCurve {
6770+ curve_type : CurveType :: ConstantPrice ,
6771+ calculator : Box :: new ( ConstantPriceCurve { token_b_price } ) ,
6772+ } ;
6773+ let total_pool = swap_curve. calculator . new_pool_supply ( ) ;
6774+ let user_key = Pubkey :: new_unique ( ) ;
6775+ let withdrawer_key = Pubkey :: new_unique ( ) ;
6776+
6777+ let mut accounts = SwapAccountInfo :: new (
6778+ & user_key,
6779+ fees,
6780+ swap_curve,
6781+ swap_token_a_amount,
6782+ swap_token_b_amount,
6783+ ) ;
6784+
6785+ accounts. initialize_swap ( ) . unwrap ( ) ;
6786+
6787+ let (
6788+ token_a_key,
6789+ mut token_a_account,
6790+ token_b_key,
6791+ mut token_b_account,
6792+ _pool_key,
6793+ _pool_account,
6794+ ) = accounts. setup_token_accounts ( & user_key, & withdrawer_key, 0 , 0 , 0 ) ;
6795+
6796+ let pool_key = accounts. pool_token_key ;
6797+ let mut pool_account = accounts. pool_token_account . clone ( ) ;
6798+
6799+ // WithdrawAllTokenTypes will not take all token A and B, since their
6800+ // ratio is unbalanced. It will try to take 1_500_000_000 worth of
6801+ // each token, which means 1_500_000_000 token A, and 750 token B.
6802+ // With no slippage, this will leave 250 token B in the pool.
6803+ assert_eq ! (
6804+ Err ( SwapError :: ExceededSlippage . into( ) ) ,
6805+ accounts. withdraw_all_token_types(
6806+ & user_key,
6807+ & pool_key,
6808+ & mut pool_account,
6809+ & token_a_key,
6810+ & mut token_a_account,
6811+ & token_b_key,
6812+ & mut token_b_account,
6813+ total_pool. try_into( ) . unwrap( ) ,
6814+ swap_token_a_amount,
6815+ swap_token_b_amount,
6816+ )
6817+ ) ;
6818+
6819+ accounts
6820+ . withdraw_all_token_types (
6821+ & user_key,
6822+ & pool_key,
6823+ & mut pool_account,
6824+ & token_a_key,
6825+ & mut token_a_account,
6826+ & token_b_key,
6827+ & mut token_b_account,
6828+ total_pool. try_into ( ) . unwrap ( ) ,
6829+ 0 ,
6830+ 0 ,
6831+ )
6832+ . unwrap ( ) ;
6833+
6834+ let token_a = spl_token:: state:: Account :: unpack ( & token_a_account. data ) . unwrap ( ) ;
6835+ assert_eq ! ( token_a. amount, swap_token_a_amount) ;
6836+ let token_b = spl_token:: state:: Account :: unpack ( & token_b_account. data ) . unwrap ( ) ;
6837+ assert_eq ! ( token_b. amount, 750 ) ;
6838+ let swap_token_a =
6839+ spl_token:: state:: Account :: unpack ( & accounts. token_a_account . data ) . unwrap ( ) ;
6840+ assert_eq ! ( swap_token_a. amount, 0 ) ;
6841+ let swap_token_b =
6842+ spl_token:: state:: Account :: unpack ( & accounts. token_b_account . data ) . unwrap ( ) ;
6843+ assert_eq ! ( swap_token_b. amount, 250 ) ;
6844+
6845+ // deposit now, not enough to cover the tokens already in there
6846+ let token_b_amount = 10 ;
6847+ let token_a_amount = token_b_amount * token_b_price;
6848+ let (
6849+ token_a_key,
6850+ mut token_a_account,
6851+ token_b_key,
6852+ mut token_b_account,
6853+ pool_key,
6854+ mut pool_account,
6855+ ) = accounts. setup_token_accounts (
6856+ & user_key,
6857+ & withdrawer_key,
6858+ token_a_amount,
6859+ token_b_amount,
6860+ 0 ,
6861+ ) ;
6862+
6863+ assert_eq ! (
6864+ Err ( SwapError :: ExceededSlippage . into( ) ) ,
6865+ accounts. deposit_all_token_types(
6866+ & withdrawer_key,
6867+ & token_a_key,
6868+ & mut token_a_account,
6869+ & token_b_key,
6870+ & mut token_b_account,
6871+ & pool_key,
6872+ & mut pool_account,
6873+ 1 , // doesn't matter
6874+ token_a_amount,
6875+ token_b_amount,
6876+ )
6877+ ) ;
6878+
6879+ // deposit enough tokens, success!
6880+ let token_b_amount = 125 ;
6881+ let token_a_amount = token_b_amount * token_b_price;
6882+ let (
6883+ token_a_key,
6884+ mut token_a_account,
6885+ token_b_key,
6886+ mut token_b_account,
6887+ pool_key,
6888+ mut pool_account,
6889+ ) = accounts. setup_token_accounts (
6890+ & user_key,
6891+ & withdrawer_key,
6892+ token_a_amount,
6893+ token_b_amount,
6894+ 0 ,
6895+ ) ;
6896+
6897+ accounts
6898+ . deposit_all_token_types (
6899+ & withdrawer_key,
6900+ & token_a_key,
6901+ & mut token_a_account,
6902+ & token_b_key,
6903+ & mut token_b_account,
6904+ & pool_key,
6905+ & mut pool_account,
6906+ 1 , // doesn't matter
6907+ token_a_amount,
6908+ token_b_amount,
6909+ )
6910+ . unwrap ( ) ;
6911+ }
67336912}
0 commit comments