diff --git a/ROCm_performance.md b/ROCm_performance.md index c6b7a0ec77a03..d9a3d854c129f 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -48,7 +48,7 @@ To obtain all the shapes of gemms during the execution of the model, set the env Next, run gradlib to obtain the best solutions of these shapes: ``` -python3 gradlib/gradlib/fp8_gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv +python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv --indtype fp8 --outdtype f16 ``` where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer. diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index 7888abb6e923c..34e95ba34dbb4 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -208,12 +208,18 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); + if (scaleA != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); + } + if (scaleB != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); + } + if (scaleC != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); + } // nvtxRangePop(); // if heuristic does not exist in the map, do search and push into the map // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 31de84cbe2052..eb5a20f0d2a29 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -12,33 +12,33 @@ rtol = 1e-5 atol = 1 -dtype = torch.float16 class Gemm: - def __init__(self, m, n, k, dtype, rocblas_decode=False): + def __init__(self, m, n, k, indtype, outdtype, rocblas_decode=False): self.m = m self.k = k self.n = n - self.dtype = dtype + self.indtype = indtype + self.outdtype = outdtype + self.use_rocblas = (indtype == outdtype + and indtype is not torch.float8_e4m3fnuz) self.nb = 37 self.inp = torch.randn((self.n, self.k), - dtype=self.dtype, - device='cuda') + device='cuda').to(self.indtype) self.weights = torch.randn((self.m, self.k), - dtype=self.dtype, - device='cuda') + device='cuda').to(self.indtype) # weights2 is used in measurement/warm iters to ensure # HBM fetch for weight tensors self.weights2 = torch.randn((self.nb, self.m, self.k), - dtype=self.dtype, - device='cuda') + device='cuda').to(self.indtype) self.blob = torch.ones(128 * 1024 * 1024, dtype=torch.float32, device='cuda') self.topn = 20 #number of top solutions from each source self.hipb_sols = [] + self.rocb_sols = [] self.rtol = 1e-5 self.atol = 1 self.start = torch.cuda.Event(enable_timing=True) @@ -49,7 +49,8 @@ def __init__(self, m, n, k, dtype, rocblas_decode=False): self.rocblas_decode = rocblas_decode def find_hipblas_sols(self): - sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t()) + sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(), + self.outdtype) print('M N K', self.m, self.n, @@ -61,34 +62,37 @@ def find_hipblas_sols(self): self.hipb_sols = sols def check_gemm_ref(self, libtype, solidx): - ref = F.linear(self.inp, self.weights) + ref = F.linear(self.inp.to(torch.float32), + self.weights.to(torch.float32)).to(self.outdtype) if libtype == 'hipblaslt': - c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx) + c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) elif libtype == 'rocblas': c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx) if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): #print('>>>',libtype,'Solidx',solidx,'passed reference test') return True - else: - print('>>>', - libtype, - 'Solidx', - solidx, - 'FAILED reference test', - flush=True) - print(ref, flush=True) - print(c, flush=True) - return False + + print('>>>', + libtype, + 'Solidx', + solidx, + 'FAILED reference test', + flush=True) + print(ref, flush=True) + print(c, flush=True) + return False def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): #print('>>>hipbtime',solidx) for i in range(cold_iters): - hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx) + hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) self.start.record() for i in range(warm_iters): hipbsolidxgemm.hipb_mm( self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), - solidx) + solidx, self.outdtype) self.end.record() torch.cuda.synchronize() gtime = self.start.elapsed_time(self.end) / warm_iters @@ -178,7 +182,8 @@ def functional_check_topn_fastest(self): self.hipb_top_sols = hipb_topn def find_fastest_solution(self): - self.find_rocblas_sols() + if self.use_rocblas: + self.find_rocblas_sols() if not (self.rocblas_decode and self.n == 1): self.find_hipblas_sols() self.warmup() @@ -228,9 +233,14 @@ def find_fastest_solution(self): class GemmTuner: - def __init__(self, dtype, tuned_file=None, rocblas_decode=False): + def __init__(self, + indtype, + outdtype, + tuned_file=None, + rocblas_decode=False): self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K']) - self.dtype = dtype + self.indtype = indtype + self.outdtype = outdtype self.rocblas_decode = rocblas_decode self.tuned_file = tuned_file if Path(tuned_file).is_file(): @@ -259,13 +269,15 @@ def find_best_sols(self): gemmobj = Gemm(ds['M'], ds['N'], ds['K'], - dtype=self.dtype, + indtype=self.indtype, + outdtype=self.outdtype, rocblas_decode=self.rocblas_decode) gemmobj.find_fastest_solution() soldf.loc[i, 'libtype'] = gemmobj.best_libtype soldf.loc[i, 'solidx'] = gemmobj.best_solidx soldf.loc[i, 'soltimems'] = gemmobj.best_soltime - soldf['dtype'] = self.dtype + soldf['indtype'] = self.indtype + soldf['outdtype'] = self.outdtype finaldf = pd.concat([self.gemm_problems, soldf], axis=1) finaldf = pd.concat([finaldf, self.gdf]) finaldf.to_csv(self.tuned_file, index=False) diff --git a/gradlib/gradlib/fp8_gemm_tuner.py b/gradlib/gradlib/fp8_gemm_tuner.py deleted file mode 100644 index babf304927605..0000000000000 --- a/gradlib/gradlib/fp8_gemm_tuner.py +++ /dev/null @@ -1,289 +0,0 @@ -import argparse -import json -import os -import random -from pathlib import Path - -import torch # isort: split -import hipbsolidxgemm -import pandas as pd -import torch.nn.functional as F - -hipbsolidxgemm.hipb_create_extension() - -rtol = 1e-5 -atol = 1 - - -class Fp8Gemm: - - def __init__(self, m, n, k, indtype, outdtype): - self.m = m - self.k = k - self.n = n - self.indtype = indtype - self.outdtype = outdtype - self.nb = 37 - self.inp = torch.randn((self.n, self.k), - device='cuda').to(self.indtype) - self.weights = torch.randn((self.m, self.k), - device='cuda').to(self.indtype) - # weights2 is used in measurement/warm iters to ensure HBM - # fetch for weight tensors - self.weights2 = torch.randn((self.nb, self.m, self.k), - device='cuda').to(self.indtype) - self.blob = torch.ones(128 * 1024 * 1024, - dtype=torch.float32, - device='cuda') - self.topn = 20 #number of top solutions from each source - self.hipb_sols = [] - self.rtol = 1e-5 - self.atol = 1 - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - - def find_hipblas_sols(self): - sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(), - self.outdtype) - print('M N K', - self.m, - self.n, - self.k, - '>>> Total hipb solutions', - len(sols), - flush=True) - #print(sols) - self.hipb_sols = sols - - def check_gemm_ref(self, libtype, solidx): - ref = F.linear(self.inp.to(torch.float32), - self.weights.to(torch.float32)).to(self.outdtype) - c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, - self.outdtype) - if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): - #print('>>>',libtype,'Solidx',solidx,'passed reference test') - return True - else: - print('>>>', 'Solidx', solidx, 'FAILED reference test', flush=True) - print(ref, flush=True) - print(c, flush=True) - return False - - def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): - #print('>>>hipbtime',solidx) - for i in range(cold_iters): - hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, - self.outdtype) - self.start.record() - for i in range(warm_iters): - hipbsolidxgemm.hipb_mm( - self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), - solidx, self.outdtype) - self.end.record() - torch.cuda.synchronize() - gtime = self.start.elapsed_time(self.end) / warm_iters - #print('>>> Solidx GTime',solidx,gtime,'ms') - return gtime - - def hipb_time_all_sols(self, fast_mode=0, top_sols=0): - coldi = 20 - warmi = 20 - if fast_mode: - coldi = 2 - warmi = 2 - solutions = self.hipb_sols - if top_sols: - solutions = self.hipb_top_sols - gtimes = {} - for solidx in solutions: - gtimes[solidx] = self.hipb_time_sol(solidx, - cold_iters=coldi, - warm_iters=warmi) - self.hipb_gtimedf = pd.DataFrame.from_dict( - gtimes, orient='index', - columns=['gtimems']).sort_values(by='gtimems') - self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') - print('>>> HipBlasLt top solutions, Fast Mode', fast_mode) - print(self.hipb_gtimedf.head(self.topn)) - - def warmup(self, warmi=500): - for i in range(warmi): - self.blob = self.blob + 0.00001 - - def functional_check_topn_fastest(self): - hipb_topn = [] - for solidx in self.hipb_gtimedf.index[:self.topn]: - if self.check_gemm_ref(libtype='hipblaslt', solidx=solidx): - hipb_topn.append(solidx) - self.hipb_top_sols = hipb_topn - - def find_fastest_solution(self): - self.find_hipblas_sols() - self.warmup() - self.hipb_time_all_sols(fast_mode=1) - self.functional_check_topn_fastest() - self.warmup() - self.hipb_time_all_sols(fast_mode=0, top_sols=1) - if len(self.hipb_gtimedf) > 0: - best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] - self.best_solidx = self.hipb_gtimedf.index[0] - self.best_soltime = best_hipb_time - else: - print('>>> No hipblas solutions found!', flush=True) - self.best_solidx = 0 - self.best_soltime = 0 - print('>>> Fastest Solution is', - self.best_solidx, - self.best_soltime, - flush=True) - - -class Fp8GemmTuner: - - def __init__(self, indtype, outdtype, tuned_file=None): - self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K']) - self.indtype = indtype - self.outdtype = outdtype - self.tuned_file = tuned_file - if Path(tuned_file).is_file(): - self.gdf = pd.read_csv(tuned_file) - else: - self.gdf = None - - def add_gemm(self, m, n, k): - if (self.gdf is None - or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & - (self.gdf['K'] == k)].empty)): - entry = {'M': [m], 'N': [n], 'K': [k]} - df = pd.DataFrame(entry) - self.gemm_problems = pd.concat([self.gemm_problems, df], - ignore_index=True) - else: - print( - f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping" - ) - - def find_best_sols(self): - df = self.gemm_problems - soldf = pd.DataFrame() - for i in range(len(df)): - ds = df.iloc[i] - gemmobj = Fp8Gemm(ds['M'], - ds['N'], - ds['K'], - indtype=self.indtype, - outdtype=self.outdtype) - gemmobj.find_fastest_solution() - soldf.loc[i, 'solidx'] = gemmobj.best_solidx - soldf.loc[i, 'soltimems'] = gemmobj.best_soltime - soldf['indtype'] = self.indtype - soldf['outdtype'] = self.outdtype - finaldf = pd.concat([self.gemm_problems, soldf], axis=1) - finaldf = pd.concat([finaldf, self.gdf]) - finaldf.to_csv(self.tuned_file, index=False) - print(finaldf) - - -def generate_mk_sets(model_dir, tp=1): - with open(f'{model_dir}/config.json') as f: - data = json.load(f) - hidden_size = data['hidden_size'] - intermediate_size = data['intermediate_size'] - total_num_heads = data['num_attention_heads'] - total_num_kv_heads = data['num_key_value_heads'] - head_dim = hidden_size // total_num_heads - return [((total_num_heads + (2 * total_num_kv_heads)) * head_dim // tp, - hidden_size), (hidden_size, hidden_size // tp), - (intermediate_size * 2 // tp, hidden_size), - (hidden_size, intermediate_size // tp)], hidden_size - - -def get_dtype(dtype_str): - dtype = torch.float8_e4m3fnuz - if dtype_str == 'f32': - dtype = torch.float32 - elif dtype_str == 'bf16': - dtype = torch.bfloat16 - elif dtype_str == 'f16': - dtype = torch.float16 - elif dtype_str == 'f8': - dtype = torch.float8_e4m3fnuz - else: - print('>>> Warning! Invalid dtype', dtype_str, - 'using default dtype f8') - return dtype - - -def list_of_ints(arg): - return list(map(int, arg.split(','))) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--model_dir", - type=str, - default=os.getenv('GTUNE_MODEL', ""), - help="Enter the location of your model directory") - parser.add_argument("--tuned_file", - type=str, - default=os.getenv('GTUNE_TUNED', "tuned.csv"), - help="output file for tuned gemm solutions") - parser.add_argument( - "--input_file", - type=str, - default=os.getenv('GTUNE_INPUT', None), - help="list of gemms to tune for, mutually exclusive with model_dir") - parser.add_argument("--tp", - type=int, - default=os.getenv('GTUNE_TP', 1), - help="Tensor parallelism to be used.") - parser.add_argument("--indtype", - type=str, - default='f8', - help="dtype f32 f16 bf16 fp8") - parser.add_argument("--outdtype", - type=str, - default='f16', - help="dtype f32 f16 bf16 fp8") - parser.add_argument("--batch_size", - type=int, - default=os.getenv('GTUNE_BATCH_SIZE', 1), - help="Batch size to tune for") - parser.add_argument("--nsets", - type=list_of_ints, - default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], - help="N sizes to tune for: 1,128,2048") - args = parser.parse_args() - - indtype = get_dtype(args.indtype) - outdtype = get_dtype(args.outdtype) - - gtuner = Fp8GemmTuner(indtype, outdtype, args.tuned_file) - nsets = [i * args.batch_size for i in args.nsets] - if args.input_file: - print(f">>> Loading {args.input_file}") - if not Path(args.input_file).is_file(): - print(f">>> ERROR: {args.input_file} does not exist. Exiting") - exit(1) - shapes = pd.read_csv(args.input_file) - for i in range(len(shapes)): - ds = shapes.iloc[i] - gtuner.add_gemm(ds['M'], ds['N'], ds['K']) - else: - if not args.model_dir: - print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") - #LL2 13B sizes - mksets = [(15360, 5120), (5120, 5120), (27648, 5120), - (5120, 13824)] - gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm - else: - mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) - gtuner.add_gemm( - m=32000 // args.tp, n=1 * args.batch_size, k=hidden_size - ) #TODO: Handle cases where vocab_size is not divisible by tp - - for n in sorted(nsets): - for m, k in mksets: - gtuner.add_gemm(m, n, k) - - gtuner.find_best_sols() diff --git a/gradlib/gradlib/gemm_runner.py b/gradlib/gradlib/gemm_runner.py index 7ad190b95a078..eb49d446bef3f 100644 --- a/gradlib/gradlib/gemm_runner.py +++ b/gradlib/gradlib/gemm_runner.py @@ -66,6 +66,8 @@ def get_dtype(dtype_csv): dtype = torch.bfloat16 elif dtype_csv == 'torch.float32': dtype = torch.float32 + elif dtype_csv == 'torch.float8_e4m3fnuz': + dtype = torch.float8_e4m3fnuz return dtype diff --git a/gradlib/gradlib/gemm_tuner.py b/gradlib/gradlib/gemm_tuner.py index b65a4492e185e..6e62c9ac93abd 100644 --- a/gradlib/gradlib/gemm_tuner.py +++ b/gradlib/gradlib/gemm_tuner.py @@ -36,6 +36,8 @@ def get_dtype(dtype_str): dtype = torch.bfloat16 elif dtype_str == 'f16': dtype = torch.float16 + elif dtype_str == 'fp8': + dtype = torch.float8_e4m3fnuz else: print('>>> Warning! Invalid dtype', dtype_str, 'using default dtype f16') @@ -70,10 +72,15 @@ def load_input_gemms(input_file): type=int, default=os.getenv('GTUNE_TP', 1), help="Tensor parallelism to be used.") - parser.add_argument("--dtype", + parser.add_argument("--indtype", type=str, default='f16', - help="dtype f32 f16 bf16") + choices=["f32", "f16", "bf16", "fp8"], + help="dtype f32 f16 bf16 fp8") + parser.add_argument("--outdtype", + type=str, + choices=["f32", "f16", "bf16", "fp8"], + help="dtype f32 f16 bf16 fp8") parser.add_argument("--rocblas-decode", action="store_true", default=False, @@ -88,9 +95,12 @@ def load_input_gemms(input_file): help="N sizes to tune for: 1,128,2048") args = parser.parse_args() - dtype = get_dtype(args.dtype) + if args.outdtype is None: + args.outdtype = args.indtype + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) - gtuner = GemmTuner(dtype, args.tuned_file, args.rocblas_decode) + gtuner = GemmTuner(indtype, outdtype, args.tuned_file, args.rocblas_decode) nsets = [i * args.batch_size for i in args.nsets] if args.input_file: print(f">>> Loading {args.input_file}") diff --git a/requirements-common.txt b/requirements-common.txt index 3ea22276f63f4..644b188b74185 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -2,7 +2,7 @@ cmake >= 3.21 ninja # For faster builds. psutil sentencepiece # Required for LLaMA tokenizer. -numpy +numpy < 2.0.0 requests py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 28edeb69eb37b..287a8a46bb394 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -15,9 +15,9 @@ try: if is_hip(): - from amdsmi import (AmdSmiException, - amdsmi_get_processor_handle_from_bdf, amdsmi_init, - amdsmi_shut_down, amdsmi_topo_get_link_type) + from amdsmi import (AmdSmiException, amdsmi_get_processor_handles, + amdsmi_init, amdsmi_shut_down, + amdsmi_topo_get_link_type) else: import pynvml @@ -62,25 +62,21 @@ def _is_full_nvlink(device_ids: List[int], world_size) -> bool: so it works on real physical device ids. """ if is_hip(): - # get devices' BDF in order to get XGMI link info from amdsmi - bdf = custom_ar.get_device_bdf(torch.cuda.current_device()) - all_bdf = [0] * world_size - dist.all_gather_object(all_bdf, bdf) - hsmi = [None] * world_size - try: - for i in range(world_size): - bdf_str = str(bytes(all_bdf[i]).decode("utf-8")) - hsmi[i] = amdsmi_get_processor_handle_from_bdf(bdf_str) - for i in range(world_size): - if i != 0: - link_type = amdsmi_topo_get_link_type(hsmi[0], hsmi[i]) - # type is 2 for XGMI - if link_type['hops'] != 1 or link_type['type'] != 2: + # On ROCm, we instead query if GPUs are connected by 1-hop XGMI + handles = [amdsmi_get_processor_handles()[i] for i in device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type( + handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD link detection failed.", + exc_info=error) return False - except AmdSmiException as e: - logger.warning(e) - return False - return True else: handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] for i, handle in enumerate(handles):