@@ -58,12 +58,12 @@ class Foo(BaseModel):
58
58
assert isinstance (validated_output ["bez" ][0 ], str )
59
59
60
60
61
- @pytest . mark . skip ( reason = "Random model infinitely recurses on complex struct. Use GPT2" )
61
+ @if_transformers_installed
62
62
def test_hugging_face_pipeline_complex_schema ():
63
63
# NOTE: This is the real GPT-2 model.
64
64
from transformers import pipeline
65
65
66
- model = pipeline ("text-generation" , "gpt2 " )
66
+ model = pipeline ("text-generation" , "distilgpt2 " )
67
67
68
68
class MultiNum (BaseModel ):
69
69
whole : int
@@ -73,10 +73,12 @@ class Tricky(BaseModel):
73
73
foo : MultiNum
74
74
75
75
g = Guard .for_pydantic (Tricky , output_formatter = "jsonformer" )
76
- response = g (model , prompt = " Sample:" )
76
+ response = g (model , messages = [{ "content" : " Sample:", "role" : "user" }] )
77
77
out = response .validated_output
78
78
assert isinstance (out , dict )
79
79
assert "foo" in out
80
80
assert isinstance (out ["foo" ], dict )
81
- assert isinstance (out ["foo" ]["whole" ], int | float )
81
+ assert isinstance (out ["foo" ]["whole" ], int ) or isinstance (
82
+ out ["foo" ]["whole" ], float
83
+ )
82
84
assert isinstance (out ["foo" ]["frac" ], float )
0 commit comments