Skip to content

Commit abf074e

Browse files
formatting
1 parent 398aa76 commit abf074e

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Redis Model Store
22

3+
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
4+
![Language](https://img.shields.io/github/languages/top/redis-applied-ai/redis-model-store)
5+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
6+
![GitHub last commit](https://img.shields.io/github/last-commit/redis-applied-ai/redis-model-store)
7+
[![pypi](https://badge.fury.io/py/redisvl.svg)](https://pypi.org/project/redis-model-store/)
8+
39
`redis-model-store` is a simple Python library designed to handle versioning and serialization of AI/ML models into Redis. It provides a streamlined way to manage your machine learning models in Redis.
410

511
## Features
@@ -8,7 +14,6 @@
814
- **Sharding for Large Models**: Splits large serialized payloads into manageable chunks to optimize Redis storage.
915
- **Version Management**: Automatically manages model versions in Redis, allowing you to store and retrieve specific versions.
1016

11-
---
1217

1318
## Installation
1419
```bash
@@ -49,12 +54,16 @@ model = RandomForestClassifier()
4954
model.fit(X_train, y_train)
5055

5156
# Save the model to Redis
52-
version = model_store.save_model(model, "random_forest")
57+
version = model_store.save_model(model, name="random_forest", description="Random forest classifier model")
5358
```
5459

5560
### Load models
5661
```python
5762
# Grab the latest model
58-
model = model_store.load_model("random_forest")
63+
model = model_store.load_model(name="random_forest")
64+
65+
# Grab a specific model version
66+
model = model_store.load_model(name="random_forest", version=version)
5967
```
6068

69+
##

model_store/serialize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
class SerializationError(Exception):
66
"""Raised when model serialization or deserialization fails."""
7+
78
pass
89

910

model_store/shard_manager.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self, shard_size: int, serializer: Serializer = PickleSerializer())
3636
self.shard_size = shard_size
3737
self.serializer = serializer
3838

39-
def _shardify(self, data: bytes) -> Iterator[bytes]:
39+
@staticmethod
40+
def _shardify(data: bytes, shard_size: int) -> Iterator[bytes]:
4041
"""
4142
Split serialized data into fixed-size shards.
4243
@@ -47,8 +48,8 @@ def _shardify(self, data: bytes) -> Iterator[bytes]:
4748
bytes: Successive shards of the data, each up to shard_size in length.
4849
"""
4950
total_size = len(data)
50-
for start in range(0, total_size, self.shard_size):
51-
yield data[start : start + self.shard_size]
51+
for start in range(0, total_size, shard_size):
52+
yield data[start : start + shard_size]
5253

5354
@staticmethod
5455
def shard_key(model_name: str, model_version: str, idx: int) -> str:
@@ -65,7 +66,7 @@ def shard_key(model_name: str, model_version: str, idx: int) -> str:
6566
"""
6667
return f"shard:{model_name}:{model_version}:{idx}"
6768

68-
def to_shards(self, model: Any) -> List[bytes]:
69+
def to_shards(self, model: Any) -> Iterator[bytes]:
6970
"""
7071
Convert model into smaller chunks (shards) ready for storage.
7172
@@ -74,14 +75,13 @@ def to_shards(self, model: Any) -> List[bytes]:
7475
7576
Returns:
7677
List[bytes]: List of binary shards derived from the model.
77-
TODO -- returns a generator here
7878
79-
Raises:
80-
SerializationError: If model serialization fails.
79+
Yields:
80+
bytes: Successive shards of the data, each up to shard_size in length.
8181
"""
8282
try:
8383
serialized_data = self.serializer.dumps(model)
84-
return self._shardify(serialized_data)
84+
return self._shardify(serialized_data, self.shard_size)
8585
except Exception as e:
8686
raise SerializationError(f"Failed to serialize model: {str(e)}") from e
8787

model_store/store.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
class ModelStoreError(Exception):
1313
"""Raised when model store I/O operations fail."""
14+
1415
pass
1516

1617

@@ -214,7 +215,9 @@ def load_model(self, name: str, version: Optional[str] = None) -> Any:
214215
logger.debug(f"Retrieving specified model version: {version}")
215216
model_version = self.model_registry.get_version(name, version)
216217

217-
logger.info(f"Retrieved model version metadata ({current_timestamp()-st:.4f}s)")
218+
logger.info(
219+
f"Retrieved model version metadata ({current_timestamp()-st:.4f}s)"
220+
)
218221
model = self._from_redis(model_version.shard_keys)
219222

220223
total_duration = current_timestamp() - total_start

0 commit comments

Comments
 (0)