class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
use_grouped_gemm, isa = self.check_grouped_gemm(layer)
self.isa = isa
if use_grouped_gemm:
self.forward_method = self.forward_grouped_gemm
self.init_moe_grouped_gemm(layer=layer)
else:
self.forward_method = self.forward_torch
self.init_moe_torch(layer=layer)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return self.forward_method(
layer,
x,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
def check_grouped_gemm(
self,
layer: torch.nn.Module,
) -> tuple[bool, str]:
if not hasattr(torch.ops._C, "prepack_moe_weight"):
return False, "none"
dtype = layer.w13_weight.dtype
w13_input_size = layer.w13_weight.size(2)
w13_output_size = layer.w13_weight.size(1)
w2_input_size = layer.w2_weight.size(2)
w2_output_size = layer.w2_weight.size(1)
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
return False, "none"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if (
supports_amx
and dtype == torch.bfloat16
and w13_input_size % 32 == 0
and w2_input_size % 32 == 0
):
return True, "amx"
if supports_amx:
return False, "none"
return True, "vec"
def init_moe_grouped_gemm(
self,
layer: torch.nn.Module,
) -> None:
new_w13 = cpu_prepack_moe_weight(layer.w13_weight, self.isa)
replace_parameter(layer, "w13_weight", new_w13)
new_w2 = cpu_prepack_moe_weight(layer.w2_weight, self.isa)
replace_parameter(layer, "w2_weight", new_w2)
def init_moe_torch(
self,
layer: torch.nn.Module,
) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias")
layer.gate_up_linear = []
layer.down_linear = []
for i in range(num_experts):
layer_w13_weight = layer.w13_weight[i]
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
layer_w2_weight = layer.w2_weight[i]
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
if use_onednn_mm:
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
layer.gate_up_linear.append(
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
handle, x, bias
)
)
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
layer.down_linear.append(
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
handle, x, bias
)
)
else:
layer.gate_up_linear.append(
lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b)
)
layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
)
if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
_CPU_MOE_LAYER_CACHE[id(layer)] = weakref.ref(layer)
def forward_grouped_gemm(
self,
layer: torch.nn.Module,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> torch.Tensor:
output = cpu_fused_moe(
input,
layer.w13_weight,
layer.w2_weight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
topk_weights,
topk_ids,
activation,
self.isa,
)
return output
def forward_torch(
self,
layer: torch.nn.Module,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> torch.Tensor:
output = torch.empty_like(input)
layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch(
layer_id,
output,
input,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
return output