diff --git a/.style.yapf b/.style.yapf index e09c22391..7301910e8 100644 --- a/.style.yapf +++ b/.style.yapf @@ -11,4 +11,9 @@ indent_dictionary_value = True allow_multiline_dictionary_keys = True each_dict_entry_on_separate_line = False allow_multiline_lambdas = True -blank_line_before_nested_class_or_def = False \ No newline at end of file +blank_line_before_nested_class_or_def = False +arithmetic_precedence_indication = True +no_spaces_around_selected_binary_operators = "*,/" +coalesce_brackets = True +space_between_ending_comma_and_closing_bracket = False +split_before_expression_after_opening_paren = False \ No newline at end of file diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 6346084a0..5ec310b58 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -158,7 +158,7 @@ class ChatGPTAPI: self.inference_engine_classname = inference_engine_classname self.response_timeout_secs = response_timeout_secs self.on_chat_completion_request = on_chat_completion_request - self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload + self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload self.prompts: PrefixDict[str, PromptSession] = PrefixDict() self.prev_token_lens: Dict[str, int] = {} self.stream_tasks: Dict[str, asyncio.Task] = {} @@ -171,7 +171,7 @@ class ChatGPTAPI: ) cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) - self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat" + self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat" self.app.router.add_get("/", self.handle_root) self.app.router.add_static("/", self.static_dir, name="static") @@ -186,7 +186,7 @@ class ChatGPTAPI: return middleware async def handle_root(self, request): - return web.FileResponse(self.static_dir / "index.html") + return web.FileResponse(self.static_dir/"index.html") async def handle_post_chat_token_encode(self, request): data = await request.json() diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index c781665af..f07a09f9a 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -62,12 +62,12 @@ def _add_wildcard_to_directories(pattern: str) -> str: def get_hf_home() -> Path: """Get the Hugging Face home directory.""" - return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) + return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) async def get_hf_token(): """Retrieve the Hugging Face token from the user's HF_HOME directory.""" - token_path = get_hf_home() / "token" + token_path = get_hf_home()/"token" if await aios.path.exists(token_path): async with aiofiles.open(token_path, 'r') as f: return (await f.read()).strip() @@ -85,7 +85,7 @@ async def get_auth_headers(): def get_repo_root(repo_id: str) -> Path: """Get the root directory for a given repo ID in the Hugging Face cache.""" sanitized_repo_id = repo_id.replace("/", "--") - return get_hf_home() / "hub" / f"models--{sanitized_repo_id}" + return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" async def fetch_file_list(session, repo_id, revision, path=""): @@ -181,9 +181,9 @@ async def download_file( downloaded_this_session += len(chunk) if progress_callback and total_size: elapsed_time = (datetime.now() - start_time).total_seconds() - speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0 + speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0 remaining_size = total_size - downloaded_size - eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0) + eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0) status = "in_progress" if downloaded_size < total_size else "complete" if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}") await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status)) @@ -199,9 +199,9 @@ async def download_repo_files( max_parallel_downloads: int = 4 ) -> Path: repo_root = get_repo_root(repo_id) - refs_dir = repo_root / "refs" - snapshots_dir = repo_root / "snapshots" - cachedreqs_dir = repo_root / "cachedreqs" + refs_dir = repo_root/"refs" + snapshots_dir = repo_root/"snapshots" + cachedreqs_dir = repo_root/"cachedreqs" # Ensure directories exist await aios.makedirs(refs_dir, exist_ok=True) @@ -209,7 +209,7 @@ async def download_repo_files( await aios.makedirs(cachedreqs_dir, exist_ok=True) # Check if we have a cached commit hash - refs_file = refs_dir / revision + refs_file = refs_dir/revision if await aios.path.exists(refs_file): async with aiofiles.open(refs_file, 'r') as f: commit_hash = (await f.read()).strip() @@ -230,13 +230,13 @@ async def download_repo_files( await f.write(commit_hash) # Set up the snapshot directory - snapshot_dir = snapshots_dir / commit_hash + snapshot_dir = snapshots_dir/commit_hash await aios.makedirs(snapshot_dir, exist_ok=True) # Set up the cached file list directory - cached_file_list_dir = cachedreqs_dir / commit_hash + cached_file_list_dir = cachedreqs_dir/commit_hash await aios.makedirs(cached_file_list_dir, exist_ok=True) - cached_file_list_path = cached_file_list_dir / "fetch_file_list.json" + cached_file_list_path = cached_file_list_dir/"fetch_file_list.json" async with aiohttp.ClientSession() as session: # Check if we have a cached file list @@ -261,7 +261,7 @@ async def download_repo_files( start_time = datetime.now() async def download_with_progress(file_info, progress_state): - local_path = snapshot_dir / file_info["path"] + local_path = snapshot_dir/file_info["path"] if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]: if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}") progress_state['completed_files'] += 1 @@ -269,9 +269,9 @@ async def download_repo_files( file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete") if progress_callback: elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0 + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0) + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) status = "in_progress" if progress_state['completed_files'] < total_files else "complete" await progress_callback( RepoProgressEvent( @@ -287,9 +287,9 @@ async def download_repo_files( file_progress[event.file_path] = event if progress_callback: elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0 + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0) + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete" await progress_callback( RepoProgressEvent( @@ -305,9 +305,9 @@ async def download_repo_files( ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete") if progress_callback: elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0 + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0) + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) status = "in_progress" if progress_state['completed_files'] < total_files else "complete" await progress_callback( RepoProgressEvent( @@ -347,11 +347,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[ # Check if the file exists repo_root = get_repo_root(repo_id) - snapshot_dir = repo_root / "snapshots" + snapshot_dir = repo_root/"snapshots" index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) if index_file: - index_file_path = snapshot_dir / index_file + index_file_path = snapshot_dir/index_file if await aios.path.exists(index_file_path): async with aiofiles.open(index_file_path, 'r') as f: index_data = json.loads(await f.read()) diff --git a/exo/download/hf/hf_shard_download.py b/exo/download/hf/hf_shard_download.py index b87536542..71cf6fe00 100644 --- a/exo/download/hf/hf_shard_download.py +++ b/exo/download/hf/hf_shard_download.py @@ -22,7 +22,7 @@ class HFShardDownloader(ShardDownloader): return self.completed_downloads[shard] if self.quick_check: repo_root = get_repo_root(shard.model_id) - snapshots_dir = repo_root / "snapshots" + snapshots_dir = repo_root/"snapshots" if snapshots_dir.exists(): most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime) return most_recent_dir diff --git a/exo/helpers.py b/exo/helpers.py index 2190824fe..f56b21a3a 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -169,7 +169,7 @@ def is_valid_uuid(val): def get_or_create_node_id(): - NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id" + NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id" try: if NODE_ID_FILE.is_file(): with open(NODE_ID_FILE, "r") as f: diff --git a/exo/inference/debug_inference_engine.py b/exo/inference/debug_inference_engine.py index b14c5acd4..27bcb592f 100644 --- a/exo/inference/debug_inference_engine.py +++ b/exo/inference/debug_inference_engine.py @@ -10,7 +10,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e from exo.inference.tinygrad.inference import Tokenizer from pathlib import Path - _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model")) + _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model")) prompt = "In a single word only, what is the last name of the president of the United States? " resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) diff --git a/exo/inference/mlx/models/deepseek_v2.py b/exo/inference/mlx/models/deepseek_v2.py index 3488cb026..9ea271edf 100644 --- a/exo/inference/mlx/models/deepseek_v2.py +++ b/exo/inference/mlx/models/deepseek_v2.py @@ -59,7 +59,7 @@ class DeepseekV2Model(nn.Module): mask = mask.astype(h.dtype) if cache is None: - cache = [None] * len(self.layers) + cache = [None]*len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) diff --git a/exo/inference/mlx/models/llama.py b/exo/inference/mlx/models/llama.py index afa7aa1ec..719d6a886 100644 --- a/exo/inference/mlx/models/llama.py +++ b/exo/inference/mlx/models/llama.py @@ -58,7 +58,7 @@ class LlamaModel(nn.Module): mask = create_attention_mask(h, cache) if cache is None: - cache = [None] * len(self.layers) + cache = [None]*len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, mask, cache=c) diff --git a/exo/inference/mlx/models/llava.py b/exo/inference/mlx/models/llava.py index c873fd135..b734b09b4 100644 --- a/exo/inference/mlx/models/llava.py +++ b/exo/inference/mlx/models/llava.py @@ -74,8 +74,8 @@ class VisionAttention(nn.Module): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys + scale = math.sqrt(1/queries.shape[-1]) + scores = (queries*scale) @ keys if mask is not None: scores = scores + mask.astype(scores.dtype) scores = mx.softmax(scores, axis=-1) @@ -129,7 +129,7 @@ class VisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = mx.zeros((config.hidden_size, )) + self.class_embedding = mx.zeros((config.hidden_size,)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, @@ -170,12 +170,12 @@ class ClipVisionModel(nn.Module): x = self.embeddings(x) x = self.pre_layrnorm(x) - encoder_states = (x, ) if output_hidden_states else None + encoder_states = (x,) if output_hidden_states else None for l in self.encoder.layers: x = l(x, mask=None) if output_hidden_states: - encoder_states = encoder_states + (x, ) + encoder_states = encoder_states + (x,) pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states @@ -263,12 +263,12 @@ class TextAttention(nn.Module): head_dim = config.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False) - rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1) + rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1) self.rope = nn.RoPE( head_dim, traditional=config.rope_traditional, @@ -312,7 +312,7 @@ class TextMLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x)) class TransformerBlock(nn.Module): @@ -382,7 +382,7 @@ class Llama(nn.Module): mask = mask.astype(h.dtype) if cache is None: - cache = [None] * len(self.layers) + cache = [None]*len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) diff --git a/exo/inference/mlx/sharded_model.py b/exo/inference/mlx/sharded_model.py index 3c25a09ba..c4570fbf6 100644 --- a/exo/inference/mlx/sharded_model.py +++ b/exo/inference/mlx/sharded_model.py @@ -38,7 +38,7 @@ class StatefulShardedModel: if top_p > 0 and top_p < 1.0: token = top_p_sampling(logits, top_p, temp) else: - token = mx.random.categorical(logits * (1 / temp)) + token = mx.random.categorical(logits*(1/temp)) return token @@ -74,7 +74,7 @@ class StatefulShardedModel: return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias) def init_cache(self, request_id: str): - kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads) + kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads) if self.max_kv_size is not None: cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads] else: diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 49b1a792c..7fa38eaa6 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -60,7 +60,7 @@ def _get_classes(config: dict): def load_config(model_path: Path) -> dict: try: - with open(model_path / "config.json", "r") as f: + with open(model_path/"config.json", "r") as f: config = json.load(f) except FileNotFoundError: logging.error(f"Config file not found in {model_path}") @@ -103,11 +103,11 @@ def load_model_shard( "n_layers": shard.n_layers, } - weight_files = glob.glob(str(model_path / "model*.safetensors")) + weight_files = glob.glob(str(model_path/"model*.safetensors")) if not weight_files: # Try weight for back-compat - weight_files = glob.glob(str(model_path / "weight*.safetensors")) + weight_files = glob.glob(str(model_path/"weight*.safetensors")) if not weight_files: logging.error(f"No safetensors found in {model_path}") diff --git a/exo/inference/mlx/test_sharded_model.py b/exo/inference/mlx/test_sharded_model.py index 5c9b3da81..c9743d078 100644 --- a/exo/inference/mlx/test_sharded_model.py +++ b/exo/inference/mlx/test_sharded_model.py @@ -38,7 +38,7 @@ model.save_weights("./test_weights.npz") n_layers = 5 shard1 = Shard("test", 0, n_layers // 2, n_layers) sharded_model1 = DummyModel(shard1) -shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers) +shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers) sharded_model2 = DummyModel(shard2) model.load_weights("./test_weights.npz") diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index caeea0cd1..c76199e61 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -33,9 +33,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No # load weights if model_path.is_dir(): - if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard) - elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard) - else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device) + if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard) + elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard) + else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device) else: weights = load(str(model_path), shard) weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"]) @@ -60,7 +60,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): toks = self.tokenizer.encode(prompt) h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize() - if h.shape == (1, ): + if h.shape == (1,): start_pos += len(toks) start_pos += 1 n_captured_toks = 0 @@ -76,7 +76,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize() - if h.shape == (1, ): + if h.shape == (1,): start_pos += n_captured_toks start_pos += 1 n_captured_toks = 0 diff --git a/exo/inference/tinygrad/models/llama.py b/exo/inference/tinygrad/models/llama.py index a304bfd24..ef876c317 100644 --- a/exo/inference/tinygrad/models/llama.py +++ b/exo/inference/tinygrad/models/llama.py @@ -5,8 +5,8 @@ from tinygrad.helpers import getenv # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor: - freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) - freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) + freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim)) + freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0) # TODO: move dtype outside this return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2) @@ -14,8 +14,8 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtype # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc) def complex_mult(A, c, d): a, b = A[..., 0:1], A[..., 1:2] - ro = a * c - b * d - co = a * d + b * c + ro = a*c - b*d + co = a*d + b*c return ro.cat(co, dim=-1) @@ -34,7 +34,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: bs, seqlen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x # NOTE: this is different from x.repeat((1, 1, n_rep, 1)) - return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) + return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim) class Attention: @@ -45,10 +45,10 @@ class Attention: self.n_rep = self.n_heads // self.n_kv_heads self.max_context = max_context - self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) - self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) + self.wq = linear(dim, self.n_heads*self.head_dim, bias=False) + self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False) + self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False) + self.wo = linear(self.n_heads*self.head_dim, dim, bias=False) def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor: if getenv("WQKV"): @@ -93,7 +93,7 @@ class FeedForward: self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit def __call__(self, x: Tensor) -> Tensor: - return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)] + return self.w2(self.w1(x).silu()*self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)] class TransformerBlock: @@ -121,29 +121,29 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): if af or ap: if not hasattr(sample, "alpha_counter"): setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous()) - logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap) + logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap) # replace NaNs with -inf logits = (logits != logits).where(-float("inf"), logits) # softmax - t = (logits / temp).softmax() + t = (logits/temp).softmax() counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() # top k if k: output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous() for i in range(k): - t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int) - output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), )) - output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), )) + t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int) + output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),)) + output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),)) t = (counter == t_argmax).where(0, t) # approximate top p # because we are already limited to top k elements we can do top p "without sorting" output_cumsum = output[::-1]._cumsum()[::-1] + t.sum() - output = (output_cumsum >= (1 - p)) * output - output_indices = (output_cumsum >= (1 - p)) * output_indices + output = (output_cumsum >= (1 - p))*output + output_indices = (output_cumsum >= (1 - p))*output_indices # sample output_idx = output.multinomial() @@ -183,7 +183,7 @@ class Transformer: self.tok_embeddings = nn.Embedding(vocab_size, dim) self.output = nn.Linear(dim, vocab_size, bias=False) self.max_context = max_context - self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous() + self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous() self.forward_jit = TinyJit(self.forward) if jit else None self.shard = shard diff --git a/exo/inference/tinygrad/tinygrad_helpers.py b/exo/inference/tinygrad/tinygrad_helpers.py index 1e7f0fcfb..d3aa234e1 100644 --- a/exo/inference/tinygrad/tinygrad_helpers.py +++ b/exo/inference/tinygrad/tinygrad_helpers.py @@ -37,7 +37,7 @@ def load(fn: str, shard: Shard): if layer_num < shard.start_layer or layer_num > shard.end_layer: continue - parts[n] = load(str(Path(fn).parent / Path(n).name), shard) + parts[n] = load(str(Path(fn).parent/Path(n).name), shard) filtered_weight_map[k] = n if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}") return {k: parts[n][k] for k, n in filtered_weight_map.items()} diff --git a/exo/models.py b/exo/models.py index 0bf167408..d355e88de 100644 --- a/exo/models.py +++ b/exo/models.py @@ -2,32 +2,28 @@ from exo.inference.shard import Shard model_base_shards = { ### llama - "llama-3.1-8b": - { - "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), - "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), - }, - "llama-3.1-70b": - { - "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), - "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80), - }, - "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), }, - "llama-3-8b": - { - "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), - "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32), - }, - "llama-3-70b": - { - "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), - "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), - }, + "llama-3.1-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3.1-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),}, + "llama-3-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), + }, ### mistral - "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), }, - "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), }, + "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, + "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, ### deepseek v2 - "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), }, + "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),}, ### llava - "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), }, + "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, } diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index e3e98dceb..0629dc777 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -27,7 +27,7 @@ class GRPCPeerHandle(PeerHandle): return self._device_capabilities async def connect(self): - self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)]) + self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)]) self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel) async def is_connected(self) -> bool: diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index d2f100299..1481ef512 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -21,9 +21,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): self.server = grpc.aio.server( futures.ThreadPoolExecutor(max_workers=10), options=[ - ("grpc.max_metadata_size", 32 * 1024 * 1024), - ("grpc.max_send_message_length", 128 * 1024 * 1024), - ("grpc.max_receive_message_length", 128 * 1024 * 1024), + ("grpc.max_metadata_size", 32*1024*1024), + ("grpc.max_send_message_length", 128*1024*1024), + ("grpc.max_receive_message_length", 128*1024*1024), ], ) node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server) diff --git a/exo/networking/grpc/node_service_pb2_grpc.py b/exo/networking/grpc/node_service_pb2_grpc.py index bf505418c..ea1d3c98f 100644 --- a/exo/networking/grpc/node_service_pb2_grpc.py +++ b/exo/networking/grpc/node_service_pb2_grpc.py @@ -150,7 +150,7 @@ def add_NodeServiceServicer_to_server(servicer, server): ), } generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler, )) + server.add_generic_rpc_handlers((generic_handler,)) server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 37c0b9ad4..6d1f427f5 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -84,19 +84,17 @@ class StandardNode(Node): asyncio.create_task( self.broadcast_opaque_status( request_id, - json.dumps( - { - "type": "node_status", - "node_id": self.id, - "status": "start_process_prompt", - "base_shard": base_shard.to_dict(), - "shard": shard.to_dict(), - "prompt": prompt, - "image_str": image_str, - "inference_state": inference_state, - "request_id": request_id, - } - ), + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + }), ) ) start_time = time.perf_counter_ns() @@ -106,21 +104,19 @@ class StandardNode(Node): asyncio.create_task( self.broadcast_opaque_status( request_id, - json.dumps( - { - "type": "node_status", - "node_id": self.id, - "status": "end_process_prompt", - "base_shard": base_shard.to_dict(), - "shard": shard.to_dict(), - "prompt": prompt, - "image_str": image_str, - "inference_state": inference_state, - "request_id": request_id, - "elapsed_time_ns": elapsed_time_ns, - "result_size": resp.size if resp is not None else 0, - } - ), + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), ) ) return resp @@ -166,19 +162,17 @@ class StandardNode(Node): asyncio.create_task( self.broadcast_opaque_status( request_id, - json.dumps( - { - "type": "node_status", - "node_id": self.id, - "status": "start_process_tensor", - "base_shard": base_shard.to_dict(), - "shard": shard.to_dict(), - "tensor_size": tensor.size, - "tensor_shape": tensor.shape, - "request_id": request_id, - "inference_state": inference_state, - } - ), + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "tensor_size": tensor.size, + "tensor_shape": tensor.shape, + "request_id": request_id, + "inference_state": inference_state, + }), ) ) start_time = time.perf_counter_ns() @@ -188,18 +182,16 @@ class StandardNode(Node): asyncio.create_task( self.broadcast_opaque_status( request_id, - json.dumps( - { - "type": "node_status", - "node_id": self.id, - "status": "end_process_tensor", - "base_shard": base_shard.to_dict(), - "shard": shard.to_dict(), - "request_id": request_id, - "elapsed_time_ns": elapsed_time_ns, - "result_size": resp.size if resp is not None else 0, - } - ), + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), ) ) return resp @@ -257,7 +249,7 @@ class StandardNode(Node): current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) if DEBUG >= 1: print(f"Current partition index: {current_partition_index}") if current_partition_index is not None: - next_partition_index = (current_partition_index + 1) % len(partitions) + next_partition_index = (current_partition_index+1) % len(partitions) next_partition: Partition = partitions[next_partition_index] next_shard = shards[next_partition_index] if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}") diff --git a/exo/stats/metrics.py b/exo/stats/metrics.py index caa125a6a..f29533ff7 100644 --- a/exo/stats/metrics.py +++ b/exo/stats/metrics.py @@ -24,6 +24,6 @@ def start_metrics_server(node: Node, port: int): elif status == "end_process_tensor": elapsed_time_ns = status_data.get("elapsed_time_ns", 0) PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc() - PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9) # Convert ns to seconds + PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds node.on_opaque_status.register("stats").on_next(_on_opaque_status) diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index ccae16963..6b8de77f1 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -44,78 +44,78 @@ CHIP_FLOPS = { # Source: https://www.cpu-monkey.com # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative ### M chips - "Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS), - "Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS), - "Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS), - "Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS), - "Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS), - "Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS), - "Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS), - "Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS), - "Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS), - "Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS), - "Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS), - "Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS), + "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS), + "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS), + "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS), + "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS), + "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS), + "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS), + "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS), + "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS), + "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS), + "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), ### A chips - "Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS), - "Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS), - "Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS), - "Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS), - "Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS), + "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS), + "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS), + "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS), + "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS), + "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS), ### NVIDIA GPUs # RTX 40 series - "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS), - "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS), - "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS), - "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS), - "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS), - "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS), - "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS), - "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS), + "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS), + "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS), + "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS), + "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS), + "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS), # RTX 30 series - "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS), - "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS), - "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS), - "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS), - "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS), - "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS), - "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS), - "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS), - "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS), - "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS), + "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS), + "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS), + "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), + "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS), + "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS), + "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), # QUATRO RTX Ampere series - "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99 * TFLOPS, fp16=7.99 * TFLOPS, int8=31.91 * TFLOPS), - "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17 * TFLOPS, fp16=19.17 * TFLOPS, int8=76.68 * TFLOPS), - "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65 * TFLOPS, fp16=23.65 * TFLOPS, int8=94.6 * TFLOPS), - "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8 * TFLOPS, fp16=27.8 * TFLOPS, int8=111.2 * TFLOPS), - "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71 * TFLOPS, fp16=38.71 * TFLOPS, int8=154.84 * TFLOPS), + "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS), + "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS), + "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS), + "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS), + "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS), # Common Server GPUs - "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4 * TFLOPS, fp16=149.7 * TFLOPS, int8=299.3 * TFLOPS), - "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), - "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), - "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), - "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), - "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), - "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), + "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS), + "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series - "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS), - "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS), - "AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS), - "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS), - "AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS), - "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS), - "AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS), - "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS), - "AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS), + "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS), + "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS), + "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS), + "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS), + "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS), + "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS), + "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS), + "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS), + "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS), # RX 7000 series - "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS), - "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS), - "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS), - "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS), - "AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS), - "AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS), + "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS), + "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS), + "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS), + "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS), + "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS), + "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), # ... add more devices if needed ... ### Qualcomm embedded chips: TODO } @@ -151,7 +151,7 @@ def mac_device_capabilities() -> DeviceCapabilities: memory_units = memory_str.split() memory_value = int(memory_units[0]) if memory_units[1] == "GB": - memory = memory_value * 1024 + memory = memory_value*1024 else: memory = memory_value diff --git a/exo/topology/partitioning_strategy.py b/exo/topology/partitioning_strategy.py index 72a2fd3e3..29c3dc6a9 100644 --- a/exo/topology/partitioning_strategy.py +++ b/exo/topology/partitioning_strategy.py @@ -22,8 +22,8 @@ class PartitioningStrategy(ABC): def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]: shards = [] for i, partition in enumerate(partitions): - start_layer = int(partition.start * num_layers) - end_layer = int(partition.end * num_layers) - 1 + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 # Ensure the last partition covers up to num_layers - 1 if i == len(partitions) - 1: diff --git a/exo/topology/ring_memory_weighted_partitioning_strategy.py b/exo/topology/ring_memory_weighted_partitioning_strategy.py index 4bf93ccd5..6550aeb19 100644 --- a/exo/topology/ring_memory_weighted_partitioning_strategy.py +++ b/exo/topology/ring_memory_weighted_partitioning_strategy.py @@ -12,7 +12,7 @@ class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy): partitions = [] start = 0 for node in nodes: - end = round(start + (node[1].memory / total_memory), 5) + end = round(start + (node[1].memory/total_memory), 5) partitions.append(Partition(node[0], start, end)) start = end return partitions diff --git a/exo/topology/test_device_capabilities.py b/exo/topology/test_device_capabilities.py index 5841156a0..5f8b4c3ac 100644 --- a/exo/topology/test_device_capabilities.py +++ b/exo/topology/test_device_capabilities.py @@ -80,7 +80,7 @@ Activation Lock Status: Disabled self.assertEqual(result.model, "MacBook Pro") self.assertEqual(result.chip, "Apple M3 Max") self.assertEqual(result.memory, 131072) # 128 GB in MB - self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS)) + self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)) self.assertEqual( str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS", diff --git a/exo/topology/test_map_partitions.py b/exo/topology/test_map_partitions.py index de77a6a4b..5254915e6 100644 --- a/exo/topology/test_map_partitions.py +++ b/exo/topology/test_map_partitions.py @@ -56,8 +56,8 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str): shards = [] for i, partition in enumerate(partitions): - start_layer = int(partition.start * num_layers) - end_layer = int(partition.end * num_layers) - 1 + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 shards.append(Shard(model_id, start_layer, end_layer, num_layers)) return shards diff --git a/exo/topology/test_ring_memory_weighted_partitioning_strategy.py b/exo/topology/test_ring_memory_weighted_partitioning_strategy.py index 33209b660..fd466f367 100644 --- a/exo/topology/test_ring_memory_weighted_partitioning_strategy.py +++ b/exo/topology/test_ring_memory_weighted_partitioning_strategy.py @@ -49,7 +49,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): DeviceCapabilities( model="MacBook Pro", chip="test1", - memory=128 * 1024 * 1024 * 1024, + memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0), ), ) @@ -58,7 +58,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): DeviceCapabilities( model="Mac Studio", chip="test2", - memory=192 * 1024 * 1024 * 1024, + memory=192*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0), ), ) @@ -67,7 +67,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): DeviceCapabilities( model="MacBook Pro", chip="test3", - memory=128 * 1024 * 1024 * 1024, + memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0), ), ) diff --git a/exo/viz/test_topology_viz.py b/exo/viz/test_topology_viz.py index faddfb979..e57de1ae3 100644 --- a/exo/viz/test_topology_viz.py +++ b/exo/viz/test_topology_viz.py @@ -66,19 +66,19 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase): self.topology = Topology() self.topology.update_node( "node1", - DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)), + DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)), ) self.topology.update_node( "node2", - DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)), + DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)), ) self.topology.update_node( "node3", - DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)), + DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)), ) self.topology.update_node( "node4", - DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)), + DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)), ) self.top_viz = TopologyViz() diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index a79e73a87..3664f3783 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -99,7 +99,7 @@ class TopologyViz: # Process prompt prompt_lines = prompt.split('\n') if len(prompt_lines) > max_lines // 2: - prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...'] + prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...'] prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white") @@ -139,7 +139,7 @@ class TopologyViz: max_line_length = max(len(line) for line in exo_lines) for i, line in enumerate(exo_lines): centered_line = line.center(max_line_length) - start_x = (100 - max_line_length) // 2 + 15 + start_x = (100-max_line_length) // 2 + 15 colored_line = Text(centered_line, style=yellow_style) for j, char in enumerate(str(colored_line)): if 0 <= start_x + j < 100 and i < len(visualization): @@ -161,18 +161,18 @@ class TopologyViz: # Calculate total FLOPS and position on the bar total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions) - bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2 + bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2 # Add GPU poor/rich bar bar_width = 30 - bar_start_x = (100 - bar_width) // 2 + bar_start_x = (100-bar_width) // 2 bar_y = info_start_y + len(info_lines) + 1 # Create a gradient bar using emojis gradient_bar = Text() emojis = ["🟥", "🟧", "🟨", "🟩"] for i in range(bar_width): - emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1) + emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1) gradient_bar.append(emojis[emoji_index]) # Add the gradient bar to the visualization @@ -183,10 +183,10 @@ class TopologyViz: # Add labels visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor" - visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich" + visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich" # Add position indicator and FLOPS value - pos_x = bar_start_x + int(bar_pos * bar_width) + pos_x = bar_start_x + int(bar_pos*bar_width) flops_str = f"{total_flops:.2f} TFLOPS" visualization[bar_y - 1][pos_x] = "▼" visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str @@ -198,9 +198,9 @@ class TopologyViz: for i, partition in enumerate(self.partitions): device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES) - angle = 2 * math.pi * i / num_partitions - x = int(center_x + radius_x * math.cos(angle)) - y = int(center_y + radius_y * math.sin(angle)) + angle = 2*math.pi*i/num_partitions + x = int(center_x + radius_x*math.cos(angle)) + y = int(center_y + radius_y*math.sin(angle)) # Place node with different color for active node and this node if partition.node_id == self.topology.active_node_id: @@ -220,8 +220,8 @@ class TopologyViz: # Calculate info position based on angle info_distance_x = radius_x + 6 info_distance_y = radius_y + 3 - info_x = int(center_x + info_distance_x * math.cos(angle)) - info_y = int(center_y + info_distance_y * math.sin(angle)) + info_x = int(center_x + info_distance_x*math.cos(angle)) + info_y = int(center_y + info_distance_y*math.sin(angle)) # Adjust text position to avoid overwriting the node icon and prevent cutoff if info_x < x: @@ -230,9 +230,9 @@ class TopologyViz: info_x = min(99 - len(max(node_info, key=len)), info_x) # Adjust for top and bottom nodes - if 5 * math.pi / 4 < angle < 7 * math.pi / 4: + if 5*math.pi/4 < angle < 7*math.pi/4: info_x += 4 - elif math.pi / 4 < angle < 3 * math.pi / 4: + elif math.pi/4 < angle < 3*math.pi/4: info_x += 3 info_y -= 2 @@ -243,16 +243,16 @@ class TopologyViz: visualization[info_y + j][info_x + k] = char # Draw line to next node - next_i = (i + 1) % num_partitions - next_angle = 2 * math.pi * next_i / num_partitions - next_x = int(center_x + radius_x * math.cos(next_angle)) - next_y = int(center_y + radius_y * math.sin(next_angle)) + next_i = (i+1) % num_partitions + next_angle = 2*math.pi*next_i/num_partitions + next_x = int(center_x + radius_x*math.cos(next_angle)) + next_y = int(center_y + radius_y*math.sin(next_angle)) # Simple line drawing steps = max(abs(next_x - x), abs(next_y - y)) for step in range(1, steps): - line_x = int(x + (next_x - x) * step / steps) - line_y = int(y + (next_y - y) * step / steps) + line_x = int(x + (next_x-x)*step/steps) + line_y = int(y + (next_y-y)*step/steps) if 0 <= line_y < 48 and 0 <= line_x < 100: visualization[line_y][line_x] = "-" @@ -280,7 +280,7 @@ class TopologyViz: for file_path, file_progress in download_progress.file_progress.items(): if file_progress.status != "complete": - progress = int(file_progress.downloaded / file_progress.total * 30) + progress = int(file_progress.downloaded/file_progress.total*30) bar = f"[{'=' * progress}{' ' * (30 - progress)}]" percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%" summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage) @@ -294,7 +294,7 @@ class TopologyViz: device = self.topology.nodes.get(node_id) partition = next((p for p in self.partitions if p.node_id == node_id), None) partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else "" - percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0 + percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0 speed = pretty_print_bytes_per_second(progress.overall_speed) device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}" progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"