Skip to content

Commit ae6c6e6

Browse files
committed
PHPORM-382 Add $vectorSearch stage to the aggregation builder
1 parent e3352c0 commit ae6c6e6

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Builder;
8+
use Doctrine\ODM\MongoDB\Aggregation\Stage;
9+
use Doctrine\ODM\MongoDB\Query\Expr;
10+
use MongoDB\BSON\Decimal128;
11+
use MongoDB\BSON\Int64;
12+
13+
/**
14+
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>
15+
* @phpstan-type VectorSearchStageExpression array{
16+
* '$vectorSearch': object{
17+
* exact?: bool,
18+
* filter?: object,
19+
* index?: string,
20+
* limit?: int,
21+
* numCandidates?: int,
22+
* path?: string,
23+
* queryVector?: Vector,
24+
* }
25+
* }
26+
* @extends Stage<VectorSearchStageExpression>
27+
*/
28+
class VectorSearch extends Stage
29+
{
30+
private ?bool $exact = null;
31+
private ?Expr $filter = null;
32+
private ?string $index = null;
33+
private ?int $limit = null;
34+
private ?int $numCandidates = null;
35+
private ?string $path = null;
36+
/** @phpstan-var Vector */
37+
private ?array $queryVector = null;
38+
39+
public function __construct(Builder $builder)
40+
{
41+
parent::__construct($builder);
42+
}
43+
44+
public function getExpression(): array
45+
{
46+
$params = [];
47+
48+
if ($this->exact !== null) {
49+
$params['exact'] = $this->exact;
50+
}
51+
52+
if ($this->filter !== null) {
53+
$params['filter'] = $this->filter->getQuery();
54+
}
55+
56+
if ($this->index !== null) {
57+
$params['index'] = $this->index;
58+
}
59+
60+
if ($this->limit !== null) {
61+
$params['limit'] = $this->limit;
62+
}
63+
64+
if ($this->numCandidates !== null) {
65+
$params['numCandidates'] = $this->numCandidates;
66+
}
67+
68+
if ($this->path !== null) {
69+
$params['path'] = $this->path;
70+
}
71+
72+
if ($this->queryVector !== null) {
73+
$params['queryVector'] = $this->queryVector;
74+
}
75+
76+
return [$this->getStageName() => $params];
77+
}
78+
79+
public function exact(bool $exact): static
80+
{
81+
$this->exact = $exact;
82+
83+
return $this;
84+
}
85+
86+
public function filter(Expr $filter): static
87+
{
88+
$this->filter = $filter;
89+
90+
return $this;
91+
}
92+
93+
public function index(string $index): static
94+
{
95+
$this->index = $index;
96+
97+
return $this;
98+
}
99+
100+
public function limit(int $limit): static
101+
{
102+
$this->limit = $limit;
103+
104+
return $this;
105+
}
106+
107+
public function numCandidates(int $numCandidates): static
108+
{
109+
$this->numCandidates = $numCandidates;
110+
111+
return $this;
112+
}
113+
114+
public function path(string $path): static
115+
{
116+
$this->path = $path;
117+
118+
return $this;
119+
}
120+
121+
/** @param list<int|Int64>|list<float|Decimal128>|list<bool|0|1> $queryVector */
122+
public function queryVector(array $queryVector): static
123+
{
124+
$this->queryVector = $queryVector;
125+
126+
return $this;
127+
}
128+
129+
protected function getStageName(): string
130+
{
131+
return '$vectorSearch';
132+
}
133+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Tests\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Stage\VectorSearch;
8+
use Doctrine\ODM\MongoDB\Tests\Aggregation\AggregationTestTrait;
9+
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
10+
11+
class VectorSearchTest extends BaseTestCase
12+
{
13+
use AggregationTestTrait;
14+
15+
public function testEmptyStage(): void
16+
{
17+
$stage = new VectorSearch($this->getTestAggregationBuilder());
18+
self::assertSame(['$vectorSearch' => []], $stage->getExpression());
19+
}
20+
21+
public function testExact(): void
22+
{
23+
$stage = new VectorSearch($this->getTestAggregationBuilder());
24+
$stage->exact(true);
25+
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
26+
}
27+
28+
public function testFilter(): void
29+
{
30+
$builder = $this->getTestAggregationBuilder();
31+
$stage = new VectorSearch($builder);
32+
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
33+
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
34+
}
35+
36+
public function testIndex(): void
37+
{
38+
$stage = new VectorSearch($this->getTestAggregationBuilder());
39+
$stage->index('myIndex');
40+
self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression());
41+
}
42+
43+
public function testLimit(): void
44+
{
45+
$stage = new VectorSearch($this->getTestAggregationBuilder());
46+
$stage->limit(10);
47+
self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression());
48+
}
49+
50+
public function testNumCandidates(): void
51+
{
52+
$stage = new VectorSearch($this->getTestAggregationBuilder());
53+
$stage->numCandidates(5);
54+
self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression());
55+
}
56+
57+
public function testPath(): void
58+
{
59+
$stage = new VectorSearch($this->getTestAggregationBuilder());
60+
$stage->path('vectorField');
61+
self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression());
62+
}
63+
64+
public function testQueryVector(): void
65+
{
66+
$stage = new VectorSearch($this->getTestAggregationBuilder());
67+
$stage->queryVector([1, 2, 3]);
68+
self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression());
69+
}
70+
71+
public function testChainingAllOptions(): void
72+
{
73+
$builder = $this->getTestAggregationBuilder();
74+
$stage = (new VectorSearch($builder))
75+
->exact(false)
76+
->filter($builder->matchExpr()->field('status')->notEqual('inactive'))
77+
->index('idx')
78+
->limit(7)
79+
->numCandidates(3)
80+
->path('vec')
81+
->queryVector([0.1, 0.2]);
82+
self::assertSame([
83+
'$vectorSearch' => [
84+
'exact' => false,
85+
'filter' => ['status' => ['$ne' => 'inactive']],
86+
'index' => 'idx',
87+
'limit' => 7,
88+
'numCandidates' => 3,
89+
'path' => 'vec',
90+
'queryVector' => [0.1, 0.2],
91+
],
92+
], $stage->getExpression());
93+
}
94+
}

0 commit comments

Comments
 (0)