diff --git a/pkg/deviceplugin/api.go b/pkg/deviceplugin/api.go index 123416ed..41c79176 100644 --- a/pkg/deviceplugin/api.go +++ b/pkg/deviceplugin/api.go @@ -94,3 +94,10 @@ type PostAllocator interface { // adding annotations consumed by CRI hooks to the responses. PostAllocate(*pluginapi.AllocateResponse) error } + +// ContainerPreStarter is an optional interface implemented by device plugins. +type ContainerPreStarter interface { + // PreStartContainer defines device initialization function before container is started. + // It might include operations like card reset. + PreStartContainer(*pluginapi.PreStartContainerRequest) error +} diff --git a/pkg/deviceplugin/manager.go b/pkg/deviceplugin/manager.go index 28abfcad..3dd8118e 100644 --- a/pkg/deviceplugin/manager.go +++ b/pkg/deviceplugin/manager.go @@ -73,7 +73,7 @@ type Manager struct { devicePlugin Scanner namespace string servers map[string]devicePluginServer - createServer func(string, func(*pluginapi.AllocateResponse) error) devicePluginServer + createServer func(string, func(*pluginapi.AllocateResponse) error, func(*pluginapi.PreStartContainerRequest) error) devicePluginServer } // NewManager creates a new instance of Manager @@ -108,12 +108,17 @@ func (m *Manager) handleUpdate(update updateInfo) { klog.V(4).Info("Received dev updates:", update) for devType, devices := range update.Added { var postAllocate func(*pluginapi.AllocateResponse) error + var preStartContainer func(*pluginapi.PreStartContainerRequest) error if postAllocator, ok := m.devicePlugin.(PostAllocator); ok { postAllocate = postAllocator.PostAllocate } - m.servers[devType] = m.createServer(devType, postAllocate) + if containerPreStarter, ok := m.devicePlugin.(ContainerPreStarter); ok { + preStartContainer = containerPreStarter.PreStartContainer + } + + m.servers[devType] = m.createServer(devType, postAllocate, preStartContainer) go func(dt string) { err := m.servers[dt].Serve(m.namespace) if err != nil { diff --git a/pkg/deviceplugin/server.go b/pkg/deviceplugin/server.go index 793b5d2e..495ee0ee 100644 --- a/pkg/deviceplugin/server.go +++ b/pkg/deviceplugin/server.go @@ -51,29 +51,32 @@ type devicePluginServer interface { // server implements devicePluginServer and pluginapi.PluginInterfaceServer interfaces. type server struct { - devType string - grpcServer *grpc.Server - updatesCh chan map[string]DeviceInfo - devices map[string]DeviceInfo - postAllocate func(*pluginapi.AllocateResponse) error - state serverState - stateMutex sync.Mutex + devType string + grpcServer *grpc.Server + updatesCh chan map[string]DeviceInfo + devices map[string]DeviceInfo + postAllocate func(*pluginapi.AllocateResponse) error + preStartContainer func(*pluginapi.PreStartContainerRequest) error + state serverState + stateMutex sync.Mutex } // newServer creates a new server satisfying the devicePluginServer interface. -func newServer(devType string, postAllocate func(*pluginapi.AllocateResponse) error) devicePluginServer { +func newServer(devType string, + postAllocate func(*pluginapi.AllocateResponse) error, + preStartContainer func(*pluginapi.PreStartContainerRequest) error) devicePluginServer { return &server{ - devType: devType, - updatesCh: make(chan map[string]DeviceInfo, 1), // TODO: is 1 needed? - devices: make(map[string]DeviceInfo), - postAllocate: postAllocate, - state: uninitialized, + devType: devType, + updatesCh: make(chan map[string]DeviceInfo, 1), // TODO: is 1 needed? + devices: make(map[string]DeviceInfo), + postAllocate: postAllocate, + preStartContainer: preStartContainer, + state: uninitialized, } } func (srv *server) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { - klog.V(4).Info("GetDevicePluginOptions: return empty options") - return new(pluginapi.DevicePluginOptions), nil + return &pluginapi.DevicePluginOptions{PreStartRequired: srv.preStartContainer != nil}, nil } func (srv *server) sendDevices(stream pluginapi.DevicePlugin_ListAndWatchServer) error { @@ -148,6 +151,10 @@ func (srv *server) Allocate(ctx context.Context, rqt *pluginapi.AllocateRequest) } func (srv *server) PreStartContainer(ctx context.Context, rqt *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { + if srv.preStartContainer != nil { + return new(pluginapi.PreStartContainerResponse), srv.preStartContainer(rqt) + } + return nil, errors.New("PreStartContainer() should not be called") } @@ -218,8 +225,12 @@ func (srv *server) setupAndServe(namespace string, devicePluginPath string, kube return err } + options := &pluginapi.DevicePluginOptions{ + PreStartRequired: srv.preStartContainer != nil, + } + // Register with Kubelet. - err = registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName) + err = registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName, options) if err != nil { return err } @@ -266,7 +277,7 @@ func watchFile(file string) error { } } -func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) error { +func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string, options *pluginapi.DevicePluginOptions) error { conn, err := grpc.Dial(kubeletSocket, grpc.WithInsecure(), grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("unix", addr, timeout) @@ -280,6 +291,7 @@ func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) err Version: pluginapi.Version, Endpoint: pluginEndPoint, ResourceName: resourceName, + Options: options, } _, err = client.Register(context.Background(), reqt)