mirror of
https://github.com/exo-explore/exo.git
synced 2025-06-03 06:30:24 +00:00
more compact operator formatting
This commit is contained in:
parent
14f2846a9c
commit
f53056dede
@ -11,4 +11,9 @@ indent_dictionary_value = True
|
||||
allow_multiline_dictionary_keys = True
|
||||
each_dict_entry_on_separate_line = False
|
||||
allow_multiline_lambdas = True
|
||||
blank_line_before_nested_class_or_def = False
|
||||
blank_line_before_nested_class_or_def = False
|
||||
arithmetic_precedence_indication = True
|
||||
no_spaces_around_selected_binary_operators = "*,/"
|
||||
coalesce_brackets = True
|
||||
space_between_ending_comma_and_closing_bracket = False
|
||||
split_before_expression_after_opening_paren = False
|
@ -158,7 +158,7 @@ class ChatGPTAPI:
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout_secs = response_timeout_secs
|
||||
self.on_chat_completion_request = on_chat_completion_request
|
||||
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
|
||||
self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
|
||||
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
|
||||
self.prev_token_lens: Dict[str, int] = {}
|
||||
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
||||
@ -171,7 +171,7 @@ class ChatGPTAPI:
|
||||
)
|
||||
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
||||
self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
|
||||
self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
|
||||
@ -186,7 +186,7 @@ class ChatGPTAPI:
|
||||
return middleware
|
||||
|
||||
async def handle_root(self, request):
|
||||
return web.FileResponse(self.static_dir / "index.html")
|
||||
return web.FileResponse(self.static_dir/"index.html")
|
||||
|
||||
async def handle_post_chat_token_encode(self, request):
|
||||
data = await request.json()
|
||||
|
@ -62,12 +62,12 @@ def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
|
||||
def get_hf_home() -> Path:
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
|
||||
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
|
||||
|
||||
|
||||
async def get_hf_token():
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home() / "token"
|
||||
token_path = get_hf_home()/"token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, 'r') as f:
|
||||
return (await f.read()).strip()
|
||||
@ -85,7 +85,7 @@ async def get_auth_headers():
|
||||
def get_repo_root(repo_id: str) -> Path:
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = repo_id.replace("/", "--")
|
||||
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
|
||||
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
||||
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
@ -181,9 +181,9 @@ async def download_file(
|
||||
downloaded_this_session += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
|
||||
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
||||
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
||||
@ -199,9 +199,9 @@ async def download_repo_files(
|
||||
max_parallel_downloads: int = 4
|
||||
) -> Path:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root / "refs"
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
cachedreqs_dir = repo_root / "cachedreqs"
|
||||
refs_dir = repo_root/"refs"
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
cachedreqs_dir = repo_root/"cachedreqs"
|
||||
|
||||
# Ensure directories exist
|
||||
await aios.makedirs(refs_dir, exist_ok=True)
|
||||
@ -209,7 +209,7 @@ async def download_repo_files(
|
||||
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
||||
|
||||
# Check if we have a cached commit hash
|
||||
refs_file = refs_dir / revision
|
||||
refs_file = refs_dir/revision
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
@ -230,13 +230,13 @@ async def download_repo_files(
|
||||
await f.write(commit_hash)
|
||||
|
||||
# Set up the snapshot directory
|
||||
snapshot_dir = snapshots_dir / commit_hash
|
||||
snapshot_dir = snapshots_dir/commit_hash
|
||||
await aios.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
# Set up the cached file list directory
|
||||
cached_file_list_dir = cachedreqs_dir / commit_hash
|
||||
cached_file_list_dir = cachedreqs_dir/commit_hash
|
||||
await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
||||
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
|
||||
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Check if we have a cached file list
|
||||
@ -261,7 +261,7 @@ async def download_repo_files(
|
||||
start_time = datetime.now()
|
||||
|
||||
async def download_with_progress(file_info, progress_state):
|
||||
local_path = snapshot_dir / file_info["path"]
|
||||
local_path = snapshot_dir/file_info["path"]
|
||||
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
||||
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
||||
progress_state['completed_files'] += 1
|
||||
@ -269,9 +269,9 @@ async def download_repo_files(
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
@ -287,9 +287,9 @@ async def download_repo_files(
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
@ -305,9 +305,9 @@ async def download_repo_files(
|
||||
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
@ -347,11 +347,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
|
||||
|
||||
# Check if the file exists
|
||||
repo_root = get_repo_root(repo_id)
|
||||
snapshot_dir = repo_root / "snapshots"
|
||||
snapshot_dir = repo_root/"snapshots"
|
||||
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
||||
|
||||
if index_file:
|
||||
index_file_path = snapshot_dir / index_file
|
||||
index_file_path = snapshot_dir/index_file
|
||||
if await aios.path.exists(index_file_path):
|
||||
async with aiofiles.open(index_file_path, 'r') as f:
|
||||
index_data = json.loads(await f.read())
|
||||
|
@ -22,7 +22,7 @@ class HFShardDownloader(ShardDownloader):
|
||||
return self.completed_downloads[shard]
|
||||
if self.quick_check:
|
||||
repo_root = get_repo_root(shard.model_id)
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
if snapshots_dir.exists():
|
||||
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
|
||||
return most_recent_dir
|
||||
|
@ -169,7 +169,7 @@ def is_valid_uuid(val):
|
||||
|
||||
|
||||
def get_or_create_node_id():
|
||||
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
|
||||
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id"
|
||||
try:
|
||||
if NODE_ID_FILE.is_file():
|
||||
with open(NODE_ID_FILE, "r") as f:
|
||||
|
@ -10,7 +10,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
from exo.inference.tinygrad.inference import Tokenizer
|
||||
from pathlib import Path
|
||||
|
||||
_tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
|
||||
_tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
|
||||
|
||||
prompt = "In a single word only, what is the last name of the president of the United States? "
|
||||
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
|
||||
|
@ -59,7 +59,7 @@ class DeepseekV2Model(nn.Module):
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
@ -58,7 +58,7 @@ class LlamaModel(nn.Module):
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
@ -74,8 +74,8 @@ class VisionAttention(nn.Module):
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
scale = math.sqrt(1/queries.shape[-1])
|
||||
scores = (queries*scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
@ -129,7 +129,7 @@ class VisionEmbeddings(nn.Module):
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = mx.zeros((config.hidden_size, ))
|
||||
self.class_embedding = mx.zeros((config.hidden_size,))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
@ -170,12 +170,12 @@ class ClipVisionModel(nn.Module):
|
||||
x = self.embeddings(x)
|
||||
x = self.pre_layrnorm(x)
|
||||
|
||||
encoder_states = (x, ) if output_hidden_states else None
|
||||
encoder_states = (x,) if output_hidden_states else None
|
||||
|
||||
for l in self.encoder.layers:
|
||||
x = l(x, mask=None)
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (x, )
|
||||
encoder_states = encoder_states + (x,)
|
||||
|
||||
pooler_output = self.post_layernorm(x[:, 0, :])
|
||||
return pooler_output, x, encoder_states
|
||||
@ -263,12 +263,12 @@ class TextAttention(nn.Module):
|
||||
head_dim = config.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
|
||||
rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
@ -312,7 +312,7 @@ class TextMLP(nn.Module):
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
@ -382,7 +382,7 @@ class Llama(nn.Module):
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
@ -38,7 +38,7 @@ class StatefulShardedModel:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
token = mx.random.categorical(logits*(1/temp))
|
||||
|
||||
return token
|
||||
|
||||
@ -74,7 +74,7 @@ class StatefulShardedModel:
|
||||
return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
|
||||
|
||||
def init_cache(self, request_id: str):
|
||||
kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
||||
kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
||||
if self.max_kv_size is not None:
|
||||
cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
||||
else:
|
||||
|
@ -60,7 +60,7 @@ def _get_classes(config: dict):
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
with open(model_path/"config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
@ -103,11 +103,11 @@ def load_model_shard(
|
||||
"n_layers": shard.n_layers,
|
||||
}
|
||||
|
||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||
weight_files = glob.glob(str(model_path/"model*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
|
@ -38,7 +38,7 @@ model.save_weights("./test_weights.npz")
|
||||
n_layers = 5
|
||||
shard1 = Shard("test", 0, n_layers // 2, n_layers)
|
||||
sharded_model1 = DummyModel(shard1)
|
||||
shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
|
||||
shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
|
||||
sharded_model2 = DummyModel(shard2)
|
||||
|
||||
model.load_weights("./test_weights.npz")
|
||||
|
@ -33,9 +33,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
||||
|
||||
# load weights
|
||||
if model_path.is_dir():
|
||||
if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
|
||||
elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
|
||||
else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
|
||||
if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
|
||||
elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
|
||||
else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
|
||||
else:
|
||||
weights = load(str(model_path), shard)
|
||||
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
||||
@ -60,7 +60,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
toks = self.tokenizer.encode(prompt)
|
||||
h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
|
||||
|
||||
if h.shape == (1, ):
|
||||
if h.shape == (1,):
|
||||
start_pos += len(toks)
|
||||
start_pos += 1
|
||||
n_captured_toks = 0
|
||||
@ -76,7 +76,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
|
||||
|
||||
if h.shape == (1, ):
|
||||
if h.shape == (1,):
|
||||
start_pos += n_captured_toks
|
||||
start_pos += 1
|
||||
n_captured_toks = 0
|
||||
|
@ -5,8 +5,8 @@ from tinygrad.helpers import getenv
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
|
||||
freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
# TODO: move dtype outside this
|
||||
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
|
||||
|
||||
@ -14,8 +14,8 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtype
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a, b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a * c - b * d
|
||||
co = a * d + b * c
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
|
||||
|
||||
|
||||
class Attention:
|
||||
@ -45,10 +45,10 @@ class Attention:
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
@ -93,7 +93,7 @@ class FeedForward:
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
return self.w2(self.w1(x).silu()*self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
@ -121,29 +121,29 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
if af or ap:
|
||||
if not hasattr(sample, "alpha_counter"):
|
||||
setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
|
||||
logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
|
||||
logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
|
||||
|
||||
# replace NaNs with -inf
|
||||
logits = (logits != logits).where(-float("inf"), logits)
|
||||
|
||||
# softmax
|
||||
t = (logits / temp).softmax()
|
||||
t = (logits/temp).softmax()
|
||||
|
||||
counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
|
||||
# top k
|
||||
if k:
|
||||
output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
|
||||
for i in range(k):
|
||||
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
|
||||
t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
t = (counter == t_argmax).where(0, t)
|
||||
|
||||
# approximate top p
|
||||
# because we are already limited to top k elements we can do top p "without sorting"
|
||||
output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
|
||||
output = (output_cumsum >= (1 - p)) * output
|
||||
output_indices = (output_cumsum >= (1 - p)) * output_indices
|
||||
output = (output_cumsum >= (1 - p))*output
|
||||
output_indices = (output_cumsum >= (1 - p))*output_indices
|
||||
|
||||
# sample
|
||||
output_idx = output.multinomial()
|
||||
@ -183,7 +183,7 @@ class Transformer:
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
self.shard = shard
|
||||
|
||||
|
@ -37,7 +37,7 @@ def load(fn: str, shard: Shard):
|
||||
if layer_num < shard.start_layer or layer_num > shard.end_layer:
|
||||
continue
|
||||
|
||||
parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
|
||||
parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
|
||||
filtered_weight_map[k] = n
|
||||
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
|
||||
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
|
||||
|
@ -2,32 +2,28 @@ from exo.inference.shard import Shard
|
||||
|
||||
model_base_shards = {
|
||||
### llama
|
||||
"llama-3.1-8b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3.1-70b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
|
||||
"llama-3-8b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3-70b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-8b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3.1-70b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
|
||||
"llama-3-8b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3-70b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
### mistral
|
||||
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
|
||||
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
|
||||
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
|
||||
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
|
||||
### deepseek v2
|
||||
"deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
|
||||
"deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
|
||||
### llava
|
||||
"llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
|
||||
"llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return self._device_capabilities
|
||||
|
||||
async def connect(self):
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
|
@ -21,9 +21,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
self.server = grpc.aio.server(
|
||||
futures.ThreadPoolExecutor(max_workers=10),
|
||||
options=[
|
||||
("grpc.max_metadata_size", 32 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 128 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 128 * 1024 * 1024),
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
("grpc.max_send_message_length", 128*1024*1024),
|
||||
("grpc.max_receive_message_length", 128*1024*1024),
|
||||
],
|
||||
)
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
|
@ -150,7 +150,7 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler, ))
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
|
||||
|
||||
|
||||
|
@ -84,19 +84,17 @@ class StandardNode(Node):
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"image_str": image_str,
|
||||
"inference_state": inference_state,
|
||||
"request_id": request_id,
|
||||
}
|
||||
),
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"image_str": image_str,
|
||||
"inference_state": inference_state,
|
||||
"request_id": request_id,
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
@ -106,21 +104,19 @@ class StandardNode(Node):
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"image_str": image_str,
|
||||
"inference_state": inference_state,
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}
|
||||
),
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"image_str": image_str,
|
||||
"inference_state": inference_state,
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}),
|
||||
)
|
||||
)
|
||||
return resp
|
||||
@ -166,19 +162,17 @@ class StandardNode(Node):
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"tensor_size": tensor.size,
|
||||
"tensor_shape": tensor.shape,
|
||||
"request_id": request_id,
|
||||
"inference_state": inference_state,
|
||||
}
|
||||
),
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"tensor_size": tensor.size,
|
||||
"tensor_shape": tensor.shape,
|
||||
"request_id": request_id,
|
||||
"inference_state": inference_state,
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
@ -188,18 +182,16 @@ class StandardNode(Node):
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}
|
||||
),
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}),
|
||||
)
|
||||
)
|
||||
return resp
|
||||
@ -257,7 +249,7 @@ class StandardNode(Node):
|
||||
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
||||
if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
|
||||
if current_partition_index is not None:
|
||||
next_partition_index = (current_partition_index + 1) % len(partitions)
|
||||
next_partition_index = (current_partition_index+1) % len(partitions)
|
||||
next_partition: Partition = partitions[next_partition_index]
|
||||
next_shard = shards[next_partition_index]
|
||||
if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
||||
|
@ -24,6 +24,6 @@ def start_metrics_server(node: Node, port: int):
|
||||
elif status == "end_process_tensor":
|
||||
elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
|
||||
PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
|
||||
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9) # Convert ns to seconds
|
||||
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds
|
||||
|
||||
node.on_opaque_status.register("stats").on_next(_on_opaque_status)
|
||||
|
@ -44,78 +44,78 @@ CHIP_FLOPS = {
|
||||
# Source: https://www.cpu-monkey.com
|
||||
# Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
|
||||
### M chips
|
||||
"Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS),
|
||||
"Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS),
|
||||
"Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS),
|
||||
"Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS),
|
||||
"Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
|
||||
"Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS),
|
||||
"Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS),
|
||||
"Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS),
|
||||
"Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
|
||||
"Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS),
|
||||
"Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS),
|
||||
"Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
|
||||
"Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
|
||||
"Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
|
||||
"Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
|
||||
"Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
|
||||
"Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
|
||||
"Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
|
||||
"Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
|
||||
"Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
|
||||
"Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
|
||||
"Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
|
||||
"Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
|
||||
"Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
|
||||
### A chips
|
||||
"Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS),
|
||||
"Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS),
|
||||
"Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS),
|
||||
"Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS),
|
||||
"Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS),
|
||||
"Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
|
||||
"Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
|
||||
"Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
|
||||
"Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
|
||||
"Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
|
||||
### NVIDIA GPUs
|
||||
# RTX 40 series
|
||||
"NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
|
||||
# RTX 30 series
|
||||
"NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
|
||||
"NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
|
||||
# QUATRO RTX Ampere series
|
||||
"NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99 * TFLOPS, fp16=7.99 * TFLOPS, int8=31.91 * TFLOPS),
|
||||
"NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17 * TFLOPS, fp16=19.17 * TFLOPS, int8=76.68 * TFLOPS),
|
||||
"NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65 * TFLOPS, fp16=23.65 * TFLOPS, int8=94.6 * TFLOPS),
|
||||
"NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8 * TFLOPS, fp16=27.8 * TFLOPS, int8=111.2 * TFLOPS),
|
||||
"NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71 * TFLOPS, fp16=38.71 * TFLOPS, int8=154.84 * TFLOPS),
|
||||
"NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
|
||||
"NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
|
||||
"NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
|
||||
"NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
|
||||
"NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
|
||||
# Common Server GPUs
|
||||
"NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4 * TFLOPS, fp16=149.7 * TFLOPS, int8=299.3 * TFLOPS),
|
||||
"NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
|
||||
"NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
|
||||
"NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
"NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
"NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
"NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
"NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
"NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
|
||||
# ... add more devices if needed ...
|
||||
### AMD GPUs
|
||||
# RX 6000 series
|
||||
"AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS),
|
||||
"AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS),
|
||||
"AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS),
|
||||
"AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS),
|
||||
"AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS),
|
||||
"AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS),
|
||||
"AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS),
|
||||
"AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS),
|
||||
"AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS),
|
||||
"AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
|
||||
"AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
|
||||
"AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
|
||||
"AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
|
||||
"AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
|
||||
"AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
|
||||
"AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
|
||||
"AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
|
||||
"AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
|
||||
# RX 7000 series
|
||||
"AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS),
|
||||
"AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS),
|
||||
"AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS),
|
||||
"AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS),
|
||||
"AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS),
|
||||
"AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
|
||||
"AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
|
||||
"AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
|
||||
"AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
|
||||
"AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
|
||||
"AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
|
||||
"AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
|
||||
# ... add more devices if needed ...
|
||||
### Qualcomm embedded chips: TODO
|
||||
}
|
||||
@ -151,7 +151,7 @@ def mac_device_capabilities() -> DeviceCapabilities:
|
||||
memory_units = memory_str.split()
|
||||
memory_value = int(memory_units[0])
|
||||
if memory_units[1] == "GB":
|
||||
memory = memory_value * 1024
|
||||
memory = memory_value*1024
|
||||
else:
|
||||
memory = memory_value
|
||||
|
||||
|
@ -22,8 +22,8 @@ class PartitioningStrategy(ABC):
|
||||
def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
|
||||
shards = []
|
||||
for i, partition in enumerate(partitions):
|
||||
start_layer = int(partition.start * num_layers)
|
||||
end_layer = int(partition.end * num_layers) - 1
|
||||
start_layer = int(partition.start*num_layers)
|
||||
end_layer = int(partition.end*num_layers) - 1
|
||||
|
||||
# Ensure the last partition covers up to num_layers - 1
|
||||
if i == len(partitions) - 1:
|
||||
|
@ -12,7 +12,7 @@ class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
|
||||
partitions = []
|
||||
start = 0
|
||||
for node in nodes:
|
||||
end = round(start + (node[1].memory / total_memory), 5)
|
||||
end = round(start + (node[1].memory/total_memory), 5)
|
||||
partitions.append(Partition(node[0], start, end))
|
||||
start = end
|
||||
return partitions
|
||||
|
@ -80,7 +80,7 @@ Activation Lock Status: Disabled
|
||||
self.assertEqual(result.model, "MacBook Pro")
|
||||
self.assertEqual(result.chip, "Apple M3 Max")
|
||||
self.assertEqual(result.memory, 131072) # 128 GB in MB
|
||||
self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS))
|
||||
self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
|
||||
self.assertEqual(
|
||||
str(result),
|
||||
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
|
||||
|
@ -56,8 +56,8 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
|
||||
shards = []
|
||||
for i, partition in enumerate(partitions):
|
||||
start_layer = int(partition.start * num_layers)
|
||||
end_layer = int(partition.end * num_layers) - 1
|
||||
start_layer = int(partition.start*num_layers)
|
||||
end_layer = int(partition.end*num_layers) - 1
|
||||
shards.append(Shard(model_id, start_layer, end_layer, num_layers))
|
||||
return shards
|
||||
|
||||
|
@ -49,7 +49,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
DeviceCapabilities(
|
||||
model="MacBook Pro",
|
||||
chip="test1",
|
||||
memory=128 * 1024 * 1024 * 1024,
|
||||
memory=128*1024*1024*1024,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
),
|
||||
)
|
||||
@ -58,7 +58,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
DeviceCapabilities(
|
||||
model="Mac Studio",
|
||||
chip="test2",
|
||||
memory=192 * 1024 * 1024 * 1024,
|
||||
memory=192*1024*1024*1024,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
),
|
||||
)
|
||||
@ -67,7 +67,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
DeviceCapabilities(
|
||||
model="MacBook Pro",
|
||||
chip="test3",
|
||||
memory=128 * 1024 * 1024 * 1024,
|
||||
memory=128*1024*1024*1024,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
),
|
||||
)
|
||||
|
@ -66,19 +66,19 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
|
||||
self.topology = Topology()
|
||||
self.topology.update_node(
|
||||
"node1",
|
||||
DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
|
||||
DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
|
||||
)
|
||||
self.topology.update_node(
|
||||
"node2",
|
||||
DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
|
||||
DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
|
||||
)
|
||||
self.topology.update_node(
|
||||
"node3",
|
||||
DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
|
||||
DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
|
||||
)
|
||||
self.topology.update_node(
|
||||
"node4",
|
||||
DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
|
||||
DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
|
||||
)
|
||||
|
||||
self.top_viz = TopologyViz()
|
||||
|
@ -99,7 +99,7 @@ class TopologyViz:
|
||||
# Process prompt
|
||||
prompt_lines = prompt.split('\n')
|
||||
if len(prompt_lines) > max_lines // 2:
|
||||
prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
|
||||
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
|
||||
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
||||
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
||||
|
||||
@ -139,7 +139,7 @@ class TopologyViz:
|
||||
max_line_length = max(len(line) for line in exo_lines)
|
||||
for i, line in enumerate(exo_lines):
|
||||
centered_line = line.center(max_line_length)
|
||||
start_x = (100 - max_line_length) // 2 + 15
|
||||
start_x = (100-max_line_length) // 2 + 15
|
||||
colored_line = Text(centered_line, style=yellow_style)
|
||||
for j, char in enumerate(str(colored_line)):
|
||||
if 0 <= start_x + j < 100 and i < len(visualization):
|
||||
@ -161,18 +161,18 @@ class TopologyViz:
|
||||
|
||||
# Calculate total FLOPS and position on the bar
|
||||
total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
|
||||
bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
|
||||
bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2
|
||||
|
||||
# Add GPU poor/rich bar
|
||||
bar_width = 30
|
||||
bar_start_x = (100 - bar_width) // 2
|
||||
bar_start_x = (100-bar_width) // 2
|
||||
bar_y = info_start_y + len(info_lines) + 1
|
||||
|
||||
# Create a gradient bar using emojis
|
||||
gradient_bar = Text()
|
||||
emojis = ["🟥", "🟧", "🟨", "🟩"]
|
||||
for i in range(bar_width):
|
||||
emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
|
||||
emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1)
|
||||
gradient_bar.append(emojis[emoji_index])
|
||||
|
||||
# Add the gradient bar to the visualization
|
||||
@ -183,10 +183,10 @@ class TopologyViz:
|
||||
|
||||
# Add labels
|
||||
visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
|
||||
visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
|
||||
visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich"
|
||||
|
||||
# Add position indicator and FLOPS value
|
||||
pos_x = bar_start_x + int(bar_pos * bar_width)
|
||||
pos_x = bar_start_x + int(bar_pos*bar_width)
|
||||
flops_str = f"{total_flops:.2f} TFLOPS"
|
||||
visualization[bar_y - 1][pos_x] = "▼"
|
||||
visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
|
||||
@ -198,9 +198,9 @@ class TopologyViz:
|
||||
for i, partition in enumerate(self.partitions):
|
||||
device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
|
||||
|
||||
angle = 2 * math.pi * i / num_partitions
|
||||
x = int(center_x + radius_x * math.cos(angle))
|
||||
y = int(center_y + radius_y * math.sin(angle))
|
||||
angle = 2*math.pi*i/num_partitions
|
||||
x = int(center_x + radius_x*math.cos(angle))
|
||||
y = int(center_y + radius_y*math.sin(angle))
|
||||
|
||||
# Place node with different color for active node and this node
|
||||
if partition.node_id == self.topology.active_node_id:
|
||||
@ -220,8 +220,8 @@ class TopologyViz:
|
||||
# Calculate info position based on angle
|
||||
info_distance_x = radius_x + 6
|
||||
info_distance_y = radius_y + 3
|
||||
info_x = int(center_x + info_distance_x * math.cos(angle))
|
||||
info_y = int(center_y + info_distance_y * math.sin(angle))
|
||||
info_x = int(center_x + info_distance_x*math.cos(angle))
|
||||
info_y = int(center_y + info_distance_y*math.sin(angle))
|
||||
|
||||
# Adjust text position to avoid overwriting the node icon and prevent cutoff
|
||||
if info_x < x:
|
||||
@ -230,9 +230,9 @@ class TopologyViz:
|
||||
info_x = min(99 - len(max(node_info, key=len)), info_x)
|
||||
|
||||
# Adjust for top and bottom nodes
|
||||
if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
|
||||
if 5*math.pi/4 < angle < 7*math.pi/4:
|
||||
info_x += 4
|
||||
elif math.pi / 4 < angle < 3 * math.pi / 4:
|
||||
elif math.pi/4 < angle < 3*math.pi/4:
|
||||
info_x += 3
|
||||
info_y -= 2
|
||||
|
||||
@ -243,16 +243,16 @@ class TopologyViz:
|
||||
visualization[info_y + j][info_x + k] = char
|
||||
|
||||
# Draw line to next node
|
||||
next_i = (i + 1) % num_partitions
|
||||
next_angle = 2 * math.pi * next_i / num_partitions
|
||||
next_x = int(center_x + radius_x * math.cos(next_angle))
|
||||
next_y = int(center_y + radius_y * math.sin(next_angle))
|
||||
next_i = (i+1) % num_partitions
|
||||
next_angle = 2*math.pi*next_i/num_partitions
|
||||
next_x = int(center_x + radius_x*math.cos(next_angle))
|
||||
next_y = int(center_y + radius_y*math.sin(next_angle))
|
||||
|
||||
# Simple line drawing
|
||||
steps = max(abs(next_x - x), abs(next_y - y))
|
||||
for step in range(1, steps):
|
||||
line_x = int(x + (next_x - x) * step / steps)
|
||||
line_y = int(y + (next_y - y) * step / steps)
|
||||
line_x = int(x + (next_x-x)*step/steps)
|
||||
line_y = int(y + (next_y-y)*step/steps)
|
||||
if 0 <= line_y < 48 and 0 <= line_x < 100:
|
||||
visualization[line_y][line_x] = "-"
|
||||
|
||||
@ -280,7 +280,7 @@ class TopologyViz:
|
||||
|
||||
for file_path, file_progress in download_progress.file_progress.items():
|
||||
if file_progress.status != "complete":
|
||||
progress = int(file_progress.downloaded / file_progress.total * 30)
|
||||
progress = int(file_progress.downloaded/file_progress.total*30)
|
||||
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
|
||||
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
|
||||
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
|
||||
@ -294,7 +294,7 @@ class TopologyViz:
|
||||
device = self.topology.nodes.get(node_id)
|
||||
partition = next((p for p in self.partitions if p.node_id == node_id), None)
|
||||
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
|
||||
percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
|
||||
percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0
|
||||
speed = pretty_print_bytes_per_second(progress.overall_speed)
|
||||
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
|
||||
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
|
||||
|
Loading…
Reference in New Issue
Block a user