From 4285ddbd7bd45a01655e3498e29a158f36e8e8f8 Mon Sep 17 00:00:00 2001 From: xtinkt Date: Tue, 2 Jul 2024 16:54:50 +0000 Subject: [PATCH 1/9] fix --- src/petals/client/inference_session.py | 21 ++++++++++++--- src/petals/server/.handler.py.swp | Bin 0 -> 20480 bytes src/petals/server/block_functions.py | 5 ++++ src/petals/server/handler.py | 1 + tests/test_speculative_generation.py | 35 +++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 src/petals/server/.handler.py.swp create mode 100644 tests/test_speculative_generation.py diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 0938df207..7ccb28df2 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -84,7 +84,8 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ break # this message means "done sending" def step( - self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str + self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, + step_id: str, last_validated_position: int ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -94,6 +95,12 @@ def step( if self.closed: raise Exception("Session is closed, cannot perform step") + if last_validated_position is not None: + assert last_validated_position <= self._position + self._position = last_validated_position + if self.history is not None and self.history.shape[1] >= last_validated_position: + self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None + n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs @@ -115,6 +122,8 @@ def step( request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: request_metadata.update(self.session_metadata) + if last_validated_position is not None: + request_metadata["last_validated_position"] = last_validated_position elif self.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: @@ -257,8 +266,13 @@ def __enter__(self) -> "InferenceSession": return self def step( - self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None + self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None ) -> torch.Tensor: + + if last_validated_position is not None: + self._position = last_validated_position + assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -303,7 +317,8 @@ def step( server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id + inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, + step_id=step_id, last_validated_position=last_validated_position ) server_idx += 1 diff --git a/src/petals/server/.handler.py.swp b/src/petals/server/.handler.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..9c64ed739095b54ac3798b4ff034083f8ef61941 GIT binary patch literal 20480 zcmeHOYltOB74G=T!$&0XhxwtY+@RaD-tF1lh=~_QGugeH9kV;x?93*b$FA;b>BOi z_)n!9zUkZbI8}A(oKtm9)odKTbLljDvvziVdgL&by}olDyZWkg&snW4&~l*VK+A!a11$$y4zwI-InZ*T<-q?B z2lDhX=e-!t3-pje{l28}epWwAI{n^8{)_bY3H|-1Z62TVWy zq4N#>4>$7vuJcdn`af&rKcn-d^Mf+i=>Z%1x?1~cInZ*TeZvvkM9su@$0Js9U9Qfxe9OrT1 zJHR)9&jSwt9{|R{+kiI#uLCXzE(N~*a>w}~umKzZo_U$$JPJGlJPh0e900!iQpZVv zw*U*kwZLnE*8q=SfwI6?fv*62z`KD9z!+EsZU$}uI>2=R1Ac$G<9rCXA2&1@q ze&AN%YTzp17YLSo3Ah_L1Y8DO3ViQ{j`JXJ8u%OH6ORIq0AB>|2krr04g3*vMf3Ij zdhmi-PdH+NMOmn`Qa`2)0s2B3nEKIU658uTz6eF(LN%g{l3lS!ze03Gi!SK|G97s$T zvhwTnG=Hqhjg)%vFdp%Y#{=$9g7bdFTEF$i21cNOYmo8h)f>=PgCOo@ErB!;5WL7Va*;GmdR87<*XrDY=Gj2$L$ZIXIQF)C`+ zf z5p&DAngda15J6#Vl{l8fF&~h1Q4MD9R0=Jya)@Suhxpwf6V9AlIf;*>@}s|e++XRh ztSsHWjI<*%ZMlD!m3CM+->2<@+mDNhBzG1sf6XpgUq$c#6GZt&lv7zt2_K`i$-qGYhe!bzIIMwQxLs{r$bqFz()O{H0S z4#AnW3A2dG=KWC-%Lx{SW(ccGo0%*fz;x*~u}f@|Nk*os6fJHun@}dzke*hbTBTu- z2R`(&801C9%MKkoxwy3KAMc-9I=!^oKig%{4*OGyW|x(!<7k$eJI8Vc#@OLbutMoM zKr=kR_onHaNF$hOjD>ojvQ+096={O4YSo<2EUi~$(?~E0QZUu$=lP(>6KoHHn4ixZ zRizl3Dz>=R2ZGJ3M;q8-z$Zy;${+8aSUh)X)jz%XcK`Ist^Uf=-TemrNkUs|DOeCi zt>9)KZ0JlWx}z3GR1D!Ryo)q18;pJIhXn>A6*|Utlw((oS+6CGq98{(l{VxTBT>th zr1gBL%4aywU>lWvIeG@2%)vkPjnSS#%PXdXx>M33*EnZ}=qkR*IVyLFT6VOPBtKwDDh^F5?f$wZt^8Y^tKQ)67Oa8w(AO8q^`ws9I`0IZLehxead>tU){$9Wbdcf7d zi-0HK%YPEM75EMO_n!b?0UiP-Km@!BcpCouqrihe3Tyz2z|Fu<;iG>DxF5I%cnR=J z_~l;!jsh-l0Jsu(4F350flc5^^q>6l#{e43eL(xP9B4Vva-iiv%YpwM4lw;>9IJ9{ zyDVP+Ji?dxQanl!Y1CV>A#K!?g{cjYH!LQpaIu%h<^ZurxoxUwQ>s^_*miSQIRiwd zihqGg;|?VA@qk5K9tli|-;gO~Y+QY+uxG_~sab?Ft=xKC+RA`CCwQ(0=n)}=SlR1kVl^U@;g|u@yULbC5C?BK zZ68cJFq?JoFfF~O|W~`&CgTO<>vuM^?ih-N)y{NSaxTYm7$_E%Jd}+W)HHd1E4~L0S@w zaK%}r925L!DS3`8*ubd-0zinA57qQCOr4Oyk3o?kwu?x65)3KQK|PmSU>OQG@++H% zW2*w|y%~BZ12cLz-D77;4`hzQwmMqSu-@gd(R zHlcMFbd^1*k4&=_heu$pOhv?#VgGK${i$hsG(ALxVt5G+iaOP%35q&(Hcj0hOug(} zK2C;v%9x0S3nTn5vwaaN{I4TpPk2;RtE;!jz(%tjj>&+|I$ z4HV^b^(=U0{@mK3H}!OgY>JAM1Q3Yl12y|pVw&A$tGYZ53I|}d=?__D~|N&tzgyCIK;Fi&Mz!&Q-M z1PBqR7-V4zYh#@hDwDE<|6DgVQLG(aubcQFLfDT2Y|^5;2c?_u9<=FonGQcJn6p-L z=XKK+^Q|h-jWx4U=%q>@W~p_l@~BHmp0VVR3H?*7OaC8kzjL!b z|39HON?(TW|1n@4@PNO6Mu4*vW@z-izVK)(IMz=wea;0oX|#4NrE+y&eT zJcV=h_W>t>K5z~26wcIt4g3%o05<}!0se&Z^>+ie0apOO#hLm?fp-Ac051f-1%7-Q z*a5bIWq^3J0@x2p55<*N+V zH~Azpma0+_^JKGjbmy=Pr-=4r^^|8XSYqVcUJeX_Y0GJH=sgK`H+cPoldb~zaY85X zZW`>;I$SM#c3oJX8P%takzHq>HqCyd-V;H4Ijr>Ruo))TX_Aubs2n<5SYYxz!lT2C zO7rX*a${9Cp_DWeUYev^J>aED)~=ZT&spe23RwhvFl;rafr*3k=%&u4&rx^ z(>((^L&7gz`FVBFI=7kK!3_k^x28y8W*ba3M@G|BWI;NGF|*Vy2ZZxrd5N*YJqH0|}d&a>UB<(+XDah!b>mda2QV5N?gN*_RW45vJpT}8}NWpymy zyz@bwv2UShpH5#!pDL#%6wLGnT#z zxgTjE#xxg2yag*sw+y7@v!&xO!xfKhTMT`wccFIChuW)W8;LN~bna|>C=aZO1=yG- z&u8K<>Ne?Fx7C39XC5hj)s-=`PU!iy>|uSH>p~+1+?sMj9Bfkbl^MLwu%X7B`aIDd z+^W;nA5utYyG+M1a@CeP;F+Vd+7%j%)%1YgLmp{`Y_i{0+gg=b653c+=4^E!+c!t{ z9@8d0KPD1o+$jJ9C6@QE*iINl(m4bJHp>P_AubR_mOZ3S%k6Jjgi_3N@fougi2vOg ztFzxIi#=QA{-P#Ij^Po#>qMrljJ3MlLB`(-FV$A3((s%0P|hGkaaIFe_~OQXz>~Di zhuKFJafr?+OD|P@62yo^42`DmW89TY-&`Fh(p?5RSya`t=}oKZpY36FY8ML-j!FdY zl-2`XApfq5@Sh}Ht$oDKQc4~An)N7;ZPvOYQk-RSSDjaj1;mQTyy?$+Gfl#D{X*Us z9YV$6w?z>-M=~_GnI%O$ls7+xHw?u9*KICNjH=dZ;UcC>ElAzHjq?<76736MIlzq( zhFhz!?bJ9e5hB>;lqAozZ$0Kzo z)uy|bCjK_V^qN*-f(Jn?*SNY>GzA*Q(j?QHuA^3$LKfgi=VUC( IzAEVa8%W{PfB*mh literal 0 HcmV?d00001 diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index a79f05c94..6636f00b9 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -160,6 +160,11 @@ async def iterate_rpc_inference( point_per_piece = points / max_length if max_length > 0 else 0.0 async for request, step_metadata in input_iterator: + if "last_validated_position" in step_metadata: + last_validated_position = min(step_metadata["last_validated_position"], prefix_length) + assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}" + prefix_length = last_validated_position + flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) if args_structure is not None: # TODO: kwargs currently is unused, it can be used later for peft-like adaptation diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 246565643..6b61f1ee2 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -150,6 +150,7 @@ async def rpc_inference( max_length = metadata.get("max_length") points = metadata.get("points", 0) session_id = metadata.get("session_id") + last_validated_position = metadata.get("last_validated_position", None) alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") if not requested_uids: diff --git a/tests/test_speculative_generation.py b/tests/test_speculative_generation.py new file mode 100644 index 000000000..9c7f4a075 --- /dev/null +++ b/tests/test_speculative_generation.py @@ -0,0 +1,35 @@ +import random + +import pytest +import torch + +from petals import AutoDistributedConfig, RemoteSequential +from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS +from petals.server.from_pretrained import load_pretrained_block +from test_utils import * + + +@pytest.mark.forked +def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3): + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_sequential = RemoteSequential(config) + + block_index = random.randint(0, config.num_hidden_layers - 1) + remote_block = remote_sequential[block_index] + + inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) + short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) + short_inputs[:, :2, :] = inputs[:, :2, :] + + initial_outputs_inference = None + secondary_outputs_inference = None + with torch.inference_mode(): + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + initial_outputs_inference = sess.step(inputs) + secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2) + result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) + + ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) + (outputs_local,) = ref_block(short_inputs) + + assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference) From b6801af79d920aff0fe43438a0a6282e221e4ca2 Mon Sep 17 00:00:00 2001 From: xtinkt Date: Tue, 2 Jul 2024 17:00:21 +0000 Subject: [PATCH 2/9] fix --- src/petals/server/handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 6b61f1ee2..246565643 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -150,7 +150,6 @@ async def rpc_inference( max_length = metadata.get("max_length") points = metadata.get("points", 0) session_id = metadata.get("session_id") - last_validated_position = metadata.get("last_validated_position", None) alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") if not requested_uids: From de2c38ab1bba7255dd493a6319e76bad6e89f900 Mon Sep 17 00:00:00 2001 From: xtinkt Date: Tue, 2 Jul 2024 17:13:59 +0000 Subject: [PATCH 3/9] fix --- src/petals/server/block_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 6636f00b9..c88c8e5a6 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -161,7 +161,6 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: if "last_validated_position" in step_metadata: - last_validated_position = min(step_metadata["last_validated_position"], prefix_length) assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}" prefix_length = last_validated_position From 10de34b72a91c771b439bb0f597abab59aa5f12c Mon Sep 17 00:00:00 2001 From: xtinkt Date: Tue, 2 Jul 2024 17:24:38 +0000 Subject: [PATCH 4/9] fix --- src/petals/server/block_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index c88c8e5a6..6b329ed50 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -161,6 +161,7 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: if "last_validated_position" in step_metadata: + last_validated_position = step_metadata["last_validated_position"] assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}" prefix_length = last_validated_position From 7565ddee81f0c6370aa0c92ebac8b67e64f556f3 Mon Sep 17 00:00:00 2001 From: xtinkt Date: Fri, 5 Jul 2024 14:37:19 +0000 Subject: [PATCH 5/9] fix --- src/petals/server/.handler.py.swp | Bin 20480 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/petals/server/.handler.py.swp diff --git a/src/petals/server/.handler.py.swp b/src/petals/server/.handler.py.swp deleted file mode 100644 index 9c64ed739095b54ac3798b4ff034083f8ef61941..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeHOYltOB74G=T!$&0XhxwtY+@RaD-tF1lh=~_QGugeH9kV;x?93*b$FA;b>BOi z_)n!9zUkZbI8}A(oKtm9)odKTbLljDvvziVdgL&by}olDyZWkg&snW4&~l*VK+A!a11$$y4zwI-InZ*T<-q?B z2lDhX=e-!t3-pje{l28}epWwAI{n^8{)_bY3H|-1Z62TVWy zq4N#>4>$7vuJcdn`af&rKcn-d^Mf+i=>Z%1x?1~cInZ*TeZvvkM9su@$0Js9U9Qfxe9OrT1 zJHR)9&jSwt9{|R{+kiI#uLCXzE(N~*a>w}~umKzZo_U$$JPJGlJPh0e900!iQpZVv zw*U*kwZLnE*8q=SfwI6?fv*62z`KD9z!+EsZU$}uI>2=R1Ac$G<9rCXA2&1@q ze&AN%YTzp17YLSo3Ah_L1Y8DO3ViQ{j`JXJ8u%OH6ORIq0AB>|2krr04g3*vMf3Ij zdhmi-PdH+NMOmn`Qa`2)0s2B3nEKIU658uTz6eF(LN%g{l3lS!ze03Gi!SK|G97s$T zvhwTnG=Hqhjg)%vFdp%Y#{=$9g7bdFTEF$i21cNOYmo8h)f>=PgCOo@ErB!;5WL7Va*;GmdR87<*XrDY=Gj2$L$ZIXIQF)C`+ zf z5p&DAngda15J6#Vl{l8fF&~h1Q4MD9R0=Jya)@Suhxpwf6V9AlIf;*>@}s|e++XRh ztSsHWjI<*%ZMlD!m3CM+->2<@+mDNhBzG1sf6XpgUq$c#6GZt&lv7zt2_K`i$-qGYhe!bzIIMwQxLs{r$bqFz()O{H0S z4#AnW3A2dG=KWC-%Lx{SW(ccGo0%*fz;x*~u}f@|Nk*os6fJHun@}dzke*hbTBTu- z2R`(&801C9%MKkoxwy3KAMc-9I=!^oKig%{4*OGyW|x(!<7k$eJI8Vc#@OLbutMoM zKr=kR_onHaNF$hOjD>ojvQ+096={O4YSo<2EUi~$(?~E0QZUu$=lP(>6KoHHn4ixZ zRizl3Dz>=R2ZGJ3M;q8-z$Zy;${+8aSUh)X)jz%XcK`Ist^Uf=-TemrNkUs|DOeCi zt>9)KZ0JlWx}z3GR1D!Ryo)q18;pJIhXn>A6*|Utlw((oS+6CGq98{(l{VxTBT>th zr1gBL%4aywU>lWvIeG@2%)vkPjnSS#%PXdXx>M33*EnZ}=qkR*IVyLFT6VOPBtKwDDh^F5?f$wZt^8Y^tKQ)67Oa8w(AO8q^`ws9I`0IZLehxead>tU){$9Wbdcf7d zi-0HK%YPEM75EMO_n!b?0UiP-Km@!BcpCouqrihe3Tyz2z|Fu<;iG>DxF5I%cnR=J z_~l;!jsh-l0Jsu(4F350flc5^^q>6l#{e43eL(xP9B4Vva-iiv%YpwM4lw;>9IJ9{ zyDVP+Ji?dxQanl!Y1CV>A#K!?g{cjYH!LQpaIu%h<^ZurxoxUwQ>s^_*miSQIRiwd zihqGg;|?VA@qk5K9tli|-;gO~Y+QY+uxG_~sab?Ft=xKC+RA`CCwQ(0=n)}=SlR1kVl^U@;g|u@yULbC5C?BK zZ68cJFq?JoFfF~O|W~`&CgTO<>vuM^?ih-N)y{NSaxTYm7$_E%Jd}+W)HHd1E4~L0S@w zaK%}r925L!DS3`8*ubd-0zinA57qQCOr4Oyk3o?kwu?x65)3KQK|PmSU>OQG@++H% zW2*w|y%~BZ12cLz-D77;4`hzQwmMqSu-@gd(R zHlcMFbd^1*k4&=_heu$pOhv?#VgGK${i$hsG(ALxVt5G+iaOP%35q&(Hcj0hOug(} zK2C;v%9x0S3nTn5vwaaN{I4TpPk2;RtE;!jz(%tjj>&+|I$ z4HV^b^(=U0{@mK3H}!OgY>JAM1Q3Yl12y|pVw&A$tGYZ53I|}d=?__D~|N&tzgyCIK;Fi&Mz!&Q-M z1PBqR7-V4zYh#@hDwDE<|6DgVQLG(aubcQFLfDT2Y|^5;2c?_u9<=FonGQcJn6p-L z=XKK+^Q|h-jWx4U=%q>@W~p_l@~BHmp0VVR3H?*7OaC8kzjL!b z|39HON?(TW|1n@4@PNO6Mu4*vW@z-izVK)(IMz=wea;0oX|#4NrE+y&eT zJcV=h_W>t>K5z~26wcIt4g3%o05<}!0se&Z^>+ie0apOO#hLm?fp-Ac051f-1%7-Q z*a5bIWq^3J0@x2p55<*N+V zH~Azpma0+_^JKGjbmy=Pr-=4r^^|8XSYqVcUJeX_Y0GJH=sgK`H+cPoldb~zaY85X zZW`>;I$SM#c3oJX8P%takzHq>HqCyd-V;H4Ijr>Ruo))TX_Aubs2n<5SYYxz!lT2C zO7rX*a${9Cp_DWeUYev^J>aED)~=ZT&spe23RwhvFl;rafr*3k=%&u4&rx^ z(>((^L&7gz`FVBFI=7kK!3_k^x28y8W*ba3M@G|BWI;NGF|*Vy2ZZxrd5N*YJqH0|}d&a>UB<(+XDah!b>mda2QV5N?gN*_RW45vJpT}8}NWpymy zyz@bwv2UShpH5#!pDL#%6wLGnT#z zxgTjE#xxg2yag*sw+y7@v!&xO!xfKhTMT`wccFIChuW)W8;LN~bna|>C=aZO1=yG- z&u8K<>Ne?Fx7C39XC5hj)s-=`PU!iy>|uSH>p~+1+?sMj9Bfkbl^MLwu%X7B`aIDd z+^W;nA5utYyG+M1a@CeP;F+Vd+7%j%)%1YgLmp{`Y_i{0+gg=b653c+=4^E!+c!t{ z9@8d0KPD1o+$jJ9C6@QE*iINl(m4bJHp>P_AubR_mOZ3S%k6Jjgi_3N@fougi2vOg ztFzxIi#=QA{-P#Ij^Po#>qMrljJ3MlLB`(-FV$A3((s%0P|hGkaaIFe_~OQXz>~Di zhuKFJafr?+OD|P@62yo^42`DmW89TY-&`Fh(p?5RSya`t=}oKZpY36FY8ML-j!FdY zl-2`XApfq5@Sh}Ht$oDKQc4~An)N7;ZPvOYQk-RSSDjaj1;mQTyy?$+Gfl#D{X*Us z9YV$6w?z>-M=~_GnI%O$ls7+xHw?u9*KICNjH=dZ;UcC>ElAzHjq?<76736MIlzq( zhFhz!?bJ9e5hB>;lqAozZ$0Kzo z)uy|bCjK_V^qN*-f(Jn?*SNY>GzA*Q(j?QHuA^3$LKfgi=VUC( IzAEVa8%W{PfB*mh From 269028d0e67ff3a0475753e334a1f4c89ea823f3 Mon Sep 17 00:00:00 2001 From: xtinkt Date: Fri, 5 Jul 2024 14:43:09 +0000 Subject: [PATCH 6/9] fix --- src/petals/client/inference_session.py | 24 ++++++++++++------------ src/petals/server/block_functions.py | 8 ++++---- tests/test_speculative_generation.py | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 7ccb28df2..c52df709a 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -85,7 +85,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ def step( self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, - step_id: str, last_validated_position: int + step_id: str, start_from_position: int ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -95,11 +95,11 @@ def step( if self.closed: raise Exception("Session is closed, cannot perform step") - if last_validated_position is not None: - assert last_validated_position <= self._position - self._position = last_validated_position - if self.history is not None and self.history.shape[1] >= last_validated_position: - self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None + if start_from_position is not None: + assert start_from_position <= self._position + self._position = start_from_position + if self.history is not None and self.history.shape[1] >= start_from_position: + self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None n_input_tokens = inputs.shape[1] if self.history is None: @@ -122,8 +122,8 @@ def step( request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: request_metadata.update(self.session_metadata) - if last_validated_position is not None: - request_metadata["last_validated_position"] = last_validated_position + if start_from_position is not None: + request_metadata["start_from_position"] = start_from_position elif self.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: @@ -267,11 +267,11 @@ def __enter__(self) -> "InferenceSession": def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None + hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None ) -> torch.Tensor: - if last_validated_position is not None: - self._position = last_validated_position + if start_from_position is not None: + self._position = start_from_position assert not self._closed if torch.is_grad_enabled(): @@ -318,7 +318,7 @@ def step( server_session = self._server_sessions[server_idx] inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, - step_id=step_id, last_validated_position=last_validated_position + step_id=step_id, start_from_position=start_from_position ) server_idx += 1 diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 6b329ed50..3127d6812 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -160,10 +160,10 @@ async def iterate_rpc_inference( point_per_piece = points / max_length if max_length > 0 else 0.0 async for request, step_metadata in input_iterator: - if "last_validated_position" in step_metadata: - last_validated_position = step_metadata["last_validated_position"] - assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}" - prefix_length = last_validated_position + if "start_from_position" in step_metadata: + start_from_position = step_metadata["start_from_position"] + assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}" + prefix_length = start_from_position flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) if args_structure is not None: diff --git a/tests/test_speculative_generation.py b/tests/test_speculative_generation.py index 9c7f4a075..e3045dea3 100644 --- a/tests/test_speculative_generation.py +++ b/tests/test_speculative_generation.py @@ -26,7 +26,7 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato with torch.inference_mode(): with remote_block.inference_session(max_length=inputs.shape[1]) as sess: initial_outputs_inference = sess.step(inputs) - secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2) + secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2) result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) From 9aecb3f39e208feb8287d6886d44c6f18f467a3f Mon Sep 17 00:00:00 2001 From: xtinkt Date: Fri, 5 Jul 2024 14:54:26 +0000 Subject: [PATCH 7/9] style --- src/petals/client/inference_session.py | 23 +++++++++++++++++------ src/petals/server/block_functions.py | 4 +++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index c52df709a..4d94e7a76 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ break # this message means "done sending" def step( - self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, - step_id: str, start_from_position: int + self, + inputs: torch.Tensor, + prompts: torch.Tensor, + hypo_ids: torch.LongTensor, + *, + step_id: str, + start_from_position: int, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession": return self def step( - self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, + start_from_position: Optional[int] = None, ) -> torch.Tensor: if start_from_position is not None: @@ -317,8 +325,11 @@ def step( server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, - step_id=step_id, start_from_position=start_from_position + inputs, + prompts[server_session.span.start : server_session.span.end], + hypo_ids, + step_id=step_id, + start_from_position=start_from_position, ) server_idx += 1 diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 3127d6812..121ec8a0e 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -162,7 +162,9 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: if "start_from_position" in step_metadata: start_from_position = step_metadata["start_from_position"] - assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}" + assert ( + prefix_length >= start_from_position, + ), f"prefix_length={prefix_length}, start_from_position={start_from_position}" prefix_length = start_from_position flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) From a59e38a57807b94e29042ab3cc56792bf5dbca0a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 21 Jul 2024 23:06:50 +0300 Subject: [PATCH 8/9] running inference session with position getter/setter (#594) * Add option to rollback inference for a certain number of steps (#588) * fix * fix * fix * fix * fix * fix * style * test running inference session with position getter/setter * add assertion * fix typo --------- Co-authored-by: Anton Sinitsin <30695750+xtinkt@users.noreply.github.com> --- src/petals/client/inference_session.py | 30 ++++++++++++++++++-------- tests/test_speculative_generation.py | 4 +++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 4d94e7a76..39148bdcc 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -83,6 +83,17 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" + @property + def position(self): + return self._position + + @position.setter + def position(self, start_from_position: int): + assert start_from_position <= self._position + self._position = start_from_position + if self.history is not None and self.history.shape[1] >= start_from_position: + self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None + def step( self, inputs: torch.Tensor, @@ -90,7 +101,6 @@ def step( hypo_ids: torch.LongTensor, *, step_id: str, - start_from_position: int, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -127,8 +137,8 @@ def step( request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: request_metadata.update(self.session_metadata) - if start_from_position is not None: - request_metadata["start_from_position"] = start_from_position + if self._position is not None: + request_metadata["start_from_position"] = self._position elif self.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: @@ -235,6 +245,13 @@ def num_blocks(self) -> int: def position(self) -> int: return self._position + @position.setter + def position(self, start_from_position: int) -> None: + self._position = start_from_position + for session in self._server_sessions: + assert isinstance(session, _ServerInferenceSession) + session.position = start_from_position + def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: server_sessions = [] try: @@ -275,12 +292,7 @@ def step( inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, - start_from_position: Optional[int] = None, ) -> torch.Tensor: - - if start_from_position is not None: - self._position = start_from_position - assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -324,12 +336,12 @@ def step( self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] + assert server_session.position == self.position inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id, - start_from_position=start_from_position, ) server_idx += 1 diff --git a/tests/test_speculative_generation.py b/tests/test_speculative_generation.py index e3045dea3..dcb5be69a 100644 --- a/tests/test_speculative_generation.py +++ b/tests/test_speculative_generation.py @@ -26,7 +26,9 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato with torch.inference_mode(): with remote_block.inference_session(max_length=inputs.shape[1]) as sess: initial_outputs_inference = sess.step(inputs) - secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2) + + sess.position = 2 + secondary_outputs_inference = sess.step(short_inputs[:, 2:, :]) result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) From 24efbeb236c282866d0fb7ca7448cb777c4d9b36 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Sep 2024 14:54:24 +0300 Subject: [PATCH 9/9] black --- src/petals/client/inference_session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index e15096974..6151385e7 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -336,7 +336,9 @@ def step( self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] - assert server_session.position == self.position, f"Position mismatch: {server_session.position} and {self.position}" + assert ( + server_session.position == self.position + ), f"Position mismatch: {server_session.position} and {self.position}" inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end],