Skip to content

Commit b06e415

Browse files
authored
Added Geo filter (#8)
1 parent 0299ece commit b06e415

File tree

6 files changed

+355
-1
lines changed

6 files changed

+355
-1
lines changed

src/Enum/SearchField.php

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
namespace Vladvildanov\PredisVl\Enum;
44

5+
use Predis\Command\Argument\Search\SchemaFields\GeoField;
56
use Predis\Command\Argument\Search\SchemaFields\NumericField;
67
use Predis\Command\Argument\Search\SchemaFields\TagField;
78
use Predis\Command\Argument\Search\SchemaFields\TextField;
@@ -16,6 +17,7 @@ enum SearchField
1617
case text;
1718
case numeric;
1819
case vector;
20+
case geo;
1921

2022
/**
2123
* Returns field class corresponding to given case.
@@ -29,6 +31,7 @@ public function fieldMapping(): string
2931
self::text => TextField::class,
3032
self::numeric => NumericField::class,
3133
self::vector => VectorField::class,
34+
self::geo => GeoField::class,
3235
};
3336
}
3437
}

src/Enum/Unit.php

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
<?php
2+
3+
namespace Vladvildanov\PredisVl\Enum;
4+
5+
use Vladvildanov\PredisVl\Enum\Traits\EnumNames;
6+
7+
enum Unit: string
8+
{
9+
use EnumNames;
10+
11+
case kilometers = 'km';
12+
case meters = 'm';
13+
case miles = 'mi';
14+
case foots = 'ft';
15+
}

src/Query/Filter/GeoFilter.php

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
<?php
2+
3+
namespace Vladvildanov\PredisVl\Query\Filter;
4+
5+
use JetBrains\PhpStorm\ArrayShape;
6+
use Vladvildanov\PredisVl\Enum\Condition;
7+
use Vladvildanov\PredisVl\Enum\Unit;
8+
9+
class GeoFilter extends AbstractFilter
10+
{
11+
/**
12+
* Creates geo filter based on condition.
13+
* Values should be provided as a specific-shaped array described as ArrayShape attribute.
14+
*
15+
* Only equal, notEqual conditions are allowed.
16+
*
17+
* @param string $fieldName
18+
* @param Condition $condition
19+
* @param array{lon: float, lat: float, radius: int, unit: Unit} $value
20+
*/
21+
public function __construct(
22+
string $fieldName,
23+
Condition $condition,
24+
#[ArrayShape([
25+
'lon' => 'float',
26+
'lat' => 'float',
27+
'radius' => 'int',
28+
'unit' => Unit::class
29+
])] $value
30+
) {
31+
parent::__construct($fieldName, $condition, $value);
32+
}
33+
34+
/**
35+
* @inheritDoc
36+
*/
37+
public function toExpression(): string
38+
{
39+
$condition = $this->conditionMappings[$this->condition->value];
40+
41+
$lon = (float) $this->value['lon'];
42+
$lat = (float) $this->value['lat'];
43+
44+
return "$condition@$this->fieldName:[{$lon} {$lat} {$this->value['radius']} {$this->value['unit']->value}]";
45+
}
46+
}

tests/Feature/Index/SearchIndexTest.php

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
use Vladvildanov\PredisVl\Enum\Condition;
66
use Vladvildanov\PredisVl\Enum\Logical;
7+
use Vladvildanov\PredisVl\Enum\Unit;
78
use Vladvildanov\PredisVl\Feature\FeatureTestCase;
89
use Predis\Client;
910
use Vladvildanov\PredisVl\Index\SearchIndex;
1011
use Vladvildanov\PredisVl\Query\Filter\FilterInterface;
12+
use Vladvildanov\PredisVl\Query\Filter\GeoFilter;
1113
use Vladvildanov\PredisVl\Query\Filter\NumericFilter;
1214
use Vladvildanov\PredisVl\Query\Filter\TagFilter;
1315
use Vladvildanov\PredisVl\Query\Filter\TextFilter;
@@ -515,6 +517,94 @@ public function testVectorQueryHashWithTextFilter(
515517
}
516518
}
517519

520+
/**
521+
* @dataProvider vectorGeoFilterProvider
522+
* @param FilterInterface|null $filter
523+
* @param array $expectedResponse
524+
* @return void
525+
*/
526+
public function testVectorQueryHashWithGeoFilter(
527+
?FilterInterface $filter,
528+
array $expectedResponse
529+
): void {
530+
$schema = [
531+
'index' => [
532+
'name' => 'products',
533+
'prefix' => 'product:',
534+
],
535+
'fields' => [
536+
'id' => [
537+
'type' => 'text',
538+
],
539+
'price' => [
540+
'type' => 'numeric',
541+
],
542+
'location' => [
543+
'type' => 'geo',
544+
],
545+
'description_embedding' => [
546+
'type' => 'vector',
547+
'dims' => 3,
548+
'datatype' => 'float32',
549+
'algorithm' => 'flat',
550+
'distance_metric' => 'cosine'
551+
],
552+
],
553+
];
554+
555+
$index = new SearchIndex($this->client, $schema);
556+
$this->assertEquals('OK', $index->create());
557+
558+
$this->assertTrue($index->load(
559+
'1',
560+
[
561+
'id' => '1', 'price' => 10, 'location' => '10.111,11.111',
562+
'description_embedding' => VectorHelper::toBytes([0.001, 0.002, 0.003])
563+
])
564+
);
565+
$this->assertTrue($index->load(
566+
'2',
567+
[
568+
'id' => '2', 'price' => 20, 'location' => '10.222,11.222',
569+
'description_embedding' => VectorHelper::toBytes([0.001, 0.002, 0.003])
570+
])
571+
);
572+
$this->assertTrue($index->load(
573+
'3',
574+
[
575+
'id' => '3', 'price' => 30, 'location' => '10.333,11.333',
576+
'description_embedding' => VectorHelper::toBytes([0.001, 0.002, 0.003])
577+
])
578+
);
579+
$this->assertTrue($index->load(
580+
'4',
581+
[
582+
'id' => '4', 'price' => 40, 'location' => '10.444,11.444',
583+
'description_embedding' => VectorHelper::toBytes([0.001, 0.002, 0.003])
584+
])
585+
);
586+
587+
$query = new VectorQuery(
588+
[0.001, 0.002, 0.03],
589+
'description_embedding',
590+
null,
591+
10,
592+
true,
593+
2,
594+
$filter
595+
);
596+
597+
$response = $index->query($query);
598+
$this->assertSame($expectedResponse['count'], $response['count']);
599+
600+
foreach ($expectedResponse['results'] as $key => $value) {
601+
$this->assertSame(
602+
$expectedResponse['results'][$key]['location'],
603+
$response['results'][$key]['location']
604+
);
605+
}
606+
}
607+
518608
/**
519609
* @return void
520610
*/
@@ -925,6 +1015,100 @@ public function testVectorQueryJsonIndexWithTextFilter(
9251015
}
9261016
}
9271017

1018+
/**
1019+
* @dataProvider vectorGeoFilterProvider
1020+
* @param FilterInterface|null $filter
1021+
* @param array $expectedResponse
1022+
* @return void
1023+
* @throws \JsonException
1024+
*/
1025+
public function testVectorQueryJsonIndexWithGeoFilter(
1026+
?FilterInterface $filter,
1027+
array $expectedResponse
1028+
): void {
1029+
$schema = [
1030+
'index' => [
1031+
'name' => 'products',
1032+
'prefix' => 'product:',
1033+
'storage_type' => 'json'
1034+
],
1035+
'fields' => [
1036+
'$.id' => [
1037+
'type' => 'text',
1038+
],
1039+
'$.price' => [
1040+
'type' => 'numeric',
1041+
],
1042+
'$.location' => [
1043+
'type' => 'geo',
1044+
'alias' => 'location',
1045+
],
1046+
'$.description_embedding' => [
1047+
'type' => 'vector',
1048+
'dims' => 3,
1049+
'datatype' => 'float32',
1050+
'algorithm' => 'flat',
1051+
'distance_metric' => 'cosine',
1052+
'alias' => 'vector_embedding',
1053+
],
1054+
],
1055+
];
1056+
1057+
$index = new SearchIndex($this->client, $schema);
1058+
$this->assertEquals('OK', $index->create());
1059+
1060+
$this->assertTrue($index->load(
1061+
'1',
1062+
json_encode([
1063+
'id' => '1', 'price' => 10, 'location' => ['10.111,11.111'],
1064+
'description_embedding' => [0.001, 0.002, 0.003],
1065+
], JSON_THROW_ON_ERROR)
1066+
));
1067+
$this->assertTrue($index->load(
1068+
'2',
1069+
json_encode([
1070+
'id' => '2', 'price' => 20, 'location' => ['10.222,11.222'],
1071+
'description_embedding' => [0.001, 0.002, 0.003],
1072+
], JSON_THROW_ON_ERROR)
1073+
));
1074+
$this->assertTrue($index->load(
1075+
'3',
1076+
json_encode([
1077+
'id' => '3', 'price' => 30, 'location' => ['10.333,11.333'],
1078+
'description_embedding' => [0.001, 0.002, 0.003],
1079+
], JSON_THROW_ON_ERROR)
1080+
));
1081+
$this->assertTrue($index->load(
1082+
'4',
1083+
json_encode([
1084+
'id' => '4', 'price' => 40, 'location' => ['10.444,11.444'],
1085+
'description_embedding' => [0.001, 0.002, 0.003],
1086+
], JSON_THROW_ON_ERROR)
1087+
));
1088+
1089+
$query = new VectorQuery(
1090+
[0.001, 0.002, 0.03],
1091+
'vector_embedding',
1092+
null,
1093+
10,
1094+
true,
1095+
2,
1096+
$filter
1097+
);
1098+
1099+
$response = $index->query($query);
1100+
$this->assertSame($expectedResponse['count'], $response['count']);
1101+
1102+
foreach ($response['results'] as $key => $value) {
1103+
$decodedResponse = json_decode($value['$'], true, 512, JSON_THROW_ON_ERROR);
1104+
1105+
$this->assertSame(
1106+
$expectedResponse['results'][$key]['location'],
1107+
$decodedResponse['location'][0]
1108+
);
1109+
}
1110+
}
1111+
9281112
/**
9291113
* @return void
9301114
* @throws \JsonException
@@ -1484,4 +1668,66 @@ public static function vectorTextFilterProvider(): array
14841668
],
14851669
];
14861670
}
1671+
1672+
public static function vectorGeoFilterProvider(): array
1673+
{
1674+
return [
1675+
'default' => [
1676+
null,
1677+
[
1678+
'count' => 4,
1679+
'results' => [
1680+
'product:1' => [
1681+
'location' => '10.111,11.111',
1682+
],
1683+
'product:2' => [
1684+
'location' => '10.222,11.222',
1685+
],
1686+
'product:3' => [
1687+
'location' => '10.333,11.333',
1688+
],
1689+
'product:4' => [
1690+
'location' => '10.444,11.444',
1691+
],
1692+
]
1693+
]
1694+
],
1695+
'equal_radius' => [
1696+
new GeoFilter(
1697+
'location',
1698+
Condition::equal,
1699+
['lon' => 10.000, 'lat' => 12.000, 'radius' => 85, 'unit' => Unit::kilometers]
1700+
),
1701+
[
1702+
'count' => 2,
1703+
'results' => [
1704+
'product:3' => [
1705+
'location' => '10.333,11.333',
1706+
],
1707+
'product:4' => [
1708+
'location' => '10.444,11.444',
1709+
],
1710+
]
1711+
]
1712+
],
1713+
'not_equal_radius' => [
1714+
new GeoFilter(
1715+
'location',
1716+
Condition::notEqual,
1717+
['lon' => 10.000, 'lat' => 12.000, 'radius' => 85, 'unit' => Unit::kilometers]
1718+
),
1719+
[
1720+
'count' => 2,
1721+
'results' => [
1722+
'product:1' => [
1723+
'location' => '10.111,11.111',
1724+
],
1725+
'product:2' => [
1726+
'location' => '10.222,11.222',
1727+
],
1728+
]
1729+
]
1730+
],
1731+
];
1732+
}
14871733
}

tests/Unit/Enum/SearchFieldTest.php

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
namespace Vladvildanov\PredisVl\Unit\Enum;
44

55
use PHPUnit\Framework\TestCase;
6+
use Predis\Command\Argument\Search\SchemaFields\GeoField;
67
use Predis\Command\Argument\Search\SchemaFields\NumericField;
78
use Predis\Command\Argument\Search\SchemaFields\TagField;
89
use Predis\Command\Argument\Search\SchemaFields\TextField;
@@ -27,7 +28,7 @@ public function testFieldMapping(SearchField $enum, string $expectedClass): void
2728
*/
2829
public function testNames(): void
2930
{
30-
$this->assertSame(['tag', 'text', 'numeric', 'vector'], SearchField::names());
31+
$this->assertSame(['tag', 'text', 'numeric', 'vector', 'geo'], SearchField::names());
3132
}
3233

3334
/**
@@ -48,6 +49,7 @@ public static function mappingProvider(): array
4849
'text' => [SearchField::text, TextField::class],
4950
'numeric' => [SearchField::numeric, NumericField::class],
5051
'vector' => [SearchField::vector, VectorField::class],
52+
'geo' => [SearchField::geo, GeoField::class],
5153
];
5254
}
5355

@@ -58,6 +60,7 @@ public static function fromNameProvider(): array
5860
['text', SearchField::text],
5961
['numeric', SearchField::numeric],
6062
['vector', SearchField::vector],
63+
['geo', SearchField::geo],
6164
];
6265
}
6366
}

0 commit comments

Comments
 (0)