Skip to content

Commit 624bdf9

Browse files
committed
feat: basic support to trigger leader election
1 parent eb79c45 commit 624bdf9

File tree

4 files changed

+377
-2
lines changed

4 files changed

+377
-2
lines changed

src/client/controller.rs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,27 @@ use crate::{
1111
messenger::RequestError,
1212
protocol::{
1313
error::Error as ProtocolError,
14-
messages::{CreateTopicRequest, CreateTopicsRequest},
15-
primitives::{Int16, Int32, NullableString, String_},
14+
messages::{
15+
CreateTopicRequest, CreateTopicsRequest, ElectLeadersRequest, ElectLeadersTopicRequest,
16+
},
17+
primitives::{Array, Int16, Int32, Int8, NullableString, String_},
1618
},
1719
validation::ExactlyOne,
1820
};
1921

22+
/// Election type of [`ControllerClient::elect_leaders`].
23+
///
24+
/// The names in this enum are borrowed from the
25+
/// [Kafka source code](https://github.com/a0x8o/kafka/blob/5383311a5cfbdaf147411004106449e3ad8081fb/core/src/main/scala/kafka/controller/KafkaController.scala#L2186-L2194>).
26+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27+
pub enum ElectionType {
28+
/// Elects the preferred replica.
29+
Preferred,
30+
31+
/// Elects the first live replica if there are no in-sync replica.
32+
Unclean,
33+
}
34+
2035
#[derive(Debug)]
2136
pub struct ControllerClient {
2237
brokers: Arc<BrokerConnector>,
@@ -78,6 +93,57 @@ impl ControllerClient {
7893
.await
7994
}
8095

96+
/// Elect leaders for given topic and partition.
97+
pub async fn elect_leaders(
98+
&self,
99+
topic: impl Into<String> + Send,
100+
partition: i32,
101+
election_type: ElectionType,
102+
timeout_ms: i32,
103+
) -> Result<()> {
104+
let request = &ElectLeadersRequest {
105+
election_type: Int8(match election_type {
106+
ElectionType::Preferred => 0,
107+
ElectionType::Unclean => 1,
108+
}),
109+
topic_partitions: vec![ElectLeadersTopicRequest {
110+
topic: String_(topic.into()),
111+
partitions: Array(Some(vec![Int32(partition)])),
112+
tagged_fields: None,
113+
}],
114+
timeout_ms: Int32(timeout_ms),
115+
tagged_fields: None,
116+
};
117+
118+
maybe_retry(&self.backoff_config, self, "elect_leaders", || async move {
119+
let broker = self.get().await?;
120+
let response = broker.request(request).await?;
121+
122+
if let Some(protocol_error) = response.error {
123+
return Err(Error::ServerError(protocol_error, Default::default()));
124+
}
125+
126+
let topic = response
127+
.replica_election_results
128+
.exactly_one()
129+
.map_err(Error::exactly_one_topic)?;
130+
131+
let partition = topic
132+
.partition_results
133+
.exactly_one()
134+
.map_err(Error::exactly_one_partition)?;
135+
136+
match partition.error {
137+
None => Ok(()),
138+
Some(protocol_error) => Err(Error::ServerError(
139+
protocol_error,
140+
partition.error_message.0.unwrap_or_default(),
141+
)),
142+
}
143+
})
144+
.await
145+
}
146+
81147
/// Retrieve the broker ID of the controller
82148
async fn get_controller_id(&self) -> Result<i32> {
83149
let metadata = self.brokers.request_metadata(None, Some(vec![])).await?;
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
use std::io::{Read, Write};
2+
3+
use crate::protocol::{
4+
api_key::ApiKey,
5+
api_version::{ApiVersion, ApiVersionRange},
6+
error::Error,
7+
messages::{
8+
read_compact_versioned_array, read_versioned_array, write_compact_versioned_array,
9+
write_versioned_array,
10+
},
11+
primitives::{
12+
Array, CompactArrayRef, CompactNullableString, CompactString, CompactStringRef, Int16,
13+
Int32, Int8, NullableString, String_, TaggedFields,
14+
},
15+
traits::{ReadType, WriteType},
16+
};
17+
18+
use super::{
19+
ReadVersionedError, ReadVersionedType, RequestBody, WriteVersionedError, WriteVersionedType,
20+
};
21+
22+
#[derive(Debug)]
23+
pub struct ElectLeadersRequest {
24+
/// Type of elections to conduct for the partition.
25+
///
26+
/// A value of `0` elects the preferred replica. A value of `1` elects the first live replica if there are no
27+
/// in-sync replica.
28+
///
29+
/// Added in version 1.
30+
pub election_type: Int8,
31+
32+
/// The topic partitions to elect leaders.
33+
pub topic_partitions: Vec<ElectLeadersTopicRequest>,
34+
35+
/// The time in ms to wait for the election to complete.
36+
pub timeout_ms: Int32,
37+
38+
/// The tagged fields.
39+
///
40+
/// Added in version 2
41+
pub tagged_fields: Option<TaggedFields>,
42+
}
43+
44+
impl<W> WriteVersionedType<W> for ElectLeadersRequest
45+
where
46+
W: Write,
47+
{
48+
fn write_versioned(
49+
&self,
50+
writer: &mut W,
51+
version: ApiVersion,
52+
) -> Result<(), WriteVersionedError> {
53+
let v = version.0 .0;
54+
assert!(v <= 2);
55+
56+
if v >= 1 {
57+
self.election_type.write(writer)?;
58+
}
59+
60+
if v >= 2 {
61+
write_compact_versioned_array(writer, version, Some(&self.topic_partitions))?;
62+
} else {
63+
write_versioned_array(writer, version, Some(&self.topic_partitions))?;
64+
}
65+
66+
self.timeout_ms.write(writer)?;
67+
68+
if v >= 2 {
69+
match self.tagged_fields.as_ref() {
70+
Some(tagged_fields) => {
71+
tagged_fields.write(writer)?;
72+
}
73+
None => {
74+
TaggedFields::default().write(writer)?;
75+
}
76+
}
77+
}
78+
79+
Ok(())
80+
}
81+
}
82+
83+
impl RequestBody for ElectLeadersRequest {
84+
type ResponseBody = ElectLeadersResponse;
85+
86+
const API_KEY: ApiKey = ApiKey::ElectLeaders;
87+
88+
/// All versions.
89+
const API_VERSION_RANGE: ApiVersionRange =
90+
ApiVersionRange::new(ApiVersion(Int16(0)), ApiVersion(Int16(2)));
91+
92+
const FIRST_TAGGED_FIELD_IN_REQUEST_VERSION: ApiVersion = ApiVersion(Int16(2));
93+
}
94+
95+
#[derive(Debug)]
96+
pub struct ElectLeadersTopicRequest {
97+
/// The name of a topic.
98+
pub topic: String_,
99+
100+
/// The partitions of this topic whose leader should be elected.
101+
pub partitions: Array<Int32>,
102+
103+
/// The tagged fields.
104+
///
105+
/// Added in version 2
106+
pub tagged_fields: Option<TaggedFields>,
107+
}
108+
109+
impl<W> WriteVersionedType<W> for ElectLeadersTopicRequest
110+
where
111+
W: Write,
112+
{
113+
fn write_versioned(
114+
&self,
115+
writer: &mut W,
116+
version: ApiVersion,
117+
) -> Result<(), WriteVersionedError> {
118+
let v = version.0 .0;
119+
assert!(v <= 2);
120+
121+
if v >= 2 {
122+
CompactStringRef(&self.topic.0).write(writer)?;
123+
} else {
124+
self.topic.write(writer)?;
125+
}
126+
127+
if v >= 2 {
128+
CompactArrayRef(self.partitions.0.as_deref()).write(writer)?;
129+
} else {
130+
self.partitions.write(writer)?;
131+
}
132+
133+
if v >= 2 {
134+
match self.tagged_fields.as_ref() {
135+
Some(tagged_fields) => {
136+
tagged_fields.write(writer)?;
137+
}
138+
None => {
139+
TaggedFields::default().write(writer)?;
140+
}
141+
}
142+
}
143+
144+
Ok(())
145+
}
146+
}
147+
148+
#[derive(Debug)]
149+
pub struct ElectLeadersResponse {
150+
/// The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the
151+
/// request did not violate any quota.
152+
pub throttle_time_ms: Int32,
153+
154+
/// The top level response error code.
155+
///
156+
/// Added in version 1.
157+
pub error: Option<Error>,
158+
159+
/// The election results, or an empty array if the requester did not have permission and the request asks for all
160+
/// partitions.
161+
pub replica_election_results: Vec<ElectLeadersTopicResponse>,
162+
163+
/// The tagged fields.
164+
///
165+
/// Added in version 2
166+
pub tagged_fields: Option<TaggedFields>,
167+
}
168+
169+
impl<R> ReadVersionedType<R> for ElectLeadersResponse
170+
where
171+
R: Read,
172+
{
173+
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
174+
let v = version.0 .0;
175+
assert!(v <= 2);
176+
177+
let throttle_time_ms = Int32::read(reader)?;
178+
let error = (v >= 1)
179+
.then(|| Int16::read(reader))
180+
.transpose()?
181+
.and_then(|e| Error::new(e.0));
182+
let replica_election_results = if v >= 2 {
183+
read_compact_versioned_array(reader, version)?.unwrap_or_default()
184+
} else {
185+
read_versioned_array(reader, version)?.unwrap_or_default()
186+
};
187+
let tagged_fields = (v >= 2).then(|| TaggedFields::read(reader)).transpose()?;
188+
189+
Ok(Self {
190+
throttle_time_ms,
191+
error,
192+
replica_election_results,
193+
tagged_fields,
194+
})
195+
}
196+
}
197+
198+
#[derive(Debug)]
199+
pub struct ElectLeadersTopicResponse {
200+
/// The topic name.
201+
pub topic: String_,
202+
203+
/// The results for each partition.
204+
pub partition_results: Vec<ElectLeadersPartitionResponse>,
205+
206+
/// The tagged fields.
207+
///
208+
/// Added in version 2
209+
pub tagged_fields: Option<TaggedFields>,
210+
}
211+
212+
impl<R> ReadVersionedType<R> for ElectLeadersTopicResponse
213+
where
214+
R: Read,
215+
{
216+
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
217+
let v = version.0 .0;
218+
assert!(v <= 2);
219+
220+
let topic = if v >= 2 {
221+
String_(CompactString::read(reader)?.0)
222+
} else {
223+
String_::read(reader)?
224+
};
225+
let partition_results = if v >= 2 {
226+
read_compact_versioned_array(reader, version)?.unwrap_or_default()
227+
} else {
228+
read_versioned_array(reader, version)?.unwrap_or_default()
229+
};
230+
let tagged_fields = (v >= 2).then(|| TaggedFields::read(reader)).transpose()?;
231+
232+
Ok(Self {
233+
topic,
234+
partition_results,
235+
tagged_fields,
236+
})
237+
}
238+
}
239+
240+
#[derive(Debug)]
241+
pub struct ElectLeadersPartitionResponse {
242+
/// The partition id.
243+
pub partition_id: Int32,
244+
245+
/// The result error, or zero if there was no error.
246+
pub error: Option<Error>,
247+
248+
/// The result message, or null if there was no error.
249+
pub error_message: NullableString,
250+
251+
/// The tagged fields.
252+
///
253+
/// Added in version 2
254+
pub tagged_fields: Option<TaggedFields>,
255+
}
256+
257+
impl<R> ReadVersionedType<R> for ElectLeadersPartitionResponse
258+
where
259+
R: Read,
260+
{
261+
fn read_versioned(reader: &mut R, version: ApiVersion) -> Result<Self, ReadVersionedError> {
262+
let v = version.0 .0;
263+
assert!(v <= 2);
264+
265+
let partition_id = Int32::read(reader)?;
266+
let error = Error::new(Int16::read(reader)?.0);
267+
let error_message = if v >= 2 {
268+
NullableString(CompactNullableString::read(reader)?.0)
269+
} else {
270+
NullableString::read(reader)?
271+
};
272+
let tagged_fields = (v >= 2).then(|| TaggedFields::read(reader)).transpose()?;
273+
274+
Ok(Self {
275+
partition_id,
276+
error,
277+
error_message,
278+
tagged_fields,
279+
})
280+
}
281+
}

src/protocol/messages/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ mod create_topics;
2424
pub use create_topics::*;
2525
mod delete_records;
2626
pub use delete_records::*;
27+
mod elect_leaders;
28+
pub use elect_leaders::*;
2729
mod fetch;
2830
pub use fetch::*;
2931
mod header;

0 commit comments

Comments
 (0)