Skip to content

Commit 45ea78c

Browse files
committed
bug #173 [AIBundle] Fix stream result profiling (valtzu)
This PR was merged into the main branch. Discussion ---------- [AIBundle] Fix stream result profiling | Q | A | ------------- | --- | Bug fix? | yes | New feature? | no | Docs? | no | Issues | Fix #161 | License | MIT Address 2 streaming-related issues: 1. When using `symfony/ai` from within `StreamedResponse` callback (in http controller), no data was collected due to not implementing `LateDataCollectorInterface` 2. An attempt to save `StreamResult` to the profiler storage always failed, because `\Generator`s are not serializable Both are used in the Turbo Stream Bot demo, and now the profiler works. Though the toolbar still does not auto-update, but that is currently expected behavior. Commits ------- 53a05cb [AIBundle] Fix stream result profiling
2 parents 0f569f4 + 53a05cb commit 45ea78c

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

src/ai-bundle/src/Profiler/DataCollector.php

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
use Symfony\Component\DependencyInjection\Attribute\TaggedIterator;
1919
use Symfony\Component\HttpFoundation\Request;
2020
use Symfony\Component\HttpFoundation\Response;
21+
use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface;
2122

2223
/**
2324
* @author Christopher Hertel <[email protected]>
2425
*
2526
* @phpstan-import-type PlatformCallData from TraceablePlatform
2627
* @phpstan-import-type ToolCallData from TraceableToolbox
2728
*/
28-
final class DataCollector extends AbstractDataCollector
29+
final class DataCollector extends AbstractDataCollector implements LateDataCollectorInterface
2930
{
3031
/**
3132
* @var TraceablePlatform[]
@@ -53,6 +54,11 @@ public function __construct(
5354
}
5455

5556
public function collect(Request $request, Response $response, ?\Throwable $exception = null): void
57+
{
58+
$this->lateCollect();
59+
}
60+
61+
public function lateCollect(): void
5662
{
5763
$this->data = [
5864
'tools' => $this->defaultToolBox->getTools(),
@@ -102,7 +108,14 @@ private function awaitCallResults(TraceablePlatform $platform): array
102108
{
103109
$calls = $platform->calls;
104110
foreach ($calls as $key => $call) {
105-
$call['result'] = $call['result']->await()->getContent();
111+
$result = $call['result']->await();
112+
113+
if (isset($platform->resultCache[$result])) {
114+
$call['result'] = $platform->resultCache[$result];
115+
} else {
116+
$call['result'] = $result->getContent();
117+
}
118+
106119
$calls[$key] = $call;
107120
}
108121

src/ai-bundle/src/Profiler/TraceablePlatform.php

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
use Symfony\AI\Platform\Message\Content\File;
1515
use Symfony\AI\Platform\Model;
1616
use Symfony\AI\Platform\PlatformInterface;
17+
use Symfony\AI\Platform\Result\ResultInterface;
1718
use Symfony\AI\Platform\Result\ResultPromise;
19+
use Symfony\AI\Platform\Result\StreamResult;
1820

1921
/**
2022
* @author Christopher Hertel <[email protected]>
@@ -32,27 +34,50 @@ final class TraceablePlatform implements PlatformInterface
3234
* @var PlatformCallData[]
3335
*/
3436
public array $calls = [];
37+
/**
38+
* @var \WeakMap<ResultInterface, string>
39+
*/
40+
public \WeakMap $resultCache;
3541

3642
public function __construct(
3743
private readonly PlatformInterface $platform,
3844
) {
45+
$this->resultCache = new \WeakMap();
3946
}
4047

4148
public function invoke(Model $model, array|string|object $input, array $options = []): ResultPromise
4249
{
43-
$result = $this->platform->invoke($model, $input, $options);
50+
$resultPromise = $this->platform->invoke($model, $input, $options);
4451

4552
if ($input instanceof File) {
4653
$input = $input::class.': '.$input->getFormat();
4754
}
4855

56+
if ($options['stream'] ?? false) {
57+
$originalStream = $resultPromise->asStream();
58+
$resultPromise = new ResultPromise(fn () => $this->createTraceableStreamResult($originalStream), $resultPromise->getRawResult(), $options);
59+
}
60+
4961
$this->calls[] = [
5062
'model' => $model,
5163
'input' => \is_object($input) ? clone $input : $input,
5264
'options' => $options,
53-
'result' => $result,
65+
'result' => $resultPromise,
5466
];
5567

56-
return $result;
68+
return $resultPromise;
69+
}
70+
71+
private function createTraceableStreamResult(\Generator $originalStream): StreamResult
72+
{
73+
return $result = new StreamResult((function () use (&$result, $originalStream) {
74+
$this->resultCache[$result] = '';
75+
foreach ($originalStream as $chunk) {
76+
yield $chunk;
77+
if (\is_string($chunk)) {
78+
$this->resultCache[$result] .= $chunk;
79+
}
80+
}
81+
})());
5782
}
5883
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\AIBundle\Tests\Profiler;
13+
14+
use PHPUnit\Framework\Attributes\CoversClass;
15+
use PHPUnit\Framework\Attributes\UsesClass;
16+
use PHPUnit\Framework\TestCase;
17+
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
18+
use Symfony\AI\AIBundle\Profiler\DataCollector;
19+
use Symfony\AI\AIBundle\Profiler\TraceablePlatform;
20+
use Symfony\AI\Platform\Message\Content\Text;
21+
use Symfony\AI\Platform\Message\Message;
22+
use Symfony\AI\Platform\Message\MessageBag;
23+
use Symfony\AI\Platform\Model;
24+
use Symfony\AI\Platform\PlatformInterface;
25+
use Symfony\AI\Platform\Result\RawResultInterface;
26+
use Symfony\AI\Platform\Result\ResultPromise;
27+
use Symfony\AI\Platform\Result\StreamResult;
28+
use Symfony\AI\Platform\Result\TextResult;
29+
30+
#[CoversClass(DataCollector::class)]
31+
#[UsesClass(TraceablePlatform::class)]
32+
#[UsesClass(ResultPromise::class)]
33+
class DataCollectorTest extends TestCase
34+
{
35+
public function testCollectsDataForNonStreamingResponse()
36+
{
37+
$platform = $this->createMock(PlatformInterface::class);
38+
$traceablePlatform = new TraceablePlatform($platform);
39+
$messageBag = new MessageBag(Message::ofUser(new Text('Hello')));
40+
$result = new TextResult('Assistant response');
41+
42+
$platform->method('invoke')->willReturn(new ResultPromise(static fn () => $result, $this->createStub(RawResultInterface::class)));
43+
44+
$result = $traceablePlatform->invoke($this->createStub(Model::class), $messageBag, ['stream' => false]);
45+
$this->assertSame('Assistant response', $result->asText());
46+
47+
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), []);
48+
$dataCollector->lateCollect();
49+
50+
$this->assertCount(1, $dataCollector->getPlatformCalls());
51+
$this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']);
52+
}
53+
54+
public function testCollectsDataForStreamingResponse()
55+
{
56+
$platform = $this->createMock(PlatformInterface::class);
57+
$traceablePlatform = new TraceablePlatform($platform);
58+
$messageBag = new MessageBag(Message::ofUser(new Text('Hello')));
59+
$result = new StreamResult(
60+
(function () {
61+
yield 'Assistant ';
62+
yield 'response';
63+
})(),
64+
);
65+
66+
$platform->method('invoke')->willReturn(new ResultPromise(static fn () => $result, $this->createStub(RawResultInterface::class)));
67+
68+
$result = $traceablePlatform->invoke($this->createStub(Model::class), $messageBag, ['stream' => true]);
69+
$this->assertSame('Assistant response', implode('', iterator_to_array($result->asStream())));
70+
71+
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), []);
72+
$dataCollector->lateCollect();
73+
74+
$this->assertCount(1, $dataCollector->getPlatformCalls());
75+
$this->assertSame('Assistant response', $dataCollector->getPlatformCalls()[0]['result']);
76+
}
77+
}

0 commit comments

Comments
 (0)