From 511be84e596edfe3d3cf3ec58b9290cddc73ae1a Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Thu, 5 Jun 2025 15:29:04 +0800 Subject: [PATCH 01/11] feat: add @status-im/trpc-webext --- .envrc | 1 + .gitignore | 2 + .vscode/settings.json | 1 + apps/api/package.json | 6 +- apps/wallet/package.json | 5 +- apps/wallet/src/data/api.ts | 31 ++- apps/wallet/src/data/webext.ts | 15 ++ packages/trpc-webext/.prettierrc | 5 + packages/trpc-webext/eslint.config.mjs | 9 + packages/trpc-webext/package.json | 49 ++++ packages/trpc-webext/src/adapter/index.ts | 263 ++++++++++++++++++++++ packages/trpc-webext/src/index.ts | 2 + packages/trpc-webext/src/link/index.ts | 146 ++++++++++++ packages/trpc-webext/tsconfig.json | 8 + pnpm-lock.yaml | 123 +++++----- pnpm-workspace.yaml | 1 + 16 files changed, 596 insertions(+), 71 deletions(-) create mode 100644 .envrc create mode 100644 apps/wallet/src/data/webext.ts create mode 100644 packages/trpc-webext/.prettierrc create mode 100644 packages/trpc-webext/eslint.config.mjs create mode 100644 packages/trpc-webext/package.json create mode 100644 packages/trpc-webext/src/adapter/index.ts create mode 100644 packages/trpc-webext/src/index.ts create mode 100644 packages/trpc-webext/src/link/index.ts create mode 100644 packages/trpc-webext/tsconfig.json diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..a8a760658 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +layout node diff --git a/.gitignore b/.gitignore index 68b8bba1f..702388202 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,5 @@ web-build/ # Contentlayer .contentlayer + +/.lsp/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 546596f37..38da763fe 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,6 +13,7 @@ "./packages/components", "./packages/status-js", "./packages/wallet", + "./packages/trpc-webext", "./apps/connector", "./apps/portfolio", "./apps/wallet", diff --git a/apps/api/package.json b/apps/api/package.json index 654bb640e..f5f122d95 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -16,9 +16,9 @@ }, "dependencies": { "@status-im/wallet": "workspace:*", - "@trpc/client": "10.45.2", - "@trpc/server": "10.45.2", - "@trpc/next": "10.45.2", + "@trpc/client": "11.3.0", + "@trpc/server": "11.3.0", + "@trpc/next": "11.3.0", "react": "^19.0.0", "react-dom": "^19.0.0", "next": "15.3.0", diff --git a/apps/wallet/package.json b/apps/wallet/package.json index 519915d81..be5ba8f05 100644 --- a/apps/wallet/package.json +++ b/apps/wallet/package.json @@ -36,11 +36,12 @@ "@status-im/components": "workspace:*", "@status-im/icons": "workspace:*", "@status-im/wallet": "workspace:*", + "@status-im/trpc-webext": "workspace:*", "@tanstack/react-query": "^5.66.0", "@tanstack/react-query-devtools": "^5.66.0", "@tanstack/react-router": "^1.109.2", - "@trpc/client": "10.45.2", - "@trpc/server": "10.45.2", + "@trpc/client": "11.3.0", + "@trpc/server": "11.3.0", "@trustwallet/wallet-core": "^4.3.6", "@wxt-dev/storage": "^1.1.0", "@zxcvbn-ts/core": "^3.0.4", diff --git a/apps/wallet/src/data/api.ts b/apps/wallet/src/data/api.ts index 2485a1ec7..997189d05 100644 --- a/apps/wallet/src/data/api.ts +++ b/apps/wallet/src/data/api.ts @@ -1,14 +1,14 @@ // import { Cardano } from '@cardano-sdk/core' // import { SodiumBip32Ed25519 } from '@cardano-sdk/crypto' // import { AddressType, InMemoryKeyAgent } from '@cardano-sdk/key-management' -import { createTRPCProxyClient } from '@trpc/client' +import { createWebExtHandler, webExtensionLink } from '@status-im/trpc-webext' +import { createTRPCClient } from '@trpc/client' import { initTRPC } from '@trpc/server' import superjson from 'superjson' -import { createChromeHandler } from 'trpc-chrome/adapter' +import { browser } from 'wxt/browser' import { z } from 'zod' import * as bitcoin from './bitcoin/bitcoin' -import { chromeLinkWithRetries } from './chromeLink' import * as ethereum from './ethereum/ethereum' import { getKeystore } from './keystore' import * as solana from './solana/solana' @@ -16,12 +16,17 @@ import { getWalletCore, // type WalletCore } from './wallet' +import { runtimePortToClientContextType } from './webext' -const createContext = async () => { +import type { CreateWebExtContextOptions } from '@status-im/trpc-webext/adapter' + +const createContext = async (webextOpts?: CreateWebExtContextOptions) => { const keyStore = await getKeystore() const walletCore = await getWalletCore() return { + ...webextOpts, + contextType: runtimePortToClientContextType(webextOpts?.req), keyStore, walletCore, } @@ -598,8 +603,11 @@ const apiRouter = router({ export type APIRouter = typeof apiRouter export async function createAPI() { - // @ts-expect-error: fixme!: - createChromeHandler({ router: apiRouter, createContext }) + createWebExtHandler({ + router: apiRouter, + createContext, + runtime: browser.runtime, + }) const ctx = await createContext() const api = createCallerFactory(apiRouter)(ctx) @@ -608,8 +616,13 @@ export async function createAPI() { } export function createAPIClient() { - return createTRPCProxyClient({ - links: [chromeLinkWithRetries()], - transformer: superjson, + return createTRPCClient({ + links: [ + webExtensionLink({ + runtime: browser.runtime, + timeoutMS: 30000, + transformer: superjson, + }), + ], }) } diff --git a/apps/wallet/src/data/webext.ts b/apps/wallet/src/data/webext.ts new file mode 100644 index 000000000..03111abfd --- /dev/null +++ b/apps/wallet/src/data/webext.ts @@ -0,0 +1,15 @@ +import type { Runtime } from 'wxt/browser' + +export type TRPCClientContextType = 'POPUP' | 'SIDE_PANEL' | 'PAGE' | 'TAB' + +export function runtimePortToClientContextType( + port?: Runtime.Port, +): TRPCClientContextType | undefined { + const { origin } = globalThis.location + if (!port) return + if (port.sender?.url?.startsWith(`${origin}/sidepanel.html`)) + return 'SIDE_PANEL' + if (port.sender?.url?.startsWith(`${origin}/popup.html`)) return 'POPUP' + if (port.sender?.url?.startsWith(`${origin}/page.html`)) return 'PAGE' + return 'TAB' +} diff --git a/packages/trpc-webext/.prettierrc b/packages/trpc-webext/.prettierrc new file mode 100644 index 000000000..759232e7c --- /dev/null +++ b/packages/trpc-webext/.prettierrc @@ -0,0 +1,5 @@ +{ + "semi": false, + "singleQuote": true, + "arrowParens": "avoid" +} diff --git a/packages/trpc-webext/eslint.config.mjs b/packages/trpc-webext/eslint.config.mjs new file mode 100644 index 000000000..d09afcb25 --- /dev/null +++ b/packages/trpc-webext/eslint.config.mjs @@ -0,0 +1,9 @@ +import configs from '@status-im/eslint-config' + +/** @type {import('eslint').Linter.Config[]} */ +export default [ + ...configs, + { + files: ['**/*.ts', '**/*.mts', '**/*.mjs', '**/*.tsx'], + }, +] diff --git a/packages/trpc-webext/package.json b/packages/trpc-webext/package.json new file mode 100644 index 000000000..c6ae9400d --- /dev/null +++ b/packages/trpc-webext/package.json @@ -0,0 +1,49 @@ +{ + "name": "@status-im/trpc-webext", + "description": "description", + "version": "0.0.1", + "license": "MIT", + "keywords": [ + "trpc", + "extension", + "webext", + "webextension" + ], + "main": "./dist/index.js", + "module": "./dist/index.js", + "types": "./dist/index.d.ts", + "exports": { + ".": { + "types": "./dist/index.d.ts", + "import": "./dist/index.js", + "require": "./dist/index.js" + }, + "./adapter": { + "types": "./dist/adapter/index.d.ts", + "import": "./dist/adapter/index.js", + "require": "./dist/adapter/index.js" + }, + "./link": { + "types": "./dist/link/index.d.ts", + "import": "./dist/link/index.js", + "require": "./dist/link/index.js" + } + }, + "files": [ + "dist" + ], + "scripts": { + "preinstall": "npx only-allow pnpm", + "dev": "tsc -w", + "build": "tsc", + "lint": "eslint src", + "format": "prettier --write ." + }, + "peerDependencies": { + "@trpc/client": "^11.0.0", + "@trpc/server": "^11.0.0" + }, + "devDependencies": { + "@types/webextension-polyfill": "^0.12.3" + } +} diff --git a/packages/trpc-webext/src/adapter/index.ts b/packages/trpc-webext/src/adapter/index.ts new file mode 100644 index 000000000..e5bc95cd1 --- /dev/null +++ b/packages/trpc-webext/src/adapter/index.ts @@ -0,0 +1,263 @@ +import { isObservable } from '@trpc/server/observable' +import { + getErrorShape, + getTRPCErrorFromUnknown, + TRPCError, +} from '@trpc/server/unstable-core-do-not-import' + +import type { Unsubscribable } from '@trpc/server/observable' +import type { + AnyProcedure, + AnyRouter, + BaseHandlerOptions, + ProcedureType, + TRPCClientOutgoingMessage, + TRPCResponseMessage, +} from '@trpc/server/unstable-core-do-not-import' +import type { Runtime } from 'webextension-polyfill' + +export type CreateWebExtContextOptions = { + req: Runtime.Port + res: undefined +} + +export type CreateWebExtHandlerOptions = + BaseHandlerOptions & { + createContext?: ( + opts: CreateWebExtContextOptions, + ) => Promise | unknown + runtime: Runtime.Static + } + +type PortInfo = { + subscriptions: Map +} + +type PortInfos = Map + +const PORTS: PortInfos = new Map() + +type ClientMessage = { + trpc: TRPCClientOutgoingMessage +} + +function onPortMessage( + opts: CreateWebExtHandlerOptions, +) { + const { router, createContext, onError } = opts + const { transformer } = router._def._config + + return async (message: unknown, port: Runtime.Port) => { + if (!(message as ClientMessage)?.trpc) return + const clientMessage = message as ClientMessage + const { trpc } = clientMessage + if (!('id' in trpc) || trpc.id === null || trpc.id === undefined) return + + const portInfo = PORTS.get(port) + if (!portInfo) return + + const { subscriptions } = portInfo + const { id, jsonrpc, method } = trpc + + const sendResponse = ( + response: Omit, + ) => + port.postMessage({ + trpc: { id, jsonrpc, ...response }, + }) + + let params: { path: string; input: unknown } | undefined + let input: unknown + let ctx: unknown + + try { + if (method === 'subscription.stop') { + const subscription = subscriptions.get(id) + if (subscription) { + subscription.unsubscribe() + sendResponse({ + result: { + type: 'stopped', + }, + }) + } + subscriptions.delete(id) + return + } + + params = trpc.params as { path: string; input: unknown } + + input = transformer.input.deserialize(params.input) + + ctx = await createContext?.({ + req: port, + res: undefined, + }) + const caller = router.createCaller(ctx) + + const segments = params.path.split('.') + const procedureFn = segments.reduce( + (acc, segment) => acc[segment], + caller as unknown, + ) as AnyProcedure + + const result = await procedureFn(input) + + if (method !== 'subscription') { + const data = transformer.output.serialize(result) + sendResponse({ + result: { + type: 'data', + data, + }, + }) + return + } + + if (!isObservable(result)) { + throw new TRPCError({ + message: `Subscription ${params.path} did not return an observable`, + code: 'INTERNAL_SERVER_ERROR', + }) + } + + const subscription = result.subscribe({ + next: data => { + sendResponse({ + result: { + type: 'data', + data, + }, + }) + }, + error: cause => { + const error = getTRPCErrorFromUnknown(cause) + + onError?.({ + error, + type: method as ProcedureType, + path: params?.path, + input, + ctx, + req: port, + }) + + sendResponse({ + error: getErrorShape({ + error, + type: method as ProcedureType, + path: params?.path, + input, + ctx, + config: router._def._config, + }), + }) + }, + complete: () => { + sendResponse({ + result: { + type: 'stopped', + }, + }) + }, + }) + + if (subscriptions.has(id)) { + subscription.unsubscribe() + sendResponse({ + result: { + type: 'stopped', + }, + }) + throw new TRPCError({ + message: `Duplicate id ${id}`, + code: 'BAD_REQUEST', + }) + } + + subscriptions.set(id, subscription) + + sendResponse({ + result: { + type: 'started', + }, + }) + return + } catch (cause) { + const error = getTRPCErrorFromUnknown(cause) + + onError?.({ + error, + type: method as ProcedureType, + path: params?.path, + input, + ctx, + req: port, + }) + + sendResponse({ + error: getErrorShape({ + error, + type: method as ProcedureType, + path: params?.path, + input, + ctx, + config: router._def._config, + }), + }) + } + } +} + +function onPortDisconnect( + onMessage: ReturnType>, +) { + return (port: Runtime.Port) => { + const portInfo = PORTS.get(port) + if (!portInfo) return + port.onMessage.removeListener(onMessage) + + const { subscriptions } = portInfo + subscriptions.forEach(sub => sub.unsubscribe()) + + PORTS.delete(port) + } +} + +function onPortConnect( + opts: CreateWebExtHandlerOptions, +) { + return (port: Runtime.Port) => { + const portInfo: PortInfo = { + subscriptions: new Map(), + } + + PORTS.set(port, portInfo) + const onMessage = onPortMessage(opts) + port.onDisconnect.addListener(onPortDisconnect(onMessage)) + port.onMessage.addListener(onMessage) + } +} + +/** + * Creates a tRPC handler for web extension communication + * + * Sets up listeners for port connections and handles tRPC procedure calls + * from various web extension contexts (content scripts, popup, side panel, etc.) + * + * @param opts - Configuration options including router, runtime, and context creator + * + * @example + * =typescript + * createWebExtHandler({ + * router: appRouter, + * runtime: browser.runtime, + * createContext, + * }); + * = + */ +export const createWebExtHandler = ( + opts: CreateWebExtHandlerOptions, +) => { + opts.runtime.onConnect.addListener(onPortConnect(opts)) +} diff --git a/packages/trpc-webext/src/index.ts b/packages/trpc-webext/src/index.ts new file mode 100644 index 000000000..fe19b41d3 --- /dev/null +++ b/packages/trpc-webext/src/index.ts @@ -0,0 +1,2 @@ +export { createWebExtHandler } from './adapter' +export { webExtensionLink } from './link' diff --git a/packages/trpc-webext/src/link/index.ts b/packages/trpc-webext/src/link/index.ts new file mode 100644 index 000000000..4b5c62aca --- /dev/null +++ b/packages/trpc-webext/src/link/index.ts @@ -0,0 +1,146 @@ +import { TRPCClientError } from '@trpc/client' +import { observable } from '@trpc/server/observable' + +import type { Operation, TRPCLink } from '@trpc/client' +import type { AnyTRPCRouter } from '@trpc/server' +import type { Observer } from '@trpc/server/observable' +import type { TRPCResponseMessage } from '@trpc/server/rpc' +import type { DataTransformer } from '@trpc/server/unstable-core-do-not-import' +import type { Runtime } from 'webextension-polyfill' + +export type WebExtensionLinkOptions = { + runtime: Runtime.Static + timeoutMS?: number + transformer: DataTransformer +} + +export type BackgroundMessage = { + trpc: TRPCResponseMessage +} + +let portToBackground: Runtime.Port | null = null + +interface ResultListeners { + [id: number]: { + timestamp: number // used to cleanup listeners + observer: Observer + type: Operation['type'] + } +} + +const resultListeners: ResultListeners = {} + +function connectToBackground({ runtime }: WebExtensionLinkOptions) { + if (!portToBackground) portToBackground = runtime.connect() +} + +function portOnMessageFromBackground( + transformer: WebExtensionLinkOptions['transformer'], +) { + return (message: unknown) => { + if (!(message as BackgroundMessage)?.trpc) return + const backgroundMessage = message as BackgroundMessage + const { trpc } = backgroundMessage + if (!('id' in trpc) || trpc.id === null || trpc.id === undefined) return + if (!(trpc.id in resultListeners)) return + const { observer, type } = resultListeners[trpc.id as number] + + if ('error' in trpc) { + // Check if it's already a SuperJSONResult or needs deserialization + const error = + typeof trpc.error === 'object' && 'json' in trpc.error + ? transformer.deserialize(trpc.error) + : trpc.error + observer.error(TRPCClientError.from({ ...trpc, error })) + return + } + + observer.next({ + result: { + ...trpc.result, + ...((!trpc.result.type || trpc.result.type === 'data') && { + type: 'data', + data: transformer.deserialize(trpc.result.data), + }), + } as unknown, + }) + + if (type !== 'subscription' || trpc.result.type === 'stopped') { + observer.complete() + } + } +} + +let clearListenersIntervalId: ReturnType | undefined + +function clearListenersIntervalFn(timeoutMS: number) { + return () => { + const timedOutAt = new Date().getTime() - timeoutMS + + for (const id in resultListeners) { + if (resultListeners[id].timestamp < timedOutAt) { + delete resultListeners[id] + } + } + } +} + +function setupClearListenersInterval(timeoutMS = 10000) { + if (clearListenersIntervalId) clearInterval(clearListenersIntervalId) + clearListenersIntervalId = setInterval( + clearListenersIntervalFn(timeoutMS), + 1000, + ) +} + +export function webExtensionLink( + opts: WebExtensionLinkOptions, +): TRPCLink { + const { timeoutMS, transformer } = opts + + setupClearListenersInterval(timeoutMS) + + connectToBackground(opts) + const onMessage = portOnMessageFromBackground(transformer) + + const portOnDisconnect = (port: Runtime.Port) => { + port.onDisconnect.removeListener(portOnDisconnect) + port.onMessage.removeListener(onMessage) + portToBackground = null + } + + portToBackground?.onDisconnect.addListener(portOnDisconnect) + portToBackground?.onMessage.addListener(onMessage) + + return () => { + return ({ op }) => { + const { id, type, path } = op + + const input = transformer.serialize(op.input) || op.input + + const trpcPayload = { + id, + jsonrpc: undefined, + method: type, + params: { path, input }, + } + + const postMessagePayload = { + trpc: trpcPayload, + } + + return observable(observer => { + resultListeners[id] = { + observer, + type, + timestamp: new Date().getTime(), + } + portToBackground?.postMessage(postMessagePayload) + + return () => { + delete resultListeners[id] + } + }) + } + } +} diff --git a/packages/trpc-webext/tsconfig.json b/packages/trpc-webext/tsconfig.json new file mode 100644 index 000000000..877fdf8cd --- /dev/null +++ b/packages/trpc-webext/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../tsconfig.base", + "compilerOptions": { + "noEmit": false, + "outDir": "./dist" + }, + "include": ["src"] +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 86aca7186..976d30d72 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -91,14 +91,14 @@ importers: specifier: workspace:* version: link:../../packages/wallet '@trpc/client': - specifier: 10.45.2 - version: 10.45.2(@trpc/server@10.45.2) + specifier: 11.3.0 + version: 11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3) '@trpc/next': - specifier: 10.45.2 - version: 10.45.2(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/react-query@11.1.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3))(@trpc/server@10.45.2)(next@15.3.0(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4))(react-dom@19.1.0(react@19.1.0))(react@19.1.0) + specifier: 11.3.0 + version: 11.3.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3))(@trpc/server@11.3.0(typescript@5.8.3))(next@15.3.0(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4))(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3) '@trpc/server': - specifier: 10.45.2 - version: 10.45.2 + specifier: 11.3.0 + version: 11.3.0(typescript@5.8.3) next: specifier: 15.3.0 version: 15.3.0(@babel/core@7.25.2)(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4) @@ -867,6 +867,9 @@ importers: '@status-im/icons': specifier: workspace:* version: link:../../packages/icons + '@status-im/trpc-webext': + specifier: workspace:* + version: link:../../packages/trpc-webext '@status-im/wallet': specifier: workspace:* version: link:../../packages/wallet @@ -880,11 +883,11 @@ importers: specifier: ^1.109.2 version: 1.120.2(react-dom@19.1.0(react@19.1.0))(react@19.1.0) '@trpc/client': - specifier: 10.45.2 - version: 10.45.2(@trpc/server@10.45.2) + specifier: 11.3.0 + version: 11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3) '@trpc/server': - specifier: 10.45.2 - version: 10.45.2 + specifier: 11.3.0 + version: 11.3.0(typescript@5.8.3) '@trustwallet/wallet-core': specifier: ^4.3.6 version: 4.3.6 @@ -941,7 +944,7 @@ importers: version: 2.2.2 trpc-chrome: specifier: ^1.0.0 - version: 1.0.0(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2) + version: 1.0.0(@trpc/client@11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3))(@trpc/server@11.3.0(typescript@5.8.3)) ts-pattern: specifier: ^5.7.1 version: 5.7.1 @@ -1331,6 +1334,19 @@ importers: specifier: ^9.1.7 version: 9.1.7 + packages/trpc-webext: + dependencies: + '@trpc/client': + specifier: ^11.0.0 + version: 11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3) + '@trpc/server': + specifier: ^11.0.0 + version: 11.3.0(typescript@5.8.3) + devDependencies: + '@types/webextension-polyfill': + specifier: ^0.12.3 + version: 0.12.3 + packages/wallet: dependencies: '@hookform/resolvers': @@ -1386,7 +1402,7 @@ importers: version: link:../icons '@trpc/react-query': specifier: 10.45.2 - version: 10.45.2(@tanstack/react-query@5.75.5(react@18.3.1))(@trpc/client@11.1.0(@trpc/server@10.45.2)(typescript@5.8.3))(@trpc/server@10.45.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + version: 10.45.2(@tanstack/react-query@5.75.5(react@18.3.1))(@trpc/client@11.3.0(@trpc/server@10.45.2)(typescript@5.8.3))(@trpc/server@10.45.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) '@trpc/server': specifier: 10.45.2 version: 10.45.2 @@ -7775,27 +7791,28 @@ packages: resolution: {integrity: sha512-GgdKez5AaQRIm0kFNp7BZnxFQ2F7LZ7g3rOQ/v11oYZR3jhH7JPGM+7hZQjYqFXD/5TK/of7hepu418K2fghvg==} engines: {node: '>=12.4.0'} - '@trpc/client@10.45.2': - resolution: {integrity: sha512-ykALM5kYWTLn1zYuUOZ2cPWlVfrXhc18HzBDyRhoPYN0jey4iQHEFSEowfnhg1RvYnrAVjNBgHNeSAXjrDbGwg==} - peerDependencies: - '@trpc/server': 10.45.2 - - '@trpc/client@11.1.0': - resolution: {integrity: sha512-Q3pL4p7AddxI/ZJTEFo1utKSdasDFjZPECIPsKDkthEt52k530JkYVltTdLkYFKrNWXKKBo8MN7NwchelczoRw==} + '@trpc/client@11.3.0': + resolution: {integrity: sha512-O8fih4Tj+mGqN7bPM1inXJ7/3SFYI3/0D/AKYDc/G6Bq7GNNX+1st0YMMdtFk+Vv3e8K95h6233ZOwRX0qiwOg==} peerDependencies: - '@trpc/server': 11.1.0 + '@trpc/server': 11.3.0 typescript: '>=5.7.2' - '@trpc/next@10.45.2': - resolution: {integrity: sha512-RSORmfC+/nXdmRY1pQ0AalsVgSzwNAFbZLYHiTvPM5QQ8wmMEHilseCYMXpu0se/TbPt9zVR6Ka2d7O6zxKkXg==} + '@trpc/next@11.3.0': + resolution: {integrity: sha512-utpUjUUaP0MR57EBF/QOv4Xriozc/rB671lfCeriaw3MuHGvZgA4czu11UAO++9tkY/wudqvpR60r8UQHDv1ag==} peerDependencies: - '@tanstack/react-query': ^4.18.0 - '@trpc/client': 10.45.2 - '@trpc/react-query': 10.45.2 - '@trpc/server': 10.45.2 + '@tanstack/react-query': ^5.59.15 + '@trpc/client': 11.3.0 + '@trpc/react-query': 11.3.0 + '@trpc/server': 11.3.0 next: '*' react: '>=16.8.0' react-dom: '>=16.8.0' + typescript: '>=5.7.2' + peerDependenciesMeta: + '@tanstack/react-query': + optional: true + '@trpc/react-query': + optional: true '@trpc/react-query@10.45.2': resolution: {integrity: sha512-BAqb9bGZIscroradlNx+Cc9522R+idY3BOSf5z0jHUtkxdMbjeGKxSSMxxu7JzoLqSIEC+LVzL3VvF8sdDWaZQ==} @@ -7806,19 +7823,14 @@ packages: react: '>=16.8.0' react-dom: '>=16.8.0' - '@trpc/react-query@11.1.0': - resolution: {integrity: sha512-qdqKdFM8hVy/YSBCg1/3VO+IgB6Nbul3Fk1SA3lefGf0bkYZdWVVyKab8HBAfOWlMsuRufhVLPdKYmnjzBrK9g==} - peerDependencies: - '@tanstack/react-query': ^5.67.1 - '@trpc/client': 11.1.0 - '@trpc/server': 11.1.0 - react: '>=18.2.0' - react-dom: '>=18.2.0' - typescript: '>=5.7.2' - '@trpc/server@10.45.2': resolution: {integrity: sha512-wOrSThNNE4HUnuhJG6PfDRp4L2009KDVxsd+2VYH8ro6o/7/jwYZ8Uu5j+VaW+mOmc8EHerHzGcdbGNQSAUPgg==} + '@trpc/server@11.3.0': + resolution: {integrity: sha512-E5y94QLxgYr+T5FOWiMqNrCW0d2CJfhxNB7bS9xQzPfPMQeo4dpX8/s3nAxITJ9ARrORhK3Efpqs3/bPI9UOTQ==} + peerDependencies: + typescript: '>=5.7.2' + '@trustwallet/wallet-core@4.3.6': resolution: {integrity: sha512-X+n2CzDhIfUtnQtqqM3Su6XmBdzUijMu8uQEnA9yQWfOv7d33LfaQ9vNbzvpFk4I52K7n6AWLaylM1unuIuTJQ==} @@ -26990,44 +27002,41 @@ snapshots: '@tinyhttp/url@1.3.0': {} - '@trpc/client@10.45.2(@trpc/server@10.45.2)': + '@trpc/client@11.3.0(@trpc/server@10.45.2)(typescript@5.8.3)': dependencies: '@trpc/server': 10.45.2 + typescript: 5.8.3 - '@trpc/client@11.1.0(@trpc/server@10.45.2)(typescript@5.8.3)': + '@trpc/client@11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3)': dependencies: - '@trpc/server': 10.45.2 + '@trpc/server': 11.3.0(typescript@5.8.3) typescript: 5.8.3 - '@trpc/next@10.45.2(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/react-query@11.1.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3))(@trpc/server@10.45.2)(next@15.3.0(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4))(react-dom@19.1.0(react@19.1.0))(react@19.1.0)': + '@trpc/next@11.3.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3))(@trpc/server@11.3.0(typescript@5.8.3))(next@15.3.0(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4))(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3)': dependencies: - '@tanstack/react-query': 5.75.5(react@19.1.0) - '@trpc/client': 10.45.2(@trpc/server@10.45.2) - '@trpc/react-query': 11.1.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3) - '@trpc/server': 10.45.2 + '@trpc/client': 11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3) + '@trpc/server': 11.3.0(typescript@5.8.3) next: 15.3.0(@babel/core@7.25.2)(@opentelemetry/api@1.9.0)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(sass@1.80.4) react: 19.1.0 react-dom: 19.1.0(react@19.1.0) + typescript: 5.8.3 + optionalDependencies: + '@tanstack/react-query': 5.75.5(react@19.1.0) - '@trpc/react-query@10.45.2(@tanstack/react-query@5.75.5(react@18.3.1))(@trpc/client@11.1.0(@trpc/server@10.45.2)(typescript@5.8.3))(@trpc/server@10.45.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + '@trpc/react-query@10.45.2(@tanstack/react-query@5.75.5(react@18.3.1))(@trpc/client@11.3.0(@trpc/server@10.45.2)(typescript@5.8.3))(@trpc/server@10.45.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': dependencies: '@tanstack/react-query': 5.75.5(react@18.3.1) - '@trpc/client': 11.1.0(@trpc/server@10.45.2)(typescript@5.8.3) + '@trpc/client': 11.3.0(@trpc/server@10.45.2)(typescript@5.8.3) '@trpc/server': 10.45.2 react: 18.3.1 react-dom: 18.3.1(react@18.3.1) - '@trpc/react-query@11.1.0(@tanstack/react-query@5.75.5(react@19.1.0))(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)(typescript@5.8.3)': + '@trpc/server@10.45.2': {} + + '@trpc/server@11.3.0(typescript@5.8.3)': dependencies: - '@tanstack/react-query': 5.75.5(react@19.1.0) - '@trpc/client': 10.45.2(@trpc/server@10.45.2) - '@trpc/server': 10.45.2 - react: 19.1.0 - react-dom: 19.1.0(react@19.1.0) typescript: 5.8.3 - '@trpc/server@10.45.2': {} - '@trustwallet/wallet-core@4.3.6': dependencies: protobufjs: 7.2.5 @@ -39543,10 +39552,10 @@ snapshots: trough@2.1.0: {} - trpc-chrome@1.0.0(@trpc/client@10.45.2(@trpc/server@10.45.2))(@trpc/server@10.45.2): + trpc-chrome@1.0.0(@trpc/client@11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3))(@trpc/server@11.3.0(typescript@5.8.3)): dependencies: - '@trpc/client': 10.45.2(@trpc/server@10.45.2) - '@trpc/server': 10.45.2 + '@trpc/client': 11.3.0(@trpc/server@11.3.0(typescript@5.8.3))(typescript@5.8.3) + '@trpc/server': 11.3.0(typescript@5.8.3) truncate-utf8-bytes@1.0.2: dependencies: diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 66ce268b9..a27cc127d 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -5,6 +5,7 @@ packages: - 'packages/icons' - 'packages/components' - 'packages/wallet' + - 'packages/trpc-webext' - 'apps/connector' - 'apps/portfolio' - 'apps/wallet' From d925995a5c318c25e0d241365a58da44b4293aab Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 10:09:47 +0800 Subject: [PATCH 02/11] feat: use trpc's builtin callProcedure --- packages/trpc-webext/src/adapter/index.ts | 33 +++++++++++------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/packages/trpc-webext/src/adapter/index.ts b/packages/trpc-webext/src/adapter/index.ts index e5bc95cd1..0ff59dcb5 100644 --- a/packages/trpc-webext/src/adapter/index.ts +++ b/packages/trpc-webext/src/adapter/index.ts @@ -1,5 +1,6 @@ import { isObservable } from '@trpc/server/observable' import { + callProcedure, getErrorShape, getTRPCErrorFromUnknown, TRPCError, @@ -7,7 +8,6 @@ import { import type { Unsubscribable } from '@trpc/server/observable' import type { - AnyProcedure, AnyRouter, BaseHandlerOptions, ProcedureType, @@ -66,7 +66,6 @@ function onPortMessage( trpc: { id, jsonrpc, ...response }, }) - let params: { path: string; input: unknown } | undefined let input: unknown let ctx: unknown @@ -85,23 +84,21 @@ function onPortMessage( return } - params = trpc.params as { path: string; input: unknown } - - input = transformer.input.deserialize(params.input) + input = transformer.input.deserialize(trpc.params.input) ctx = await createContext?.({ req: port, res: undefined, }) - const caller = router.createCaller(ctx) - - const segments = params.path.split('.') - const procedureFn = segments.reduce( - (acc, segment) => acc[segment], - caller as unknown, - ) as AnyProcedure - const result = await procedureFn(input) + const result = await callProcedure({ + router, + path: trpc.params.path, + getRawInput: async () => input, + ctx, + type: method, + signal: undefined, + }) if (method !== 'subscription') { const data = transformer.output.serialize(result) @@ -116,7 +113,7 @@ function onPortMessage( if (!isObservable(result)) { throw new TRPCError({ - message: `Subscription ${params.path} did not return an observable`, + message: `Subscription ${trpc.params.path} did not return an observable`, code: 'INTERNAL_SERVER_ERROR', }) } @@ -136,7 +133,7 @@ function onPortMessage( onError?.({ error, type: method as ProcedureType, - path: params?.path, + path: trpc.params?.path, input, ctx, req: port, @@ -146,7 +143,7 @@ function onPortMessage( error: getErrorShape({ error, type: method as ProcedureType, - path: params?.path, + path: trpc.params?.path, input, ctx, config: router._def._config, @@ -189,7 +186,7 @@ function onPortMessage( onError?.({ error, type: method as ProcedureType, - path: params?.path, + path: 'params' in trpc ? trpc.params?.path : undefined, input, ctx, req: port, @@ -199,7 +196,7 @@ function onPortMessage( error: getErrorShape({ error, type: method as ProcedureType, - path: params?.path, + path: 'params' in trpc ? trpc.params?.path : undefined, input, ctx, config: router._def._config, From 7432ebc90c53d24651b2d012733743b3b3e9443b Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 10:33:14 +0800 Subject: [PATCH 03/11] feat: refactor to use more trpc types --- packages/trpc-webext/src/adapter/index.ts | 467 ++++++++++++++-------- packages/trpc-webext/src/link/index.ts | 148 ++++--- 2 files changed, 407 insertions(+), 208 deletions(-) diff --git a/packages/trpc-webext/src/adapter/index.ts b/packages/trpc-webext/src/adapter/index.ts index 0ff59dcb5..eabcd9db7 100644 --- a/packages/trpc-webext/src/adapter/index.ts +++ b/packages/trpc-webext/src/adapter/index.ts @@ -16,174 +16,321 @@ import type { } from '@trpc/server/unstable-core-do-not-import' import type { Runtime } from 'webextension-polyfill' -export type CreateWebExtContextOptions = { +export interface CreateWebExtContextOptions { req: Runtime.Port res: undefined } -export type CreateWebExtHandlerOptions = - BaseHandlerOptions & { - createContext?: ( - opts: CreateWebExtContextOptions, - ) => Promise | unknown - runtime: Runtime.Static - } +export interface CreateWebExtHandlerOptions + extends BaseHandlerOptions { + createContext?: ( + opts: CreateWebExtContextOptions, + ) => Promise | unknown + runtime: Runtime.Static +} -type PortInfo = { +interface ClientMessage { + trpc: TRPCClientOutgoingMessage +} + +interface PortSubscriptions { subscriptions: Map } -type PortInfos = Map +interface TRPCResponseSender { + (response: Omit): void +} -const PORTS: PortInfos = new Map() +interface MessageHandlerContext { + router: TRouter + createContext?: CreateWebExtHandlerOptions['createContext'] + onError?: CreateWebExtHandlerOptions['onError'] + transformer: TRouter['_def']['_config']['transformer'] +} -type ClientMessage = { - trpc: TRPCClientOutgoingMessage +function isClientMessage(message: unknown): message is ClientMessage { + return ( + typeof message === 'object' && + message !== null && + 'trpc' in message && + typeof (message as ClientMessage).trpc === 'object' + ) } -function onPortMessage( - opts: CreateWebExtHandlerOptions, -) { - const { router, createContext, onError } = opts - const { transformer } = router._def._config +function hasTRPCId( + trpc: TRPCClientOutgoingMessage, +): trpc is TRPCClientOutgoingMessage & { + id: NonNullable +} { + return 'id' in trpc && trpc.id !== null && trpc.id !== undefined +} - return async (message: unknown, port: Runtime.Port) => { - if (!(message as ClientMessage)?.trpc) return - const clientMessage = message as ClientMessage - const { trpc } = clientMessage - if (!('id' in trpc) || trpc.id === null || trpc.id === undefined) return +function isSubscriptionStop( + trpc: TRPCClientOutgoingMessage, +): trpc is TRPCClientOutgoingMessage & { method: 'subscription.stop' } { + return 'method' in trpc && trpc.method === 'subscription.stop' +} - const portInfo = PORTS.get(port) - if (!portInfo) return +function isSubscriptionMethod(method: string): method is 'subscription' { + return method === 'subscription' +} - const { subscriptions } = portInfo - const { id, jsonrpc, method } = trpc +class PortConnectionManager { + private readonly portSubscriptions = new Map< + Runtime.Port, + PortSubscriptions + >() - const sendResponse = ( - response: Omit, - ) => - port.postMessage({ - trpc: { id, jsonrpc, ...response }, - }) + addPort(port: Runtime.Port): void { + const portInfo: PortSubscriptions = { + subscriptions: new Map(), + } + this.portSubscriptions.set(port, portInfo) + } - let input: unknown - let ctx: unknown + getPortSubscriptions(port: Runtime.Port): PortSubscriptions | undefined { + return this.portSubscriptions.get(port) + } - try { - if (method === 'subscription.stop') { - const subscription = subscriptions.get(id) - if (subscription) { - subscription.unsubscribe() - sendResponse({ - result: { - type: 'stopped', - }, - }) - } - subscriptions.delete(id) - return - } + removePort(port: Runtime.Port): void { + const portInfo = this.portSubscriptions.get(port) + if (portInfo) { + portInfo.subscriptions.forEach(subscription => subscription.unsubscribe()) + this.portSubscriptions.delete(port) + } + } - input = transformer.input.deserialize(trpc.params.input) + addSubscription( + port: Runtime.Port, + id: number | string, + subscription: Unsubscribable, + ): boolean { + const portInfo = this.portSubscriptions.get(port) + if (!portInfo) return false + + if (portInfo.subscriptions.has(id)) { + const existingSub = portInfo.subscriptions.get(id) + existingSub?.unsubscribe() + } - ctx = await createContext?.({ - req: port, - res: undefined, - }) + portInfo.subscriptions.set(id, subscription) + return true + } - const result = await callProcedure({ - router, - path: trpc.params.path, - getRawInput: async () => input, - ctx, - type: method, - signal: undefined, - }) + removeSubscription(port: Runtime.Port, id: number | string): boolean { + const portInfo = this.portSubscriptions.get(port) + if (!portInfo) return false - if (method !== 'subscription') { - const data = transformer.output.serialize(result) - sendResponse({ - result: { - type: 'data', - data, - }, - }) - return - } + const subscription = portInfo.subscriptions.get(id) + if (subscription) { + subscription.unsubscribe() + portInfo.subscriptions.delete(id) + return true + } + return false + } +} - if (!isObservable(result)) { - throw new TRPCError({ - message: `Subscription ${trpc.params.path} did not return an observable`, - code: 'INTERNAL_SERVER_ERROR', - }) - } +function createResponseSender( + port: Runtime.Port, + id: NonNullable, + jsonrpc: TRPCClientOutgoingMessage['jsonrpc'], +): TRPCResponseSender { + return response => { + port.postMessage({ + trpc: { id, jsonrpc, ...response }, + }) + } +} - const subscription = result.subscribe({ - next: data => { - sendResponse({ - result: { - type: 'data', - data, - }, - }) - }, - error: cause => { - const error = getTRPCErrorFromUnknown(cause) - - onError?.({ - error, - type: method as ProcedureType, - path: trpc.params?.path, - input, - ctx, - req: port, - }) - - sendResponse({ - error: getErrorShape({ - error, - type: method as ProcedureType, - path: trpc.params?.path, - input, - ctx, - config: router._def._config, - }), - }) - }, - complete: () => { - sendResponse({ - result: { - type: 'stopped', - }, - }) +async function handleSubscriptionStop( + portManager: PortConnectionManager, + port: Runtime.Port, + id: NonNullable, + sendResponse: TRPCResponseSender, +): Promise { + const removed = portManager.removeSubscription(port, id) + if (removed) { + sendResponse({ + result: { + type: 'stopped', + }, + }) + } +} + +async function handleRegularProcedure( + context: MessageHandlerContext, + trpc: TRPCClientOutgoingMessage, + port: Runtime.Port, + sendResponse: TRPCResponseSender, +): Promise { + if (!('params' in trpc)) { + throw new TRPCError({ + message: 'Missing params in request', + code: 'BAD_REQUEST', + }) + } + + const input = context.transformer.input.deserialize(trpc.params.input) + const ctx = await context.createContext?.({ + req: port, + res: undefined, + }) + + const result = await callProcedure({ + router: context.router, + path: trpc.params.path, + getRawInput: async () => input, + ctx, + type: trpc.method as ProcedureType, + signal: undefined, + }) + + const serializedData = context.transformer.output.serialize(result) + sendResponse({ + result: { + type: 'data', + data: serializedData, + }, + }) +} + +async function handleSubscription( + context: MessageHandlerContext, + portManager: PortConnectionManager, + trpc: TRPCClientOutgoingMessage, + port: Runtime.Port, + id: NonNullable, + sendResponse: TRPCResponseSender, +): Promise { + if (!('params' in trpc)) { + throw new TRPCError({ + message: 'Missing params in subscription request', + code: 'BAD_REQUEST', + }) + } + + const input = context.transformer.input.deserialize(trpc.params.input) + const ctx = await context.createContext?.({ + req: port, + res: undefined, + }) + + const result = await callProcedure({ + router: context.router, + path: trpc.params.path, + getRawInput: async () => input, + ctx, + type: 'subscription', + signal: undefined, + }) + + if (!isObservable(result)) { + throw new TRPCError({ + message: `Subscription ${trpc.params.path} did not return an observable`, + code: 'INTERNAL_SERVER_ERROR', + }) + } + + const subscription = result.subscribe({ + next: data => { + sendResponse({ + result: { + type: 'data', + data, }, }) + }, + error: cause => { + const error = getTRPCErrorFromUnknown(cause) - if (subscriptions.has(id)) { - subscription.unsubscribe() - sendResponse({ - result: { - type: 'stopped', - }, - }) - throw new TRPCError({ - message: `Duplicate id ${id}`, - code: 'BAD_REQUEST', - }) - } - - subscriptions.set(id, subscription) + context.onError?.({ + error, + type: 'subscription', + path: trpc.params?.path, + input, + ctx, + req: port, + }) + sendResponse({ + error: getErrorShape({ + error, + type: 'subscription', + path: trpc.params?.path, + input, + ctx, + config: context.router._def._config, + }), + }) + }, + complete: () => { sendResponse({ result: { - type: 'started', + type: 'stopped', }, }) - return + }, + }) + + const subscriptionAdded = portManager.addSubscription(port, id, subscription) + if (!subscriptionAdded) { + subscription.unsubscribe() + throw new TRPCError({ + message: 'Failed to register subscription', + code: 'INTERNAL_SERVER_ERROR', + }) + } + + sendResponse({ + result: { + type: 'started', + }, + }) +} + +function createMessageHandler( + context: MessageHandlerContext, + portManager: PortConnectionManager, +) { + return async (message: unknown, port: Runtime.Port): Promise => { + if (!isClientMessage(message)) return + + const { trpc } = message + if (!hasTRPCId(trpc)) return + + const portInfo = portManager.getPortSubscriptions(port) + if (!portInfo) return + + const { id, jsonrpc, method } = trpc + const sendResponse = createResponseSender(port, id, jsonrpc) + + let input: unknown + let ctx: unknown + + try { + if (isSubscriptionStop(trpc)) { + await handleSubscriptionStop(portManager, port, id, sendResponse) + return + } + + if (isSubscriptionMethod(method)) { + await handleSubscription( + context, + portManager, + trpc, + port, + id, + sendResponse, + ) + } else { + await handleRegularProcedure(context, trpc, port, sendResponse) + } } catch (cause) { const error = getTRPCErrorFromUnknown(cause) - onError?.({ + context.onError?.({ error, type: method as ProcedureType, path: 'params' in trpc ? trpc.params?.path : undefined, @@ -199,39 +346,34 @@ function onPortMessage( path: 'params' in trpc ? trpc.params?.path : undefined, input, ctx, - config: router._def._config, + config: context.router._def._config, }), }) } } } -function onPortDisconnect( - onMessage: ReturnType>, +function createDisconnectHandler( + portManager: PortConnectionManager, + onMessage: ReturnType>, ) { - return (port: Runtime.Port) => { - const portInfo = PORTS.get(port) - if (!portInfo) return + return (port: Runtime.Port): void => { port.onMessage.removeListener(onMessage) - - const { subscriptions } = portInfo - subscriptions.forEach(sub => sub.unsubscribe()) - - PORTS.delete(port) + portManager.removePort(port) } } -function onPortConnect( - opts: CreateWebExtHandlerOptions, +function createConnectHandler( + context: MessageHandlerContext, + portManager: PortConnectionManager, ) { - return (port: Runtime.Port) => { - const portInfo: PortInfo = { - subscriptions: new Map(), - } + return (port: Runtime.Port): void => { + portManager.addPort(port) - PORTS.set(port, portInfo) - const onMessage = onPortMessage(opts) - port.onDisconnect.addListener(onPortDisconnect(onMessage)) + const onMessage = createMessageHandler(context, portManager) + const onDisconnect = createDisconnectHandler(portManager, onMessage) + + port.onDisconnect.addListener(onDisconnect) port.onMessage.addListener(onMessage) } } @@ -253,8 +395,21 @@ function onPortConnect( * }); * = */ -export const createWebExtHandler = ( +export function createWebExtHandler( opts: CreateWebExtHandlerOptions, -) => { - opts.runtime.onConnect.addListener(onPortConnect(opts)) +): void { + const { router, createContext, onError, runtime } = opts + const { transformer } = router._def._config + + const context: MessageHandlerContext = { + router, + createContext, + onError, + transformer, + } + + const portManager = new PortConnectionManager() + const onConnect = createConnectHandler(context, portManager) + + runtime.onConnect.addListener(onConnect) } diff --git a/packages/trpc-webext/src/link/index.ts b/packages/trpc-webext/src/link/index.ts index 4b5c62aca..76bc0053f 100644 --- a/packages/trpc-webext/src/link/index.ts +++ b/packages/trpc-webext/src/link/index.ts @@ -1,129 +1,172 @@ import { TRPCClientError } from '@trpc/client' import { observable } from '@trpc/server/observable' -import type { Operation, TRPCLink } from '@trpc/client' +import type { Operation, OperationResultEnvelope, TRPCLink } from '@trpc/client' import type { AnyTRPCRouter } from '@trpc/server' import type { Observer } from '@trpc/server/observable' import type { TRPCResponseMessage } from '@trpc/server/rpc' import type { DataTransformer } from '@trpc/server/unstable-core-do-not-import' import type { Runtime } from 'webextension-polyfill' -export type WebExtensionLinkOptions = { +export interface WebExtensionLinkOptions { runtime: Runtime.Static timeoutMS?: number transformer: DataTransformer } -export type BackgroundMessage = { +export interface BackgroundMessage { trpc: TRPCResponseMessage } -let portToBackground: Runtime.Port | null = null +type ResultListener = { + timestamp: number + observer: Observer< + OperationResultEnvelope>, + TRPCClientError + > + type: Operation['type'] +} interface ResultListeners { - [id: number]: { - timestamp: number // used to cleanup listeners - observer: Observer - type: Operation['type'] - } + [id: number]: ResultListener } +let portToBackground: Runtime.Port | null = null const resultListeners: ResultListeners = {} -function connectToBackground({ runtime }: WebExtensionLinkOptions) { - if (!portToBackground) portToBackground = runtime.connect() +function connectToBackground({ runtime }: WebExtensionLinkOptions): void { + if (!portToBackground) { + portToBackground = runtime.connect() + } +} + +function isBackgroundMessage(message: unknown): message is BackgroundMessage { + return ( + typeof message === 'object' && + message !== null && + 'trpc' in message && + typeof (message as any).trpc === 'object' + ) } -function portOnMessageFromBackground( - transformer: WebExtensionLinkOptions['transformer'], -) { +function isTRPCResponseWithId( + trpc: TRPCResponseMessage, +): trpc is TRPCResponseMessage & { id: number } { + return ( + 'id' in trpc && + trpc.id !== null && + trpc.id !== undefined && + typeof trpc.id === 'number' + ) +} + +function createPortMessageHandler( + transformer: DataTransformer, +): (message: unknown) => void { return (message: unknown) => { - if (!(message as BackgroundMessage)?.trpc) return - const backgroundMessage = message as BackgroundMessage - const { trpc } = backgroundMessage - if (!('id' in trpc) || trpc.id === null || trpc.id === undefined) return - if (!(trpc.id in resultListeners)) return - const { observer, type } = resultListeners[trpc.id as number] + if (!isBackgroundMessage(message)) return + + const { trpc } = message + if (!isTRPCResponseWithId(trpc)) return + + const listener = resultListeners[trpc.id] + if (!listener) return + + const { observer, type } = listener if ('error' in trpc) { - // Check if it's already a SuperJSONResult or needs deserialization - const error = - typeof trpc.error === 'object' && 'json' in trpc.error - ? transformer.deserialize(trpc.error) - : trpc.error + // Handle error response + const error = shouldDeserialize(trpc.error) + ? transformer.deserialize(trpc.error) + : trpc.error + observer.error(TRPCClientError.from({ ...trpc, error })) return } + // Handle success response observer.next({ result: { ...trpc.result, ...((!trpc.result.type || trpc.result.type === 'data') && { - type: 'data', + type: 'data' as const, data: transformer.deserialize(trpc.result.data), }), - } as unknown, + }, }) + // Complete for non-subscription or stopped subscription if (type !== 'subscription' || trpc.result.type === 'stopped') { observer.complete() } } } +function shouldDeserialize(error: unknown): error is { json: unknown } { + return typeof error === 'object' && error !== null && 'json' in error +} + let clearListenersIntervalId: ReturnType | undefined -function clearListenersIntervalFn(timeoutMS: number) { +function createListenerCleaner(timeoutMS: number): () => void { return () => { - const timedOutAt = new Date().getTime() - timeoutMS + const timedOutAt = Date.now() - timeoutMS - for (const id in resultListeners) { - if (resultListeners[id].timestamp < timedOutAt) { - delete resultListeners[id] + for (const [id, listener] of Object.entries(resultListeners)) { + if (listener.timestamp < timedOutAt) { + delete resultListeners[Number(id)] } } } } -function setupClearListenersInterval(timeoutMS = 10000) { - if (clearListenersIntervalId) clearInterval(clearListenersIntervalId) - clearListenersIntervalId = setInterval( - clearListenersIntervalFn(timeoutMS), - 1000, - ) +function setupClearListenersInterval(timeoutMS = 10000): void { + if (clearListenersIntervalId) { + clearInterval(clearListenersIntervalId) + } + + clearListenersIntervalId = setInterval(createListenerCleaner(timeoutMS), 1000) +} + +function createPortDisconnectHandler( + port: Runtime.Port, + onMessage: (message: unknown) => void, +): (disconnectedPort: Runtime.Port) => void { + return (disconnectedPort: Runtime.Port) => { + disconnectedPort.onDisconnect.removeListener( + createPortDisconnectHandler(port, onMessage), + ) + disconnectedPort.onMessage.removeListener(onMessage) + portToBackground = null + } } export function webExtensionLink( opts: WebExtensionLinkOptions, ): TRPCLink { - const { timeoutMS, transformer } = opts + const { timeoutMS = 10000, transformer } = opts setupClearListenersInterval(timeoutMS) - connectToBackground(opts) - const onMessage = portOnMessageFromBackground(transformer) - const portOnDisconnect = (port: Runtime.Port) => { - port.onDisconnect.removeListener(portOnDisconnect) - port.onMessage.removeListener(onMessage) - portToBackground = null - } + const onMessage = createPortMessageHandler(transformer) + const onDisconnect = createPortDisconnectHandler(portToBackground!, onMessage) - portToBackground?.onDisconnect.addListener(portOnDisconnect) + portToBackground?.onDisconnect.addListener(onDisconnect) portToBackground?.onMessage.addListener(onMessage) return () => { return ({ op }) => { - const { id, type, path } = op + const { id, type, path, input } = op - const input = transformer.serialize(op.input) || op.input + const serializedInput = transformer.serialize(input) ?? input const trpcPayload = { id, jsonrpc: undefined, method: type, - params: { path, input }, - } + params: { path, input: serializedInput }, + } as const const postMessagePayload = { trpc: trpcPayload, @@ -133,8 +176,9 @@ export function webExtensionLink( resultListeners[id] = { observer, type, - timestamp: new Date().getTime(), + timestamp: Date.now(), } + portToBackground?.postMessage(postMessagePayload) return () => { From 2a12a5ed234191871a3697dba2fb1ecdc3aeac4e Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 13:26:39 +0800 Subject: [PATCH 04/11] feat: add unit tests --- packages/trpc-webext/package.json | 7 +- .../trpc-webext/src/adapter/index.test.ts | 499 ++++++++++ packages/trpc-webext/src/link/index.test.ts | 909 ++++++++++++++++++ packages/trpc-webext/src/link/index.ts | 4 +- packages/trpc-webext/test-setup.ts | 12 + packages/trpc-webext/vitest.config.ts | 9 + pnpm-lock.yaml | 3 + 7 files changed, 1439 insertions(+), 4 deletions(-) create mode 100644 packages/trpc-webext/src/adapter/index.test.ts create mode 100644 packages/trpc-webext/src/link/index.test.ts create mode 100644 packages/trpc-webext/test-setup.ts create mode 100644 packages/trpc-webext/vitest.config.ts diff --git a/packages/trpc-webext/package.json b/packages/trpc-webext/package.json index c6ae9400d..186f8e39a 100644 --- a/packages/trpc-webext/package.json +++ b/packages/trpc-webext/package.json @@ -37,13 +37,16 @@ "dev": "tsc -w", "build": "tsc", "lint": "eslint src", - "format": "prettier --write ." + "format": "prettier --write .", + "test": "vitest", + "test:watch": "vitest --watch" }, "peerDependencies": { "@trpc/client": "^11.0.0", "@trpc/server": "^11.0.0" }, "devDependencies": { - "@types/webextension-polyfill": "^0.12.3" + "@types/webextension-polyfill": "^0.12.3", + "zod": "^3.23.8" } } diff --git a/packages/trpc-webext/src/adapter/index.test.ts b/packages/trpc-webext/src/adapter/index.test.ts new file mode 100644 index 000000000..2fc14ab3d --- /dev/null +++ b/packages/trpc-webext/src/adapter/index.test.ts @@ -0,0 +1,499 @@ +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest' +import { initTRPC, TRPCError } from '@trpc/server' +import { observable } from '@trpc/server/observable' +import { z } from 'zod' +import { createWebExtHandler } from './' +import type { Runtime } from 'webextension-polyfill' + +// Mock runtime +const mockRuntime = { + onConnect: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, +} as unknown as Runtime.Static + +// Mock port +const createMockPort = (): Runtime.Port => ({ + name: 'test-port', + postMessage: vi.fn(), + onMessage: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, + onDisconnect: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, + disconnect: vi.fn(), + sender: undefined, + error: undefined, +}) + +// Test router setup +const t = initTRPC.create() +const testRouter = t.router({ + greeting: t.procedure + .input(z.object({ name: z.string() })) + .query(({ input }) => `Hello ${input.name}!`), + + count: t.procedure + .input(z.object({ start: z.number() })) + .subscription(({ input }) => { + return observable(emit => { + let count = input.start + const timer = setInterval(() => { + emit.next(count++) + if (count > input.start + 3) { + emit.complete() + } + }, 100) + return () => clearInterval(timer) + }) + }), + + error: t.procedure.query(() => { + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: 'Test error', + }) + }), + + mutation: t.procedure + .input(z.object({ value: z.string() })) + .mutation(({ input }) => ({ result: input.value.toUpperCase() })), +}) + +describe('createWebExtHandler', () => { + let mockPort: Runtime.Port + let onConnectListener: (port: Runtime.Port) => void + let onMessageListener: ( + message: unknown, + port: Runtime.Port, + ) => Promise | void + let onDisconnectListener: (port: Runtime.Port) => void + + beforeEach(() => { + vi.clearAllMocks() + mockPort = createMockPort() + + // Capture the listeners when they're added + ;(mockRuntime.onConnect.addListener as Mock).mockImplementation( + listener => { + onConnectListener = listener + }, + ) + ;(mockPort.onMessage.addListener as Mock).mockImplementation(listener => { + onMessageListener = listener + }) + ;(mockPort.onDisconnect.addListener as Mock).mockImplementation( + listener => { + onDisconnectListener = listener + }, + ) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('Handler Setup', () => { + it('should register onConnect listener', () => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + + expect(mockRuntime.onConnect.addListener).toHaveBeenCalledOnce() + expect(typeof onConnectListener).toBe('function') + }) + + it('should handle port connection', () => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + + onConnectListener(mockPort) + + expect(mockPort.onMessage.addListener).toHaveBeenCalledOnce() + expect(mockPort.onDisconnect.addListener).toHaveBeenCalledOnce() + }) + + it('should clean up on port disconnection', () => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + + onConnectListener(mockPort) + onDisconnectListener(mockPort) + + expect(mockPort.onMessage.removeListener).toHaveBeenCalledOnce() + }) + }) + + describe('Message Handling', () => { + beforeEach(() => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + onConnectListener(mockPort) + }) + + it('should ignore non-tRPC messages', async () => { + await onMessageListener({ notTrpc: true }, mockPort) + expect(mockPort.postMessage).not.toHaveBeenCalled() + }) + + it('should ignore messages without ID', async () => { + await onMessageListener( + { + trpc: { method: 'query', params: { path: 'greeting' } }, + }, + mockPort, + ) + expect(mockPort.postMessage).not.toHaveBeenCalled() + }) + + it('should handle query procedure', async () => { + const message = { + trpc: { + id: 1, + jsonrpc: '2.0' as const, + method: 'query' as const, + params: { + path: 'greeting', + input: { name: 'World' }, + }, + }, + } + + await onMessageListener(message, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 1, + jsonrpc: '2.0', + result: { + type: 'data', + data: 'Hello World!', + }, + }, + }) + }) + + it('should handle mutation procedure', async () => { + const message = { + trpc: { + id: 2, + jsonrpc: '2.0' as const, + method: 'mutation' as const, + params: { + path: 'mutation', + input: { value: 'test' }, + }, + }, + } + + await onMessageListener(message, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 2, + jsonrpc: '2.0', + result: { + type: 'data', + data: { result: 'TEST' }, + }, + }, + }) + }) + + it('should handle procedure errors', async () => { + const message = { + trpc: { + id: 3, + jsonrpc: '2.0' as const, + method: 'query' as const, + params: { + path: 'error', + input: null, + }, + }, + } + + await onMessageListener(message, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 3, + jsonrpc: '2.0', + error: expect.objectContaining({ + code: -32603, + message: 'Test error', + }), + }, + }) + }) + + it('should handle missing params', async () => { + const message = { + trpc: { + id: 4, + jsonrpc: '2.0' as const, + method: 'query' as const, + // Missing params + }, + } + + await onMessageListener(message, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 4, + jsonrpc: '2.0', + error: expect.objectContaining({ + code: -32600, + message: 'Missing params in request', + }), + }, + }) + }) + }) + + describe('Subscription Handling', () => { + beforeEach(() => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + onConnectListener(mockPort) + }) + + it('should handle subscription start', async () => { + const message = { + trpc: { + id: 5, + jsonrpc: '2.0' as const, + method: 'subscription' as const, + params: { + path: 'count', + input: { start: 0 }, + }, + }, + } + + await onMessageListener(message, mockPort) + + // Should send started message + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 5, + jsonrpc: '2.0', + result: { + type: 'started', + }, + }, + }) + + // Wait for subscription data + await new Promise(resolve => setTimeout(resolve, 150)) + + // Should have sent data messages + expect(mockPort.postMessage).toHaveBeenCalledWith( + expect.objectContaining({ + trpc: expect.objectContaining({ + id: 5, + result: { + type: 'data', + data: 0, + }, + }), + }), + ) + }) + + it('should handle subscription stop', async () => { + // Start subscription first + const startMessage = { + trpc: { + id: 6, + jsonrpc: '2.0' as const, + method: 'subscription' as const, + params: { + path: 'count', + input: { start: 0 }, + }, + }, + } + + await onMessageListener(startMessage, mockPort) + + // Stop subscription + const stopMessage = { + trpc: { + id: 6, + jsonrpc: '2.0' as const, + method: 'subscription.stop' as const, + }, + } + + await onMessageListener(stopMessage, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 6, + jsonrpc: '2.0', + result: { + type: 'stopped', + }, + }, + }) + }) + + it('should handle subscription with missing params', async () => { + const message = { + trpc: { + id: 7, + jsonrpc: '2.0' as const, + method: 'subscription' as const, + // Missing params + }, + } + + await onMessageListener(message, mockPort) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 7, + jsonrpc: '2.0', + error: expect.objectContaining({ + code: -32600, + message: 'Missing params in subscription request', + }), + }, + }) + }) + }) + + describe('Context Creation', () => { + it('should call createContext with port', async () => { + const createContext = vi.fn().mockResolvedValue({ userId: 'test' }) + + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + createContext, + }) + + onConnectListener(mockPort) + + const message = { + trpc: { + id: 8, + jsonrpc: '2.0' as const, + method: 'query' as const, + params: { + path: 'greeting', + input: { name: 'World' }, + }, + }, + } + + await onMessageListener(message, mockPort) + + expect(createContext).toHaveBeenCalledWith({ + req: mockPort, + res: undefined, + }) + }) + }) + + describe('Error Handling', () => { + it('should call onError callback', async () => { + const onError = vi.fn() + + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + onError, + }) + + onConnectListener(mockPort) + + const message = { + trpc: { + id: 9, + jsonrpc: '2.0' as const, + method: 'query' as const, + params: { + path: 'error', + input: null, + }, + }, + } + + await onMessageListener(message, mockPort) + + expect(onError).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.any(TRPCError), + type: 'query', + path: 'error', + }), + ) + }) + }) + + describe('Port Connection Manager', () => { + it('should clean up subscriptions on disconnect', async () => { + createWebExtHandler({ + router: testRouter, + runtime: mockRuntime, + }) + + onConnectListener(mockPort) + + // Start a subscription + const message = { + trpc: { + id: 10, + jsonrpc: '2.0' as const, + method: 'subscription' as const, + params: { + path: 'count', + input: { start: 0 }, + }, + }, + } + + await onMessageListener(message, mockPort) + + // Disconnect port + onDisconnectListener(mockPort) + + // Subscription should be cleaned up (no more messages) + const initialCallCount = (mockPort.postMessage as Mock).mock.calls.length + + await new Promise(resolve => setTimeout(resolve, 200)) + + // Should not have sent more messages after disconnect + expect((mockPort.postMessage as Mock).mock.calls.length).toBe( + initialCallCount, + ) + }) + }) +}) diff --git a/packages/trpc-webext/src/link/index.test.ts b/packages/trpc-webext/src/link/index.test.ts new file mode 100644 index 000000000..49c983f97 --- /dev/null +++ b/packages/trpc-webext/src/link/index.test.ts @@ -0,0 +1,909 @@ +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest' +import { TRPCClientError } from '@trpc/client' +import { webExtensionLink } from './' +import type { Runtime } from 'webextension-polyfill' +import type { Operation } from '@trpc/client' + +// Mock runtime +const mockRuntime = { + connect: vi.fn(), +} as unknown as Runtime.Static + +// Mock port +const createMockPort = (): Runtime.Port => ({ + name: 'test-port', + postMessage: vi.fn(), + onMessage: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, + onDisconnect: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, + disconnect: vi.fn(), + sender: undefined, + error: undefined, +}) + +// Mock transformer +const mockTransformer = { + serialize: vi.fn(data => data), + deserialize: vi.fn(data => data), +} + +describe('webExtensionLink', () => { + let mockPort: Runtime.Port + let onMessageListener: (message: unknown) => void + let onDisconnectListener: (port: Runtime.Port) => void + + beforeEach(() => { + vi.clearAllMocks() + vi.clearAllTimers() + + mockPort = createMockPort() + + // Setup runtime.connect to return our mock port + ;(mockRuntime.connect as Mock).mockReturnValue(mockPort) + + // Capture listeners when they're added + ;(mockPort.onMessage.addListener as Mock).mockImplementation(listener => { + onMessageListener = listener + }) + ;(mockPort.onDisconnect.addListener as Mock).mockImplementation( + listener => { + onDisconnectListener = listener + }, + ) + + // Reset transformer mocks + mockTransformer.serialize.mockImplementation(data => data) + mockTransformer.deserialize.mockImplementation(data => data) + }) + + afterEach(() => { + // Trigger disconnect to reset port state + if (onDisconnectListener && mockPort) { + onDisconnectListener(mockPort) + } + + vi.restoreAllMocks() + vi.clearAllTimers() + }) + + describe('Link Creation', () => { + it('should create a link function', () => { + const link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + + expect(typeof link).toBe('function') + }) + + it('should connect to background script', () => { + webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + + expect(mockRuntime.connect).toHaveBeenCalledOnce() + }) + + it('should setup port listeners', () => { + webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + + expect(mockPort.onMessage.addListener).toHaveBeenCalledOnce() + expect(mockPort.onDisconnect.addListener).toHaveBeenCalledOnce() + }) + + it('should setup cleanup interval with default timeout', () => { + vi.useFakeTimers() + + webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + + expect(vi.getTimerCount()).toBe(1) + + vi.useRealTimers() + }) + + it('should setup cleanup interval with custom timeout', () => { + vi.useFakeTimers() + + webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + timeoutMS: 5000, + }) + + expect(vi.getTimerCount()).toBe(1) + + vi.useRealTimers() + }) + }) + + describe('Operation Handling', () => { + let link: ReturnType + let operationLink: ReturnType> + + beforeEach(() => { + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + operationLink = link() + }) + + it('should serialize input before sending', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test.procedure', + input: { name: 'test' }, + context: {}, + signal: null, + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + // Subscribe to trigger the operation + observable.subscribe({}) + + expect(mockTransformer.serialize).toHaveBeenCalledWith({ name: 'test' }) + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 1, + jsonrpc: undefined, + method: 'query', + params: { + path: 'test.procedure', + input: { name: 'test' }, + }, + }, + }) + }) + + it('should handle null input', () => { + const operation: Operation = { + id: 2, + type: 'mutation', + path: 'test.procedure', + input: null, + context: {}, + signal: null, + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe({}) + + expect(mockTransformer.serialize).toHaveBeenCalledWith(null) + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 2, + jsonrpc: undefined, + method: 'mutation', + params: { + path: 'test.procedure', + input: null, + }, + }, + }) + }) + + it('should handle subscription operations', () => { + const operation: Operation = { + id: 3, + type: 'subscription', + path: 'test.subscription', + input: { topic: 'test' }, + context: {}, + signal: null, + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe({}) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 3, + jsonrpc: undefined, + method: 'subscription', + params: { + path: 'test.subscription', + input: { topic: 'test' }, + }, + }, + }) + }) + }) + + describe('Message Response Handling', () => { + let link: ReturnType + let operationLink: ReturnType> + + beforeEach(() => { + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + operationLink = link() + }) + + it('should ignore non-tRPC messages', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send non-tRPC message + onMessageListener({ notTrpc: true }) + + expect(observer.next).not.toHaveBeenCalled() + expect(observer.error).not.toHaveBeenCalled() + }) + + it('should ignore messages without ID', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send message without ID + onMessageListener({ + trpc: { + result: { type: 'data', data: 'test' }, + }, + }) + + expect(observer.next).not.toHaveBeenCalled() + }) + + it('should ignore messages with wrong ID', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send message with different ID + onMessageListener({ + trpc: { + id: 999, + result: { type: 'data', data: 'test' }, + }, + }) + + expect(observer.next).not.toHaveBeenCalled() + }) + + it('should handle successful data response', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send successful response + onMessageListener({ + trpc: { + id: 1, + result: { type: 'data', data: 'Hello World' }, + }, + }) + + expect(mockTransformer.deserialize).toHaveBeenCalledWith('Hello World') + expect(observer.next).toHaveBeenCalledWith({ + result: { + type: 'data', + data: 'Hello World', + }, + }) + expect(observer.complete).toHaveBeenCalledOnce() + }) + + it('should handle response without explicit type', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send response without type + onMessageListener({ + trpc: { + id: 1, + result: { data: 'Hello World' }, + }, + }) + + expect(observer.next).toHaveBeenCalledWith({ + result: { + type: 'data', + data: 'Hello World', + }, + }) + }) + + it('should handle error response', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + const errorResponse = { + code: -32603, + message: 'Internal error', + data: { custom: 'error data' }, + } + + // Send error response + onMessageListener({ + trpc: { + id: 1, + error: errorResponse, + }, + }) + + expect(observer.error).toHaveBeenCalledWith(expect.any(TRPCClientError)) + expect(observer.next).not.toHaveBeenCalled() + expect(observer.complete).not.toHaveBeenCalled() + }) + + it('should deserialize error with json property', () => { + mockTransformer.deserialize.mockReturnValue({ + message: 'Deserialized error', + }) + + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + error: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send error with json property + onMessageListener({ + trpc: { + id: 1, + error: { + json: { message: 'Serialized error' }, + }, + }, + }) + + expect(mockTransformer.deserialize).toHaveBeenCalledWith({ + json: { message: 'Serialized error' }, + }) + }) + }) + + describe('Subscription Handling', () => { + let link: ReturnType + let operationLink: ReturnType> + + beforeEach(() => { + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + operationLink = link() + }) + + it('should not complete subscription on data', () => { + const operation: Operation = { + id: 1, + type: 'subscription', + path: 'test.sub', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send subscription data + onMessageListener({ + trpc: { + id: 1, + result: { type: 'data', data: 'subscription data' }, + }, + }) + + expect(observer.next).toHaveBeenCalled() + expect(observer.complete).not.toHaveBeenCalled() + }) + + it('should complete subscription on stopped', () => { + const operation: Operation = { + id: 1, + type: 'subscription', + path: 'test.sub', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send stopped message + onMessageListener({ + trpc: { + id: 1, + result: { type: 'stopped' }, + }, + }) + + expect(observer.next).toHaveBeenCalled() + expect(observer.complete).toHaveBeenCalled() + }) + + it('should complete non-subscription on data', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send data response + onMessageListener({ + trpc: { + id: 1, + result: { type: 'data', data: 'query result' }, + }, + }) + + expect(observer.next).toHaveBeenCalled() + expect(observer.complete).toHaveBeenCalled() + }) + }) + + describe('Cleanup and Memory Management', () => { + let link: ReturnType + let operationLink: ReturnType> + + beforeEach(() => { + vi.useFakeTimers() + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + timeoutMS: 1000, + }) + operationLink = link() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('should clean up listener on unsubscribe', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + const subscription = observable.subscribe({}) + + // Unsubscribe + subscription.unsubscribe() + + // Try to send message - should be ignored + onMessageListener({ + trpc: { + id: 1, + result: { type: 'data', data: 'test' }, + }, + }) + + // Observer shouldn't be called since listener was cleaned up + expect(mockPort.postMessage).toHaveBeenCalledOnce() // Only the initial postMessage + }) + + it('should clean up timed out listeners', () => { + const operation1: Operation = { + id: 1, + type: 'query', + path: 'test1', + input: null, + context: {}, + signal: null, + } + + const operation2: Operation = { + id: 2, + type: 'query', + path: 'test2', + input: null, + context: {}, + signal: null, + } + + // Create two operations + const observable1 = operationLink({ + op: operation1, + next: vi.fn(), + }) + + observable1.subscribe({}) + + // Advance time past timeout + vi.advanceTimersByTime(1500) + + // Create another operation after timeout + const observable2 = operationLink({ + op: operation2, + next: vi.fn(), + }) + + const observer2 = { + next: vi.fn(), + } + + observable2.subscribe(observer2) + + // Run cleanup interval + vi.advanceTimersByTime(1000) + + // Response to first operation should be ignored (timed out) + onMessageListener({ + trpc: { + id: 1, + result: { type: 'data', data: 'test1' }, + }, + }) + + // Response to second operation should work + onMessageListener({ + trpc: { + id: 2, + result: { type: 'data', data: 'test2' }, + }, + }) + + expect(observer2.next).toHaveBeenCalledWith({ + result: { + type: 'data', + data: 'test2', + }, + }) + }) + }) + + describe('Port Disconnection', () => { + let link: ReturnType + + beforeEach(() => { + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + }) + + it('should clean up listeners on disconnect', () => { + // Trigger disconnect + onDisconnectListener(mockPort) + + expect(mockPort.onMessage.removeListener).toHaveBeenCalledOnce() + expect(mockPort.onDisconnect.removeListener).toHaveBeenCalledOnce() + }) + + it('should reset port to null on disconnect', () => { + // Create another link after disconnect + onDisconnectListener(mockPort) + + const newMockPort = createMockPort() + ;(mockRuntime.connect as Mock).mockReturnValue(newMockPort) + + // Create new link - should reconnect + const newLink = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + + expect(mockRuntime.connect).toHaveBeenCalledTimes(2) + }) + }) + + describe('Edge Cases', () => { + let link: ReturnType + let operationLink: ReturnType> + + beforeEach(() => { + link = webExtensionLink({ + runtime: mockRuntime, + transformer: mockTransformer, + }) + operationLink = link() + }) + + it('should handle transformer that returns undefined', () => { + mockTransformer.serialize.mockReturnValue(undefined) + + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: { test: 'data' }, + context: {}, + signal: null, + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe({}) + + expect(mockPort.postMessage).toHaveBeenCalledWith({ + trpc: { + id: 1, + jsonrpc: undefined, + method: 'query', + params: { + path: 'test', + input: { test: 'data' }, // Falls back to original input + }, + }, + }) + }) + + it('should handle malformed response messages', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + error: vi.fn(), + complete: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send message with result but no data property + onMessageListener({ + trpc: { + id: 1, + result: {}, // Empty result object + }, + }) + + // Should now handle gracefully with default empty object + expect(mockTransformer.deserialize).toHaveBeenCalledWith({}) + expect(observer.next).toHaveBeenCalledWith({ + result: { + type: 'data', + data: {}, // Default empty object after deserialization + }, + }) + expect(observer.complete).toHaveBeenCalledOnce() + expect(observer.error).not.toHaveBeenCalled() + }) + + it('should handle string ID', () => { + const operation: Operation = { + id: 1, + type: 'query', + path: 'test', + input: null, + context: {}, + signal: null, + } + + const observer = { + next: vi.fn(), + } + + const observable = operationLink({ + op: operation, + next: vi.fn(), + }) + + observable.subscribe(observer) + + // Send response with string ID that matches number + onMessageListener({ + trpc: { + id: '1', // String instead of number + result: { type: 'data', data: 'test' }, + }, + }) + + // Should be ignored since IDs don't match exactly + expect(observer.next).not.toHaveBeenCalled() + }) + }) +}) diff --git a/packages/trpc-webext/src/link/index.ts b/packages/trpc-webext/src/link/index.ts index 76bc0053f..73453dcff 100644 --- a/packages/trpc-webext/src/link/index.ts +++ b/packages/trpc-webext/src/link/index.ts @@ -88,9 +88,9 @@ function createPortMessageHandler( observer.next({ result: { ...trpc.result, - ...((!trpc.result.type || trpc.result.type === 'data') && { + ...((!trpc.result?.type || trpc.result.type === 'data') && { type: 'data' as const, - data: transformer.deserialize(trpc.result.data), + data: transformer.deserialize(trpc.result?.data || {}), }), }, }) diff --git a/packages/trpc-webext/test-setup.ts b/packages/trpc-webext/test-setup.ts new file mode 100644 index 000000000..613dde84c --- /dev/null +++ b/packages/trpc-webext/test-setup.ts @@ -0,0 +1,12 @@ +import { vi } from 'vitest' + +// Mock webextension-polyfill types +vi.mock('webextension-polyfill', () => ({ + runtime: { + onConnect: { + addListener: vi.fn(), + removeListener: vi.fn(), + hasListener: vi.fn(), + }, + }, +})) diff --git a/packages/trpc-webext/vitest.config.ts b/packages/trpc-webext/vitest.config.ts new file mode 100644 index 000000000..6d0aa5446 --- /dev/null +++ b/packages/trpc-webext/vitest.config.ts @@ -0,0 +1,9 @@ +import { defineConfig } from 'vitest/config' + +export default defineConfig({ + test: { + globals: true, + environment: 'happy-dom', + setupFiles: ['./test-setup.ts'], + }, +}) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 976d30d72..37aeb7891 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1346,6 +1346,9 @@ importers: '@types/webextension-polyfill': specifier: ^0.12.3 version: 0.12.3 + zod: + specifier: 3.23.8 + version: 3.23.8 packages/wallet: dependencies: From 36c7abf4846a9d9835236f0b991b5417f9f8c7fb Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 14:20:35 +0800 Subject: [PATCH 05/11] feat: safeDeserialize, safeSerialize --- packages/trpc-webext/package.json | 2 +- .../trpc-webext/src/adapter/index.test.ts | 14 ++++---- packages/trpc-webext/src/adapter/index.ts | 8 +++-- packages/trpc-webext/src/link/index.test.ts | 30 ++++++++--------- packages/trpc-webext/src/link/index.ts | 17 ++++++---- packages/trpc-webext/src/utils.ts | 33 +++++++++++++++++++ 6 files changed, 73 insertions(+), 31 deletions(-) create mode 100644 packages/trpc-webext/src/utils.ts diff --git a/packages/trpc-webext/package.json b/packages/trpc-webext/package.json index 186f8e39a..817562c96 100644 --- a/packages/trpc-webext/package.json +++ b/packages/trpc-webext/package.json @@ -38,7 +38,7 @@ "build": "tsc", "lint": "eslint src", "format": "prettier --write .", - "test": "vitest", + "test": "vitest run", "test:watch": "vitest --watch" }, "peerDependencies": { diff --git a/packages/trpc-webext/src/adapter/index.test.ts b/packages/trpc-webext/src/adapter/index.test.ts index 2fc14ab3d..bc3b70f53 100644 --- a/packages/trpc-webext/src/adapter/index.test.ts +++ b/packages/trpc-webext/src/adapter/index.test.ts @@ -1,16 +1,18 @@ +import { initTRPC, TRPCError } from '@trpc/server' +import { observable } from '@trpc/server/observable' import { + afterEach, + beforeEach, describe, - it, expect, - vi, - beforeEach, - afterEach, + it, type Mock, + vi, } from 'vitest' -import { initTRPC, TRPCError } from '@trpc/server' -import { observable } from '@trpc/server/observable' import { z } from 'zod' + import { createWebExtHandler } from './' + import type { Runtime } from 'webextension-polyfill' // Mock runtime diff --git a/packages/trpc-webext/src/adapter/index.ts b/packages/trpc-webext/src/adapter/index.ts index eabcd9db7..ae14f2c69 100644 --- a/packages/trpc-webext/src/adapter/index.ts +++ b/packages/trpc-webext/src/adapter/index.ts @@ -6,6 +6,8 @@ import { TRPCError, } from '@trpc/server/unstable-core-do-not-import' +import { safeDeserialize, safeSerialize } from '../utils' + import type { Unsubscribable } from '@trpc/server/observable' import type { AnyRouter, @@ -172,7 +174,7 @@ async function handleRegularProcedure( }) } - const input = context.transformer.input.deserialize(trpc.params.input) + const input = safeDeserialize(context.transformer.input, trpc.params.input) const ctx = await context.createContext?.({ req: port, res: undefined, @@ -187,7 +189,7 @@ async function handleRegularProcedure( signal: undefined, }) - const serializedData = context.transformer.output.serialize(result) + const serializedData = safeSerialize(context.transformer.output, result) sendResponse({ result: { type: 'data', @@ -211,7 +213,7 @@ async function handleSubscription( }) } - const input = context.transformer.input.deserialize(trpc.params.input) + const input = safeDeserialize(context.transformer.input, trpc.params.input) const ctx = await context.createContext?.({ req: port, res: undefined, diff --git a/packages/trpc-webext/src/link/index.test.ts b/packages/trpc-webext/src/link/index.test.ts index 49c983f97..f02ac918e 100644 --- a/packages/trpc-webext/src/link/index.test.ts +++ b/packages/trpc-webext/src/link/index.test.ts @@ -1,16 +1,18 @@ +import { TRPCClientError } from '@trpc/client' import { + afterEach, + beforeEach, describe, - it, expect, - vi, - beforeEach, - afterEach, + it, type Mock, + vi, } from 'vitest' -import { TRPCClientError } from '@trpc/client' + import { webExtensionLink } from './' + +import type { Operation, TRPCClientRuntime } from '@trpc/client' import type { Runtime } from 'webextension-polyfill' -import type { Operation } from '@trpc/client' // Mock runtime const mockRuntime = { @@ -147,7 +149,7 @@ describe('webExtensionLink', () => { runtime: mockRuntime, transformer: mockTransformer, }) - operationLink = link() + operationLink = link({} as TRPCClientRuntime) }) it('should serialize input before sending', () => { @@ -253,7 +255,7 @@ describe('webExtensionLink', () => { runtime: mockRuntime, transformer: mockTransformer, }) - operationLink = link() + operationLink = link({} as TRPCClientRuntime) }) it('should ignore non-tRPC messages', () => { @@ -525,7 +527,7 @@ describe('webExtensionLink', () => { runtime: mockRuntime, transformer: mockTransformer, }) - operationLink = link() + operationLink = link({} as TRPCClientRuntime) }) it('should not complete subscription on data', () => { @@ -642,7 +644,7 @@ describe('webExtensionLink', () => { transformer: mockTransformer, timeoutMS: 1000, }) - operationLink = link() + operationLink = link({} as TRPCClientRuntime) }) afterEach(() => { @@ -752,10 +754,8 @@ describe('webExtensionLink', () => { }) describe('Port Disconnection', () => { - let link: ReturnType - beforeEach(() => { - link = webExtensionLink({ + webExtensionLink({ runtime: mockRuntime, transformer: mockTransformer, }) @@ -777,7 +777,7 @@ describe('webExtensionLink', () => { ;(mockRuntime.connect as Mock).mockReturnValue(newMockPort) // Create new link - should reconnect - const newLink = webExtensionLink({ + webExtensionLink({ runtime: mockRuntime, transformer: mockTransformer, }) @@ -795,7 +795,7 @@ describe('webExtensionLink', () => { runtime: mockRuntime, transformer: mockTransformer, }) - operationLink = link() + operationLink = link({} as TRPCClientRuntime) }) it('should handle transformer that returns undefined', () => { diff --git a/packages/trpc-webext/src/link/index.ts b/packages/trpc-webext/src/link/index.ts index 73453dcff..95b583ec8 100644 --- a/packages/trpc-webext/src/link/index.ts +++ b/packages/trpc-webext/src/link/index.ts @@ -1,11 +1,16 @@ import { TRPCClientError } from '@trpc/client' import { observable } from '@trpc/server/observable' +import { safeDeserialize, safeSerialize } from '../utils' + import type { Operation, OperationResultEnvelope, TRPCLink } from '@trpc/client' import type { AnyTRPCRouter } from '@trpc/server' import type { Observer } from '@trpc/server/observable' import type { TRPCResponseMessage } from '@trpc/server/rpc' -import type { DataTransformer } from '@trpc/server/unstable-core-do-not-import' +import type { + DataTransformer, + TRPCResult, +} from '@trpc/server/unstable-core-do-not-import' import type { Runtime } from 'webextension-polyfill' export interface WebExtensionLinkOptions { @@ -45,7 +50,7 @@ function isBackgroundMessage(message: unknown): message is BackgroundMessage { typeof message === 'object' && message !== null && 'trpc' in message && - typeof (message as any).trpc === 'object' + typeof (message as BackgroundMessage).trpc === 'object' ) } @@ -77,7 +82,7 @@ function createPortMessageHandler( if ('error' in trpc) { // Handle error response const error = shouldDeserialize(trpc.error) - ? transformer.deserialize(trpc.error) + ? safeDeserialize(transformer, trpc.error) : trpc.error observer.error(TRPCClientError.from({ ...trpc, error })) @@ -90,9 +95,9 @@ function createPortMessageHandler( ...trpc.result, ...((!trpc.result?.type || trpc.result.type === 'data') && { type: 'data' as const, - data: transformer.deserialize(trpc.result?.data || {}), + data: safeDeserialize(transformer, trpc.result?.data || {}), }), - }, + } as TRPCResult, }) // Complete for non-subscription or stopped subscription @@ -159,7 +164,7 @@ export function webExtensionLink( return ({ op }) => { const { id, type, path, input } = op - const serializedInput = transformer.serialize(input) ?? input + const serializedInput = safeSerialize(transformer, input) ?? input const trpcPayload = { id, diff --git a/packages/trpc-webext/src/utils.ts b/packages/trpc-webext/src/utils.ts new file mode 100644 index 000000000..e7730c54c --- /dev/null +++ b/packages/trpc-webext/src/utils.ts @@ -0,0 +1,33 @@ +import { TRPCError } from '@trpc/server' + +import type { DataTransformer } from '@trpc/server/unstable-core-do-not-import' + +export function safeDeserialize( + transformer: DataTransformer, + data: unknown, +): T { + try { + return transformer.deserialize(data) + } catch (error) { + throw new TRPCError({ + code: 'BAD_REQUEST', + message: 'Failed to deserialize input data', + cause: error, + }) + } +} + +export function safeSerialize( + transformer: DataTransformer, + data: T, +): unknown { + try { + return transformer.serialize(data) + } catch (error) { + throw new TRPCError({ + code: 'BAD_REQUEST', + message: 'Failed to serialize output data', + cause: error, + }) + } +} From 07402f6ea8a516ddd5dde6b598f2925896e37cf9 Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 15:31:04 +0800 Subject: [PATCH 06/11] docs: readme --- packages/trpc-webext/README.md | 110 +++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 packages/trpc-webext/README.md diff --git a/packages/trpc-webext/README.md b/packages/trpc-webext/README.md new file mode 100644 index 000000000..14f93caff --- /dev/null +++ b/packages/trpc-webext/README.md @@ -0,0 +1,110 @@ +# @status-im/trpc-webext + +A tRPC adapter for web extensions that enables type-safe communication between different extension contexts (background, content scripts, popup, etc.). + +## Installation + +```sh +pnpm add @status-im/trpc-webext +``` + +## Basic Usage + +### 1. Create your tRPC router (typically in background script) + +```typescript +import { initTRPC } from '@trpc/server' +import { createWebExtHandler } from '@status-im/trpc-webext/adapter' +import { browser } from 'webextension-polyfill' +import superjson from 'superjson' + +// Initialize tRPC +const t = initTRPC.context().create({ + transformer: superjson, + isServer: false, + allowOutsideOfServer: true, +}) + +// Define your router +const appRouter = t.router({ + greeting: t.procedure + .input(z.object({ name: z.string() })) + .query(({ input }) => { + return =Hello ${input.name}!= + }), +}) + +// Create context function +const createContext = async (opts) => { + return { + // Add your context data here + userId: 'user Alice', + } +} + +// Set up the handler in background script +createWebExtHandler({ + router: appRouter, + createContext, + runtime: browser.runtime, +}) + +export type AppRouter = typeof appRouter +``` + +### 2. Create a client (in popup, content script, etc.) + +```typescript +import { createTRPCClient } from '@trpc/client' +import { webExtensionLink } from '@status-im/trpc-webext/link' +import { browser } from 'webextension-polyfill' +import superjson from 'superjson' +import type { AppRouter } from './background' + +const client = createTRPCClient({ + links: [ + webExtensionLink({ + runtime: browser.runtime, + transformer: superjson, // same transformer as the server + timeoutMS: 30000, // optional, defaults to 10000ms + }), + ], +}) + +// Use the client +async function example() { + const result = await client.greeting.query({ name: 'World' }) + console.log(result) // "Hello World!" +} +``` + +## Key Features + +- **Type Safety**: Full TypeScript support with end-to-end type safety +- **Real-time Communication**: Support for subscriptions using observables +- **Multiple Contexts**: Works across all web extension contexts (background, popup, content scripts, options page, etc.) +- **Data Transformation**: Built-in support for data transformers like SuperJSON +- **Error Handling**: Proper error propagation and handling +- **Connection Management**: Automatic cleanup of connections and subscriptions + +## Configuration Options + +### `createWebExtHandler` options: + +- `router`: Your tRPC router +- `createContext`: Function to create request context +- `runtime`: Browser runtime (e.g., `browser.runtime`) +- `onError`: Optional error handler + +### `webExtensionLink` options: + +- `runtime`: Browser runtime (e.g., `browser.runtime`) +- `transformer`: Data transformer (e.g., SuperJSON) +- `timeoutMS`: Request timeout in milliseconds (default: 10000) + +## Notes + +- The handler should be set up in your background script +- Clients can be created in any extension context +- Make sure to use the same transformer on both ends +- Subscriptions are automatically cleaned up when ports disconnect From d48b7a6b58929ce761853ef1234d7097fe5770e9 Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 15:43:13 +0800 Subject: [PATCH 07/11] docs: add changelog --- packages/trpc-webext/CHANGELOG.md | 1 + packages/trpc-webext/package.json | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 packages/trpc-webext/CHANGELOG.md diff --git a/packages/trpc-webext/CHANGELOG.md b/packages/trpc-webext/CHANGELOG.md new file mode 100644 index 000000000..0dd24a04b --- /dev/null +++ b/packages/trpc-webext/CHANGELOG.md @@ -0,0 +1 @@ +# @status-im/trpc-webext diff --git a/packages/trpc-webext/package.json b/packages/trpc-webext/package.json index 817562c96..487ea38b4 100644 --- a/packages/trpc-webext/package.json +++ b/packages/trpc-webext/package.json @@ -48,5 +48,8 @@ "devDependencies": { "@types/webextension-polyfill": "^0.12.3", "zod": "^3.23.8" + }, + "publishConfig": { + "access": "public" } } From 6db6dc3e2e5646473e6546f5c5dc3964dedddf34 Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 6 Jun 2025 15:54:02 +0800 Subject: [PATCH 08/11] chore: changeset --- .changeset/light-waves-cross.md | 5 +++++ package.json | 1 + packages/trpc-webext/package.json | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 .changeset/light-waves-cross.md diff --git a/.changeset/light-waves-cross.md b/.changeset/light-waves-cross.md new file mode 100644 index 000000000..57cf4d001 --- /dev/null +++ b/.changeset/light-waves-cross.md @@ -0,0 +1,5 @@ +--- +'@status-im/trpc-webext': patch +--- + +First version of @status-im/trpc-webext diff --git a/package.json b/package.json index e90ae8b79..a7dae432f 100644 --- a/package.json +++ b/package.json @@ -9,6 +9,7 @@ "packages/icons", "packages/components", "packages/wallet", + "packages/trpc-webext", "apps/connector", "apps/portfolio", "apps/wallet", diff --git a/packages/trpc-webext/package.json b/packages/trpc-webext/package.json index 487ea38b4..f35f6a6c0 100644 --- a/packages/trpc-webext/package.json +++ b/packages/trpc-webext/package.json @@ -1,7 +1,7 @@ { "name": "@status-im/trpc-webext", "description": "description", - "version": "0.0.1", + "version": "0.0.0", "license": "MIT", "keywords": [ "trpc", From db1822381e96d0ccf1197517a8000d0b03bce27c Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Fri, 13 Jun 2025 17:02:51 +0800 Subject: [PATCH 09/11] feat: password auth resolve #698 --- apps/wallet/src/data/api.ts | 33 ++++++++------- .../data/trpc/middlewares/password-auth.ts | 41 +++++++++++++++++++ 2 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 apps/wallet/src/data/trpc/middlewares/password-auth.ts diff --git a/apps/wallet/src/data/api.ts b/apps/wallet/src/data/api.ts index 997189d05..47f3aecba 100644 --- a/apps/wallet/src/data/api.ts +++ b/apps/wallet/src/data/api.ts @@ -12,6 +12,7 @@ import * as bitcoin from './bitcoin/bitcoin' import * as ethereum from './ethereum/ethereum' import { getKeystore } from './keystore' import * as solana from './solana/solana' +import { createPasswordAuthPlugin } from './trpc/middlewares/password-auth' import { getWalletCore, // type WalletCore @@ -34,6 +35,8 @@ const createContext = async (webextOpts?: CreateWebExtContextOptions) => { type Context = Awaited> +const passwordAuthPlugin = createPasswordAuthPlugin() + /** * @see https://trpc.io/docs/server/routers#runtime-configuration */ @@ -43,14 +46,15 @@ const t = initTRPC.context().create({ allowOutsideOfServer: true, }) -// const publicProcedure = t.procedure +const publicProcedure = t.procedure.concat(passwordAuthPlugin) + const { createCallerFactory, router } = t // todo: lock with password as trpc auth procedure // todo?: expose password in context or use other (session) token derived from it for encrypting and storing const apiRouter = router({ wallet: router({ - all: t.procedure.query(async ({ ctx }) => { + all: publicProcedure.query(async ({ ctx }) => { const { keyStore } = ctx const wallets = await keyStore.loadAll() @@ -61,7 +65,7 @@ const apiRouter = router({ // todo: validation (e.g. password, mnemonic, already exists) // todo: words count option // todo: handle cancelation - add: t.procedure + add: publicProcedure .input( z.object({ password: z.string(), @@ -70,6 +74,7 @@ const apiRouter = router({ ) .mutation(async ({ input, ctx }) => { const { walletCore, keyStore } = ctx + console.log('ctx = ', ctx) const wallet = walletCore.HDWallet.create(128, input.password) const mnemonic = wallet.mnemonic() @@ -110,7 +115,7 @@ const apiRouter = router({ } }), - get: t.procedure + get: publicProcedure .input( z.object({ walletId: z.string(), @@ -129,7 +134,7 @@ const apiRouter = router({ } }), - import: t.procedure + import: publicProcedure .input( z.object({ mnemonic: z.string(), @@ -182,7 +187,7 @@ const apiRouter = router({ }), account: router({ - all: t.procedure + all: publicProcedure .input( z.object({ walletId: z.string(), @@ -197,7 +202,7 @@ const apiRouter = router({ }), ethereum: router({ - add: t.procedure + add: publicProcedure .input( z.object({ walletId: z.string(), @@ -265,7 +270,7 @@ const apiRouter = router({ }), // note: our first tx https://holesky.etherscan.io/tx/0xdc2aa244933260c50e665aa816767dce6b76d5d498e6358392d5f79bfc9626d5 - send: t.procedure + send: publicProcedure .input( z.object({ walletId: z.string(), @@ -333,7 +338,7 @@ const apiRouter = router({ bitcoin: router({ // note?: create all variants (e.g. segwit, nested segwit, legacy, taproot) for each added account by default - add: t.procedure + add: publicProcedure .input( z.object({ walletId: z.string(), @@ -391,7 +396,7 @@ const apiRouter = router({ }), // note: our first tx https://mempool.space/testnet4/tx/4d1797f4a6e92ab5164cfa8030e5954670f162e2aae792c8d6d6a81aae32fbd4 - send: t.procedure + send: publicProcedure .input( z.object({ walletId: z.string(), @@ -435,7 +440,7 @@ const apiRouter = router({ }), solana: router({ - add: t.procedure + add: publicProcedure .input( z.object({ walletId: z.string(), @@ -459,7 +464,7 @@ const apiRouter = router({ }), // note: our first tx https://solscan.io/tx/LNgKUb6bewbcgVXi9NBF4qYNJC5kjMPpH5GDVZBsVXFC7MDhYtdygkuP1avq7c31bHDkr9pkKYvMSdT16mt294g?cluster=devnet - send: t.procedure + send: publicProcedure .input( z.object({ walletId: z.string(), @@ -503,7 +508,7 @@ const apiRouter = router({ }), cardano: router({ - add: t.procedure + add: publicProcedure .input( z.object({ walletId: z.string(), @@ -557,7 +562,7 @@ const apiRouter = router({ }), privateKey: router({ - import: t.procedure + import: publicProcedure .input( z.object({ privateKey: z.string(), diff --git a/apps/wallet/src/data/trpc/middlewares/password-auth.ts b/apps/wallet/src/data/trpc/middlewares/password-auth.ts new file mode 100644 index 000000000..c503c81d5 --- /dev/null +++ b/apps/wallet/src/data/trpc/middlewares/password-auth.ts @@ -0,0 +1,41 @@ +import { initTRPC } from '@trpc/server' + +import type { KeyStore } from '@trustwallet/wallet-core' + +export interface PasswordAuthParams { + password?: string + walletId?: string +} + +type Context = { + keyStore: KeyStore.Default +} + +export function createPasswordAuthPlugin() { + const t = initTRPC.context().create({ + isServer: false, + allowOutsideOfServer: true, + }) + + return t.procedure.use(async opts => { + const { ctx } = opts + const { keyStore } = ctx + const params = (await opts.getRawInput()) as PasswordAuthParams + if ( + typeof params?.password !== 'string' || + typeof params?.walletId !== 'string' + ) + return opts.next() + + let validPassword: undefined | string + + await keyStore + .export(params.walletId, params.password) + .then(() => { + validPassword = params.password + }) + .catch(() => {}) + + return opts.next({ ctx: { validPassword } }) + }) +} From 480c8f02c4e6b7d8332ae8704a743a777fcaf2b6 Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Tue, 17 Jun 2025 09:42:47 +0800 Subject: [PATCH 10/11] feat(bg): use ctx.validPassword instead of input.password --- apps/wallet/src/data/api.ts | 95 +++++++++++++++---- .../data/trpc/middlewares/password-auth.ts | 4 + 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/apps/wallet/src/data/api.ts b/apps/wallet/src/data/api.ts index 47f3aecba..3af272cfa 100644 --- a/apps/wallet/src/data/api.ts +++ b/apps/wallet/src/data/api.ts @@ -3,7 +3,7 @@ // import { AddressType, InMemoryKeyAgent } from '@cardano-sdk/key-management' import { createWebExtHandler, webExtensionLink } from '@status-im/trpc-webext' import { createTRPCClient } from '@trpc/client' -import { initTRPC } from '@trpc/server' +import { initTRPC, TRPCError } from '@trpc/server' import superjson from 'superjson' import { browser } from 'wxt/browser' import { z } from 'zod' @@ -19,6 +19,7 @@ import { } from './wallet' import { runtimePortToClientContextType } from './webext' +import type { ValidPasswordContext } from './trpc/middlewares/password-auth' import type { CreateWebExtContextOptions } from '@status-im/trpc-webext/adapter' const createContext = async (webextOpts?: CreateWebExtContextOptions) => { @@ -33,7 +34,7 @@ const createContext = async (webextOpts?: CreateWebExtContextOptions) => { } } -type Context = Awaited> +type Context = Awaited> & ValidPasswordContext const passwordAuthPlugin = createPasswordAuthPlugin() @@ -46,7 +47,12 @@ const t = initTRPC.context().create({ allowOutsideOfServer: true, }) -const publicProcedure = t.procedure.concat(passwordAuthPlugin) +const trpcGlobalPlugins = [passwordAuthPlugin] + +const publicProcedure = trpcGlobalPlugins.reduce( + (procedure, plugin) => procedure.concat(plugin), + t.procedure, +) const { createCallerFactory, router } = t @@ -74,7 +80,6 @@ const apiRouter = router({ ) .mutation(async ({ input, ctx }) => { const { walletCore, keyStore } = ctx - console.log('ctx = ', ctx) const wallet = walletCore.HDWallet.create(128, input.password) const mnemonic = wallet.mnemonic() @@ -127,10 +132,16 @@ const apiRouter = router({ const wallet = await keyStore.load(input.walletId) + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + return { id: wallet.id, name: wallet.name, - mnemonic: await keyStore.exportMnemonic(wallet.id, input.password), + mnemonic: await keyStore.exportMnemonic(wallet.id, ctx.validPassword), } }), @@ -213,18 +224,24 @@ const apiRouter = router({ .mutation(async ({ input, ctx }) => { const { keyStore, walletCore } = ctx + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const wallet = await keyStore.load(input.walletId) // todo!: test calling multiple times // const { id } = await keyStore.addAccounts( // wallet.id, - // input.password, + // ctx.validPassword, // [walletCore.CoinType.ethereum], // ) const { id } = await keyStore.addAccountsWithDerivations( wallet.id, - input.password, + ctx.validPassword, [ { // coin: wallet.activeAccounts[0].coin, @@ -237,7 +254,7 @@ const apiRouter = router({ // note: add account with custom derivation path // const mnemonic = (await keyStore.export( // wallet.id, - // input.password, + // ctx.validPassword, // )) as string // // fixme: calculate index based on last account // const index = 0 @@ -246,19 +263,19 @@ const apiRouter = router({ // const key = walletCore.StoredKey.importHDWallet( // mnemonic, // input.name, - // Buffer.from(input.password), + // Buffer.from(ctx.validPassword), // walletCore.CoinType.ethereum, // ) // const privateKey = key - // .wallet(Buffer.from(input.password)) + // .wallet(Buffer.from(ctx.validPassword)) // .getKey(walletCore.CoinType.ethereum, derivationPath) // // note!: would be categorized separatley from mnemonic wallet and as as private key, so if used instead of adding accounts add private keys from the start // const { id } = await keyStore.importKey( // privateKey.data(), // 'untitled', - // input.password, + // ctx.validPassword, // walletCore.CoinType.ethereum, // walletCore.StoredKeyEncryption.aes256Ctr, // ) @@ -296,25 +313,31 @@ const apiRouter = router({ throw new Error('From address not found') } + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + // const mnemonic = (await keyStore.export( // wallet.id, - // input.password, + // ctx.validPassword, // )) as string // const key = walletCore.StoredKey.importHDWallet( // mnemonic, // wallet.name, - // Buffer.from(input.password), + // Buffer.from(ctx.validPassword), // walletCore.CoinType.ethereum, // ) // const privateKey = key - // .wallet(Buffer.from(input.password)) + // .wallet(Buffer.from(ctx.validPassword)) // .getKey(walletCore.CoinType.ethereum, account.derivationPath) const privateKey = await keyStore.getKey( wallet.id, - input.password, + ctx.validPassword, account, ) @@ -350,9 +373,15 @@ const apiRouter = router({ const wallet = await keyStore.load(input.walletId) + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const { id } = await keyStore.addAccountsWithDerivations( wallet.id, - input.password, + ctx.validPassword, [ { coin: walletCore.CoinType.bitcoin, @@ -381,7 +410,7 @@ const apiRouter = router({ // note!: second default derivation; does not add new account // await keyStore.addAccountsWithDerivations( // wallet.id, - // input.password, + // ctx.validPassword, // [ // { // coin: walletCore.CoinType.bitcoin, @@ -419,9 +448,15 @@ const apiRouter = router({ throw new Error('From address not found') } + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const privateKey = await keyStore.getKey( wallet.id, - input.password, + ctx.validPassword, account, ) @@ -452,9 +487,15 @@ const apiRouter = router({ const wallet = await keyStore.load(input.walletId) + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const { id } = await keyStore.addAccounts( wallet.id, - input.password, + ctx.validPassword, [walletCore.CoinType.solana], ) @@ -487,9 +528,15 @@ const apiRouter = router({ throw new Error('From address not found') } + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const privateKey = await keyStore.getKey( wallet.id, - input.password, + ctx.validPassword, account, ) @@ -520,9 +567,15 @@ const apiRouter = router({ const wallet = await keyStore.load(input.walletId) + if (!ctx.validPassword) + throw new TRPCError({ + message: 'Invalid password', + code: 'UNAUTHORIZED', + }) + const { id } = await keyStore.addAccounts( wallet.id, - input.password, + ctx.validPassword, [walletCore.CoinType.cardano], ) diff --git a/apps/wallet/src/data/trpc/middlewares/password-auth.ts b/apps/wallet/src/data/trpc/middlewares/password-auth.ts index c503c81d5..7740c141b 100644 --- a/apps/wallet/src/data/trpc/middlewares/password-auth.ts +++ b/apps/wallet/src/data/trpc/middlewares/password-auth.ts @@ -11,6 +11,10 @@ type Context = { keyStore: KeyStore.Default } +export type ValidPasswordContext = { + validPassword?: string +} + export function createPasswordAuthPlugin() { const t = initTRPC.context().create({ isServer: false, From aa64eca69b2c2b0b1cf22ffcecff844c833ddbe8 Mon Sep 17 00:00:00 2001 From: yqrashawn Date: Tue, 17 Jun 2025 09:43:35 +0800 Subject: [PATCH 11/11] feat: get wallet hook (used only in dev for now) --- apps/wallet/src/hooks/use-get-wallet.tsx | 29 ++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 apps/wallet/src/hooks/use-get-wallet.tsx diff --git a/apps/wallet/src/hooks/use-get-wallet.tsx b/apps/wallet/src/hooks/use-get-wallet.tsx new file mode 100644 index 000000000..322da2e8d --- /dev/null +++ b/apps/wallet/src/hooks/use-get-wallet.tsx @@ -0,0 +1,29 @@ +import { queryOptions, useQuery } from '@tanstack/react-query' + +import { useAPI } from '../providers/api-client' + +export const useGetWallet = (walletId: string, password: string) => { + const api = useAPI() + + const result = useQuery( + queryOptions({ + enabled: Boolean(walletId && password), + queryKey: ['get-wallet', walletId], + queryFn: async () => { + const { mnemonic } = await api.wallet.get.query({ + walletId, + password, + }) + + return mnemonic + }, + staleTime: 60 * 60 * 1000, // 1 hour + gcTime: 60 * 60 * 1000, // 1 hour + refetchOnMount: false, + refetchOnWindowFocus: false, + refetchOnReconnect: false, + }), + ) + + return result +}