|
8 | 8 | "encoding/json" |
9 | 9 | "errors" |
10 | 10 | "fmt" |
| 11 | + "io" |
11 | 12 | "os" |
12 | 13 | "os/signal" |
13 | 14 | "regexp" |
@@ -1624,6 +1625,78 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot |
1624 | 1625 | return nil |
1625 | 1626 | } |
1626 | 1627 |
|
| 1628 | +func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error { |
| 1629 | + rpc := "WriteStateBytes" |
| 1630 | + ctx := srv.Context() |
| 1631 | + ctx = s.loggingContext(ctx) |
| 1632 | + ctx = logging.RpcContext(ctx, rpc) |
| 1633 | + // ctx = logging.StateStoreContext(ctx, protoReq.TypeName) |
| 1634 | + ctx = s.stoppableContext(ctx) |
| 1635 | + // logging.ProtocolTrace(ctx, "Received request") |
| 1636 | + // defer logging.ProtocolTrace(ctx, "Served request") |
| 1637 | + |
| 1638 | + ctx = tf6serverlogging.DownstreamRequest(ctx) |
| 1639 | + |
| 1640 | + server, ok := s.downstream.(tfprotov6.StateStoreServer) |
| 1641 | + if !ok { |
| 1642 | + err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes") |
| 1643 | + logging.ProtocolError(ctx, err.Error()) |
| 1644 | + return err |
| 1645 | + } |
| 1646 | + |
| 1647 | + iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) { |
| 1648 | + for { |
| 1649 | + chunk, err := srv.Recv() |
| 1650 | + if err == io.EOF { |
| 1651 | + break |
| 1652 | + } |
| 1653 | + if err != nil { |
| 1654 | + // attempt to send the error back to client |
| 1655 | + msgErr := srv.SendMsg(&tfplugin6.WriteStateBytes_Response{ |
| 1656 | + Diagnostics: toproto.Diagnostics([]*tfprotov6.Diagnostic{ |
| 1657 | + { |
| 1658 | + Severity: tfprotov6.DiagnosticSeverityError, |
| 1659 | + Summary: "Writing state chunk failed", |
| 1660 | + Detail: fmt.Sprintf("Attempt to write a byte chunk of state %q to %q failed: %s", |
| 1661 | + chunk.StateId, chunk.TypeName, err), |
| 1662 | + }, |
| 1663 | + }), |
| 1664 | + }) |
| 1665 | + if msgErr != nil { |
| 1666 | + err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes") |
| 1667 | + logging.ProtocolError(ctx, err.Error()) |
| 1668 | + return |
| 1669 | + } |
| 1670 | + return |
| 1671 | + } |
| 1672 | + |
| 1673 | + ok := yield(tfprotov6.WriteStateByteChunk{ |
| 1674 | + Bytes: chunk.Bytes, |
| 1675 | + TotalLength: chunk.TotalLength, |
| 1676 | + Range: tfprotov6.StateByteRange{ |
| 1677 | + Start: chunk.Range.Start, |
| 1678 | + End: chunk.Range.End, |
| 1679 | + }, |
| 1680 | + }) |
| 1681 | + if !ok { |
| 1682 | + return |
| 1683 | + } |
| 1684 | + |
| 1685 | + } |
| 1686 | + } |
| 1687 | + |
| 1688 | + resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{ |
| 1689 | + Chunks: iterator, |
| 1690 | + }) |
| 1691 | + if err != nil { |
| 1692 | + return err |
| 1693 | + } |
| 1694 | + |
| 1695 | + return srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{ |
| 1696 | + Diagnostics: toproto.Diagnostics(resp.Diagnostics), |
| 1697 | + }) |
| 1698 | +} |
| 1699 | + |
1627 | 1700 | func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) { |
1628 | 1701 | rpc := "GetStates" |
1629 | 1702 | ctx = s.loggingContext(ctx) |
|
0 commit comments