diff --git a/pkg/nfs/controllerserver.go b/pkg/nfs/controllerserver.go index 63986cfb..103e80d1 100644 --- a/pkg/nfs/controllerserver.go +++ b/pkg/nfs/controllerserver.go @@ -75,7 +75,8 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol if len(name) == 0 { return nil, status.Error(codes.InvalidArgument, "CreateVolume name must be provided") } - if err := cs.validateVolumeCapabilities(req.GetVolumeCapabilities()); err != nil { + + if err := isValidVolumeCapabilities(req.GetVolumeCapabilities()); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } @@ -208,8 +209,8 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(ctx context.Context, req if len(req.GetVolumeId()) == 0 { return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request") } - if req.GetVolumeCapabilities() == nil { - return nil, status.Error(codes.InvalidArgument, "Volume capabilities missing in request") + if err := isValidVolumeCapabilities(req.GetVolumeCapabilities()); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) } return &csi.ValidateVolumeCapabilitiesResponse{ @@ -252,41 +253,6 @@ func (cs *ControllerServer) ControllerExpandVolume(ctx context.Context, req *csi return nil, status.Error(codes.Unimplemented, "") } -func (cs *ControllerServer) validateVolumeCapabilities(caps []*csi.VolumeCapability) error { - if len(caps) == 0 { - return fmt.Errorf("volume capabilities must be provided") - } - - for _, c := range caps { - if err := cs.validateVolumeCapability(c); err != nil { - return err - } - } - return nil -} - -func (cs *ControllerServer) validateVolumeCapability(c *csi.VolumeCapability) error { - if c == nil { - return fmt.Errorf("volume capability must be provided") - } - - // Validate access mode - accessMode := c.GetAccessMode() - if accessMode == nil { - return fmt.Errorf("volume capability access mode not set") - } - if !cs.Driver.cap[accessMode.Mode] { - return fmt.Errorf("driver does not support access mode: %v", accessMode.Mode.String()) - } - - // Validate access type - accessType := c.GetAccessType() - if accessType == nil { - return fmt.Errorf("volume capability access type not set") - } - return nil -} - // Mount nfs server at base-dir func (cs *ControllerServer) internalMount(ctx context.Context, vol *nfsVolume, volumeContext map[string]string, volCap *csi.VolumeCapability) error { sharePath := filepath.Join(string(filepath.Separator) + vol.baseDir) @@ -422,3 +388,16 @@ func getNfsVolFromID(id string) (*nfsVolume, error) { subDir: subDir, }, nil } + +// isValidVolumeCapabilities validates the given VolumeCapability array is valid +func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error { + if len(volCaps) == 0 { + return fmt.Errorf("volume capabilities missing in request") + } + for _, c := range volCaps { + if c.GetBlock() != nil { + return fmt.Errorf("block volume capability not supported") + } + } + return nil +} diff --git a/pkg/nfs/controllerserver_test.go b/pkg/nfs/controllerserver_test.go index 8b104101..2444a4fa 100644 --- a/pkg/nfs/controllerserver_test.go +++ b/pkg/nfs/controllerserver_test.go @@ -165,24 +165,6 @@ func TestCreateVolume(t *testing.T) { }, expectErr: true, }, - { - name: "invalid volume capability", - req: &csi.CreateVolumeRequest{ - Name: testCSIVolume, - VolumeCapabilities: []*csi.VolumeCapability{ - { - AccessMode: &csi.VolumeCapability_AccessMode{ - Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, - }, - }, - }, - Parameters: map[string]string{ - paramServer: testServer, - paramShare: testBaseDir, - }, - }, - expectErr: true, - }, { name: "invalid create context", req: &csi.CreateVolumeRequest{ @@ -317,78 +299,6 @@ func TestDeleteVolume(t *testing.T) { } } -func TestValidateVolumeCapabilities(t *testing.T) { - capabilities := []*csi.VolumeCapability{ - { - AccessType: &csi.VolumeCapability_Mount{ - Mount: &csi.VolumeCapability_MountVolume{}, - }, - AccessMode: &csi.VolumeCapability_AccessMode{ - Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, - }, - }, - } - - cases := []struct { - desc string - req *csi.ValidateVolumeCapabilitiesRequest - resp *csi.ValidateVolumeCapabilitiesResponse - expectedErr error - }{ - { - desc: "Volume ID missing", - req: &csi.ValidateVolumeCapabilitiesRequest{}, - resp: nil, - expectedErr: status.Error(codes.InvalidArgument, "Volume ID missing in request"), - }, - { - desc: "Volume capabilities missing", - req: &csi.ValidateVolumeCapabilitiesRequest{VolumeId: testVolumeID}, - resp: nil, - expectedErr: status.Error(codes.InvalidArgument, "Volume capabilities missing in request"), - }, - { - desc: "valid request", - req: &csi.ValidateVolumeCapabilitiesRequest{ - VolumeId: testVolumeID, - VolumeCapabilities: capabilities, - }, - resp: &csi.ValidateVolumeCapabilitiesResponse{ - Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{VolumeCapabilities: capabilities}, - }, - expectedErr: nil, - }, - { - desc: "valid request with newTestVolumeID", - req: &csi.ValidateVolumeCapabilitiesRequest{ - VolumeId: newTestVolumeID, - VolumeCapabilities: capabilities, - }, - resp: &csi.ValidateVolumeCapabilitiesResponse{ - Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{VolumeCapabilities: capabilities}, - }, - expectedErr: nil, - }, - } - - for _, test := range cases { - test := test //pin - t.Run(test.desc, func(t *testing.T) { - cs := initTestController(t) - resp, err := cs.ValidateVolumeCapabilities(context.TODO(), test.req) - if test.expectedErr == nil && err != nil { - t.Errorf("test %q failed: %v", test.desc, err) - } - if test.expectedErr != nil && err == nil { - t.Errorf("test %q failed; expected error %v, got success", test.desc, test.expectedErr) - } - if !reflect.DeepEqual(resp, test.resp) { - t.Errorf("test %q failed: got resp %+v, expected %+v", test.desc, resp, test.resp) - } - }) - } -} - func TestControllerGetCapabilities(t *testing.T) { cases := []struct { desc string @@ -526,3 +436,46 @@ func TestNfsVolFromId(t *testing.T) { }) } } + +func TestIsValidVolumeCapabilities(t *testing.T) { + mountVolumeCapabilities := []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{}, + }, + }, + } + blockVolumeCapabilities := []*csi.VolumeCapability{ + { + AccessType: &csi.VolumeCapability_Block{ + Block: &csi.VolumeCapability_BlockVolume{}, + }, + }, + } + + cases := []struct { + desc string + volCaps []*csi.VolumeCapability + expectErr error + }{ + { + volCaps: mountVolumeCapabilities, + expectErr: nil, + }, + { + volCaps: blockVolumeCapabilities, + expectErr: fmt.Errorf("block volume capability not supported"), + }, + { + volCaps: []*csi.VolumeCapability{}, + expectErr: fmt.Errorf("volume capabilities missing in request"), + }, + } + + for _, test := range cases { + err := isValidVolumeCapabilities(test.volCaps) + if !reflect.DeepEqual(err, test.expectErr) { + t.Errorf("[test: %s] Unexpected error: %v, expected error: %v", test.desc, err, test.expectErr) + } + } +} diff --git a/pkg/nfs/nfs.go b/pkg/nfs/nfs.go index 0de496a1..9c977cb8 100644 --- a/pkg/nfs/nfs.go +++ b/pkg/nfs/nfs.go @@ -41,7 +41,6 @@ type Driver struct { //ids *identityServer ns *NodeServer - cap map[csi.VolumeCapability_AccessMode_Mode]bool cscap []*csi.ControllerServiceCapability nscap []*csi.NodeServiceCapability volumeLocks *VolumeLocks @@ -70,20 +69,8 @@ func NewDriver(options *DriverOptions) *Driver { endpoint: options.Endpoint, mountPermissions: options.MountPermissions, workingMountDir: options.WorkingMountDir, - cap: map[csi.VolumeCapability_AccessMode_Mode]bool{}, } - vcam := []csi.VolumeCapability_AccessMode_Mode{ - csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, - csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY, - csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER, - csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER, - csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, - csi.VolumeCapability_AccessMode_MULTI_NODE_SINGLE_WRITER, - csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, - } - n.AddVolumeCapabilityAccessModes(vcam) - n.AddControllerServiceCapabilities([]csi.ControllerServiceCapability_RPC_Type{ csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, csi.ControllerServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER, @@ -124,15 +111,6 @@ func (n *Driver) Run(testMode bool) { s.Wait() } -func (n *Driver) AddVolumeCapabilityAccessModes(vc []csi.VolumeCapability_AccessMode_Mode) []*csi.VolumeCapability_AccessMode { - var vca []*csi.VolumeCapability_AccessMode - for _, c := range vc { - vca = append(vca, &csi.VolumeCapability_AccessMode{Mode: c}) - n.cap[c] = true - } - return vca -} - func (n *Driver) AddControllerServiceCapabilities(cl []csi.ControllerServiceCapability_RPC_Type) { var csc []*csi.ControllerServiceCapability for _, c := range cl { diff --git a/pkg/nfs/nfs_test.go b/pkg/nfs/nfs_test.go index 61fd58d3..e1035832 100644 --- a/pkg/nfs/nfs_test.go +++ b/pkg/nfs/nfs_test.go @@ -38,21 +38,18 @@ func NewEmptyDriver(emptyField string) *Driver { name: DefaultDriverName, version: "", nodeID: fakeNodeID, - cap: map[csi.VolumeCapability_AccessMode_Mode]bool{}, } case "name": d = &Driver{ name: "", version: driverVersion, nodeID: fakeNodeID, - cap: map[csi.VolumeCapability_AccessMode_Mode]bool{}, } default: d = &Driver{ name: DefaultDriverName, version: driverVersion, nodeID: fakeNodeID, - cap: map[csi.VolumeCapability_AccessMode_Mode]bool{}, } } d.volumeLocks = NewVolumeLocks()