-
-
Notifications
You must be signed in to change notification settings - Fork 513
PHPORM-382 Add $vectorSearch stage to the aggregation builder #2822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Doctrine\ODM\MongoDB\Aggregation\Stage; | ||
|
||
use Doctrine\ODM\MongoDB\Aggregation\Builder; | ||
use Doctrine\ODM\MongoDB\Aggregation\Stage; | ||
use Doctrine\ODM\MongoDB\Query\Expr; | ||
use MongoDB\BSON\Binary; | ||
use MongoDB\BSON\Decimal128; | ||
use MongoDB\BSON\Int64; | ||
|
||
/** | ||
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>|Binary | ||
* @phpstan-type VectorSearchStageExpression array{ | ||
* '$vectorSearch': object{ | ||
* exact?: bool, | ||
* filter?: object, | ||
* index?: string, | ||
* limit?: int, | ||
* numCandidates?: int, | ||
* path?: string, | ||
* queryVector?: Vector, | ||
* } | ||
* } | ||
*/ | ||
class VectorSearch extends Stage | ||
{ | ||
private ?bool $exact = null; | ||
private ?Expr $filter = null; | ||
private ?string $index = null; | ||
private ?int $limit = null; | ||
private ?int $numCandidates = null; | ||
private ?string $path = null; | ||
/** @phpstan-var Vector|null */ | ||
private array|Binary|null $queryVector = null; | ||
|
||
public function __construct(Builder $builder) | ||
{ | ||
parent::__construct($builder); | ||
} | ||
|
||
public function getExpression(): array | ||
{ | ||
$params = []; | ||
|
||
if ($this->exact !== null) { | ||
$params['exact'] = $this->exact; | ||
} | ||
|
||
if ($this->filter !== null) { | ||
$params['filter'] = $this->filter->getQuery(); | ||
} | ||
|
||
if ($this->index !== null) { | ||
$params['index'] = $this->index; | ||
} | ||
|
||
if ($this->limit !== null) { | ||
$params['limit'] = $this->limit; | ||
} | ||
|
||
if ($this->numCandidates !== null) { | ||
$params['numCandidates'] = $this->numCandidates; | ||
} | ||
|
||
if ($this->path !== null) { | ||
$params['path'] = $this->path; | ||
} | ||
Comment on lines
+68
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a TODO item? If so, is there another ticket to track this? This seems related to #2820 (comment) from the PR that introduced a |
||
|
||
if ($this->queryVector !== null) { | ||
$params['queryVector'] = $this->queryVector; | ||
} | ||
Comment on lines
+48
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense to me. Maybe it's nice to document this as a standard in some sort of development decision documentation in the repo, as I think we've had conversations about this before. |
||
|
||
return [$this->getStageName() => $params]; | ||
} | ||
|
||
public function exact(bool $exact): static | ||
{ | ||
$this->exact = $exact; | ||
|
||
return $this; | ||
} | ||
|
||
public function filter(Expr $filter): static | ||
{ | ||
$this->filter = $filter; | ||
|
||
return $this; | ||
} | ||
|
||
public function index(string $index): static | ||
{ | ||
$this->index = $index; | ||
|
||
return $this; | ||
} | ||
|
||
public function limit(int $limit): static | ||
{ | ||
$this->limit = $limit; | ||
|
||
return $this; | ||
} | ||
|
||
public function numCandidates(int $numCandidates): static | ||
{ | ||
$this->numCandidates = $numCandidates; | ||
|
||
return $this; | ||
} | ||
|
||
public function path(string $path): static | ||
{ | ||
$this->path = $path; | ||
|
||
return $this; | ||
} | ||
|
||
/** @phpstan-param Vector $queryVector */ | ||
public function queryVector(array|Binary $queryVector): static | ||
{ | ||
$this->queryVector = $queryVector; | ||
|
||
return $this; | ||
} | ||
|
||
protected function getStageName(): string | ||
{ | ||
return '$vectorSearch'; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Doctrine\ODM\MongoDB\Tests\Aggregation\Stage; | ||
|
||
use Doctrine\ODM\MongoDB\Aggregation\Stage\VectorSearch; | ||
use Doctrine\ODM\MongoDB\Tests\Aggregation\AggregationTestTrait; | ||
use Doctrine\ODM\MongoDB\Tests\BaseTestCase; | ||
use MongoDB\BSON\Binary; | ||
|
||
class VectorSearchTest extends BaseTestCase | ||
{ | ||
use AggregationTestTrait; | ||
|
||
public function testEmptyStage(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
self::assertSame(['$vectorSearch' => []], $stage->getExpression()); | ||
} | ||
|
||
public function testExact(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->exact(true); | ||
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression()); | ||
} | ||
|
||
public function testFilter(): void | ||
{ | ||
$builder = $this->getTestAggregationBuilder(); | ||
$stage = new VectorSearch($builder); | ||
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive')); | ||
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression()); | ||
} | ||
|
||
public function testIndex(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->index('myIndex'); | ||
self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression()); | ||
} | ||
|
||
public function testLimit(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->limit(10); | ||
self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression()); | ||
} | ||
|
||
public function testNumCandidates(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->numCandidates(5); | ||
self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression()); | ||
} | ||
|
||
public function testPath(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->path('vectorField'); | ||
self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression()); | ||
} | ||
|
||
public function testQueryVector(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$stage->queryVector([1, 2, 3]); | ||
self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression()); | ||
} | ||
|
||
public function testQueryVectorAcceptsBinary(): void | ||
{ | ||
$stage = new VectorSearch($this->getTestAggregationBuilder()); | ||
$binaryVector = new Binary("\x01\x02\x03", 9); | ||
$stage->queryVector($binaryVector); | ||
self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression()); | ||
} | ||
|
||
public function testChainingAllOptions(): void | ||
{ | ||
$builder = $this->getTestAggregationBuilder(); | ||
$stage = (new VectorSearch($builder)) | ||
->exact(false) | ||
->filter($builder->matchExpr()->field('status')->notEqual('inactive')) | ||
->index('idx') | ||
->limit(7) | ||
->numCandidates(3) | ||
->path('vec') | ||
->queryVector([0.1, 0.2]); | ||
self::assertSame([ | ||
'$vectorSearch' => [ | ||
'exact' => false, | ||
'filter' => ['status' => ['$ne' => 'inactive']], | ||
'index' => 'idx', | ||
'limit' => 7, | ||
'numCandidates' => 3, | ||
'path' => 'vec', | ||
'queryVector' => [0.1, 0.2], | ||
], | ||
], $stage->getExpression()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this PR does not add a
Stage::vectorSearch()
method for the same reason that you deprecatedStage::search()
in #2823, correct?