Skip to content

RAGFlow 工作流执行机制详解

目录

  1. 执行概览
  2. DSL到执行的转换
  3. 依赖解析与调度
  4. 组件执行流程
  5. 特殊组件处理
  6. 数据流转机制
  7. 错误处理与恢复
  8. 实时状态同步
  9. 性能优化机制

1. 执行概览

1.1 整体执行架构

mermaid
flowchart TB
    subgraph "前端触发层"
        A1[用户点击运行] --> A2[收集工作流DSL]
        A2 --> A3[发送执行请求]
    end
    
    subgraph "API网关层"
        B1[Flask路由接收] --> B2[参数验证]
        B2 --> B3[权限检查]
        B3 --> B4[创建执行上下文]
    end
    
    subgraph "执行引擎层"
        C1[Canvas初始化] --> C2[DSL解析]
        C2 --> C3[组件实例化]
        C3 --> C4[依赖图构建]
        C4 --> C5[执行调度器]
    end
    
    subgraph "组件执行层"
        D1[组件准备] --> D2[输入数据获取]
        D2 --> D3[参数验证]
        D3 --> D4[核心逻辑执行]
        D4 --> D5[输出数据处理]
    end
    
    subgraph "结果反馈层"
        E1[状态更新] --> E2[SSE推送]
        E2 --> E3[前端状态同步]
        E3 --> E4[UI实时更新]
    end
    
    A3 --> B1
    B4 --> C1
    C5 --> D1
    D5 --> E1
    E4 --> A1

1.2 核心执行时序

mermaid
sequenceDiagram
    participant User as 用户
    participant Frontend as 前端
    participant API as Flask API
    participant Canvas as Canvas引擎
    participant Components as 组件集合
    participant SSE as SSE推送
    
    User->>Frontend: 点击运行工作流
    Frontend->>API: POST /v1/canvas/completion
    
    Note over API,Canvas: 初始化阶段
    API->>Canvas: 创建Canvas实例
    Canvas->>Canvas: 解析DSL JSON
    Canvas->>Components: 实例化所有组件
    Canvas->>Canvas: 构建依赖图
    
    Note over Canvas,Components: 执行阶段
    Canvas->>Components: 执行Begin组件
    Components->>SSE: 推送开始状态
    SSE->>Frontend: 实时状态更新
    
    loop 组件依次执行
        Canvas->>Components: 检查依赖并执行组件
        Components->>Components: 处理输入数据
        Components->>Components: 执行核心逻辑
        Components->>SSE: 推送执行状态
        SSE->>Frontend: 实时进度更新
        Components->>Canvas: 返回执行结果
    end
    
    Note over Canvas,SSE: 完成阶段
    Canvas->>SSE: 推送最终结果
    SSE->>Frontend: 显示完成状态
    Frontend->>User: 展示工作流结果

2. DSL到执行的转换

2.1 DSL结构解析

前端工作流编辑器生成的DSL JSON结构如下:

json
{
  "components": {
    "begin_001": {
      "obj": {
        "component_name": "Begin",
        "params": {
          "prologue": "您好,我是您的AI助手。"
        }
      },
      "downstream": ["retrieval_002"],
      "upstream": []
    },
    "retrieval_002": {
      "obj": {
        "component_name": "Retrieval", 
        "params": {
          "kb_ids": ["kb_123"],
          "similarity_threshold": 0.8,
          "top_n": 5
        }
      },
      "downstream": ["generate_003"],
      "upstream": ["begin_001"]
    },
    "generate_003": {
      "obj": {
        "component_name": "Generate",
        "params": {
          "llm_id": "openai_gpt4",
          "prompt": "基于以下知识回答问题:{retrieval:0}",
          "temperature": 0.7
        }
      },
      "downstream": ["answer_004"],
      "upstream": ["retrieval_002"]
    },
    "answer_004": {
      "obj": {
        "component_name": "Answer",
        "params": {}
      },
      "downstream": [],
      "upstream": ["generate_003"]
    }
  },
  "history": [],
  "path": [],
  "answer": [],
  "reference": [],
  "messages": []
}

2.2 Canvas初始化过程

python
# agent/canvas.py
class Canvas:
    def __init__(self, dsl: str, tenant_id=None):
        """
        Canvas初始化流程
        """
        # 1. 解析DSL JSON
        self.dsl = json.loads(dsl)
        self.tenant_id = tenant_id
        
        # 2. 初始化核心数据结构
        self.components = {}           # 组件实例字典
        self.path = []                # 执行路径
        self.history = []             # 对话历史
        self.answer = []              # Answer组件队列
        self.reference = []           # 引用信息
        
        # 3. 加载并实例化组件
        self.load()
        
        # 4. 验证必要组件
        self._validate_required_components()
    
    def load(self):
        """
        组件加载和实例化
        """
        for component_id, cpn in self.dsl["components"].items():
            try:
                # 获取组件类型
                component_name = cpn["obj"]["component_name"]
                
                # 动态导入组件类
                component_class = component_class(component_name)
                param_class = component_class(component_name + "Param")
                
                # 创建参数实例
                param = param_class()
                param.update(cpn["obj"]["params"])
                param.check()  # 参数验证
                
                # 创建组件实例
                component = component_class(
                    canvas=self,
                    component_id=component_id,
                    param=param
                )
                
                # 设置上下游关系
                component._upstream = cpn.get("upstream", [])
                component._downstream = cpn.get("downstream", [])
                
                # 保存组件实例
                cpn["obj"] = component
                self.components[component_id] = cpn
                
            except Exception as e:
                raise ValueError(f"组件 {component_id} 初始化失败: {str(e)}")
    
    def _validate_required_components(self):
        """
        验证必要组件存在性
        """
        # 检查Begin组件
        begin_components = [
            cid for cid, cpn in self.components.items()
            if cpn["obj"].component_name == "Begin"
        ]
        if not begin_components:
            raise ValueError("工作流必须包含至少一个Begin组件")
        
        # 检查Answer组件
        answer_components = [
            cid for cid, cpn in self.components.items()
            if cpn["obj"].component_name == "Answer"
        ]
        if not answer_components:
            raise ValueError("工作流必须包含至少一个Answer组件")

3. 依赖解析与调度

3.1 依赖图构建

python
def build_dependency_graph(self):
    """
    构建组件依赖图
    返回每个组件的前置依赖关系
    """
    dependency_graph = {}
    
    for component_id, cpn in self.components.items():
        upstream = cpn.get("upstream", [])
        dependency_graph[component_id] = {
            "dependencies": upstream,
            "dependents": [],
            "status": "pending"  # pending, ready, running, completed, failed
        }
    
    # 构建反向依赖关系
    for component_id, info in dependency_graph.items():
        for dep_id in info["dependencies"]:
            if dep_id in dependency_graph:
                dependency_graph[dep_id]["dependents"].append(component_id)
    
    return dependency_graph

def get_ready_components(self, dependency_graph):
    """
    获取可以执行的组件列表
    """
    ready_components = []
    
    for component_id, info in dependency_graph.items():
        if info["status"] == "pending":
            # 检查所有依赖是否已完成
            dependencies_completed = all(
                dependency_graph[dep_id]["status"] == "completed"
                for dep_id in info["dependencies"]
                if dep_id in dependency_graph
            )
            
            if dependencies_completed:
                ready_components.append(component_id)
                info["status"] = "ready"
    
    return ready_components

3.2 执行调度策略

python
def run(self, stream=True, **kwargs):
    """
    工作流执行主入口
    """
    # 1. 处理Answer组件继续逻辑
    if self.answer:
        yield from self._handle_answer_continuation(**kwargs)
        return
    
    # 2. 初始化执行路径
    if not self.path:
        self.path = [[]]
    
    # 3. 构建依赖图
    dependency_graph = self.build_dependency_graph()
    
    # 4. 执行调度循环
    max_iterations = len(self.components) * 2  # 防止无限循环
    iteration = 0
    
    while iteration < max_iterations:
        # 获取可执行组件
        ready_components = self.get_ready_components(dependency_graph)
        
        if not ready_components:
            # 检查是否所有组件都已完成
            all_completed = all(
                info["status"] in ["completed", "failed"]
                for info in dependency_graph.values()
            )
            if all_completed:
                break
            else:
                # 可能存在循环依赖或其他问题
                raise RuntimeError("工作流执行陷入死锁状态")
        
        # 并行执行准备好的组件
        for component_id in ready_components:
            try:
                dependency_graph[component_id]["status"] = "running"
                
                # 推送执行状态
                yield {
                    "running_status": True,
                    "content": f"正在执行组件: {component_id}",
                    "component_id": component_id
                }
                
                # 执行组件
                result = self._execute_component(component_id, **kwargs)
                
                # 处理执行结果
                if result:
                    dependency_graph[component_id]["status"] = "completed"
                    
                    # 特殊组件处理
                    if self._is_answer_component(component_id):
                        # Answer组件需要等待用户输入
                        yield from self._handle_answer_component(component_id, result)
                        return
                    elif self._is_terminal_component(component_id):
                        # 终端组件,输出最终结果
                        yield {
                            "running_status": False,
                            "content": result["content"],
                            "reference": result.get("reference", []),
                            "component_id": component_id
                        }
                        return
                    else:
                        # 中间组件,继续流式输出
                        yield {
                            "running_status": True,
                            "content": result["content"],
                            "component_id": component_id
                        }
                
            except Exception as e:
                dependency_graph[component_id]["status"] = "failed"
                yield {
                    "running_status": False,
                    "content": f"组件执行失败: {str(e)}",
                    "error": str(e),
                    "component_id": component_id
                }
                return
        
        iteration += 1
    
    # 执行完成
    yield {
        "running_status": False,
        "content": "工作流执行完成",
        "path": self.path
    }

4. 组件执行流程

4.1 单个组件执行

python
def _execute_component(self, component_id, **kwargs):
    """
    执行单个组件
    """
    component_info = self.components[component_id]
    component = component_info["obj"]
    
    # 1. 记录执行路径
    self.path[-1].append(component_id)
    
    # 2. 获取组件输入
    try:
        component_input = component.get_input()
    except Exception as e:
        raise RuntimeError(f"获取组件输入失败: {str(e)}")
    
    # 3. 执行组件核心逻辑
    try:
        result_df = component._run(self.history, **kwargs)
        
        if result_df.empty:
            raise ValueError("组件执行结果为空")
            
    except Exception as e:
        raise RuntimeError(f"组件执行失败: {str(e)}")
    
    # 4. 处理执行结果
    result = {
        "content": result_df.iloc[0]["content"],
        "component_id": component_id,
        "reference": result_df.iloc[0].get("reference", []),
        "metadata": {
            "execution_time": time.time(),
            "input_size": len(str(component_input)),
            "output_size": len(str(result_df))
        }
    }
    
    # 5. 更新组件状态
    component_info["result"] = result_df
    component_info["execution_time"] = time.time()
    
    return result

4.2 组件输入数据获取

python
# agent/component/base.py
class ComponentBase:
    def get_input(self):
        """
        获取组件输入数据
        优先级:显式参数 > 上游组件输出 > 默认值
        """
        # 1. 检查是否有显式的查询参数
        if hasattr(self._param, 'query') and self._param.query:
            return pd.DataFrame([{
                "content": self._param.query,
                "component_id": "user_input",
                "reference": []
            }])
        
        # 2. 获取上游组件输出
        if not self._upstream:
            # 没有上游组件,返回空DataFrame
            return pd.DataFrame()
        
        inputs = []
        for upstream_id in self._upstream:
            upstream_component = self._canvas.components.get(upstream_id)
            if upstream_component and "result" in upstream_component:
                upstream_result = upstream_component["result"]
                inputs.append(upstream_result)
        
        # 3. 合并所有上游输入
        if inputs:
            combined_input = pd.concat(inputs, ignore_index=True)
            return combined_input
        else:
            return pd.DataFrame()
    
    def output(self, start_at=0):
        """
        获取组件输出数据
        """
        if "result" not in self._canvas.components[self._id]:
            raise ValueError(f"组件 {self._id} 尚未执行")
        
        result_df = self._canvas.components[self._id]["result"]
        return result_df.iloc[start_at:] if start_at > 0 else result_df

5. 特殊组件处理

5.1 Switch组件条件路由

python
# agent/component/switch.py
class Switch(ComponentBase):
    def _run(self, history, **kwargs):
        """
        Switch组件执行逻辑
        根据条件判断选择不同的输出路径
        """
        input_df = self.get_input()
        
        if input_df.empty:
            raise ValueError("Switch组件需要输入数据")
        
        input_content = input_df.iloc[0]["content"]
        
        # 评估所有条件
        for condition in self._param.conditions:
            if self._evaluate_condition(input_content, condition):
                # 找到匹配的条件,设置输出路径
                self._set_output_path(condition.output_component)
                
                return pd.DataFrame([{
                    "content": input_content,
                    "component_id": self._id,
                    "matched_condition": condition.name,
                    "output_path": condition.output_component,
                    "reference": input_df.iloc[0].get("reference", [])
                }])
        
        # 没有匹配的条件,使用默认路径
        default_output = self._param.default_output or self._downstream[0]
        self._set_output_path(default_output)
        
        return pd.DataFrame([{
            "content": input_content,
            "component_id": self._id,
            "matched_condition": "default",
            "output_path": default_output,
            "reference": input_df.iloc[0].get("reference", [])
        }])
    
    def _evaluate_condition(self, content, condition):
        """
        评估单个条件
        """
        if condition.type == "contains":
            return condition.value in content
        elif condition.type == "regex":
            import re
            return bool(re.search(condition.value, content))
        elif condition.type == "length":
            return len(content) >= int(condition.value)
        elif condition.type == "sentiment":
            # 情感分析逻辑
            return self._analyze_sentiment(content) == condition.value
        else:
            return False
    
    def _set_output_path(self, target_component):
        """
        动态设置输出路径
        """
        # 更新Canvas中的连接关系
        self._canvas._update_dynamic_connections(self._id, target_component)

5.2 Categorize组件多路分发

python
# agent/component/categorize.py  
class Categorize(ComponentBase):
    def _run(self, history, **kwargs):
        """
        Categorize组件执行逻辑
        将输入内容分类到不同的处理路径
        """
        input_df = self.get_input()
        
        if input_df.empty:
            raise ValueError("Categorize组件需要输入数据")
        
        input_content = input_df.iloc[0]["content"]
        
        # 使用LLM进行分类
        classification_result = self._classify_content(input_content)
        
        # 根据分类结果设置多个输出路径
        for category, confidence in classification_result.items():
            if confidence > self._param.confidence_threshold:
                target_component = self._param.category_mapping.get(category)
                if target_component:
                    # 创建分支执行路径
                    self._canvas._create_branch_path(
                        self._id, 
                        target_component,
                        category
                    )
        
        return pd.DataFrame([{
            "content": input_content,
            "component_id": self._id,
            "classifications": classification_result,
            "reference": input_df.iloc[0].get("reference", [])
        }])
    
    def _classify_content(self, content):
        """
        使用LLM对内容进行分类
        """
        classification_prompt = f"""
请将以下内容进行分类,从这些类别中选择:{list(self._param.categories.keys())}

内容:{content}

请返回JSON格式,包含每个类别的置信度(0-1之间):
{{"类别1": 0.8, "类别2": 0.1, ...}}
"""
        
        # 调用LLM进行分类
        llm_response = self._call_llm(classification_prompt)
        
        try:
            import json
            return json.loads(llm_response)
        except:
            # 分类失败,返回默认分类
            return {"default": 1.0}

5.3 Iteration组件循环处理

python
# agent/component/iteration.py
class Iteration(ComponentBase):
    def _run(self, history, **kwargs):
        """
        Iteration组件执行逻辑
        对输入数据进行循环处理
        """
        input_df = self.get_input()
        
        if input_df.empty:
            raise ValueError("Iteration组件需要输入数据")
        
        input_content = input_df.iloc[0]["content"]
        
        # 根据分隔符拆分数据
        items = self._split_input(input_content)
        
        results = []
        
        # 循环处理每个数据项
        for i, item in enumerate(items):
            if i >= self._param.max_iterations:
                break
            
            # 创建子执行上下文
            iteration_result = self._execute_iteration(item, i)
            results.append(iteration_result)
            
            # 检查停止条件
            if self._should_stop_iteration(iteration_result):
                break
        
        # 合并所有迭代结果
        final_result = self._merge_iteration_results(results)
        
        return pd.DataFrame([{
            "content": final_result,
            "component_id": self._id,
            "iteration_count": len(results),
            "items_processed": len(items),
            "reference": []
        }])
    
    def _split_input(self, content):
        """
        根据分隔符拆分输入内容
        """
        delimiter = self._param.delimiter or "\n"
        return [item.strip() for item in content.split(delimiter) if item.strip()]
    
    def _execute_iteration(self, item, index):
        """
        执行单次迭代
        """
        # 创建临时的子工作流
        sub_canvas = self._create_sub_canvas(item)
        
        # 执行子工作流
        sub_result = sub_canvas.run(stream=False)
        
        return {
            "index": index,
            "input": item,
            "output": sub_result,
            "timestamp": time.time()
        }
    
    def _should_stop_iteration(self, result):
        """
        检查是否应该停止迭代
        """
        if self._param.stop_condition:
            return self._evaluate_stop_condition(result)
        return False

6. 数据流转机制

6.1 组件间数据传递

python
def _handle_data_flow(self, source_component_id, target_component_id, data):
    """
    处理组件间的数据流转
    """
    # 1. 数据格式标准化
    standardized_data = self._standardize_data_format(data)
    
    # 2. 数据类型转换
    converted_data = self._convert_data_types(
        standardized_data, 
        target_component_id
    )
    
    # 3. 数据验证
    self._validate_data_schema(converted_data, target_component_id)
    
    # 4. 数据传递
    target_component = self.components[target_component_id]["obj"]
    target_component._receive_data(converted_data)
    
    return converted_data

def _standardize_data_format(self, data):
    """
    标准化数据格式为DataFrame
    """
    if isinstance(data, pd.DataFrame):
        return data
    elif isinstance(data, dict):
        return pd.DataFrame([data])
    elif isinstance(data, list):
        return pd.DataFrame(data)
    elif isinstance(data, str):
        return pd.DataFrame([{"content": data}])
    else:
        return pd.DataFrame([{"content": str(data)}])

def _convert_data_types(self, data, target_component_id):
    """
    根据目标组件需求转换数据类型
    """
    target_component = self.components[target_component_id]["obj"]
    
    # 获取组件期望的数据类型
    expected_schema = target_component.get_expected_input_schema()
    
    # 执行类型转换
    for column, expected_type in expected_schema.items():
        if column in data.columns:
            data[column] = data[column].astype(expected_type)
    
    return data

6.2 参数引用解析

python
def resolve_parameter_references(self, param_value):
    """
    解析参数中的组件引用
    支持格式:{component_id:index} 或 {begin@param_name}
    """
    if not isinstance(param_value, str):
        return param_value
    
    import re
    
    # 匹配组件引用模式
    component_ref_pattern = r'\{([^}]+)\}'
    matches = re.findall(component_ref_pattern, param_value)
    
    resolved_value = param_value
    
    for match in matches:
        if '@' in match:
            # Begin组件参数引用
            component_id, param_name = match.split('@')
            referenced_value = self._get_begin_parameter(component_id, param_name)
        elif ':' in match:
            # 组件输出引用
            component_id, index = match.split(':')
            referenced_value = self._get_component_output(component_id, int(index))
        else:
            # 简单组件引用(取第一个输出)
            referenced_value = self._get_component_output(match, 0)
        
        # 替换引用
        resolved_value = resolved_value.replace(f'{{{match}}}', str(referenced_value))
    
    return resolved_value

def _get_component_output(self, component_id, index):
    """
    获取指定组件的输出
    """
    if component_id not in self.components:
        raise ValueError(f"组件 {component_id} 不存在")
    
    component_info = self.components[component_id]
    
    if "result" not in component_info:
        raise ValueError(f"组件 {component_id} 尚未执行")
    
    result_df = component_info["result"]
    
    if index >= len(result_df):
        raise ValueError(f"组件 {component_id} 输出索引 {index} 超出范围")
    
    return result_df.iloc[index]["content"]

7. 错误处理与恢复

7.1 错误分类与处理策略

python
class WorkflowErrorHandler:
    def __init__(self, canvas):
        self.canvas = canvas
        self.error_history = []
    
    def handle_error(self, error, component_id, context):
        """
        统一错误处理入口
        """
        error_info = {
            "error": error,
            "component_id": component_id,
            "timestamp": time.time(),
            "context": context,
            "error_type": self._classify_error(error)
        }
        
        self.error_history.append(error_info)
        
        # 根据错误类型选择处理策略
        if error_info["error_type"] == "parameter_error":
            return self._handle_parameter_error(error_info)
        elif error_info["error_type"] == "network_error":
            return self._handle_network_error(error_info)
        elif error_info["error_type"] == "resource_error":
            return self._handle_resource_error(error_info)
        elif error_info["error_type"] == "logic_error":
            return self._handle_logic_error(error_info)
        else:
            return self._handle_unknown_error(error_info)
    
    def _classify_error(self, error):
        """
        错误分类
        """
        error_msg = str(error).lower()
        
        if any(keyword in error_msg for keyword in ["parameter", "validation", "required"]):
            return "parameter_error"
        elif any(keyword in error_msg for keyword in ["network", "connection", "timeout"]):
            return "network_error"
        elif any(keyword in error_msg for keyword in ["memory", "disk", "resource"]):
            return "resource_error"
        elif any(keyword in error_msg for keyword in ["logic", "runtime", "execution"]):
            return "logic_error"
        else:
            return "unknown_error"
    
    def _handle_parameter_error(self, error_info):
        """
        处理参数错误
        """
        return {
            "action": "stop",
            "message": f"参数配置错误: {error_info['error']}",
            "suggestion": "请检查组件参数配置",
            "recoverable": False
        }
    
    def _handle_network_error(self, error_info):
        """
        处理网络错误
        """
        retry_count = self._get_retry_count(error_info["component_id"])
        max_retries = 3
        
        if retry_count < max_retries:
            return {
                "action": "retry",
                "message": f"网络错误,正在重试 ({retry_count + 1}/{max_retries})",
                "delay": 2 ** retry_count,  # 指数退避
                "recoverable": True
            }
        else:
            return {
                "action": "skip",
                "message": "网络错误重试次数已达上限,跳过该组件",
                "recoverable": False
            }
    
    def _handle_logic_error(self, error_info):
        """
        处理逻辑错误
        """
        # 尝试降级处理
        component = self.canvas.components[error_info["component_id"]]["obj"]
        
        if hasattr(component, 'fallback_execution'):
            return {
                "action": "fallback",
                "message": "执行主逻辑失败,尝试备用逻辑",
                "recoverable": True
            }
        else:
            return {
                "action": "stop",
                "message": f"组件执行逻辑错误: {error_info['error']}",
                "recoverable": False
            }

7.2 断点续传机制

python
def save_execution_checkpoint(self):
    """
    保存执行检查点
    """
    checkpoint = {
        "timestamp": time.time(),
        "path": self.path,
        "completed_components": [],
        "component_results": {},
        "execution_state": {}
    }
    
    # 收集已完成组件的结果
    for component_id, component_info in self.components.items():
        if "result" in component_info:
            checkpoint["completed_components"].append(component_id)
            checkpoint["component_results"][component_id] = {
                "result": component_info["result"].to_dict(),
                "execution_time": component_info.get("execution_time")
            }
    
    # 保存到持久化存储
    self._save_checkpoint_to_storage(checkpoint)
    
    return checkpoint

def restore_from_checkpoint(self, checkpoint):
    """
    从检查点恢复执行
    """
    # 恢复执行路径
    self.path = checkpoint["path"]
    
    # 恢复组件结果
    for component_id, result_info in checkpoint["component_results"].items():
        if component_id in self.components:
            self.components[component_id]["result"] = pd.DataFrame(result_info["result"])
            self.components[component_id]["execution_time"] = result_info["execution_time"]
    
    # 更新组件状态
    self._update_component_states_from_checkpoint(checkpoint)

8. 实时状态同步

8.1 SSE事件推送

python
# api/apps/canvas_app.py
@canvas_app.route("/v1/canvas/completion", methods=["POST"])
@login_required  
def completion():
    """
    工作流执行接口
    支持Server-Sent Events实时推送
    """
    req = request.json
    canvas_id = req.get("canvas_id")
    message = req.get("message", "")
    stream = req.get("stream", True)
    
    try:
        # 获取工作流定义
        canvas_dsl = get_canvas_dsl(canvas_id)
        
        # 创建Canvas实例
        canvas = Canvas(canvas_dsl, current_user.id)
        
        if stream:
            # 流式执行
            def generate():
                try:
                    for chunk in canvas.run(stream=True, message=message, **req):
                        # 构造SSE格式
                        sse_data = {
                            "code": 0,
                            "data": {
                                "answer": chunk.get("content", ""),
                                "running_status": chunk.get("running_status", False),
                                "component_id": chunk.get("component_id"),
                                "reference": chunk.get("reference", []),
                                "path": chunk.get("path", []),
                                "timestamp": time.time()
                            }
                        }
                        
                        yield f"data: {json.dumps(sse_data)}\n\n"
                        
                except Exception as e:
                    # 错误处理
                    error_data = {
                        "code": 500,
                        "message": str(e),
                        "data": {"error": str(e)}
                    }
                    yield f"data: {json.dumps(error_data)}\n\n"
            
            return Response(
                generate(),
                mimetype='text/event-stream',
                headers={
                    'Cache-Control': 'no-cache',
                    'Connection': 'keep-alive',
                    'Access-Control-Allow-Origin': '*'
                }
            )
        else:
            # 批量执行
            result = canvas.run(stream=False, message=message, **req)
            return {"code": 0, "data": result}
            
    except Exception as e:
        return {"code": 500, "message": str(e)}

8.2 前端状态同步

typescript
// web/src/hooks/use-send-message-with-sse.ts
export const useSendMessageWithSse = () => {
  const [executionState, setExecutionState] = useState({
    status: 'idle',
    currentStep: null,
    progress: 0,
    result: null,
    error: null
  });
  
  const sendMessage = async (params: ExecutionParams) => {
    try {
      setExecutionState(prev => ({ ...prev, status: 'running' }));
      
      // 创建SSE连接
      const response = await fetch('/v1/canvas/completion', {
        method: 'POST',
        headers: {
          'Content-Type': 'application/json',
          'Authorization': `Bearer ${getToken()}`
        },
        body: JSON.stringify(params)
      });
      
      if (!response.ok) {
        throw new Error(`HTTP ${response.status}: ${response.statusText}`);
      }
      
      // 处理SSE流
      const reader = response.body
        ?.pipeThrough(new TextDecoderStream())
        ?.pipeThrough(new EventSourceParserStream())
        ?.getReader();
      
      while (true) {
        const { done, value } = await reader?.read();
        
        if (done) break;
        
        if (value?.type === 'event' && value.data) {
          const eventData = JSON.parse(value.data);
          
          if (eventData.code === 0) {
            const { data } = eventData;
            
            // 更新执行状态
            setExecutionState(prev => ({
              ...prev,
              currentStep: data.component_id,
              progress: calculateProgress(data.path),
              result: data.running_status ? prev.result : data.answer,
              error: null
            }));
            
            // 触发UI更新回调
            onStateUpdate?.(data);
            
            // 检查是否执行完成
            if (!data.running_status) {
              setExecutionState(prev => ({
                ...prev,
                status: 'completed',
                result: data.answer,
                references: data.reference
              }));
              break;
            }
          } else {
            // 处理错误
            setExecutionState(prev => ({
              ...prev,
              status: 'error',
              error: eventData.message
            }));
            break;
          }
        }
      }
      
    } catch (error) {
      setExecutionState(prev => ({
        ...prev,
        status: 'error',
        error: error.message
      }));
    }
  };
  
  return { sendMessage, executionState };
};

function calculateProgress(path: string[][]): number {
  if (!path || path.length === 0) return 0;
  
  const currentPath = path[path.length - 1];
  const totalComponents = getTotalComponentCount();
  
  return Math.min((currentPath.length / totalComponents) * 100, 100);
}

9. 性能优化机制

9.1 并行执行优化

python
import asyncio
from concurrent.futures import ThreadPoolExecutor

class ParallelExecutor:
    def __init__(self, canvas, max_workers=4):
        self.canvas = canvas
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
    
    async def execute_parallel_components(self, component_ids):
        """
        并行执行独立的组件
        """
        # 创建异步任务
        tasks = []
        for component_id in component_ids:
            task = asyncio.create_task(
                self._execute_component_async(component_id)
            )
            tasks.append((component_id, task))
        
        # 等待所有任务完成
        results = {}
        for component_id, task in tasks:
            try:
                result = await task
                results[component_id] = result
            except Exception as e:
                results[component_id] = {"error": str(e)}
        
        return results
    
    async def _execute_component_async(self, component_id):
        """
        异步执行单个组件
        """
        loop = asyncio.get_event_loop()
        
        # 在线程池中执行同步的组件逻辑
        result = await loop.run_in_executor(
            self.executor,
            self.canvas._execute_component,
            component_id
        )
        
        return result

def find_parallel_groups(self, dependency_graph):
    """
    找出可以并行执行的组件组
    """
    parallel_groups = []
    processed = set()
    
    # 按依赖层级分组
    for level in range(max_dependency_depth):
        level_components = []
        
        for component_id, info in dependency_graph.items():
            if component_id in processed:
                continue
                
            # 检查是否所有依赖都在前面的层级中
            dependencies_depth = max(
                (get_component_depth(dep_id) for dep_id in info["dependencies"]),
                default=-1
            )
            
            if dependencies_depth < level:
                level_components.append(component_id)
        
        if level_components:
            parallel_groups.append(level_components)
            processed.update(level_components)
    
    return parallel_groups

9.2 缓存机制

python
import hashlib
import pickle
from functools import wraps

class ComponentCache:
    def __init__(self, max_size=1000, ttl=3600):
        self.cache = {}
        self.max_size = max_size
        self.ttl = ttl
        self.access_times = {}
    
    def get_cache_key(self, component_id, params, input_data):
        """
        生成缓存键
        """
        cache_data = {
            "component_id": component_id,
            "params": params,
            "input_hash": hashlib.md5(str(input_data).encode()).hexdigest()
        }
        
        cache_str = json.dumps(cache_data, sort_keys=True)
        return hashlib.sha256(cache_str.encode()).hexdigest()
    
    def get(self, cache_key):
        """
        获取缓存
        """
        if cache_key in self.cache:
            entry = self.cache[cache_key]
            
            # 检查TTL
            if time.time() - entry["timestamp"] < self.ttl:
                self.access_times[cache_key] = time.time()
                return entry["data"]
            else:
                # 过期删除
                del self.cache[cache_key]
                if cache_key in self.access_times:
                    del self.access_times[cache_key]
        
        return None
    
    def set(self, cache_key, data):
        """
        设置缓存
        """
        # 检查缓存大小限制
        if len(self.cache) >= self.max_size:
            self._evict_lru()
        
        self.cache[cache_key] = {
            "data": data,
            "timestamp": time.time()
        }
        self.access_times[cache_key] = time.time()
    
    def _evict_lru(self):
        """
        LRU淘汰策略
        """
        if not self.access_times:
            return
        
        # 找到最久未访问的缓存项
        lru_key = min(self.access_times.items(), key=lambda x: x[1])[0]
        
        if lru_key in self.cache:
            del self.cache[lru_key]
        del self.access_times[lru_key]

# 组件缓存装饰器
def cache_component_result(cache_instance):
    def decorator(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            # 生成缓存键
            cache_key = cache_instance.get_cache_key(
                self._id,
                self._param.to_dict(),
                self.get_input()
            )
            
            # 尝试从缓存获取
            cached_result = cache_instance.get(cache_key)
            if cached_result is not None:
                return cached_result
            
            # 执行组件逻辑
            result = func(self, *args, **kwargs)
            
            # 缓存结果
            cache_instance.set(cache_key, result)
            
            return result
        return wrapper
    return decorator

9.3 资源管理优化

python
class ResourceManager:
    def __init__(self):
        self.connection_pools = {}
        self.memory_monitor = MemoryMonitor()
        self.execution_limiter = asyncio.Semaphore(10)  # 限制并发执行数量
    
    async def execute_with_resource_control(self, component_id, execute_func):
        """
        带资源控制的组件执行
        """
        async with self.execution_limiter:
            # 内存检查
            if self.memory_monitor.get_usage() > 0.8:
                await self._cleanup_memory()
            
            # 执行组件
            try:
                result = await execute_func()
                return result
            finally:
                # 清理临时资源
                self._cleanup_component_resources(component_id)
    
    def get_connection_pool(self, service_type, config):
        """
        获取连接池
        """
        pool_key = f"{service_type}_{hash(str(config))}"
        
        if pool_key not in self.connection_pools:
            if service_type == "http":
                self.connection_pools[pool_key] = aiohttp.ClientSession(
                    connector=aiohttp.TCPConnector(limit=100),
                    timeout=aiohttp.ClientTimeout(total=30)
                )
            elif service_type == "database":
                self.connection_pools[pool_key] = create_db_pool(config)
        
        return self.connection_pools[pool_key]
    
    async def _cleanup_memory(self):
        """
        内存清理
        """
        # 清理过期缓存
        cache_manager.cleanup_expired()
        
        # 触发垃圾回收
        import gc
        gc.collect()
        
        # 等待一段时间让系统稳定
        await asyncio.sleep(0.1)

class MemoryMonitor:
    def get_usage(self):
        """
        获取内存使用率
        """
        import psutil
        return psutil.virtual_memory().percent / 100.0

总结

RAGFlow的工作流执行机制是一个高度优化的分布式执行引擎,具有以下特点:

核心特性

  1. 声明式DSL: 前端可视化编排转换为可执行的JSON DSL
  2. 依赖驱动: 基于组件依赖关系的智能调度
  3. 实时反馈: SSE流式通信提供实时执行状态
  4. 错误恢复: 多层次错误处理和断点续传
  5. 性能优化: 并行执行、缓存机制、资源管理

执行流程总览

用户触发 → DSL解析 → 依赖分析 → 组件调度 → 并行执行 → 状态推送 → 结果聚合

核心优势

  • 可扩展性: 组件化架构支持无限扩展
  • 可靠性: 完善的错误处理和恢复机制
  • 高性能: 智能并行执行和缓存优化
  • 用户友好: 实时状态反馈和直观的错误提示

这个执行引擎为RAGFlow提供了强大而灵活的工作流编排能力,能够支持从简单的文档检索到复杂的多模态AI应用场景。