add timeout on mount

This commit is contained in:
umagnus 2024-11-15 03:43:58 +00:00
parent 18fdc4a39e
commit 0e85efab3d
3 changed files with 86 additions and 0 deletions

View File

@ -130,6 +130,13 @@ func (ns *NodeServer) NodePublishVolume(_ context.Context, req *csi.NodePublishV
}
klog.V(2).Infof("NodePublishVolume: volumeID(%v) source(%s) targetPath(%s) mountflags(%v)", volumeID, source, targetPath, mountOptions)
execFunc := func() error {
return ns.mounter.Mount(source, targetPath, "nfs", mountOptions)
}
timeoutFunc := func() error { return fmt.Errorf("time out") }
if err := WaitUntilTimeout(90*time.Second, execFunc, timeoutFunc); err != nil {
return nil, status.Error(codes.Internal, fmt.Sprintf("volume(%s) mount %q on %q failed with %v", volumeID, source, targetPath, err))
}
err = ns.mounter.Mount(source, targetPath, "nfs", mountOptions)
if err != nil {
if os.IsPermission(err) {

View File

@ -221,3 +221,30 @@ func getRootDir(path string) string {
parts := strings.Split(path, "/")
return parts[0]
}
// ExecFunc returns a exec function's output and error
type ExecFunc func() (err error)
// TimeoutFunc returns output and error if an ExecFunc timeout
type TimeoutFunc func() (err error)
// WaitUntilTimeout waits for the exec function to complete or return timeout error
func WaitUntilTimeout(timeout time.Duration, execFunc ExecFunc, timeoutFunc TimeoutFunc) error {
// Create a channel to receive the result of the exec function
done := make(chan bool)
var err error
// Start the exec function in a goroutine
go func() {
err = execFunc()
done <- true
}()
// Wait for the function to complete or time out
select {
case <-done:
return err
case <-time.After(timeout):
return timeoutFunc()
}
}

View File

@ -428,3 +428,55 @@ func TestGetRootPath(t *testing.T) {
}
}
}
func TestWaitUntilTimeout(t *testing.T) {
tests := []struct {
desc string
timeout time.Duration
execFunc ExecFunc
timeoutFunc TimeoutFunc
expectedErr error
}{
{
desc: "execFunc returns error",
timeout: 1 * time.Second,
execFunc: func() error {
return fmt.Errorf("execFunc error")
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: fmt.Errorf("execFunc error"),
},
{
desc: "execFunc timeout",
timeout: 1 * time.Second,
execFunc: func() error {
time.Sleep(2 * time.Second)
return nil
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: fmt.Errorf("timeout error"),
},
{
desc: "execFunc completed successfully",
timeout: 1 * time.Second,
execFunc: func() error {
return nil
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: nil,
},
}
for _, test := range tests {
err := WaitUntilTimeout(test.timeout, test.execFunc, test.timeoutFunc)
if err != nil && (err.Error() != test.expectedErr.Error()) {
t.Errorf("unexpected error: %v, expected error: %v", err, test.expectedErr)
}
}
}