Skip to content

Commit 1a7eaa8

Browse files
committed
feat: Validate shard count on startup
1 parent d813f0d commit 1a7eaa8

File tree

2 files changed

+86
-7
lines changed

2 files changed

+86
-7
lines changed

src/app_state.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use futures::future::join_all;
2+
use miette::Result;
23
use moka::future::Cache;
34
use mpchash::HashRing;
45
use sqlx::SqlitePool;
@@ -44,7 +45,7 @@ impl AppState {
4445
pub async fn new(
4546
cfg: Cfg,
4647
shard_receivers_out: &mut Vec<mpsc::Receiver<ShardWriteOperation>>,
47-
) -> Self {
48+
) -> Result<Self> {
4849
let mut shard_senders_vec = Vec::new();
4950
// Clear the output vector first to ensure it's empty
5051
shard_receivers_out.clear();
@@ -58,6 +59,9 @@ impl AppState {
5859
// Create data directory if it doesn't exist
5960
fs::create_dir_all(&cfg.data_dir).expect("Failed to create data directory");
6061

62+
// Validate the number of shards in the data directory
63+
validate_shard_count(&cfg.data_dir, cfg.num_shards)?;
64+
6165
let mut db_pools_futures = vec![];
6266
for i in 0..cfg.num_shards {
6367
let data_dir = cfg.data_dir.clone();
@@ -168,7 +172,7 @@ impl AppState {
168172
ring.add(ShardNode(i as u64));
169173
}
170174

171-
AppState {
175+
Ok(AppState {
172176
cfg,
173177
shard_senders: shard_senders_vec,
174178
db_pools,
@@ -178,7 +182,7 @@ impl AppState {
178182
compressor,
179183
metrics,
180184
ring,
181-
}
185+
})
182186
}
183187

184188
pub fn get_shard(&self, key: &str) -> usize {
@@ -191,3 +195,76 @@ impl AppState {
191195
format!("{}:{}", namespace, key)
192196
}
193197
}
198+
199+
/// Validates that the number of shard files in the data directory matches the expected count
200+
fn validate_shard_count(data_dir: &str, expected_shards: usize) -> Result<()> {
201+
use std::path::Path;
202+
203+
let data_path = Path::new(data_dir);
204+
205+
// If the directory doesn't exist yet, that's fine - it will be created
206+
if !data_path.exists() {
207+
return Ok(());
208+
}
209+
210+
// If the directory is empty, that's also fine - no validation needed
211+
if let Ok(entries) = fs::read_dir(data_path) {
212+
if entries.count() == 0 {
213+
return Ok(());
214+
}
215+
}
216+
217+
// Count existing shard files
218+
let mut actual_shards = 0;
219+
for i in 0..expected_shards {
220+
let shard_path = data_path.join(format!("shard_{}.db", i));
221+
if shard_path.exists() {
222+
actual_shards += 1;
223+
}
224+
}
225+
226+
// Check if we have more shards than expected
227+
// Look for any shard files beyond the expected range
228+
let mut extra_shards = Vec::new();
229+
if let Ok(entries) = fs::read_dir(data_path) {
230+
for entry in entries.flatten() {
231+
let file_name = entry.file_name();
232+
let file_name_str = file_name.to_string_lossy();
233+
if file_name_str.starts_with("shard_") && file_name_str.ends_with(".db") {
234+
if let Some(num_str) = file_name_str
235+
.strip_prefix("shard_")
236+
.and_then(|s| s.strip_suffix(".db"))
237+
{
238+
if let Ok(shard_num) = num_str.parse::<usize>() {
239+
if shard_num >= expected_shards {
240+
extra_shards.push(shard_num);
241+
}
242+
}
243+
}
244+
}
245+
}
246+
}
247+
248+
// Report any mismatches
249+
if actual_shards != expected_shards || !extra_shards.is_empty() {
250+
let mut error_msg = format!(
251+
"Shard count mismatch: expected {} shards, but found {} existing shards",
252+
expected_shards, actual_shards
253+
);
254+
255+
if !extra_shards.is_empty() {
256+
error_msg.push_str(&format!(". Found extra shard files: {:?}", extra_shards));
257+
}
258+
259+
if actual_shards < expected_shards {
260+
error_msg.push_str(&format!(
261+
". Missing {} shard files.",
262+
expected_shards - actual_shards
263+
));
264+
}
265+
266+
return Err(miette::miette!(error_msg));
267+
}
268+
269+
Ok(())
270+
}

src/main.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ async fn main() -> Result<()> {
8080

8181
// Handle different commands
8282
match args.command {
83-
Some(Command::Shard(shard_cmd)) => {
84-
handle_shard_command(shard_cmd, &args.config).await
85-
}
83+
Some(Command::Shard(shard_cmd)) => handle_shard_command(shard_cmd, &args.config).await,
8684
Some(Command::Serve(_)) | None => {
8785
// Default to serving if no command specified
8886
run_server(&args.config).await
@@ -132,7 +130,11 @@ async fn run_server(config_path: &str) -> Result<()> {
132130
let mut shard_receivers = Vec::with_capacity(cfg.num_shards);
133131

134132
// Initialize AppState, AppState::new will populate shard_receivers
135-
let shared_state = Arc::new(AppState::new(cfg.clone(), &mut shard_receivers).await);
133+
let shared_state = Arc::new(
134+
AppState::new(cfg.clone(), &mut shard_receivers)
135+
.await
136+
.wrap_err("initializing AppState")?,
137+
);
136138

137139
// Initialize metrics
138140
shared_state.metrics.record_server_startup();

0 commit comments

Comments
 (0)