diff --git a/pkg/nfs/nfs.go b/pkg/nfs/nfs.go index f6558c36..94adcce0 100644 --- a/pkg/nfs/nfs.go +++ b/pkg/nfs/nfs.go @@ -34,10 +34,11 @@ type Driver struct { perm *uint32 //ids *identityServer - ns *NodeServer - cap map[csi.VolumeCapability_AccessMode_Mode]bool - cscap []*csi.ControllerServiceCapability - nscap []*csi.NodeServiceCapability + ns *NodeServer + cap map[csi.VolumeCapability_AccessMode_Mode]bool + cscap []*csi.ControllerServiceCapability + nscap []*csi.NodeServiceCapability + volumeLocks *VolumeLocks } const ( @@ -87,6 +88,7 @@ func NewNFSdriver(nodeID, endpoint string, perm *uint32) *Driver { csi.NodeServiceCapability_RPC_GET_VOLUME_STATS, csi.NodeServiceCapability_RPC_UNKNOWN, }) + n.volumeLocks = NewVolumeLocks() return n } diff --git a/pkg/nfs/nfs_test.go b/pkg/nfs/nfs_test.go index db0d0338..e789a700 100644 --- a/pkg/nfs/nfs_test.go +++ b/pkg/nfs/nfs_test.go @@ -59,7 +59,7 @@ func NewEmptyDriver(emptyField string) *Driver { perm: perm, } } - + d.volumeLocks = NewVolumeLocks() return d } diff --git a/pkg/nfs/nodeserver.go b/pkg/nfs/nodeserver.go index cff6a48d..88c9ad84 100644 --- a/pkg/nfs/nodeserver.go +++ b/pkg/nfs/nodeserver.go @@ -65,6 +65,11 @@ func (ns *NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis return &csi.NodePublishVolumeResponse{}, nil } + if acquired := ns.Driver.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volumeID) + } + defer ns.Driver.volumeLocks.Release(volumeID) + mountOptions := req.GetVolumeCapability().GetMount().GetMountFlags() if req.GetReadonly() { mountOptions = append(mountOptions, "ro") @@ -117,6 +122,11 @@ func (ns *NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpu return nil, status.Error(codes.NotFound, "Volume not mounted") } + if acquired := ns.Driver.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volumeID) + } + defer ns.Driver.volumeLocks.Release(volumeID) + klog.V(2).Infof("NodeUnpublishVolume: CleanupMountPoint %s on volumeID(%s)", targetPath, volumeID) err = mount.CleanupMountPoint(targetPath, ns.mounter, false) if err != nil { diff --git a/pkg/nfs/nodeserver_test.go b/pkg/nfs/nodeserver_test.go index 6a0cb633..5bd2a01d 100644 --- a/pkg/nfs/nodeserver_test.go +++ b/pkg/nfs/nodeserver_test.go @@ -19,6 +19,7 @@ package nfs import ( "context" "errors" + "fmt" "os" "reflect" "testing" @@ -35,15 +36,22 @@ const ( ) func TestNodePublishVolume(t *testing.T) { + ns, err := getTestNodeServer() + if err != nil { + t.Fatalf(err.Error()) + } + volumeCap := csi.VolumeCapability_AccessMode{Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER} alreadyMountedTarget := testutil.GetWorkDirPath("false_is_likely_exist_target", t) targetTest := testutil.GetWorkDirPath("target_test", t) tests := []struct { desc string + setup func() req csi.NodePublishVolumeRequest skipOnWindows bool expectedErr error + cleanup func() }{ { desc: "[Error] Volume capabilities missing", @@ -61,6 +69,19 @@ func TestNodePublishVolume(t *testing.T) { VolumeId: "vol_1"}, expectedErr: status.Error(codes.InvalidArgument, "Target path not provided"), }, + { + desc: "[Error] Volume operation in progress", + setup: func() { + ns.Driver.volumeLocks.TryAcquire("vol_1") + }, + req: csi.NodePublishVolumeRequest{VolumeCapability: &csi.VolumeCapability{AccessMode: &volumeCap}, + VolumeId: "vol_1", + TargetPath: targetTest}, + expectedErr: status.Error(codes.Aborted, fmt.Sprintf(volumeOperationAlreadyExistsFmt, "vol_1")), + cleanup: func() { + ns.Driver.volumeLocks.Release("vol_1") + }, + }, { desc: "[Success] Stage target path missing", req: csi.NodePublishVolumeRequest{VolumeCapability: &csi.VolumeCapability{AccessMode: &volumeCap}, @@ -97,16 +118,17 @@ func TestNodePublishVolume(t *testing.T) { // setup _ = makeDir(alreadyMountedTarget) - ns, err := getTestNodeServer() - if err != nil { - t.Fatalf(err.Error()) - } - for _, tc := range tests { + if tc.setup != nil { + tc.setup() + } _, err := ns.NodePublishVolume(context.Background(), &tc.req) if !reflect.DeepEqual(err, tc.expectedErr) { t.Errorf("Desc:%v\nUnexpected error: %v\nExpected: %v", tc.desc, err, tc.expectedErr) } + if tc.cleanup != nil { + tc.cleanup() + } } // Clean up @@ -118,14 +140,22 @@ func TestNodePublishVolume(t *testing.T) { } func TestNodeUnpublishVolume(t *testing.T) { + ns, err := getTestNodeServer() + if err != nil { + t.Fatalf(err.Error()) + } + errorTarget := testutil.GetWorkDirPath("error_is_likely_target", t) targetTest := testutil.GetWorkDirPath("target_test", t) targetFile := testutil.GetWorkDirPath("abc.go", t) + alreadyMountedTarget := testutil.GetWorkDirPath("false_is_likely_exist_target", t) tests := []struct { desc string + setup func() req csi.NodeUnpublishVolumeRequest expectedErr error + cleanup func() }{ { desc: "[Error] Volume ID missing", @@ -147,21 +177,33 @@ func TestNodeUnpublishVolume(t *testing.T) { req: csi.NodeUnpublishVolumeRequest{TargetPath: targetFile, VolumeId: "vol_1"}, expectedErr: status.Error(codes.NotFound, "Volume not mounted"), }, + { + desc: "[Error] Volume operation in progress", + setup: func() { + ns.Driver.volumeLocks.TryAcquire("vol_1") + }, + req: csi.NodeUnpublishVolumeRequest{TargetPath: alreadyMountedTarget, VolumeId: "vol_1"}, + expectedErr: status.Error(codes.Aborted, fmt.Sprintf(volumeOperationAlreadyExistsFmt, "vol_1")), + cleanup: func() { + ns.Driver.volumeLocks.Release("vol_1") + }, + }, } // Setup _ = makeDir(errorTarget) - ns, err := getTestNodeServer() - if err != nil { - t.Fatalf(err.Error()) - } - for _, tc := range tests { + if tc.setup != nil { + tc.setup() + } _, err := ns.NodeUnpublishVolume(context.Background(), &tc.req) if !reflect.DeepEqual(err, tc.expectedErr) { t.Errorf("Desc:%v\nUnexpected error: %v\nExpected: %v", tc.desc, err, tc.expectedErr) } + if tc.cleanup != nil { + tc.cleanup() + } } // Clean up diff --git a/pkg/nfs/utils.go b/pkg/nfs/utils.go index c2d526ce..3b3c68ef 100644 --- a/pkg/nfs/utils.go +++ b/pkg/nfs/utils.go @@ -19,11 +19,13 @@ package nfs import ( "fmt" "strings" + "sync" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" "golang.org/x/net/context" "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" ) @@ -92,3 +94,34 @@ func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, h } return resp, err } + +const ( + volumeOperationAlreadyExistsFmt = "An operation with the given Volume ID %s already exists" +) + +type VolumeLocks struct { + locks sets.String + mux sync.Mutex +} + +func NewVolumeLocks() *VolumeLocks { + return &VolumeLocks{ + locks: sets.NewString(), + } +} + +func (vl *VolumeLocks) TryAcquire(volumeID string) bool { + vl.mux.Lock() + defer vl.mux.Unlock() + if vl.locks.Has(volumeID) { + return false + } + vl.locks.Insert(volumeID) + return true +} + +func (vl *VolumeLocks) Release(volumeID string) { + vl.mux.Lock() + defer vl.mux.Unlock() + vl.locks.Delete(volumeID) +}