diff --git a/Makefile b/Makefile index 38b7abfdc..6cb0bc140 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,10 @@ ROOT_DIR ?= $(shell git rev-parse --show-toplevel) SCRIPTS_BASE ?= $(ROOT_DIR)/scripts +BIN_DIR ?= $(ROOT_DIR)/bin + +# https://github.com/golangci/golangci-lint/releases +GOLANGCI_LINT_VERSION = 1.64.8 +GOLANGCI_LINT = $(BIN_DIR)/golangci-lint # SETUP AND TOOL INITIALIZATION TASKS project-help: @@ -8,10 +13,14 @@ project-help: project-tools: @$(SCRIPTS_BASE)/project.sh tools +# GOLANGCI-LINT INSTALLATION +$(GOLANGCI_LINT): + @GOLANGCI_LINT_VERSION=$(GOLANGCI_LINT_VERSION) $(SCRIPTS_BASE)/install-golangci-lint.sh + # LINT -lint-golangci-lint: +lint-golangci-lint: $(GOLANGCI_LINT) @echo "Linting with golangci-lint" - @$(SCRIPTS_BASE)/lint-golangci-lint.sh + @$(SCRIPTS_BASE)/lint-golangci-lint.sh $(GOLANGCI_LINT) lint-tf: @echo "Linting terraform files" diff --git a/scripts/install-golangci-lint.sh b/scripts/install-golangci-lint.sh new file mode 100755 index 000000000..725a25309 --- /dev/null +++ b/scripts/install-golangci-lint.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -e +. $(dirname ${0})/utility.sh + +BINARY_NAME=golangci-lint +INSTALL_TO=${BIN_DIR}/${BINARY_NAME} + +install() { + echo " installing ${BINARY_NAME} ${GOLANGCI_LINT_VERSION}" + + TYPE=windows + if [[ "${OSTYPE}" == linux* ]]; then + TYPE=linux + elif [[ "${OSTYPE}" == darwin* ]]; then + TYPE=darwin + fi + + case $(uname -m) in + arm64|aarch64) + ARCH=arm64 + ;; + *) + ARCH=amd64 + ;; + esac + + BASE_URL=https://github.com/golangci/golangci-lint/releases/download/v${GOLANGCI_LINT_VERSION} + URL=${BASE_URL}/golangci-lint-${GOLANGCI_LINT_VERSION}-${TYPE}-${ARCH}.tar.gz + echo " Downloading: ${URL}" + download ${URL} | tar --extract --gzip --strip-components 1 --preserve-permissions -C ${BIN_DIR} -f- + + # Ensure the binary has the correct name + if [ -f "${BIN_DIR}/golangci-lint" ] && [ "${BIN_DIR}/golangci-lint" != "${INSTALL_TO}" ]; then + mv "${BIN_DIR}/golangci-lint" "${INSTALL_TO}" + fi +} + +get_version() { + ${INSTALL_TO} version 2>/dev/null | awk '{print $4}' +} + +update_if_necessary ${GOLANGCI_LINT_VERSION} diff --git a/scripts/lint-golangci-lint.sh b/scripts/lint-golangci-lint.sh index c2ffd78fe..c2930464f 100755 --- a/scripts/lint-golangci-lint.sh +++ b/scripts/lint-golangci-lint.sh @@ -1,18 +1,19 @@ #!/usr/bin/env bash # This script lints the SDK modules and the internal examples -# Pre-requisites: golangci-lint +# Pre-requisites: golangci-lint (provided by Makefile or system) set -eo pipefail ROOT_DIR=$(git rev-parse --show-toplevel) GOLANG_CI_YAML_PATH="${ROOT_DIR}/golang-ci.yaml" GOLANG_CI_ARGS="--allow-parallel-runners --timeout=5m --config=${GOLANG_CI_YAML_PATH}" -if type -p golangci-lint >/dev/null; then - : -else - echo "golangci-lint not installed, unable to proceed." +# Use provided golangci-lint binary or fallback to system installation +GOLANGCI_LINT_BIN="${1:-golangci-lint}" + +if [ ! -x "${GOLANGCI_LINT_BIN}" ] && ! type -p "${GOLANGCI_LINT_BIN}" >/dev/null; then + echo "golangci-lint not found at ${GOLANGCI_LINT_BIN} and not installed in PATH, unable to proceed." exit 1 fi cd ${ROOT_DIR} -golangci-lint run ${GOLANG_CI_ARGS} +${GOLANGCI_LINT_BIN} run ${GOLANG_CI_ARGS} diff --git a/scripts/utility.sh b/scripts/utility.sh new file mode 100755 index 000000000..c46fcb8d8 --- /dev/null +++ b/scripts/utility.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# Common utility functions for tool installation scripts + +ROOT_DIR=$(git rev-parse --show-toplevel) +BIN_DIR="${ROOT_DIR}/bin" + +# Ensure bin directory exists +mkdir -p "${BIN_DIR}" + +# Download function using curl +download() { + local URL=$1 + if command -v curl &> /dev/null; then + curl -sSfL "${URL}" + elif command -v wget &> /dev/null; then + wget -qO- "${URL}" + else + echo "Error: Neither curl nor wget found. Please install one of them." + exit 1 + fi +} + +# Update tool if necessary +update_if_necessary() { + local EXPECTED_VERSION=$1 + + if [ -x "${INSTALL_TO}" ]; then + CURRENT_VERSION=$(get_version 2>/dev/null || echo "") + if [ "${CURRENT_VERSION}" = "${EXPECTED_VERSION}" ]; then + echo " ${BINARY_NAME} ${EXPECTED_VERSION} already installed" + return 0 + else + echo " ${BINARY_NAME} version mismatch (current: ${CURRENT_VERSION}, expected: ${EXPECTED_VERSION})" + echo " updating to ${EXPECTED_VERSION}..." + fi + fi + + install + + INSTALLED_VERSION=$(get_version 2>/dev/null || echo "unknown") + if [ "${INSTALLED_VERSION}" = "${EXPECTED_VERSION}" ]; then + echo " ${BINARY_NAME} ${EXPECTED_VERSION} installed successfully" + else + echo " Warning: installed version (${INSTALLED_VERSION}) does not match expected version (${EXPECTED_VERSION})" + fi +} diff --git a/stackit/internal/services/dns/recordset/resource.go b/stackit/internal/services/dns/recordset/resource.go index 9a9046350..abd6c172b 100644 --- a/stackit/internal/services/dns/recordset/resource.go +++ b/stackit/internal/services/dns/recordset/resource.go @@ -2,7 +2,9 @@ package dns import ( "context" + "errors" "fmt" + "net/http" "strings" "github.com/hashicorp/terraform-plugin-framework-validators/int64validator" @@ -16,6 +18,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/schema/validator" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" "github.com/stackitcloud/stackit-sdk-go/services/dns" "github.com/stackitcloud/stackit-sdk-go/services/dns/wait" "github.com/stackitcloud/terraform-provider-stackit/stackit/internal/conversion" @@ -200,6 +203,13 @@ func (r *recordSetResource) Create(ctx context.Context, req resource.CreateReque return } + // Get a fresh copy from plan for minimal state + var minimalModel Model + resp.Diagnostics.Append(req.Plan.Get(ctx, &minimalModel)...) + if resp.Diagnostics.HasError() { + return + } + projectId := model.ProjectId.ValueString() zoneId := model.ZoneId.ValueString() ctx = tflog.SetField(ctx, "project_id", projectId) @@ -219,18 +229,36 @@ func (r *recordSetResource) Create(ctx context.Context, req resource.CreateReque } // Write id attributes to state before polling via the wait handler - just in case anything goes wrong during the wait handler - utils.SetAndLogStateFields(ctx, &resp.Diagnostics, &resp.State, map[string]any{ - "project_id": projectId, - "zone_id": zoneId, - "record_set_id": *recordSetResp.Rrset.Id, - }) + recordSetId := *recordSetResp.Rrset.Id + minimalModel.RecordSetId = types.StringValue(recordSetId) + minimalModel.Id = utils.BuildInternalTerraformId(projectId, zoneId, recordSetId) + + // Set all unknown/null fields to null before saving state + if err := utils.SetModelFieldsToNull(ctx, &minimalModel); err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating record set", fmt.Sprintf("Setting model fields to null: %v", err)) + return + } + + diags = resp.State.Set(ctx, minimalModel) + resp.Diagnostics.Append(diags...) if resp.Diagnostics.HasError() { return } + if !utils.ShouldWait() { + tflog.Info(ctx, "Skipping wait; async mode for Crossplane/Upjet") + return + } + waitResp, err := wait.CreateRecordSetWaitHandler(ctx, r.client, projectId, zoneId, *recordSetResp.Rrset.Id).WaitWithContext(ctx) if err != nil { - core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating record set", fmt.Sprintf("Instance creation waiting: %v", err)) + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn(ctx, fmt.Sprintf("Record set creation waiting failed: %v. The record set creation was triggered but waiting for completion was interrupted. The record set may still be creating.", err)) + return + } + + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating record set", fmt.Sprintf("Record set creation waiting: %v", err)) + return } @@ -266,6 +294,12 @@ func (r *recordSetResource) Read(ctx context.Context, req resource.ReadRequest, recordSetResp, err := r.client.GetRecordSet(ctx, projectId, zoneId, recordSetId).Execute() if err != nil { + var oapiErr *oapierror.GenericOpenAPIError + ok := errors.As(err, &oapiErr) + if ok && (oapiErr.StatusCode == http.StatusNotFound || oapiErr.StatusCode == http.StatusGone) { + resp.State.RemoveResource(ctx) + return + } core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading record set", fmt.Sprintf("Calling API: %v", err)) return } @@ -319,9 +353,21 @@ func (r *recordSetResource) Update(ctx context.Context, req resource.UpdateReque core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating record set", err.Error()) return } + + if !utils.ShouldWait() { + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn(ctx, fmt.Sprintf("Record set update waiting failed: %v. The record set update was triggered but waiting for completion was interrupted. The record set may still be updating.", err)) + return + } + + core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating record set", fmt.Sprintf("Record set update waiting: %v", err)) + + return + } + waitResp, err := wait.PartialUpdateRecordSetWaitHandler(ctx, r.client, projectId, zoneId, recordSetId).WaitWithContext(ctx) if err != nil { - core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating record set", fmt.Sprintf("Instance update waiting: %v", err)) + tflog.Warn(ctx, fmt.Sprintf("Record set update waiting failed: %v. The record set update was triggered but waiting for completion was interrupted. The record set may still be updating.", err)) return } @@ -358,11 +404,31 @@ func (r *recordSetResource) Delete(ctx context.Context, req resource.DeleteReque // Delete existing record set _, err := r.client.DeleteRecordSet(ctx, projectId, zoneId, recordSetId).Execute() if err != nil { + // If resource is already gone (404 or 410), treat as success for idempotency + var oapiErr *oapierror.GenericOpenAPIError + ok := errors.As(err, &oapiErr) + if ok && (oapiErr.StatusCode == http.StatusNotFound || oapiErr.StatusCode == http.StatusGone) { + tflog.Info(ctx, "Record set already deleted") + return + } core.LogAndAddError(ctx, &resp.Diagnostics, "Error deleting record set", fmt.Sprintf("Calling API: %v", err)) + return + } + + if !utils.ShouldWait() { + tflog.Info(ctx, "Skipping wait; async mode for Crossplane/Upjet") + return } + _, err = wait.DeleteRecordSetWaitHandler(ctx, r.client, projectId, zoneId, recordSetId).WaitWithContext(ctx) if err != nil { - core.LogAndAddError(ctx, &resp.Diagnostics, "Error deleting record set", fmt.Sprintf("Instance deletion waiting: %v", err)) + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn(ctx, fmt.Sprintf("Record set deletion waiting failed: %v. The record set deletion was triggered but waiting for completion was interrupted. The record set may still be deleting.", err)) + return + } + + core.LogAndAddError(ctx, &resp.Diagnostics, "Error deleting record set", fmt.Sprintf("Record set deletion waiting: %v", err)) + return } tflog.Info(ctx, "DNS record set deleted") @@ -380,11 +446,23 @@ func (r *recordSetResource) ImportState(ctx context.Context, req resource.Import return } - utils.SetAndLogStateFields(ctx, &resp.Diagnostics, &resp.State, map[string]interface{}{ - "project_id": idParts[0], - "zone_id": idParts[1], - "record_set_id": idParts[2], - }) + var model Model + model.ProjectId = types.StringValue(idParts[0]) + model.ZoneId = types.StringValue(idParts[1]) + model.RecordSetId = types.StringValue(idParts[2]) + model.Id = utils.BuildInternalTerraformId(idParts[0], idParts[1], idParts[2]) + + if err := utils.SetModelFieldsToNull(ctx, &model); err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error importing zone", fmt.Sprintf("Setting model fields to null: %v", err)) + return + } + + diags := resp.State.Set(ctx, model) + resp.Diagnostics.Append(diags...) + if diags.HasError() { + return + } + tflog.Info(ctx, "DNS record set state imported") } diff --git a/stackit/internal/services/dns/zone/resource.go b/stackit/internal/services/dns/zone/resource.go index cfddeafea..52cbdb988 100644 --- a/stackit/internal/services/dns/zone/resource.go +++ b/stackit/internal/services/dns/zone/resource.go @@ -2,8 +2,10 @@ package dns import ( "context" + "errors" "fmt" "math" + "net/http" "strings" dnsUtils "github.com/stackitcloud/terraform-provider-stackit/stackit/internal/services/dns/utils" @@ -21,6 +23,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/schema/validator" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" "github.com/stackitcloud/stackit-sdk-go/services/dns" "github.com/stackitcloud/stackit-sdk-go/services/dns/wait" "github.com/stackitcloud/terraform-provider-stackit/stackit/internal/conversion" @@ -284,6 +287,13 @@ func (r *zoneResource) Create(ctx context.Context, req resource.CreateRequest, r return } + // Get a fresh copy from plan for minimal state + var minimalModel Model + resp.Diagnostics.Append(req.Plan.Get(ctx, &minimalModel)...) + if resp.Diagnostics.HasError() { + return + } + projectId := model.ProjectId.ValueString() ctx = tflog.SetField(ctx, "project_id", projectId) @@ -300,19 +310,43 @@ func (r *zoneResource) Create(ctx context.Context, req resource.CreateRequest, r return } - // Write id attributes to state before polling via the wait handler - just in case anything goes wrong during the wait handler + // Save minimal state immediately after API call succeeds to ensure idempotency zoneId := *createResp.Zone.Id - utils.SetAndLogStateFields(ctx, &resp.Diagnostics, &resp.State, map[string]interface{}{ - "project_id": projectId, - "zone_id": zoneId, - }) - if resp.Diagnostics.HasError() { + minimalModel.ZoneId = types.StringValue(zoneId) + minimalModel.Id = utils.BuildInternalTerraformId(projectId, zoneId) + + // Set all unknown/null fields to null before saving state + if err := utils.SetModelFieldsToNull(ctx, &minimalModel); err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating zone", fmt.Sprintf("Setting model fields to null: %v", err)) + return + } + + diags := resp.State.Set(ctx, minimalModel) + resp.Diagnostics.Append(diags...) + if diags.HasError() { + return + } + + if !utils.ShouldWait() { + tflog.Info(ctx, "Skipping wait; async mode for Crossplane/Upjet") return } waitResp, err := wait.CreateZoneWaitHandler(ctx, r.client, projectId, zoneId).WaitWithContext(ctx) if err != nil { + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn( + ctx, + fmt.Sprintf( + "Zone creation waiting failed: %v. The zone creation was triggered but waiting for completion was interrupted. The zone may still be creating.", + err, + ), + ) + return + } + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating zone", fmt.Sprintf("Zone creation waiting: %v", err)) + return } @@ -345,10 +379,17 @@ func (r *zoneResource) Read(ctx context.Context, req resource.ReadRequest, resp zoneResp, err := r.client.GetZone(ctx, projectId, zoneId).Execute() if err != nil { + var oapiErr *oapierror.GenericOpenAPIError + ok := errors.As(err, &oapiErr) + if ok && (oapiErr.StatusCode == http.StatusNotFound || oapiErr.StatusCode == http.StatusGone) { + resp.State.RemoveResource(ctx) + return + } core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading zone", fmt.Sprintf("Calling API: %v", err)) return } - if zoneResp != nil && zoneResp.Zone.State != nil && *zoneResp.Zone.State == dns.ZONESTATE_DELETE_SUCCEEDED { + if zoneResp != nil && zoneResp.Zone.State != nil && + *zoneResp.Zone.State == dns.ZONESTATE_DELETE_SUCCEEDED { resp.State.RemoveResource(ctx) return } @@ -394,9 +435,27 @@ func (r *zoneResource) Update(ctx context.Context, req resource.UpdateRequest, r core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating zone", fmt.Sprintf("Calling API: %v", err)) return } + + if !utils.ShouldWait() { + tflog.Info(ctx, "Skipping wait; async mode for Crossplane/Upjet") + return + } + waitResp, err := wait.PartialUpdateZoneWaitHandler(ctx, r.client, projectId, zoneId).WaitWithContext(ctx) if err != nil { + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn( + ctx, + fmt.Sprintf( + "Zone update waiting failed: %v. The zone update was triggered but waiting for completion was interrupted. The zone may still be updating.", + err, + ), + ) + return + } + core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating zone", fmt.Sprintf("Zone update waiting: %v", err)) + return } @@ -431,12 +490,38 @@ func (r *zoneResource) Delete(ctx context.Context, req resource.DeleteRequest, r // Delete existing zone _, err := r.client.DeleteZone(ctx, projectId, zoneId).Execute() if err != nil { + // If resource is already gone (404 or 410), treat as success for idempotency + var oapiErr *oapierror.GenericOpenAPIError + ok := errors.As(err, &oapiErr) + if ok && + (oapiErr.StatusCode == http.StatusNotFound || oapiErr.StatusCode == http.StatusGone) { + tflog.Info(ctx, "DNS zone already deleted") + return + } core.LogAndAddError(ctx, &resp.Diagnostics, "Error deleting zone", fmt.Sprintf("Calling API: %v", err)) return } + + if !utils.ShouldWait() { + tflog.Info(ctx, "Skipping wait; async mode for Crossplane/Upjet") + return + } + _, err = wait.DeleteZoneWaitHandler(ctx, r.client, projectId, zoneId).WaitWithContext(ctx) if err != nil { + if utils.ShouldIgnoreWaitError(err) { + tflog.Warn( + ctx, + fmt.Sprintf( + "Zone deletion waiting failed: %v. The zone deletion was triggered but waiting for completion was interrupted. The zone may still be deleting.", + err, + ), + ) + return + } + core.LogAndAddError(ctx, &resp.Diagnostics, "Error deleting zone", fmt.Sprintf("Zone deletion waiting: %v", err)) + return } @@ -456,10 +541,21 @@ func (r *zoneResource) ImportState(ctx context.Context, req resource.ImportState return } - utils.SetAndLogStateFields(ctx, &resp.Diagnostics, &resp.State, map[string]interface{}{ - "project_id": idParts[0], - "zone_id": idParts[1], - }) + var model Model + model.ProjectId = types.StringValue(idParts[0]) + model.ZoneId = types.StringValue(idParts[1]) + model.Id = utils.BuildInternalTerraformId(idParts[0], idParts[1]) + + if err := utils.SetModelFieldsToNull(ctx, &model); err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error importing zone", fmt.Sprintf("Setting model fields to null: %v", err)) + return + } + + diags := resp.State.Set(ctx, model) + resp.Diagnostics.Append(diags...) + if diags.HasError() { + return + } tflog.Info(ctx, "DNS zone state imported") } diff --git a/stackit/internal/utils/utils.go b/stackit/internal/utils/utils.go index a76141134..852e6b2bb 100644 --- a/stackit/internal/utils/utils.go +++ b/stackit/internal/utils/utils.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "net" + "net/http" + "os" + "reflect" "regexp" "strings" @@ -182,3 +186,600 @@ func SetAndLogStateFields(ctx context.Context, diags *diag.Diagnostics, state *t diags.Append(state.SetAttribute(ctx, path.Root(key), val)...) } } + +// SetModelFieldsToNull sets all Unknown or Null fields in a model struct to their appropriate Null values. +// This is useful when saving minimal state after API calls to ensure idempotency. +// The model parameter must be a pointer to a struct containing Terraform framework types. +// This function recursively processes nested objects, lists, sets, and maps. +func SetModelFieldsToNull(ctx context.Context, model any) error { + if model == nil { + return fmt.Errorf("model cannot be nil") + } + + v := reflect.ValueOf(model) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("model must be a pointer, got %v", v.Kind()) + } + + v = v.Elem() + if !v.IsValid() || v.Kind() != reflect.Struct { + return fmt.Errorf("model must point to a struct, got %v", v.Kind()) + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + if !field.CanInterface() || !field.CanSet() { + continue + } + + fieldValue := field.Interface() + + // Check if the field implements IsUnknown and IsNull + isUnknownMethod := field.MethodByName("IsUnknown") + isNullMethod := field.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + continue + } + + // Call IsUnknown() and IsNull() + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + // If the field is Unknown or Null at the top level, convert it to Null + if isUnknown || isNull { + if err := setFieldToNull(ctx, field, fieldValue, &fieldType); err != nil { + return err + } + continue + } + + // If the field is Known and not Null, recursively process it + if err := processKnownField(ctx, field, fieldValue, &fieldType); err != nil { + return err + } + } + + return nil +} + +// setFieldToNull sets a field to its appropriate Null value based on type +func setFieldToNull(ctx context.Context, field reflect.Value, fieldValue any, fieldType *reflect.StructField) error { + switch v := fieldValue.(type) { + case basetypes.StringValue: + field.Set(reflect.ValueOf(types.StringNull())) + + case basetypes.BoolValue: + field.Set(reflect.ValueOf(types.BoolNull())) + + case basetypes.Int64Value: + field.Set(reflect.ValueOf(types.Int64Null())) + + case basetypes.Float64Value: + field.Set(reflect.ValueOf(types.Float64Null())) + + case basetypes.NumberValue: + field.Set(reflect.ValueOf(types.NumberNull())) + + case basetypes.ListValue: + elemType := v.ElementType(ctx) + field.Set(reflect.ValueOf(types.ListNull(elemType))) + + case basetypes.SetValue: + elemType := v.ElementType(ctx) + field.Set(reflect.ValueOf(types.SetNull(elemType))) + + case basetypes.MapValue: + elemType := v.ElementType(ctx) + field.Set(reflect.ValueOf(types.MapNull(elemType))) + + case basetypes.ObjectValue: + attrTypes := v.AttributeTypes(ctx) + field.Set(reflect.ValueOf(types.ObjectNull(attrTypes))) + + default: + tflog.Debug(ctx, fmt.Sprintf("SetModelFieldsToNull: skipping field %s of unsupported type %T", fieldType.Name, fieldValue)) + } + return nil +} + +// processKnownField recursively processes known (non-null, non-unknown) fields +// to handle nested structures like objects within lists, maps, etc. +func processKnownField(ctx context.Context, field reflect.Value, fieldValue any, fieldType *reflect.StructField) error { + switch v := fieldValue.(type) { + case basetypes.ObjectValue: + // Recursively process object fields + return processObjectValue(ctx, field, v, fieldType) + + case basetypes.ListValue: + // Recursively process list elements + return processListValue(ctx, field, v, fieldType) + + case basetypes.SetValue: + // Recursively process set elements + return processSetValue(ctx, field, v, fieldType) + + case basetypes.MapValue: + // Recursively process map values + return processMapValue(ctx, field, v, fieldType) + + default: + // Primitive types (String, Bool, Int64, etc.) don't need recursion + return nil + } +} + +// processObjectValue recursively processes fields within an ObjectValue +func processObjectValue(ctx context.Context, field reflect.Value, objValue basetypes.ObjectValue, fieldType *reflect.StructField) error { + attrs := objValue.Attributes() + attrTypes := objValue.AttributeTypes(ctx) + modified := false + newAttrs := make(map[string]attr.Value, len(attrs)) + + for key, attrVal := range attrs { + // Check if the attribute has IsUnknown and IsNull methods + attrValReflect := reflect.ValueOf(attrVal) + isUnknownMethod := attrValReflect.MethodByName("IsUnknown") + isNullMethod := attrValReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newAttrs[key] = attrVal + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newAttrs[key] = attrVal + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + // Convert Unknown or Null attributes to Null + if isUnknown || isNull { + nullVal := createNullValue(ctx, attrVal, attrTypes[key]) + if nullVal != nil { + newAttrs[key] = nullVal + modified = true + } else { + newAttrs[key] = attrVal + } + } else { + // Recursively process known attributes + processedVal, wasModified, err := processAttributeValueWithFlag(ctx, attrVal, attrTypes[key]) + if err != nil { + return err + } + newAttrs[key] = processedVal + if wasModified { + modified = true + } + } + } + + // Only update the field if something changed + if modified { + newObj, diags := types.ObjectValue(attrTypes, newAttrs) + if diags.HasError() { + return fmt.Errorf("creating new object value for field %s: %v", fieldType.Name, diags.Errors()) + } + field.Set(reflect.ValueOf(newObj)) + } + + return nil +} + +// processListValue recursively processes elements within a ListValue +func processListValue(ctx context.Context, field reflect.Value, listValue basetypes.ListValue, fieldType *reflect.StructField) error { + elements := listValue.Elements() + if len(elements) == 0 { + return nil + } + + elemType := listValue.ElementType(ctx) + modified := false + newElements := make([]attr.Value, len(elements)) + + for i, elem := range elements { + // Check if element is Unknown or Null + elemReflect := reflect.ValueOf(elem) + isUnknownMethod := elemReflect.MethodByName("IsUnknown") + isNullMethod := elemReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newElements[i] = elem + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newElements[i] = elem + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + if isUnknown || isNull { + nullVal := createNullValue(ctx, elem, elemType) + if nullVal != nil { + newElements[i] = nullVal + modified = true + } else { + newElements[i] = elem + } + } else { + // Recursively process known elements (objects, lists, etc.) + processedElem, wasModified, err := processAttributeValueWithFlag(ctx, elem, elemType) + if err != nil { + return err + } + newElements[i] = processedElem + if wasModified { + modified = true + } + } + } + + // Only update if something changed + if modified { + newList, diags := types.ListValue(elemType, newElements) + if diags.HasError() { + return fmt.Errorf("creating new list value for field %s: %v", fieldType.Name, diags.Errors()) + } + field.Set(reflect.ValueOf(newList)) + } + + return nil +} + +// processSetValue recursively processes elements within a SetValue +func processSetValue(ctx context.Context, field reflect.Value, setValue basetypes.SetValue, fieldType *reflect.StructField) error { + elements := setValue.Elements() + if len(elements) == 0 { + return nil + } + + elemType := setValue.ElementType(ctx) + modified := false + newElements := make([]attr.Value, len(elements)) + + for i, elem := range elements { + elemReflect := reflect.ValueOf(elem) + isUnknownMethod := elemReflect.MethodByName("IsUnknown") + isNullMethod := elemReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newElements[i] = elem + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newElements[i] = elem + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + if isUnknown || isNull { + nullVal := createNullValue(ctx, elem, elemType) + if nullVal != nil { + newElements[i] = nullVal + modified = true + } else { + newElements[i] = elem + } + } else { + processedElem, wasModified, err := processAttributeValueWithFlag(ctx, elem, elemType) + if err != nil { + return err + } + newElements[i] = processedElem + if wasModified { + modified = true + } + } + } + + if modified { + newSet, diags := types.SetValue(elemType, newElements) + if diags.HasError() { + return fmt.Errorf("creating new set value for field %s: %v", fieldType.Name, diags.Errors()) + } + field.Set(reflect.ValueOf(newSet)) + } + + return nil +} + +// processMapValue recursively processes values within a MapValue +func processMapValue(ctx context.Context, field reflect.Value, mapValue basetypes.MapValue, fieldType *reflect.StructField) error { + elements := mapValue.Elements() + if len(elements) == 0 { + return nil + } + + elemType := mapValue.ElementType(ctx) + modified := false + newElements := make(map[string]attr.Value, len(elements)) + + for key, val := range elements { + valReflect := reflect.ValueOf(val) + isUnknownMethod := valReflect.MethodByName("IsUnknown") + isNullMethod := valReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newElements[key] = val + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newElements[key] = val + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + if isUnknown || isNull { + nullVal := createNullValue(ctx, val, elemType) + if nullVal != nil { + newElements[key] = nullVal + modified = true + } else { + newElements[key] = val + } + } else { + processedVal, wasModified, err := processAttributeValueWithFlag(ctx, val, elemType) + if err != nil { + return err + } + newElements[key] = processedVal + if wasModified { + modified = true + } + } + } + + if modified { + newMap, diags := types.MapValue(elemType, newElements) + if diags.HasError() { + return fmt.Errorf("creating new map value for field %s: %v", fieldType.Name, diags.Errors()) + } + field.Set(reflect.ValueOf(newMap)) + } + + return nil +} + +// processAttributeValueWithFlag recursively processes a single attribute value +// Returns the processed value, a flag indicating if it was modified, and an error +func processAttributeValueWithFlag(ctx context.Context, attrVal attr.Value, attrType attr.Type) (attr.Value, bool, error) { + switch v := attrVal.(type) { + case basetypes.ObjectValue: + // Recursively process object attributes + attrs := v.Attributes() + objType, ok := attrType.(types.ObjectType) + if !ok { + return attrVal, false, nil + } + attrTypes := objType.AttrTypes + modified := false + newAttrs := make(map[string]attr.Value, len(attrs)) + + for key, subAttr := range attrs { + subAttrReflect := reflect.ValueOf(subAttr) + isUnknownMethod := subAttrReflect.MethodByName("IsUnknown") + isNullMethod := subAttrReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newAttrs[key] = subAttr + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newAttrs[key] = subAttr + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + if isUnknown || isNull { + nullVal := createNullValue(ctx, subAttr, attrTypes[key]) + if nullVal != nil { + newAttrs[key] = nullVal + modified = true + } else { + newAttrs[key] = subAttr + } + } else { + processedSubAttr, wasModified, err := processAttributeValueWithFlag(ctx, subAttr, attrTypes[key]) + if err != nil { + return attrVal, false, err + } + newAttrs[key] = processedSubAttr + if wasModified { + modified = true + } + } + } + + if modified { + newObj, diags := types.ObjectValue(attrTypes, newAttrs) + if diags.HasError() { + return attrVal, false, fmt.Errorf("creating new object value: %v", diags.Errors()) + } + return newObj, true, nil + } + return attrVal, false, nil + + case basetypes.ListValue: + // Recursively process list elements + elements := v.Elements() + if len(elements) == 0 { + return attrVal, false, nil + } + + elemType := v.ElementType(ctx) + modified := false + newElements := make([]attr.Value, len(elements)) + + for i, elem := range elements { + elemReflect := reflect.ValueOf(elem) + isUnknownMethod := elemReflect.MethodByName("IsUnknown") + isNullMethod := elemReflect.MethodByName("IsNull") + + if !isUnknownMethod.IsValid() || !isNullMethod.IsValid() { + newElements[i] = elem + continue + } + + isUnknownResult := isUnknownMethod.Call(nil) + isNullResult := isNullMethod.Call(nil) + + if len(isUnknownResult) == 0 || len(isNullResult) == 0 { + newElements[i] = elem + continue + } + + isUnknown := isUnknownResult[0].Bool() + isNull := isNullResult[0].Bool() + + if isUnknown || isNull { + nullVal := createNullValue(ctx, elem, elemType) + if nullVal != nil { + newElements[i] = nullVal + modified = true + } else { + newElements[i] = elem + } + } else { + processedElem, wasModified, err := processAttributeValueWithFlag(ctx, elem, elemType) + if err != nil { + return attrVal, false, err + } + newElements[i] = processedElem + if wasModified { + modified = true + } + } + } + + if modified { + newList, diags := types.ListValue(elemType, newElements) + if diags.HasError() { + return attrVal, false, fmt.Errorf("creating new list value: %v", diags.Errors()) + } + return newList, true, nil + } + return attrVal, false, nil + + default: + // Primitive types don't need further processing + return attrVal, false, nil + } +} + +// createNullValue creates a null value of the appropriate type +func createNullValue(_ context.Context, val attr.Value, attrType attr.Type) attr.Value { + switch val.(type) { + case basetypes.StringValue: + return types.StringNull() + case basetypes.BoolValue: + return types.BoolNull() + case basetypes.Int64Value: + return types.Int64Null() + case basetypes.Float64Value: + return types.Float64Null() + case basetypes.NumberValue: + return types.NumberNull() + case basetypes.ListValue: + if listType, ok := attrType.(types.ListType); ok { + return types.ListNull(listType.ElemType) + } + return nil + case basetypes.SetValue: + if setType, ok := attrType.(types.SetType); ok { + return types.SetNull(setType.ElemType) + } + return nil + case basetypes.MapValue: + if mapType, ok := attrType.(types.MapType); ok { + return types.MapNull(mapType.ElemType) + } + return nil + case basetypes.ObjectValue: + if objType, ok := attrType.(types.ObjectType); ok { + return types.ObjectNull(objType.AttrTypes) + } + return nil + default: + return nil + } +} + +// ShouldWait checks the STACKIT_TF_WAIT_FOR_READY environment variable to determine +// if the provider should wait for resources to be ready after creation/update. +// Returns true if the variable is unset or set to "true" (case-insensitive). +// Returns false if the variable is set to any other value. +// This is typically used to skip waiting in async mode for Crossplane/Upjet. +func ShouldWait() bool { + v := os.Getenv("STACKIT_TF_WAIT_FOR_READY") + return v == "" || strings.EqualFold(v, "true") +} + +// ShouldIgnoreWaitError determines if a wait error should be ignored. +// Returns true for transient errors like context timeouts, network errors, and API 5xx errors. +// These errors are considered recoverable and shouldn't fail the operation. +func ShouldIgnoreWaitError(err error) bool { + // Context errors (deadline exceeded, canceled) + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return true + } + + // Network errors (timeout, connection refused, etc.) + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + + // API 5xx errors + var oapiErr *oapierror.GenericOpenAPIError + if errors.As(err, &oapiErr) { + if oapiErr.StatusCode >= http.StatusInternalServerError { + return true + } + + if oapiErr.StatusCode == http.StatusNotFound || oapiErr.StatusCode == http.StatusGone { + return true + } + } + + return false +} diff --git a/stackit/internal/utils/utils_test.go b/stackit/internal/utils/utils_test.go index 0dc5bf5b3..9e2fd3270 100644 --- a/stackit/internal/utils/utils_test.go +++ b/stackit/internal/utils/utils_test.go @@ -2,8 +2,12 @@ package utils import ( "context" + "errors" "fmt" + "net/http" + "os" "reflect" + "strings" "testing" "github.com/hashicorp/terraform-plugin-framework/diag" @@ -17,6 +21,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-framework/types/basetypes" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" "github.com/stackitcloud/stackit-sdk-go/core/utils" ) @@ -610,3 +615,1069 @@ func TestSetAndLogStateFields(t *testing.T) { }) } } + +func TestSetModelFieldsToNull(t *testing.T) { + ctx := context.Background() + + type TestModel struct { + StringField types.String `tfsdk:"string_field"` + BoolField types.Bool `tfsdk:"bool_field"` + Int64Field types.Int64 `tfsdk:"int64_field"` + Float64Field types.Float64 `tfsdk:"float64_field"` + ListField types.List `tfsdk:"list_field"` + SetField types.Set `tfsdk:"set_field"` + MapField types.Map `tfsdk:"map_field"` + ObjectField types.Object `tfsdk:"object_field"` + } + + tests := []struct { + name string + input *TestModel + expected *TestModel + expectError bool + }{ + { + name: "all unknown fields should be set to null", + input: &TestModel{ + StringField: types.StringUnknown(), + BoolField: types.BoolUnknown(), + Int64Field: types.Int64Unknown(), + Float64Field: types.Float64Unknown(), + ListField: types.ListUnknown(types.StringType), + SetField: types.SetUnknown(types.StringType), + MapField: types.MapUnknown(types.StringType), + ObjectField: types.ObjectUnknown(map[string]attr.Type{"field1": types.StringType}), + }, + expected: &TestModel{ + StringField: types.StringNull(), + BoolField: types.BoolNull(), + Int64Field: types.Int64Null(), + Float64Field: types.Float64Null(), + ListField: types.ListNull(types.StringType), + SetField: types.SetNull(types.StringType), + MapField: types.MapNull(types.StringType), + ObjectField: types.ObjectNull(map[string]attr.Type{"field1": types.StringType}), + }, + expectError: false, + }, + { + name: "all null fields should remain null", + input: &TestModel{ + StringField: types.StringNull(), + BoolField: types.BoolNull(), + Int64Field: types.Int64Null(), + Float64Field: types.Float64Null(), + ListField: types.ListNull(types.StringType), + SetField: types.SetNull(types.StringType), + MapField: types.MapNull(types.StringType), + ObjectField: types.ObjectNull(map[string]attr.Type{"field1": types.StringType}), + }, + expected: &TestModel{ + StringField: types.StringNull(), + BoolField: types.BoolNull(), + Int64Field: types.Int64Null(), + Float64Field: types.Float64Null(), + ListField: types.ListNull(types.StringType), + SetField: types.SetNull(types.StringType), + MapField: types.MapNull(types.StringType), + ObjectField: types.ObjectNull(map[string]attr.Type{"field1": types.StringType}), + }, + expectError: false, + }, + { + name: "known fields should not be modified", + input: &TestModel{ + StringField: types.StringValue("test"), + BoolField: types.BoolValue(true), + Int64Field: types.Int64Value(42), + Float64Field: types.Float64Value(3.14), + ListField: types.ListValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + SetField: types.SetValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + MapField: types.MapValueMust(types.StringType, map[string]attr.Value{"key": types.StringValue("value")}), + ObjectField: types.ObjectValueMust(map[string]attr.Type{"field1": types.StringType}, map[string]attr.Value{"field1": types.StringValue("value")}), + }, + expected: &TestModel{ + StringField: types.StringValue("test"), + BoolField: types.BoolValue(true), + Int64Field: types.Int64Value(42), + Float64Field: types.Float64Value(3.14), + ListField: types.ListValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + SetField: types.SetValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + MapField: types.MapValueMust(types.StringType, map[string]attr.Value{"key": types.StringValue("value")}), + ObjectField: types.ObjectValueMust(map[string]attr.Type{"field1": types.StringType}, map[string]attr.Value{"field1": types.StringValue("value")}), + }, + expectError: false, + }, + { + name: "mixed fields - some unknown, some known", + input: &TestModel{ + StringField: types.StringUnknown(), + BoolField: types.BoolValue(true), + Int64Field: types.Int64Unknown(), + Float64Field: types.Float64Value(2.71), + ListField: types.ListNull(types.StringType), + SetField: types.SetValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + MapField: types.MapUnknown(types.StringType), + ObjectField: types.ObjectValueMust(map[string]attr.Type{"field1": types.StringType}, map[string]attr.Value{"field1": types.StringValue("value")}), + }, + expected: &TestModel{ + StringField: types.StringNull(), + BoolField: types.BoolValue(true), + Int64Field: types.Int64Null(), + Float64Field: types.Float64Value(2.71), + ListField: types.ListNull(types.StringType), + SetField: types.SetValueMust(types.StringType, []attr.Value{types.StringValue("item")}), + MapField: types.MapNull(types.StringType), + ObjectField: types.ObjectValueMust(map[string]attr.Type{"field1": types.StringType}, map[string]attr.Value{"field1": types.StringValue("value")}), + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := SetModelFieldsToNull(ctx, tt.input) + + if tt.expectError { + if err == nil { + t.Fatal("expected error but got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Compare each field + if diff := cmp.Diff(tt.input.StringField, tt.expected.StringField); diff != "" { + t.Errorf("StringField mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.BoolField, tt.expected.BoolField); diff != "" { + t.Errorf("BoolField mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.Int64Field, tt.expected.Int64Field); diff != "" { + t.Errorf("Int64Field mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.Float64Field, tt.expected.Float64Field); diff != "" { + t.Errorf("Float64Field mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.ListField, tt.expected.ListField); diff != "" { + t.Errorf("ListField mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.SetField, tt.expected.SetField); diff != "" { + t.Errorf("SetField mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.MapField, tt.expected.MapField); diff != "" { + t.Errorf("MapField mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(tt.input.ObjectField, tt.expected.ObjectField); diff != "" { + t.Errorf("ObjectField mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestSetModelFieldsToNull_Errors(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + input any + wantError string + }{ + { + name: "nil model", + input: nil, + wantError: "model cannot be nil", + }, + { + name: "non-pointer", + input: struct{}{}, + wantError: "model must be a pointer", + }, + { + name: "pointer to non-struct", + input: func() *string { s := "test"; return &s }(), + wantError: "model must point to a struct", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := SetModelFieldsToNull(ctx, tt.input) + if err == nil { + t.Fatal("expected error but got nil") + } + if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("expected error containing %q, got %q", tt.wantError, err.Error()) + } + }) + } +} + +func TestSetModelFieldsToNull_ComplexStructures(t *testing.T) { + ctx := context.Background() + + // Test nested objects + t.Run("object with unknown fields inside known object", func(t *testing.T) { + type NestedModel struct { + NestedObject types.Object `tfsdk:"nested_object"` + } + + input := &NestedModel{ + NestedObject: types.ObjectValueMust( + map[string]attr.Type{ + "field1": types.StringType, + "field2": types.Int64Type, + }, + map[string]attr.Value{ + "field1": types.StringUnknown(), + "field2": types.Int64Value(42), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the object was modified + attrs := input.NestedObject.Attributes() + if !attrs["field1"].IsNull() { + t.Error("field1 should be null after processing unknown field in nested object") + } + if attrs["field2"].IsNull() { + t.Error("field2 should remain non-null") + } + }) + + // Test list with unknown elements + t.Run("list with unknown and null elements", func(t *testing.T) { + type ListModel struct { + MyList types.List `tfsdk:"my_list"` + } + + input := &ListModel{ + MyList: types.ListValueMust( + types.StringType, + []attr.Value{ + types.StringValue("known"), + types.StringUnknown(), + types.StringNull(), + types.StringValue("another_known"), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.MyList.Elements() + if len(elements) != 4 { + t.Fatalf("expected 4 elements, got %d", len(elements)) + } + + // Check that unknown was converted to null + if !elements[1].IsNull() { + t.Error("element at index 1 (was unknown) should be null") + } + // Check that null remained null + if !elements[2].IsNull() { + t.Error("element at index 2 (was null) should remain null") + } + // Check known values remain unchanged + if elements[0].IsNull() || elements[3].IsNull() { + t.Error("known elements should not be null") + } + }) + + // Test list of objects with unknown fields + t.Run("list of objects with unknown fields", func(t *testing.T) { + type ListOfObjectsModel struct { + Objects types.List `tfsdk:"objects"` + } + + objectType := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "name": types.StringType, + "age": types.Int64Type, + }, + } + + input := &ListOfObjectsModel{ + Objects: types.ListValueMust( + objectType, + []attr.Value{ + types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "name": types.StringValue("Alice"), + "age": types.Int64Unknown(), + }, + ), + types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "name": types.StringUnknown(), + "age": types.Int64Value(30), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.Objects.Elements() + if len(elements) != 2 { + t.Fatalf("expected 2 elements, got %d", len(elements)) + } + + // Check first object - age should be null + obj1, ok := elements[0].(types.Object) + if !ok { + t.Fatalf("expected element 0 to be types.Object, got %T", elements[0]) + } + if !obj1.Attributes()["age"].IsNull() { + t.Error("first object's age field should be null") + } + if obj1.Attributes()["name"].IsNull() { + t.Error("first object's name field should not be null") + } + + // Check second object - name should be null + obj2, ok := elements[1].(types.Object) + if !ok { + t.Fatalf("expected element 1 to be types.Object, got %T", elements[1]) + } + if !obj2.Attributes()["name"].IsNull() { + t.Error("second object's name field should be null") + } + if obj2.Attributes()["age"].IsNull() { + t.Error("second object's age field should not be null") + } + }) + + // Test deeply nested objects + t.Run("deeply nested objects", func(t *testing.T) { + type DeepModel struct { + Level1 types.Object `tfsdk:"level1"` + } + + level3Type := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "deep_field": types.StringType, + }, + } + + level2Type := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "level3": level3Type, + }, + } + + level1Type := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "level2": level2Type, + }, + } + + input := &DeepModel{ + Level1: types.ObjectValueMust( + level1Type.AttrTypes, + map[string]attr.Value{ + "level2": types.ObjectValueMust( + level2Type.AttrTypes, + map[string]attr.Value{ + "level3": types.ObjectValueMust( + level3Type.AttrTypes, + map[string]attr.Value{ + "deep_field": types.StringUnknown(), + }, + ), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Navigate to the deep field + level2, ok := input.Level1.Attributes()["level2"].(types.Object) + if !ok { + t.Fatalf("expected level2 to be types.Object, got %T", input.Level1.Attributes()["level2"]) + } + level3, ok := level2.Attributes()["level3"].(types.Object) + if !ok { + t.Fatalf("expected level3 to be types.Object, got %T", level2.Attributes()["level3"]) + } + deepField := level3.Attributes()["deep_field"] + + if !deepField.IsNull() { + t.Error("deep_field should be null after processing") + } + }) + + // Test list of lists (nested lists) + t.Run("list of lists with unknown elements", func(t *testing.T) { + type NestedListModel struct { + OuterList types.List `tfsdk:"outer_list"` + } + + innerListType := types.ListType{ElemType: types.StringType} + + input := &NestedListModel{ + OuterList: types.ListValueMust( + innerListType, + []attr.Value{ + types.ListValueMust( + types.StringType, + []attr.Value{ + types.StringValue("a"), + types.StringUnknown(), + }, + ), + types.ListValueMust( + types.StringType, + []attr.Value{ + types.StringUnknown(), + types.StringValue("b"), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + outerElements := input.OuterList.Elements() + + // Check first inner list + innerList1, ok := outerElements[0].(types.List) + if !ok { + t.Fatalf("expected outerElements[0] to be types.List, got %T", outerElements[0]) + } + innerElements1 := innerList1.Elements() + if !innerElements1[1].IsNull() { + t.Error("second element of first inner list should be null") + } + + // Check second inner list + innerList2, ok := outerElements[1].(types.List) + if !ok { + t.Fatalf("expected outerElements[1] to be types.List, got %T", outerElements[1]) + } + innerElements2 := innerList2.Elements() + if !innerElements2[0].IsNull() { + t.Error("first element of second inner list should be null") + } + }) + + // Test map with object values containing unknown fields + t.Run("map with object values containing unknown fields", func(t *testing.T) { + type MapModel struct { + MyMap types.Map `tfsdk:"my_map"` + } + + objectType := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "field1": types.StringType, + "field2": types.BoolType, + }, + } + + input := &MapModel{ + MyMap: types.MapValueMust( + objectType, + map[string]attr.Value{ + "key1": types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "field1": types.StringValue("known"), + "field2": types.BoolUnknown(), + }, + ), + "key2": types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "field1": types.StringUnknown(), + "field2": types.BoolValue(true), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.MyMap.Elements() + + // Check key1 object + obj1, ok := elements["key1"].(types.Object) + if !ok { + t.Fatalf("expected elements[\"key1\"] to be types.Object, got %T", elements["key1"]) + } + if !obj1.Attributes()["field2"].IsNull() { + t.Error("key1 object's field2 should be null") + } + if obj1.Attributes()["field1"].IsNull() { + t.Error("key1 object's field1 should not be null") + } + + // Check key2 object + obj2, ok := elements["key2"].(types.Object) + if !ok { + t.Fatalf("expected elements[\"key2\"] to be types.Object, got %T", elements["key2"]) + } + if !obj2.Attributes()["field1"].IsNull() { + t.Error("key2 object's field1 should be null") + } + if obj2.Attributes()["field2"].IsNull() { + t.Error("key2 object's field2 should not be null") + } + }) + + // Test set with unknown elements + t.Run("set with unknown elements", func(t *testing.T) { + type SetModel struct { + MySet types.Set `tfsdk:"my_set"` + } + + input := &SetModel{ + MySet: types.SetValueMust( + types.StringType, + []attr.Value{ + types.StringValue("known"), + types.StringUnknown(), + types.StringNull(), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.MySet.Elements() + + // Count null elements (should have at least 2: the original null and the converted unknown) + nullCount := 0 + for _, elem := range elements { + if elem.IsNull() { + nullCount++ + } + } + + if nullCount < 2 { + t.Errorf("expected at least 2 null elements, got %d", nullCount) + } + }) + + // Test set of objects with unknown fields + t.Run("set of objects with unknown fields", func(t *testing.T) { + type SetOfObjectsModel struct { + Objects types.Set `tfsdk:"objects"` + } + + objectType := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "id": types.StringType, + "name": types.StringType, + }, + } + + input := &SetOfObjectsModel{ + Objects: types.SetValueMust( + objectType, + []attr.Value{ + types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "id": types.StringValue("1"), + "name": types.StringUnknown(), + }, + ), + types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "id": types.StringUnknown(), + "name": types.StringValue("Test"), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.Objects.Elements() + if len(elements) != 2 { + t.Fatalf("expected 2 elements, got %d", len(elements)) + } + + // Check that unknown fields within objects were converted to null + for _, elem := range elements { + obj, ok := elem.(types.Object) + if !ok { + t.Fatalf("expected element to be types.Object, got %T", elem) + } + attrs := obj.Attributes() + + // At least one field in each object should be null (the unknown one) + if !attrs["name"].IsNull() && !attrs["id"].IsNull() { + t.Error("expected at least one field to be null in each object") + } + } + }) + + // Test map with list values containing objects + t.Run("map with list values containing objects with unknown fields", func(t *testing.T) { + type ComplexMapModel struct { + MyMap types.Map `tfsdk:"my_map"` + } + + objectType := types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "prop": types.StringType, + }, + } + listOfObjectsType := types.ListType{ElemType: objectType} + + input := &ComplexMapModel{ + MyMap: types.MapValueMust( + listOfObjectsType, + map[string]attr.Value{ + "key1": types.ListValueMust( + objectType, + []attr.Value{ + types.ObjectValueMust( + objectType.AttrTypes, + map[string]attr.Value{ + "prop": types.StringUnknown(), + }, + ), + }, + ), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + elements := input.MyMap.Elements() + list, ok := elements["key1"].(types.List) + if !ok { + t.Fatalf("expected elements[\"key1\"] to be types.List, got %T", elements["key1"]) + } + listElements := list.Elements() + obj, ok := listElements[0].(types.Object) + if !ok { + t.Fatalf("expected listElements[0] to be types.Object, got %T", listElements[0]) + } + + if !obj.Attributes()["prop"].IsNull() { + t.Error("prop field should be null after processing") + } + }) + + // Test top-level null object (should remain null) + t.Run("top-level null object", func(t *testing.T) { + type NullObjectModel struct { + MyObject types.Object `tfsdk:"my_object"` + } + + attrTypes := map[string]attr.Type{"field": types.StringType} + input := &NullObjectModel{ + MyObject: types.ObjectNull(attrTypes), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !input.MyObject.IsNull() { + t.Error("top-level null object should remain null") + } + }) + + // Test top-level unknown list (should be converted to null) + t.Run("top-level unknown list", func(t *testing.T) { + type UnknownListModel struct { + MyList types.List `tfsdk:"my_list"` + } + + input := &UnknownListModel{ + MyList: types.ListUnknown(types.StringType), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !input.MyList.IsNull() { + t.Error("top-level unknown list should be converted to null") + } + if input.MyList.IsUnknown() { + t.Error("top-level list should no longer be unknown") + } + }) + + // Test empty list (should remain unchanged) + t.Run("empty list", func(t *testing.T) { + type EmptyListModel struct { + MyList types.List `tfsdk:"my_list"` + } + + input := &EmptyListModel{ + MyList: types.ListValueMust(types.StringType, []attr.Value{}), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if input.MyList.IsNull() { + t.Error("empty list should not become null") + } + if len(input.MyList.Elements()) != 0 { + t.Error("list should remain empty") + } + }) + + // Test object with all null fields + t.Run("object with all null fields", func(t *testing.T) { + type AllNullFieldsModel struct { + MyObject types.Object `tfsdk:"my_object"` + } + + attrTypes := map[string]attr.Type{ + "field1": types.StringType, + "field2": types.Int64Type, + } + + input := &AllNullFieldsModel{ + MyObject: types.ObjectValueMust( + attrTypes, + map[string]attr.Value{ + "field1": types.StringNull(), + "field2": types.Int64Null(), + }, + ), + } + + err := SetModelFieldsToNull(ctx, input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + attrs := input.MyObject.Attributes() + if !attrs["field1"].IsNull() || !attrs["field2"].IsNull() { + t.Error("all fields should remain null") + } + }) +} + +func TestShouldWait(t *testing.T) { + tests := []struct { + name string + envValue string + setEnv bool + expected bool + }{ + { + name: "env not set - should wait", + setEnv: false, + expected: true, + }, + { + name: "env set to empty string - should wait", + envValue: "", + setEnv: true, + expected: true, + }, + { + name: "env set to 'true' - should wait", + envValue: "true", + setEnv: true, + expected: true, + }, + { + name: "env set to 'TRUE' - should wait (case insensitive)", + envValue: "TRUE", + setEnv: true, + expected: true, + }, + { + name: "env set to 'True' - should wait (case insensitive)", + envValue: "True", + setEnv: true, + expected: true, + }, + { + name: "env set to 'false' - should not wait", + envValue: "false", + setEnv: true, + expected: false, + }, + { + name: "env set to 'FALSE' - should not wait", + envValue: "FALSE", + setEnv: true, + expected: false, + }, + { + name: "env set to '0' - should not wait", + envValue: "0", + setEnv: true, + expected: false, + }, + { + name: "env set to 'no' - should not wait", + envValue: "no", + setEnv: true, + expected: false, + }, + { + name: "env set to random value - should not wait", + envValue: "random", + setEnv: true, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original env value + originalValue, wasSet := os.LookupEnv("STACKIT_TF_WAIT_FOR_READY") + defer func() { + if wasSet { + _ = os.Setenv("STACKIT_TF_WAIT_FOR_READY", originalValue) + } else { + _ = os.Unsetenv("STACKIT_TF_WAIT_FOR_READY") + } + }() + + // Set up test environment + if tt.setEnv { + _ = os.Setenv("STACKIT_TF_WAIT_FOR_READY", tt.envValue) + } else { + _ = os.Unsetenv("STACKIT_TF_WAIT_FOR_READY") + } + + // Test + result := ShouldWait() + if result != tt.expected { + t.Errorf("ShouldWait() = %v, want %v", result, tt.expected) + } + }) + } +} + +// mockNetError implements net.Error for testing +type mockNetError struct { + err string + isTimeout bool + isTemp bool +} + +func (e *mockNetError) Error() string { return e.err } +func (e *mockNetError) Timeout() bool { return e.isTimeout } +func (e *mockNetError) Temporary() bool { return e.isTemp } + +func TestShouldIgnoreWaitError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error - should not ignore", + err: nil, + expected: false, + }, + { + name: "context deadline exceeded - should ignore", + err: context.DeadlineExceeded, + expected: true, + }, + { + name: "context canceled - should ignore", + err: context.Canceled, + expected: true, + }, + { + name: "wrapped context deadline exceeded - should ignore", + err: fmt.Errorf("operation failed: %w", context.DeadlineExceeded), + expected: true, + }, + { + name: "wrapped context canceled - should ignore", + err: fmt.Errorf("request canceled: %w", context.Canceled), + expected: true, + }, + { + name: "network timeout error - should ignore", + err: &mockNetError{ + err: "network timeout", + isTimeout: true, + isTemp: true, + }, + expected: true, + }, + { + name: "network temporary error - should ignore", + err: &mockNetError{ + err: "connection reset", + isTimeout: false, + isTemp: true, + }, + expected: true, + }, + { + name: "wrapped network error - should ignore", + err: fmt.Errorf("failed to connect: %w", &mockNetError{ + err: "connection refused", + isTimeout: false, + isTemp: true, + }), + expected: true, + }, + { + name: "API 500 internal server error - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusInternalServerError, + }, + expected: true, + }, + { + name: "API 502 bad gateway - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusBadGateway, + }, + expected: true, + }, + { + name: "API 503 service unavailable - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusServiceUnavailable, + }, + expected: true, + }, + { + name: "API 504 gateway timeout - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusGatewayTimeout, + }, + expected: true, + }, + { + name: "API 404 not found - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusNotFound, + }, + expected: true, + }, + { + name: "API 410 gone - should ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusGone, + }, + expected: true, + }, + { + name: "wrapped API 500 error - should ignore", + err: fmt.Errorf("request failed: %w", &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusInternalServerError, + }), + expected: true, + }, + { + name: "wrapped API 404 error - should ignore", + err: fmt.Errorf("resource not found: %w", &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusNotFound, + }), + expected: true, + }, + { + name: "API 400 bad request - should not ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusBadRequest, + }, + expected: false, + }, + { + name: "API 401 unauthorized - should not ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusUnauthorized, + }, + expected: false, + }, + { + name: "API 403 forbidden - should not ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusForbidden, + }, + expected: false, + }, + { + name: "API 409 conflict - should not ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusConflict, + }, + expected: false, + }, + { + name: "API 422 unprocessable entity - should not ignore", + err: &oapierror.GenericOpenAPIError{ + StatusCode: http.StatusUnprocessableEntity, + }, + expected: false, + }, + { + name: "generic error - should not ignore", + err: errors.New("some random error"), + expected: false, + }, + { + name: "wrapped generic error - should not ignore", + err: fmt.Errorf("operation failed: %w", errors.New("some error")), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ShouldIgnoreWaitError(tt.err) + if result != tt.expected { + t.Errorf("ShouldIgnoreWaitError() = %v, want %v for error: %v", result, tt.expected, tt.err) + } + }) + } +}