Add interface ContainerPreStarter optionally implemented by device plugins

This commit is contained in:
Dmitry Rozhkov 2020-01-28 13:03:10 +02:00
parent d3f6401335
commit 3ff989e4b0
3 changed files with 43 additions and 19 deletions

View File

@ -94,3 +94,10 @@ type PostAllocator interface {
// adding annotations consumed by CRI hooks to the responses. // adding annotations consumed by CRI hooks to the responses.
PostAllocate(*pluginapi.AllocateResponse) error PostAllocate(*pluginapi.AllocateResponse) error
} }
// ContainerPreStarter is an optional interface implemented by device plugins.
type ContainerPreStarter interface {
// PreStartContainer defines device initialization function before container is started.
// It might include operations like card reset.
PreStartContainer(*pluginapi.PreStartContainerRequest) error
}

View File

@ -73,7 +73,7 @@ type Manager struct {
devicePlugin Scanner devicePlugin Scanner
namespace string namespace string
servers map[string]devicePluginServer servers map[string]devicePluginServer
createServer func(string, func(*pluginapi.AllocateResponse) error) devicePluginServer createServer func(string, func(*pluginapi.AllocateResponse) error, func(*pluginapi.PreStartContainerRequest) error) devicePluginServer
} }
// NewManager creates a new instance of Manager // NewManager creates a new instance of Manager
@ -108,12 +108,17 @@ func (m *Manager) handleUpdate(update updateInfo) {
klog.V(4).Info("Received dev updates:", update) klog.V(4).Info("Received dev updates:", update)
for devType, devices := range update.Added { for devType, devices := range update.Added {
var postAllocate func(*pluginapi.AllocateResponse) error var postAllocate func(*pluginapi.AllocateResponse) error
var preStartContainer func(*pluginapi.PreStartContainerRequest) error
if postAllocator, ok := m.devicePlugin.(PostAllocator); ok { if postAllocator, ok := m.devicePlugin.(PostAllocator); ok {
postAllocate = postAllocator.PostAllocate postAllocate = postAllocator.PostAllocate
} }
m.servers[devType] = m.createServer(devType, postAllocate) if containerPreStarter, ok := m.devicePlugin.(ContainerPreStarter); ok {
preStartContainer = containerPreStarter.PreStartContainer
}
m.servers[devType] = m.createServer(devType, postAllocate, preStartContainer)
go func(dt string) { go func(dt string) {
err := m.servers[dt].Serve(m.namespace) err := m.servers[dt].Serve(m.namespace)
if err != nil { if err != nil {

View File

@ -51,29 +51,32 @@ type devicePluginServer interface {
// server implements devicePluginServer and pluginapi.PluginInterfaceServer interfaces. // server implements devicePluginServer and pluginapi.PluginInterfaceServer interfaces.
type server struct { type server struct {
devType string devType string
grpcServer *grpc.Server grpcServer *grpc.Server
updatesCh chan map[string]DeviceInfo updatesCh chan map[string]DeviceInfo
devices map[string]DeviceInfo devices map[string]DeviceInfo
postAllocate func(*pluginapi.AllocateResponse) error postAllocate func(*pluginapi.AllocateResponse) error
state serverState preStartContainer func(*pluginapi.PreStartContainerRequest) error
stateMutex sync.Mutex state serverState
stateMutex sync.Mutex
} }
// newServer creates a new server satisfying the devicePluginServer interface. // newServer creates a new server satisfying the devicePluginServer interface.
func newServer(devType string, postAllocate func(*pluginapi.AllocateResponse) error) devicePluginServer { func newServer(devType string,
postAllocate func(*pluginapi.AllocateResponse) error,
preStartContainer func(*pluginapi.PreStartContainerRequest) error) devicePluginServer {
return &server{ return &server{
devType: devType, devType: devType,
updatesCh: make(chan map[string]DeviceInfo, 1), // TODO: is 1 needed? updatesCh: make(chan map[string]DeviceInfo, 1), // TODO: is 1 needed?
devices: make(map[string]DeviceInfo), devices: make(map[string]DeviceInfo),
postAllocate: postAllocate, postAllocate: postAllocate,
state: uninitialized, preStartContainer: preStartContainer,
state: uninitialized,
} }
} }
func (srv *server) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { func (srv *server) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
klog.V(4).Info("GetDevicePluginOptions: return empty options") return &pluginapi.DevicePluginOptions{PreStartRequired: srv.preStartContainer != nil}, nil
return new(pluginapi.DevicePluginOptions), nil
} }
func (srv *server) sendDevices(stream pluginapi.DevicePlugin_ListAndWatchServer) error { func (srv *server) sendDevices(stream pluginapi.DevicePlugin_ListAndWatchServer) error {
@ -148,6 +151,10 @@ func (srv *server) Allocate(ctx context.Context, rqt *pluginapi.AllocateRequest)
} }
func (srv *server) PreStartContainer(ctx context.Context, rqt *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { func (srv *server) PreStartContainer(ctx context.Context, rqt *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
if srv.preStartContainer != nil {
return new(pluginapi.PreStartContainerResponse), srv.preStartContainer(rqt)
}
return nil, errors.New("PreStartContainer() should not be called") return nil, errors.New("PreStartContainer() should not be called")
} }
@ -218,8 +225,12 @@ func (srv *server) setupAndServe(namespace string, devicePluginPath string, kube
return err return err
} }
options := &pluginapi.DevicePluginOptions{
PreStartRequired: srv.preStartContainer != nil,
}
// Register with Kubelet. // Register with Kubelet.
err = registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName) err = registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName, options)
if err != nil { if err != nil {
return err return err
} }
@ -266,7 +277,7 @@ func watchFile(file string) error {
} }
} }
func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) error { func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string, options *pluginapi.DevicePluginOptions) error {
conn, err := grpc.Dial(kubeletSocket, grpc.WithInsecure(), conn, err := grpc.Dial(kubeletSocket, grpc.WithInsecure(),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout) return net.DialTimeout("unix", addr, timeout)
@ -280,6 +291,7 @@ func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) err
Version: pluginapi.Version, Version: pluginapi.Version,
Endpoint: pluginEndPoint, Endpoint: pluginEndPoint,
ResourceName: resourceName, ResourceName: resourceName,
Options: options,
} }
_, err = client.Register(context.Background(), reqt) _, err = client.Register(context.Background(), reqt)