1010import id # pylint: disable=redefined-builtin
1111import requests
1212
13- _GITHUB_STEP_SUMMARY = Path (os .getenv ( 'GITHUB_STEP_SUMMARY' ) )
13+ _GITHUB_STEP_SUMMARY = Path (os .environ [ 'GITHUB_STEP_SUMMARY' ] )
1414
1515# The top-level error message that gets rendered.
1616# This message wraps one of the other templates/messages defined below.
135135""" # noqa: S105; not a password
136136
137137
138+ class TrustedPublishingClaims (t .TypedDict ):
139+ sub : str
140+ repository : str
141+ repository_owner : str
142+ repository_owner_id : str
143+ workflow_ref : str
144+ job_workflow_ref : str
145+ ref : str
146+ environment : str
147+
148+
149+ class PullRequestRepoGitHubEventObject (t .TypedDict ):
150+ fork : bool
151+
152+
153+ class PullRequestHeadGitHubEventObject (t .TypedDict ):
154+ repo : PullRequestRepoGitHubEventObject
155+
156+
157+ class PullRequestGitHubEventObject (t .TypedDict ):
158+ head : PullRequestHeadGitHubEventObject
159+
160+
161+ class ThirdPartyPullRequestGitHubEvent (t .TypedDict ):
162+ pull_request : PullRequestGitHubEventObject
163+
164+
165+ class TrustedPublishingAudience (t .TypedDict ):
166+ audience : str
167+
168+
169+ class TrustedPublishingTokenRetrievalError (t .TypedDict ):
170+ code : str
171+ description : str
172+
173+
174+ class TrustedPublishingToken (t .TypedDict ):
175+ message : str
176+ errors : list [TrustedPublishingTokenRetrievalError ]
177+ token : str
178+ success : bool
179+ expires : int
180+
181+
138182def die (msg : str ) -> t .NoReturn :
139183 with _GITHUB_STEP_SUMMARY .open ('a' , encoding = 'utf-8' ) as io :
140184 print (_ERROR_SUMMARY_MESSAGE .format (message = msg ), file = io )
@@ -155,7 +199,7 @@ def warn(msg: str) -> None:
155199 print (f'::warning::Potential workflow misconfiguration: { msg } ' , file = sys .stderr )
156200
157201
158- def debug (msg : str ):
202+ def debug (msg : str ) -> None :
159203 print (f'::debug::{ msg .title ()} ' , file = sys .stderr )
160204
161205
@@ -166,7 +210,7 @@ def get_normalized_input(name: str) -> str | None:
166210 return os .getenv (name .replace ('-' , '_' ))
167211
168212
169- def assert_successful_audience_call (resp : requests .Response , domain : str ):
213+ def assert_successful_audience_call (resp : requests .Response , domain : str ) -> None :
170214 if resp .ok :
171215 return
172216
@@ -194,17 +238,21 @@ def assert_successful_audience_call(resp: requests.Response, domain: str):
194238 )
195239
196240
197- def extract_claims (token : str ) -> dict [ str , object ] :
241+ def extract_claims (token : str ) -> TrustedPublishingClaims :
198242 _ , payload , _ = token .split ('.' , 2 )
199243
200244 # urlsafe_b64decode needs padding; JWT payloads don't contain any.
201245 payload += '=' * (4 - (len (payload ) % 4 ))
202- return json .loads (base64 .urlsafe_b64decode (payload ))
246+
247+ claims : TrustedPublishingClaims = json .loads (
248+ base64 .urlsafe_b64decode (payload ),
249+ )
250+ return claims
203251
204252
205- def render_claims (claims : dict [ str , object ] ) -> str :
253+ def render_claims (claims : TrustedPublishingClaims ) -> str :
206254 def _get (name : str ) -> str : # noqa: WPS430
207- return claims .get (name , 'MISSING' )
255+ return str ( claims .get (name , 'MISSING' ) )
208256
209257 return _RENDERED_CLAIMS .format (
210258 sub = _get ('sub' ),
@@ -218,7 +266,7 @@ def _get(name: str) -> str: # noqa: WPS430
218266 )
219267
220268
221- def warn_on_reusable_workflow (claims : dict [ str , object ] ) -> None :
269+ def warn_on_reusable_workflow (claims : TrustedPublishingClaims ) -> None :
222270 # A reusable workflow is identified by having different values
223271 # for its workflow_ref (the initiating workflow) and job_workflow_ref
224272 # (the reusable workflow).
@@ -228,7 +276,11 @@ def warn_on_reusable_workflow(claims: dict[str, object]) -> None:
228276 if workflow_ref == job_workflow_ref :
229277 return
230278
231- warn (_REUSABLE_WORKFLOW_WARNING .format_map (locals ()))
279+ warn (
280+ _REUSABLE_WORKFLOW_WARNING .format (
281+ workflow_ref = workflow_ref , job_workflow_ref = job_workflow_ref ,
282+ ),
283+ )
232284
233285
234286def event_is_third_party_pr () -> bool :
@@ -243,7 +295,9 @@ def event_is_third_party_pr() -> bool:
243295 return False
244296
245297 try :
246- event = json .loads (Path (event_path ).read_bytes ())
298+ event : ThirdPartyPullRequestGitHubEvent = json .loads (
299+ Path (event_path ).read_bytes (),
300+ )
247301 except json .JSONDecodeError :
248302 debug ('unexpected: GITHUB_EVENT_PATH does not contain valid JSON' )
249303 return False
@@ -254,8 +308,17 @@ def event_is_third_party_pr() -> bool:
254308 return False
255309
256310
311+ def _detect_credential (audience : str , / ) -> str :
312+ token = id .detect_credential (audience = audience )
313+ if token is None :
314+ raise id .IdentityError (
315+ 'Attempted to discover OIDC in broken environment' ,
316+ )
317+ return token
318+
319+
257320repository_url = get_normalized_input ('repository-url' )
258- repository_domain = urlparse (repository_url ).netloc
321+ repository_domain = str ( urlparse (repository_url ).netloc )
259322token_exchange_url = f'https://{ repository_domain } /_/oidc/mint-token'
260323
261324# Indices are expected to support `https://{domain}/_/oidc/audience`,
@@ -264,12 +327,15 @@ def event_is_third_party_pr() -> bool:
264327audience_resp = requests .get (audience_url , timeout = 5 ) # S113 wants a timeout
265328assert_successful_audience_call (audience_resp , repository_domain )
266329
267- oidc_audience = audience_resp .json ()['audience' ]
330+
331+ oidc_audience_resp : TrustedPublishingAudience = audience_resp .json ()
332+ oidc_audience = oidc_audience_resp ['audience' ]
268333
269334debug (f'selected trusted publishing exchange endpoint: { token_exchange_url } ' )
270335
336+
271337try :
272- oidc_token = id . detect_credential ( audience = oidc_audience )
338+ oidc_token = _detect_credential ( oidc_audience )
273339except id .IdentityError as identity_error :
274340 cause_msg_tmpl = (
275341 _TOKEN_RETRIEVAL_FAILED_FORK_PR_MESSAGE
@@ -285,15 +351,17 @@ def event_is_third_party_pr() -> bool:
285351oidc_claims = extract_claims (oidc_token )
286352warn_on_reusable_workflow (oidc_claims )
287353
354+ oidc_token_payload : dict [str , str ] = {'token' : oidc_token }
288355# Now we can do the actual token exchange.
289356mint_token_resp = requests .post (
290357 token_exchange_url ,
291- json = { 'token' : oidc_token } ,
358+ json = oidc_token_payload ,
292359 timeout = 5 , # S113 wants a timeout
293360)
294361
362+
295363try :
296- mint_token_payload = mint_token_resp .json ()
364+ mint_token_payload : TrustedPublishingToken = mint_token_resp .json ()
297365except requests .JSONDecodeError :
298366 # Token exchange failure normally produces a JSON error response, but
299367 # we might have hit a server error instead.
0 commit comments