1- from typing import Union
1+ from typing import Any , Dict , Union
22
33import torch
44from neuronx_distributed .operators .argmax import argmax as nxd_argmax
55from neuronx_distributed .operators .topk import topk as nxd_topk
66from neuronx_distributed .parallel_layers import parallel_state
77from torch_neuronx .xla_impl .ops import xla_hlo_call
88
9- from neuronx_distributed_inference .models .config import NeuronConfig
9+ from neuronx_distributed_inference .models .config import NeuronConfig , OnDeviceSamplingConfig
1010
1111
1212@xla_hlo_call
@@ -18,6 +18,62 @@ def rand_like(tensor):
1818 return dtype [shape ].Rng (minimum , maximum , distribution = 1 ) # Uniform distribution
1919
2020
21+ def validate_sampling_params (
22+ params : torch .Tensor , on_device_sampling_config : Union [Dict [str , Any ], OnDeviceSamplingConfig ]
23+ ) -> None :
24+ """
25+ Validates sampling parameters for language models.
26+
27+ Args:
28+ params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters
29+ in the order: top-k, top-p, temperature.
30+ on_device_sampling_config
31+
32+ Raises:
33+ ValueError: If any of the parameters are invalid.
34+ """
35+ if params .shape [1 ] != 3 :
36+ raise ValueError (f"Expected tensor of shape (batch_size, 3), but got { params .shape } " )
37+
38+ # autocast params tensor to float32
39+ params = params .to (torch .float32 )
40+
41+ # Unpack parameters
42+ top_k , top_p , temperature = params [:, 0 ], params [:, 1 ], params [:, 2 ]
43+
44+ if isinstance (on_device_sampling_config , OnDeviceSamplingConfig ):
45+ global_top_k = on_device_sampling_config .global_topk
46+ else :
47+ global_top_k = on_device_sampling_config ["global_topk" ]
48+
49+ # Validate top-k value range
50+ valid_top_k = (top_k == - 1 ) | ((top_k > 0 ) & (top_k <= global_top_k ))
51+ if not torch .all (valid_top_k ):
52+ raise ValueError (
53+ f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to { global_top_k = } . Found { top_k = } ."
54+ )
55+
56+ # checks if top-k values can be represented as integers
57+ if not torch .equal (top_k , top_k .floor ()):
58+ raise ValueError (
59+ f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found { top_k = } ."
60+ )
61+
62+ # Validate top-p
63+ valid_top_p = (top_p > 0.0 ) & (top_p <= 1.0 )
64+ if not torch .all (valid_top_p ):
65+ raise ValueError (
66+ f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found { top_p = } ."
67+ )
68+
69+ # Validate temperature
70+ valid_temp = temperature > 0.0
71+ if not torch .all (valid_temp ):
72+ raise ValueError (
73+ f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found { temperature = } ."
74+ )
75+
76+
2177def prepare_sampling_params (batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ]):
2278 top_k = prepare_tensor (top_k )
2379 top_p = prepare_tensor (top_p )
0 commit comments