diff --git a/pkg/nfs/controllerserver.go b/pkg/nfs/controllerserver.go index 309657f9..76967bab 100644 --- a/pkg/nfs/controllerserver.go +++ b/pkg/nfs/controllerserver.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" "github.com/container-storage-interface/spec/lib/go/csi" @@ -359,15 +360,16 @@ func (cs *ControllerServer) getVolumeIDFromNfsVol(vol *nfsVolume) string { // Given a CSI volume id, return a nfsVolume func (cs *ControllerServer) getNfsVolFromID(id string) (*nfsVolume, error) { - tokens := strings.Split(id, "/") - if len(tokens) != totalIDElements { - return nil, fmt.Errorf("volume id %q unexpected format: got %v token(s) instead of %v", id, len(tokens), totalIDElements) + volRegex := regexp.MustCompile("^([^/]+)/(.*)/([^/]+)$") + tokens := volRegex.FindStringSubmatch(id) + if tokens == nil { + return nil, fmt.Errorf("Could not split %q into server, baseDir and subDir", id) } return &nfsVolume{ id: id, - server: tokens[idServer], - baseDir: tokens[idBaseDir], - subDir: tokens[idSubDir], + server: tokens[1], + baseDir: tokens[2], + subDir: tokens[3], }, nil } diff --git a/pkg/nfs/controllerserver_test.go b/pkg/nfs/controllerserver_test.go index 99601472..3be5fde2 100644 --- a/pkg/nfs/controllerserver_test.go +++ b/pkg/nfs/controllerserver_test.go @@ -20,6 +20,7 @@ import ( "os" "path/filepath" "reflect" + "strings" "testing" "fmt" @@ -32,10 +33,12 @@ import ( ) const ( - testServer = "test-server" - testBaseDir = "test-base-dir" - testCSIVolume = "test-csi" - testVolumeID = "test-server/test-base-dir/test-csi" + testServer = "test-server" + testBaseDir = "test-base-dir" + testBaseDirNested = "test/base/dir" + testCSIVolume = "test-csi" + testVolumeID = "test-server/test-base-dir/test-csi" + testVolumeIDNested = "test-server/test/base/dir/test-csi" ) // for Windows support in the future @@ -356,3 +359,69 @@ func TestControllerGetCapabilities(t *testing.T) { }) } } + +func TestNfsVolFromId(t *testing.T) { + cases := []struct { + name string + req string + resp *nfsVolume + expectErr bool + }{ + { + name: "ID only server", + req: testServer, + resp: nil, + expectErr: true, + }, + { + name: "ID missing subDir", + req: strings.Join([]string{testServer, testBaseDir}, "/"), + resp: nil, + expectErr: true, + }, + { + name: "valid request single baseDir", + req: testVolumeID, + resp: &nfsVolume{ + id: testVolumeID, + server: testServer, + baseDir: testBaseDir, + subDir: testCSIVolume, + }, + expectErr: false, + }, + { + name: "valid request nested baseDir", + req: testVolumeIDNested, + resp: &nfsVolume{ + id: testVolumeIDNested, + server: testServer, + baseDir: testBaseDirNested, + subDir: testCSIVolume, + }, + expectErr: false, + }, + } + + for _, test := range cases { + test := test //pin + t.Run(test.name, func(t *testing.T) { + // Setup + cs := initTestController(t) + + // Run + resp, err := cs.getNfsVolFromID(test.req) + + // Verify + if !test.expectErr && err != nil { + t.Errorf("test %q failed: %v", test.name, err) + } + if test.expectErr && err == nil { + t.Errorf("test %q failed; got success", test.name) + } + if !reflect.DeepEqual(resp, test.resp) { + t.Errorf("test %q failed: got resp %+v, expected %+v", test.name, resp, test.resp) + } + }) + } +}