From 0e85efab3d3ed13be5d13d4432692c48dd195372 Mon Sep 17 00:00:00 2001 From: umagnus Date: Fri, 15 Nov 2024 03:43:58 +0000 Subject: [PATCH] add timeout on mount --- pkg/nfs/nodeserver.go | 7 ++++++ pkg/nfs/utils.go | 27 ++++++++++++++++++++++ pkg/nfs/utils_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+) diff --git a/pkg/nfs/nodeserver.go b/pkg/nfs/nodeserver.go index 8e330303..afbde9d0 100644 --- a/pkg/nfs/nodeserver.go +++ b/pkg/nfs/nodeserver.go @@ -130,6 +130,13 @@ func (ns *NodeServer) NodePublishVolume(_ context.Context, req *csi.NodePublishV } klog.V(2).Infof("NodePublishVolume: volumeID(%v) source(%s) targetPath(%s) mountflags(%v)", volumeID, source, targetPath, mountOptions) + execFunc := func() error { + return ns.mounter.Mount(source, targetPath, "nfs", mountOptions) + } + timeoutFunc := func() error { return fmt.Errorf("time out") } + if err := WaitUntilTimeout(90*time.Second, execFunc, timeoutFunc); err != nil { + return nil, status.Error(codes.Internal, fmt.Sprintf("volume(%s) mount %q on %q failed with %v", volumeID, source, targetPath, err)) + } err = ns.mounter.Mount(source, targetPath, "nfs", mountOptions) if err != nil { if os.IsPermission(err) { diff --git a/pkg/nfs/utils.go b/pkg/nfs/utils.go index 3e54923e..4c5ad1c6 100644 --- a/pkg/nfs/utils.go +++ b/pkg/nfs/utils.go @@ -221,3 +221,30 @@ func getRootDir(path string) string { parts := strings.Split(path, "/") return parts[0] } + +// ExecFunc returns a exec function's output and error +type ExecFunc func() (err error) + +// TimeoutFunc returns output and error if an ExecFunc timeout +type TimeoutFunc func() (err error) + +// WaitUntilTimeout waits for the exec function to complete or return timeout error +func WaitUntilTimeout(timeout time.Duration, execFunc ExecFunc, timeoutFunc TimeoutFunc) error { + // Create a channel to receive the result of the exec function + done := make(chan bool) + var err error + + // Start the exec function in a goroutine + go func() { + err = execFunc() + done <- true + }() + + // Wait for the function to complete or time out + select { + case <-done: + return err + case <-time.After(timeout): + return timeoutFunc() + } +} diff --git a/pkg/nfs/utils_test.go b/pkg/nfs/utils_test.go index 1d843085..66aa3fc4 100644 --- a/pkg/nfs/utils_test.go +++ b/pkg/nfs/utils_test.go @@ -428,3 +428,55 @@ func TestGetRootPath(t *testing.T) { } } } + +func TestWaitUntilTimeout(t *testing.T) { + tests := []struct { + desc string + timeout time.Duration + execFunc ExecFunc + timeoutFunc TimeoutFunc + expectedErr error + }{ + { + desc: "execFunc returns error", + timeout: 1 * time.Second, + execFunc: func() error { + return fmt.Errorf("execFunc error") + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: fmt.Errorf("execFunc error"), + }, + { + desc: "execFunc timeout", + timeout: 1 * time.Second, + execFunc: func() error { + time.Sleep(2 * time.Second) + return nil + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: fmt.Errorf("timeout error"), + }, + { + desc: "execFunc completed successfully", + timeout: 1 * time.Second, + execFunc: func() error { + return nil + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: nil, + }, + } + + for _, test := range tests { + err := WaitUntilTimeout(test.timeout, test.execFunc, test.timeoutFunc) + if err != nil && (err.Error() != test.expectedErr.Error()) { + t.Errorf("unexpected error: %v, expected error: %v", err, test.expectedErr) + } + } +}