diff --git a/bin/src/py.rs b/bin/src/py.rs index fc97c9e..adb464d 100644 --- a/bin/src/py.rs +++ b/bin/src/py.rs @@ -51,15 +51,15 @@ fn get_export( let n_args = f.args.args.len(); let has_return = f.returns.is_some(); - if is_plugin_fn && n_args > 0 { - anyhow::bail!( - "plugin_fn expects a function with no arguments, {func} should have no arguments" - ); - } + // if is_plugin_fn && n_args > 0 { + // anyhow::bail!( + // "plugin_fn expects a function with no arguments, {func} should have no arguments" + // ); + // } Ok(Export { name: func, is_plugin_fn, - params: vec![wagen::ValType::I64; n_args], + params: vec![wagen::ValType::I64; 0], results: if is_plugin_fn { vec![wagen::ValType::I32] } else if has_return { diff --git a/examples/plugin-fn-param-extractors/out.wasm b/examples/plugin-fn-param-extractors/out.wasm new file mode 100644 index 0000000..d2d4fe2 Binary files /dev/null and b/examples/plugin-fn-param-extractors/out.wasm differ diff --git a/examples/plugin-fn-param-extractors/plugin.py b/examples/plugin-fn-param-extractors/plugin.py new file mode 100644 index 0000000..d75a000 --- /dev/null +++ b/examples/plugin-fn-param-extractors/plugin.py @@ -0,0 +1,27 @@ +import extism +from dataclasses import dataclass + +@dataclass +class Count(extism.Json): + count: int + +@dataclass +class CountVowelsInput(extism.Json): + text: str + +@extism.plugin_fn +def count_vowels(cfg: extism.Config, input: str) -> Count: + msg = cfg.get_str("message") + extism.log(extism.LogLevel.Info, f"Config: {msg}") + extism.log(extism.LogLevel.Info, f"Input: {input}") + + +@extism.plugin_fn +def count_vowels_dataclass(input: CountVowelsInput) -> Count: + extism.log(extism.LogLevel.Info, f"Json Input: {input.text}") + + +@extism.plugin_fn +def count_vowels_http(http: extism.Http) -> Count: + resp = http.request("http://www.example.com") + extism.log(extism.LogLevel.Info, f"Response: {resp}") \ No newline at end of file diff --git a/examples/plugin-fn-param-extractors/readme.md b/examples/plugin-fn-param-extractors/readme.md new file mode 100644 index 0000000..a500473 --- /dev/null +++ b/examples/plugin-fn-param-extractors/readme.md @@ -0,0 +1,28 @@ + +## Building the example + +```sh +./build.py +./extism-py examples/plugin-fn-param-extractors/plugin.py -o examples/plugin-fn-param-extractors/out.wasm + +``` + +## Calling the example functions + +```sh +extism call examples/plugin-fn-param-extractors/out.wasm count_vowels \ + --wasi \ + --input='Hello World Test!' \ + --log-level=info \ + --config message="hello" + +extism call examples/plugin-fn-param-extractors/out.wasm count_vowels_dataclass \ + --wasi \ + --input='{"text": "Hello"}' \ + --log-level=info + +extism call examples/plugin-fn-param-extractors/out.wasm count_vowels_http \ + --wasi \ + --allow-host '*' \ + --log-level=info +``` \ No newline at end of file diff --git a/lib/src/prelude.py b/lib/src/prelude.py index c7c2aa1..65affc2 100644 --- a/lib/src/prelude.py +++ b/lib/src/prelude.py @@ -1,10 +1,12 @@ from typing import Union, Optional import json +import inspect from enum import Enum from abc import ABC, abstractmethod from datetime import datetime from base64 import b64encode, b64decode from dataclasses import is_dataclass +from functools import partial import extism_ffi as ffi @@ -213,12 +215,37 @@ def wrapper(*args): def plugin_fn(func): """Annotate a function that will be called by Extism""" global __exports - __exports.append(func) - - def inner(): - return func() - - return inner + def _handle_arg(arg_name: str, arg_type: type): + match (arg_name, arg_type): + case (_, _type) if _type is Config: + return Config + case (_, _type) if _type is Http: + return Http + case _: + raise ValueError(f"Unsupported argument") + + + + sig = inspect.signature(func) + annotated_args = {k: v.annotation for k, v in sig.parameters.items()} + if "input" in annotated_args: + input_arg = annotated_args.pop("input") + else: + input_arg = None + func_args = {k: _handle_arg(k, v) for k, v in annotated_args.items()} + func_args = {k: v for k, v in func_args.items() if v is not None} + annotated_func = partial(func, **func_args) + + def _defered_input(): + fn_input = input(input_arg) + return annotated_func(input=fn_input) + + if input_arg: + __exports.append(_defered_input) + return _defered_input + else: + __exports.append(annotated_func) + return annotated_func def shared_fn(f): diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..58a7f8f --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.81.0" +targets = ["wasm32-wasi"] \ No newline at end of file