diff --git a/cmd/gpu_plugin/gpu_plugin.go b/cmd/gpu_plugin/gpu_plugin.go index 5c2e36c8..cacc9ea8 100644 --- a/cmd/gpu_plugin/gpu_plugin.go +++ b/cmd/gpu_plugin/gpu_plugin.go @@ -51,11 +51,15 @@ const ( scanPeriod = 5 * time.Second ) +type cliOptions struct { + sharedDevNum int +} + type devicePlugin struct { sysfsDir string devfsDir string - sharedDevNum int + options cliOptions gpuDeviceReg *regexp.Regexp controlDeviceReg *regexp.Regexp @@ -64,11 +68,11 @@ type devicePlugin struct { scanDone chan bool } -func newDevicePlugin(sysfsDir, devfsDir string, sharedDevNum int) *devicePlugin { +func newDevicePlugin(sysfsDir, devfsDir string, options cliOptions) *devicePlugin { return &devicePlugin{ sysfsDir: sysfsDir, devfsDir: devfsDir, - sharedDevNum: sharedDevNum, + options: options, gpuDeviceReg: regexp.MustCompile(gpuDeviceRE), controlDeviceReg: regexp.MustCompile(controlDeviceRE), scanTicker: time.NewTicker(scanPeriod), @@ -168,7 +172,7 @@ func (dp *devicePlugin) scan() (dpapi.DeviceTree, error) { if len(nodes) > 0 { deviceInfo := dpapi.NewDeviceInfo(pluginapi.Healthy, nodes, nil, nil) - for i := 0; i < dp.sharedDevNum; i++ { + for i := 0; i < dp.options.sharedDevNum; i++ { devID := fmt.Sprintf("%s-%d", f.Name(), i) // Currently only one device type (i915) is supported. // TODO: check model ID to differentiate device models. @@ -186,19 +190,19 @@ func (dp *devicePlugin) scan() (dpapi.DeviceTree, error) { } func main() { - var sharedDevNum int + var opts cliOptions - flag.IntVar(&sharedDevNum, "shared-dev-num", 1, "number of containers sharing the same GPU device") + flag.IntVar(&opts.sharedDevNum, "shared-dev-num", 1, "number of containers sharing the same GPU device") flag.Parse() - if sharedDevNum < 1 { + if opts.sharedDevNum < 1 { klog.Warning("The number of containers sharing the same GPU must greater than zero") os.Exit(1) } klog.V(1).Info("GPU device plugin started") - plugin := newDevicePlugin(sysfsDrmDirectory, devfsDriDirectory, sharedDevNum) + plugin := newDevicePlugin(sysfsDrmDirectory, devfsDriDirectory, opts) manager := dpapi.NewManager(namespace, plugin) manager.Run() } diff --git a/cmd/gpu_plugin/gpu_plugin_test.go b/cmd/gpu_plugin/gpu_plugin_test.go index e9616dde..5e418bd1 100644 --- a/cmd/gpu_plugin/gpu_plugin_test.go +++ b/cmd/gpu_plugin/gpu_plugin_test.go @@ -155,6 +155,8 @@ func TestScan(t *testing.T) { }, } + opts := cliOptions{sharedDevNum: 1} + for _, tc := range tcases { t.Run(tc.name, func(t *testing.T) { root, err := ioutil.TempDir("", "test_new_device_plugin") @@ -169,7 +171,7 @@ func TestScan(t *testing.T) { t.Errorf("unexpected error: %+v", err) } - plugin := newDevicePlugin(sysfs, devfs, 1) + plugin := newDevicePlugin(sysfs, devfs, opts) notifier := &mockNotifier{ scanDone: plugin.scanDone,