Skip to content

Commit c8df3a8

Browse files
committed
PR feedback
1 parent a0129ab commit c8df3a8

File tree

4 files changed

+68
-71
lines changed

4 files changed

+68
-71
lines changed

internal/clients/elasticsearch/ml_job.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ func GetMLJobStats(ctx context.Context, apiClient *clients.ApiClient, jobId stri
144144
if res.StatusCode == http.StatusNotFound {
145145
return nil, diags
146146
}
147-
if fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML job stats: %s", jobId)); fwDiags.HasError() {
148-
diags.Append(fwDiags...)
147+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML job stats: %s", jobId))...)
148+
if diags.HasError() {
149149
return nil, diags
150150
}
151151

@@ -193,8 +193,7 @@ func GetDatafeed(ctx context.Context, apiClient *clients.ApiClient, datafeedId s
193193
return nil, diags
194194
}
195195

196-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML datafeed: %s", datafeedId))
197-
diags.Append(fwDiags...)
196+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML datafeed: %s", datafeedId))...)
198197
if diags.HasError() {
199198
return nil, diags
200199
}
@@ -241,8 +240,7 @@ func UpdateDatafeed(ctx context.Context, apiClient *clients.ApiClient, datafeedI
241240
}
242241
defer res.Body.Close()
243242

244-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to update ML datafeed: %s", datafeedId))
245-
diags.Append(fwDiags...)
243+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to update ML datafeed: %s", datafeedId))...)
246244

247245
return diags
248246
}
@@ -269,8 +267,7 @@ func DeleteDatafeed(ctx context.Context, apiClient *clients.ApiClient, datafeedI
269267
}
270268
defer res.Body.Close()
271269

272-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to delete ML datafeed: %s", datafeedId))
273-
diags.Append(fwDiags...)
270+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to delete ML datafeed: %s", datafeedId))...)
274271

275272
return diags
276273
}
@@ -302,8 +299,7 @@ func StopDatafeed(ctx context.Context, apiClient *clients.ApiClient, datafeedId
302299
}
303300
defer res.Body.Close()
304301

305-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to stop ML datafeed: %s", datafeedId))
306-
diags.Append(fwDiags...)
302+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to stop ML datafeed: %s", datafeedId))...)
307303

308304
return diags
309305
}
@@ -341,8 +337,7 @@ func StartDatafeed(ctx context.Context, apiClient *clients.ApiClient, datafeedId
341337
}
342338
defer res.Body.Close()
343339

344-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to start ML datafeed: %s", datafeedId))
345-
diags.Append(fwDiags...)
340+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to start ML datafeed: %s", datafeedId))...)
346341

347342
return diags
348343
}
@@ -374,8 +369,7 @@ func GetDatafeedStats(ctx context.Context, apiClient *clients.ApiClient, datafee
374369
return nil, diags
375370
}
376371

377-
fwDiags := diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML datafeed stats: %s", datafeedId))
378-
diags.Append(fwDiags...)
372+
diags.Append(diagutil.CheckErrorFromFW(res, fmt.Sprintf("Unable to get ML datafeed stats: %s", datafeedId))...)
379373
if diags.HasError() {
380374
return nil, diags
381375
}

internal/elasticsearch/ml/job_state/read.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66

77
"github.com/elastic/terraform-provider-elasticstack/internal/clients"
8-
"github.com/elastic/terraform-provider-elasticstack/internal/clients/elasticsearch"
98
"github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes"
109
"github.com/hashicorp/terraform-plugin-framework/resource"
1110
"github.com/hashicorp/terraform-plugin-framework/types"
@@ -24,30 +23,24 @@ func (r *mlJobStateResource) Read(ctx context.Context, req resource.ReadRequest,
2423
if resp.Diagnostics.HasError() {
2524
return
2625
}
27-
jobId := compId.ResourceId
28-
29-
client, diags := clients.MaybeNewApiClientFromFrameworkResource(ctx, data.ElasticsearchConnection, r.client)
30-
resp.Diagnostics.Append(diags...)
31-
if resp.Diagnostics.HasError() {
32-
return
33-
}
3426

3527
// Get job stats to check current state
36-
currentJob, fwDiags := elasticsearch.GetMLJobStats(ctx, client, jobId)
37-
resp.Diagnostics.Append(fwDiags...)
28+
jobId := compId.ResourceId
29+
currentState, diags := r.getJobState(ctx, jobId)
30+
resp.Diagnostics.Append(diags...)
3831
if resp.Diagnostics.HasError() {
3932
return
4033
}
4134

42-
if currentJob == nil {
35+
if currentState == nil {
4336
tflog.Warn(ctx, fmt.Sprintf(`ML job "%s" not found, removing from state`, jobId))
4437
resp.State.RemoveResource(ctx)
4538
return
4639
}
4740

4841
// Update the state with current job information
4942
data.JobId = types.StringValue(jobId)
50-
data.State = types.StringValue(currentJob.State)
43+
data.State = types.StringValue(*currentState)
5144

5245
// Set defaults for computed attributes if they're not already set (e.g., during import)
5346
if data.Force.IsNull() {
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package job_state
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/elastic/terraform-provider-elasticstack/internal/asyncutils"
8+
"github.com/elastic/terraform-provider-elasticstack/internal/clients/elasticsearch"
9+
"github.com/elastic/terraform-provider-elasticstack/internal/diagutil"
10+
"github.com/hashicorp/terraform-plugin-framework/diag"
11+
)
12+
13+
var errJobNotFound = fmt.Errorf("ML job not found")
14+
15+
// getJobState returns the current state of a job
16+
func (r *mlJobStateResource) getJobState(ctx context.Context, jobId string) (*string, diag.Diagnostics) {
17+
// Get job stats to check current state
18+
currentJob, diags := elasticsearch.GetMLJobStats(ctx, r.client, jobId)
19+
if diags.HasError() {
20+
return nil, diags
21+
}
22+
23+
if currentJob == nil {
24+
return nil, nil
25+
}
26+
27+
return &currentJob.State, nil
28+
}
29+
30+
// waitForJobState waits for a job to reach the desired state
31+
func (r *mlJobStateResource) waitForJobState(ctx context.Context, jobId, desiredState string) diag.Diagnostics {
32+
stateChecker := func(ctx context.Context) (bool, error) {
33+
currentState, diags := r.getJobState(ctx, jobId)
34+
if diags.HasError() {
35+
return false, diagutil.FwDiagsAsError(diags)
36+
}
37+
38+
if currentState == nil {
39+
return false, errJobNotFound
40+
}
41+
42+
return *currentState == desiredState, nil
43+
}
44+
45+
err := asyncutils.WaitForStateTransition(ctx, "ml_job", jobId, stateChecker)
46+
return diagutil.FrameworkDiagFromError(err)
47+
}

internal/elasticsearch/ml/job_state/update.go

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,22 @@ func (r *mlJobStateResource) update(ctx context.Context, plan tfsdk.Plan, state
5353
desiredState := data.State.ValueString()
5454

5555
// First, get the current job stats to check if the job exists and its current state
56-
currentJob, fwDiags := elasticsearch.GetMLJobStats(ctx, client, jobId)
56+
currentState, fwDiags := r.getJobState(ctx, jobId)
5757
diags.Append(fwDiags...)
5858
if diags.HasError() {
5959
return diags
6060
}
6161

62-
if currentJob == nil {
62+
if currentState == nil {
6363
diags.AddError(
6464
"ML Job not found",
6565
fmt.Sprintf("ML job %s does not exist", jobId),
6666
)
6767
return diags
6868
}
6969

70-
currentState := currentJob.State
71-
7270
// Perform state transition if needed
73-
fwDiags = r.performStateTransition(ctx, client, data, currentState, operationTimeout)
71+
fwDiags = r.performStateTransition(ctx, client, data, *currentState, operationTimeout)
7472
diags.Append(fwDiags...)
7573
if diags.HasError() {
7674
return diags
@@ -94,34 +92,6 @@ func (r *mlJobStateResource) update(ctx context.Context, plan tfsdk.Plan, state
9492
return diags
9593
}
9694

97-
// waitForJobStateTransition waits for an ML job to reach the desired state
98-
func waitForJobStateTransition(ctx context.Context, client *clients.ApiClient, jobId, desiredState string) error {
99-
const pollInterval = 2 * time.Second
100-
ticker := time.NewTicker(pollInterval)
101-
defer ticker.Stop()
102-
103-
for {
104-
select {
105-
case <-ctx.Done():
106-
return ctx.Err()
107-
case <-ticker.C:
108-
currentJob, fwDiags := elasticsearch.GetMLJobStats(ctx, client, jobId)
109-
if fwDiags.HasError() {
110-
return fmt.Errorf("failed to get job stats during state transition check")
111-
}
112-
113-
if currentJob == nil {
114-
return fmt.Errorf("job not found during state transition check")
115-
}
116-
117-
if currentJob.State == desiredState {
118-
return nil // Successfully reached desired state
119-
}
120-
tflog.Debug(ctx, fmt.Sprintf("ML job %s current state: %s, waiting for: %s", jobId, currentJob.State, desiredState))
121-
}
122-
}
123-
}
124-
12595
// performStateTransition handles the ML job state transition process
12696
func (r *mlJobStateResource) performStateTransition(ctx context.Context, client *clients.ApiClient, data MLJobStateData, currentState string, operationTimeout time.Duration) diag.Diagnostics {
12797
jobId := data.JobId.ValueString()
@@ -147,13 +117,11 @@ func (r *mlJobStateResource) performStateTransition(ctx context.Context, client
147117
// Initiate the state change
148118
switch desiredState {
149119
case "opened":
150-
diags := elasticsearch.OpenMLJob(ctx, client, jobId)
151-
if diags.HasError() {
120+
if diags := elasticsearch.OpenMLJob(ctx, client, jobId); diags.HasError() {
152121
return diags
153122
}
154123
case "closed":
155-
diags := elasticsearch.CloseMLJob(ctx, client, jobId, force, timeout) // Always allow no match
156-
if diags.HasError() {
124+
if diags := elasticsearch.CloseMLJob(ctx, client, jobId, force, timeout); diags.HasError() {
157125
return diags
158126
}
159127
default:
@@ -166,14 +134,9 @@ func (r *mlJobStateResource) performStateTransition(ctx context.Context, client
166134
}
167135

168136
// Wait for state transition to complete
169-
err := waitForJobStateTransition(ctx, client, jobId, desiredState)
170-
if err != nil {
171-
return diag.Diagnostics{
172-
diag.NewErrorDiagnostic(
173-
"State transition timeout",
174-
fmt.Sprintf("ML job %s did not transition to state %s within timeout: %s", jobId, desiredState, err.Error()),
175-
),
176-
}
137+
diags := r.waitForJobState(ctx, jobId, desiredState)
138+
if diags.HasError() {
139+
return diags
177140
}
178141

179142
tflog.Info(ctx, fmt.Sprintf("ML job %s successfully transitioned to state %s", jobId, desiredState))

0 commit comments

Comments
 (0)