diff --git a/cmd/nfsplugin/main.go b/cmd/nfsplugin/main.go index fb7177df..14bd1c2b 100644 --- a/cmd/nfsplugin/main.go +++ b/cmd/nfsplugin/main.go @@ -82,5 +82,5 @@ func handle() { } d := nfs.NewNFSdriver(nodeID, endpoint, parsedPerm) - d.Run() + d.Run(false) } diff --git a/pkg/nfs/fake_mounter_test.go b/pkg/nfs/fake_mounter_test.go new file mode 100644 index 00000000..4b6faf5a --- /dev/null +++ b/pkg/nfs/fake_mounter_test.go @@ -0,0 +1,159 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package nfs + +import ( + "fmt" + "reflect" + "testing" + + "k8s.io/utils/mount" +) + +func TestMount(t *testing.T) { + targetTest := "./target_test" + sourceTest := "./source_test" + + tests := []struct { + desc string + source string + target string + expectedErr error + }{ + { + desc: "[Error] Mocked source error", + source: "./error_mount_source", + target: targetTest, + expectedErr: fmt.Errorf("fake Mount: source error"), + }, + { + desc: "[Error] Mocked target error", + source: sourceTest, + target: "./error_mount_target", + expectedErr: fmt.Errorf("fake Mount: target error"), + }, + { + desc: "[Success] Successful run", + source: sourceTest, + target: targetTest, + expectedErr: nil, + }, + } + + d, err := getTestNodeServer() + if err != nil { + t.Errorf("failed to get test node server") + } + fakeMounter := &fakeMounter{} + d.mounter = &mount.SafeFormatAndMount{ + Interface: fakeMounter, + } + + for _, test := range tests { + err := d.mounter.Mount(test.source, test.target, "", nil) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + } +} + +func TestMountSensitive(t *testing.T) { + targetTest := "./target_test" + sourceTest := "./source_test" + + tests := []struct { + desc string + source string + target string + expectedErr error + }{ + { + desc: "[Error] Mocked source error", + source: "./error_mount_sens_source", + target: targetTest, + expectedErr: fmt.Errorf("fake MountSensitive: source error"), + }, + { + desc: "[Error] Mocked target error", + source: sourceTest, + target: "./error_mount_sens_target", + expectedErr: fmt.Errorf("fake MountSensitive: target error"), + }, + { + desc: "[Success] Successful run", + source: sourceTest, + target: targetTest, + expectedErr: nil, + }, + } + + d, err := getTestNodeServer() + if err != nil { + t.Errorf("failed to get test node server") + } + fakeMounter := &fakeMounter{} + d.mounter = &mount.SafeFormatAndMount{ + Interface: fakeMounter, + } + + for _, test := range tests { + err := d.mounter.MountSensitive(test.source, test.target, "", nil, nil) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + } +} + +func TestIsLikelyNotMountPoint(t *testing.T) { + targetTest := "./target_test" + tests := []struct { + desc string + file string + expectedErr error + }{ + { + desc: "[Error] Mocked file error", + file: "./error_is_likely_target", + expectedErr: fmt.Errorf("fake IsLikelyNotMountPoint: fake error"), + }, + {desc: "[Success] Successful run", + file: targetTest, + expectedErr: nil, + }, + { + desc: "[Success] Successful run not a mount", + file: "./false_is_likely_target", + expectedErr: nil, + }, + } + + d, err := getTestNodeServer() + if err != nil { + t.Errorf("failed to get test node server") + } + fakeMounter := &fakeMounter{} + d.mounter = &mount.SafeFormatAndMount{ + Interface: fakeMounter, + } + + for _, test := range tests { + _, err := d.mounter.IsLikelyNotMountPoint(test.file) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + } +} diff --git a/pkg/nfs/nfs.go b/pkg/nfs/nfs.go index 9d71b888..65ec0ed5 100644 --- a/pkg/nfs/nfs.go +++ b/pkg/nfs/nfs.go @@ -17,6 +17,8 @@ limitations under the License. package nfs import ( + "fmt" + "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/glog" "k8s.io/utils/mount" @@ -89,7 +91,7 @@ func NewNodeServer(n *Driver, mounter mount.Interface) *NodeServer { } } -func (n *Driver) Run() { +func (n *Driver) Run(testMode bool) { n.ns = NewNodeServer(n, mount.New("")) s := NewNonBlockingGRPCServer() s.Start(n.endpoint, @@ -97,7 +99,8 @@ func (n *Driver) Run() { // NFS plugin has not implemented ControllerServer // using default controllerserver. NewControllerServer(n), - n.ns) + n.ns, + testMode) s.Wait() } @@ -121,3 +124,9 @@ func (n *Driver) AddControllerServiceCapabilities(cl []csi.ControllerServiceCapa n.cscap = csc } + +func IsCorruptedDir(dir string) bool { + _, pathErr := mount.PathExists(dir) + fmt.Printf("IsCorruptedDir(%s) returned with error: %v", dir, pathErr) + return pathErr != nil && mount.IsCorruptedMnt(pathErr) +} diff --git a/pkg/nfs/nfs_test.go b/pkg/nfs/nfs_test.go index 440736bc..243724d0 100644 --- a/pkg/nfs/nfs_test.go +++ b/pkg/nfs/nfs_test.go @@ -16,7 +16,15 @@ limitations under the License. package nfs -import "github.com/container-storage-interface/spec/lib/go/csi" +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/stretchr/testify/assert" +) const ( fakeNodeID = "fakeNodeID" @@ -54,3 +62,75 @@ func NewEmptyDriver(emptyField string) *Driver { return d } + +func TestNewFakeDriver(t *testing.T) { + d := NewEmptyDriver("version") + assert.Empty(t, d.version) + + d = NewEmptyDriver("name") + assert.Empty(t, d.name) +} + +func TestIsCorruptedDir(t *testing.T) { + existingMountPath, err := ioutil.TempDir(os.TempDir(), "csi-mount-test") + if err != nil { + t.Fatalf("failed to create tmp dir: %v", err) + } + defer os.RemoveAll(existingMountPath) + + curruptedPath := filepath.Join(existingMountPath, "curruptedPath") + if err := os.Symlink(existingMountPath, curruptedPath); err != nil { + t.Fatalf("failed to create curruptedPath: %v", err) + } + + tests := []struct { + desc string + dir string + expectedResult bool + }{ + { + desc: "NotExist dir", + dir: "/tmp/NotExist", + expectedResult: false, + }, + { + desc: "Existing dir", + dir: existingMountPath, + expectedResult: false, + }, + } + + for i, test := range tests { + isCorruptedDir := IsCorruptedDir(test.dir) + assert.Equal(t, test.expectedResult, isCorruptedDir, "TestCase[%d]: %s", i, test.desc) + } +} + +func TestRun(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "Successful run", + testFunc: func(t *testing.T) { + d := NewEmptyDriver("") + d.endpoint = "tcp://127.0.0.1:0" + d.Run(true) + }, + }, + { + name: "Successful run with node ID missing", + testFunc: func(t *testing.T) { + d := NewEmptyDriver("") + d.endpoint = "tcp://127.0.0.1:0" + d.nodeID = "" + d.Run(true) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} diff --git a/pkg/nfs/server.go b/pkg/nfs/server.go index eccd9328..5dede780 100644 --- a/pkg/nfs/server.go +++ b/pkg/nfs/server.go @@ -20,6 +20,7 @@ import ( "net" "os" "sync" + "time" "github.com/golang/glog" "google.golang.org/grpc" @@ -30,7 +31,7 @@ import ( // Defines Non blocking GRPC server interfaces type NonBlockingGRPCServer interface { // Start services at the endpoint - Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) + Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) // Waits for the service to stop Wait() // Stops the service gracefully @@ -49,11 +50,11 @@ type nonBlockingGRPCServer struct { server *grpc.Server } -func (s *nonBlockingGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { +func (s *nonBlockingGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) { s.wg.Add(1) - go s.serve(endpoint, ids, cs, ns) + go s.serve(endpoint, ids, cs, ns, testMode) } func (s *nonBlockingGRPCServer) Wait() { @@ -68,7 +69,7 @@ func (s *nonBlockingGRPCServer) ForceStop() { s.server.Stop() } -func (s *nonBlockingGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { +func (s *nonBlockingGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) { proto, addr, err := ParseEndpoint(endpoint) if err != nil { @@ -103,6 +104,17 @@ func (s *nonBlockingGRPCServer) serve(endpoint string, ids csi.IdentityServer, c csi.RegisterNodeServer(server, ns) } + // Used to stop the server while running tests + if testMode { + s.wg.Done() + go func() { + // make sure Serve() is called + s.wg.Wait() + time.Sleep(time.Millisecond * 1000) + s.server.GracefulStop() + }() + } + glog.Infof("Listening for connections on address: %#v", listener.Addr()) err = server.Serve(listener) diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index 49522fc3..ea837c90 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -89,7 +89,7 @@ var _ = ginkgo.BeforeSuite(func() { execTestCmd([]testCmd{installNFSServer, e2eBootstrap}) go func() { - nfsDriver.Run() + nfsDriver.Run(false) }() })