Home >  > VNPY源码(六)BacktesterEngine回测引擎

VNPY源码(六)BacktesterEngine回测引擎

0

提示:
回测时最好使用脚本,使用UI界面回测经常出错了啥提示也没有,让你抓狂。

首先看看回测引擎的代码:

from vnpy.app.cta_strategy.backtesting import BacktestingEngine
from vnpy.app.cta_strategy.strategies.boll_channel_strategy import BollChannelStrategy
from datetime import datetime


engine = BacktestingEngine()
engine.set_parameters(
	vt_symbol = "000001.SZSE",
	interval ="d",
	start = datetime(2018,3,23),
	end = datetime(2018,4,23),
	rate = 0,
	slippage = 0,
	size = 300,
	pricetick = 0.2,
	capital = 1_000_000,
	)

engine.add_strategy(BollChannelStrategy,{})
engine.load_data()
engine.run_backtesting()
df = engine.calculate_result()
engine.calculate_statistics()
engine.show_chart()

一、首先看一下set_parameters函数。
首先要明确,这是BacktestingEngine,是在vnpy/app/cta_strategy/backtesting.py这里面定义的,在vnpy/app/cta_backtester/engine.py里面,有一个BacktesterEngine,很容易搞混。

这个函数没啥特别的地方,就是赋值。

二、add_strategy
传入参数,实例化一个策略,相当于执行了DoubleMaStrategy(strategy_name,vt_symbol, setting)

def add_strategy(self, strategy_class: type, setting: dict):
    """"""
    self.strategy_class = strategy_class
    self.strategy = strategy_class(
        self, strategy_class.__name__, self.vt_symbol, setting
    )

三、loda_data
最终的结果是通过数据库的ORM取出DbBarData,遍历DbBarData,通过to_tick或to_bar方法生成tick或Bar,最终得到self.history_data(里面保存tick或bar)。

def load_data(self):
    """"""
    self.output("开始加载历史数据")

    if not self.end:
        self.end = datetime.now()

    if self.start >= self.end:
        self.output("起始日期必须小于结束日期")    
        return        

    self.history_data.clear()       # Clear previously loaded history data

    # Load 30 days of data each time and allow for progress update
    progress_delta = timedelta(days=30)
    total_delta = self.end - self.start

    start = self.start
    end = self.start + progress_delta
    progress = 0

    while start < self.end:
        end = min(end, self.end)  # Make sure end time stays within set range
        
        if self.mode == BacktestingMode.BAR:
            data = load_bar_data(
                self.symbol,
                self.exchange,
                self.interval,
                start,
                end
            )
        else:
            data = load_tick_data(
                self.symbol,
                self.exchange,
                start,
                end
            )

        self.history_data.extend(data)
        
        progress += progress_delta / total_delta
        progress = min(progress, 1)
        progress_bar = "#" * int(progress * 10)
        self.output(f"加载进度:{progress_bar} [{progress:.0%}]")
        
        start = end
        end += progress_delta
    
    self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")

1.load_bar_data

@lru_cache(maxsize=999)
def load_bar_data(
    symbol: str,
    exchange: Exchange,
    interval: Interval,
    start: datetime,
    end: datetime
):
    """"""
    return database_manager.load_bar_data(
        symbol, exchange, interval, start, end
    )

它其实是调用database_manager的load_bar_data方法。通过查看头部可以发现from vnpy.trader.database import database_manager,可是在vnpy.trader.database下面没有找到这个database_manager,只在init.py找到下面这句:database_manager: "BaseDatabaseManager" = init(settings=settings),据官方的说法:

这里的database_manager,是在database内部代码中定义的,并会基于GlobalSetting中的数据库配置自行创建配置不同的对象,inti函数就是返回这个对象

然后在vnpy\trader\database\database_mongo.py里面找到下在的语句。

class MongoManager(BaseDatabaseManager):

    def load_bar_data(
        self,
        symbol: str,
        exchange: Exchange,
        interval: Interval,
        start: datetime,
        end: datetime,
    ) -> Sequence[BarData]:
        s = DbBarData.objects(
            symbol=symbol,
            exchange=exchange.value,
            interval=interval.value,
            datetime__gte=start,
            datetime__lte=end,
        )
        data = [db_bar.to_bar() for db_bar in s]
        return data

这里的.objects是MangoDB的ORM语法。

2.to.bar函数:

def to_bar(self):
    """
    Generate BarData object from DbBarData.
    """
    bar = BarData(
        symbol=self.symbol,
        exchange=Exchange(self.exchange),
        datetime=self.datetime,
        interval=Interval(self.interval),
        volume=self.volume,
        open_interest=self.open_interest,
        open_price=self.open_price,
        high_price=self.high_price,
        low_price=self.low_price,
        close_price=self.close_price,
        gateway_name="DB",
    )
    return bar

再看看to_tick函数

    def to_tick(self):
        """
        Generate TickData object from DbTickData.
        """
        tick = TickData(
            symbol=self.symbol,
            exchange=Exchange(self.exchange),
            datetime=self.datetime,
            name=self.name,
            volume=self.volume,
            open_interest=self.open_interest,
            last_price=self.last_price,
            last_volume=self.last_volume,
            limit_up=self.limit_up,
            limit_down=self.limit_down,
            open_price=self.open_price,
            high_price=self.high_price,
            low_price=self.low_price,
            pre_close=self.pre_close,
            bid_price_1=self.bid_price_1,
            ask_price_1=self.ask_price_1,
            bid_volume_1=self.bid_volume_1,
            ask_volume_1=self.ask_volume_1,
            gateway_name="DB",
        )

        if self.bid_price_2:
            tick.bid_price_2 = self.bid_price_2
            tick.bid_price_3 = self.bid_price_3
            tick.bid_price_4 = self.bid_price_4
            tick.bid_price_5 = self.bid_price_5

            tick.ask_price_2 = self.ask_price_2
            tick.ask_price_3 = self.ask_price_3
            tick.ask_price_4 = self.ask_price_4
            tick.ask_price_5 = self.ask_price_5

            tick.bid_volume_2 = self.bid_volume_2
            tick.bid_volume_3 = self.bid_volume_3
            tick.bid_volume_4 = self.bid_volume_4
            tick.bid_volume_5 = self.bid_volume_5

            tick.ask_volume_2 = self.ask_volume_2
            tick.ask_volume_3 = self.ask_volume_3
            tick.ask_volume_4 = self.ask_volume_4
            tick.ask_volume_5 = self.ask_volume_5

        return tick

四、run_backtesting
这个函数的作是初始化策略,遍历之前的history_data,并撮合限价单,撮合停止单,再执行策略的on_bar函数。

def run_backtesting(self):
    """"""
    if self.mode == BacktestingMode.BAR:
        func = self.new_bar
    else:
        func = self.new_tick

    self.strategy.on_init()

    # Use the first [days] of history data for initializing strategy
    day_count = 0
    ix = 0
    
    for ix, data in enumerate(self.history_data):
        if self.datetime and data.datetime.day != self.datetime.day:
            day_count += 1
            if day_count >= self.days:
                break

        self.datetime = data.datetime
        self.callback(data)

    self.strategy.inited = True
    self.output("策略初始化完成")

    self.strategy.on_start()
    self.strategy.trading = True
    self.output("开始回放历史数据")

    # Use the rest of history data for running backtesting
    for data in self.history_data[ix:]:
        func(data)

    self.output("历史数据回放结束")

1.history_data
这里的history_data是在执行engine.load_data()后得到的。前面有定义ix=0,所以history_data[ix:]就是所有的数据了。

2.new_bar函数
先撮合限价单,再撮合停止单,再执行策略的on_bar函数,进行策略的判断,最后更新每天的收盘价?

def new_bar(self, bar: BarData):
    """"""
    self.bar = bar
    self.datetime = bar.datetime

    self.cross_limit_order()
    self.cross_stop_order()
    self.strategy.on_bar(bar)

    self.update_daily_close(bar.close_price)

五 、calculate_result
这个函数的功能是实现逐日盯市盈亏计算,返回一个名叫daily_df的DataFrame。

def calculate_result(self):
    """"""
    self.output("开始计算逐日盯市盈亏")

    if not self.trades:
        self.output("成交记录为空,无法计算")
        return

    # Add trade data into daily reuslt.
    for trade in self.trades.values():
        d = trade.datetime.date()
        daily_result = self.daily_results[d]
        daily_result.add_trade(trade)

    # Calculate daily result by iteration.
    pre_close = 0
    start_pos = 0

    for daily_result in self.daily_results.values():
        daily_result.calculate_pnl(
            pre_close, start_pos, self.size, self.rate, self.slippage
        )

        pre_close = daily_result.close_price
        start_pos = daily_result.end_pos

    # Generate dataframe
    results = defaultdict(list)

    for daily_result in self.daily_results.values():
        for key, value in daily_result.__dict__.items():
            results[key].append(value)

    self.daily_df = DataFrame.from_dict(results).set_index("date")

    self.output("逐日盯市盈亏计算完成")
    return self.daily_df        

1.self.trades
这里出现了self.trades,它是一个字典,我们发现它是在cross_limit_order、cross_stop_order这两个函数里面有赋值。因为这两个函数比较复杂难懂,关于这两个函数,打算下一章专门探索。

2.add_trade
这里出现了一个add_trade函数,非常简单,差不多就是一个append的封装。

    def add_trade(self, trade: TradeData):
        """"""
        self.trades.append(trade)

3.calculate_pnl

def calculate_pnl(
    self,
    pre_close: float,
    start_pos: float,
    size: int,
    rate: float,
    slippage: float,
):
    """"""
    self.pre_close = pre_close

    # Holding pnl is the pnl from holding position at day start
    self.start_pos = start_pos
    self.end_pos = start_pos
    self.holding_pnl = self.start_pos * \
        (self.close_price - self.pre_close) * size

    # Trading pnl is the pnl from new trade during the day
    self.trade_count = len(self.trades)

    for trade in self.trades:
        if trade.direction == Direction.LONG:
            pos_change = trade.volume
        else:
            pos_change = -trade.volume

        turnover = trade.price * trade.volume * size

        self.trading_pnl += pos_change * \
            (self.close_price - trade.price) * size
        self.end_pos += pos_change
        self.turnover += turnover
        self.commission += turnover * rate
        self.slippage += trade.volume * size * slippage

    # Net pnl takes account of commission and slippage cost
    self.total_pnl = self.trading_pnl + self.holding_pnl
    self.net_pnl = self.total_pnl - self.commission - self.slippage

六、calculate_statistics
这个函数的功能是计算策略统计指标,返回一个统计指标的字典:

def calculate_statistics(self, df: DataFrame = None, output=True):
    """"""
    self.output("开始计算策略统计指标")

    # Check DataFrame input exterior
    if df is None:
        df = self.daily_df
    
    # Check for init DataFrame 
    if df is None:
        # Set all statistics to 0 if no trade.
        start_date = ""
        end_date = ""
        total_days = 0
        ......
        return_drawdown_ratio = 0
    else:
        # Calculate balance related time series data
        df["balance"] = df["net_pnl"].cumsum() + self.capital
        df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
        df["highlevel"] = (
            df["balance"].rolling(
                min_periods=1, window=len(df), center=False).max()
        )
        df["drawdown"] = df["balance"] - df["highlevel"]
        df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100

        # Calculate statistics value
        start_date = df.index[0]
        end_date = df.index[-1]

        total_days = len(df)
        profit_days = len(df[df["net_pnl"] > 0])
        loss_days = len(df[df["net_pnl"] < 0])

        end_balance = df["balance"].iloc[-1]
        max_drawdown = df["drawdown"].min()
        max_ddpercent = df["ddpercent"].min()

        total_net_pnl = df["net_pnl"].sum()
        daily_net_pnl = total_net_pnl / total_days

        total_commission = df["commission"].sum()
        daily_commission = total_commission / total_days

        total_slippage = df["slippage"].sum()
        daily_slippage = total_slippage / total_days

        total_turnover = df["turnover"].sum()
        daily_turnover = total_turnover / total_days

        total_trade_count = df["trade_count"].sum()
        daily_trade_count = total_trade_count / total_days

        total_return = (end_balance / self.capital - 1) * 100
        annual_return = total_return / total_days * 240
        daily_return = df["return"].mean() * 100
        return_std = df["return"].std() * 100

        if return_std:
            sharpe_ratio = daily_return / return_std * np.sqrt(240)
        else:
            sharpe_ratio = 0

        return_drawdown_ratio = -total_return / max_ddpercent

    # Output
    if output:
        self.output("-" * 30)
        self.output(f"首个交易日:\t{start_date}")
        self.output(f"最后交易日:\t{end_date}")

        self.output(f"总交易日:\t{total_days}")
        self.output(f"盈利交易日:\t{profit_days}")
        self.output(f"亏损交易日:\t{loss_days}")

        self.output(f"起始资金:\t{self.capital:,.2f}")
        self.output(f"结束资金:\t{end_balance:,.2f}")

        self.output(f"总收益率:\t{total_return:,.2f}%")
        self.output(f"年化收益:\t{annual_return:,.2f}%")
        self.output(f"最大回撤: \t{max_drawdown:,.2f}")
        self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")

        self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
        self.output(f"总手续费:\t{total_commission:,.2f}")
        self.output(f"总滑点:\t{total_slippage:,.2f}")
        self.output(f"总成交金额:\t{total_turnover:,.2f}")
        self.output(f"总成交笔数:\t{total_trade_count}")

        self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
        self.output(f"日均手续费:\t{daily_commission:,.2f}")
        self.output(f"日均滑点:\t{daily_slippage:,.2f}")
        self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
        self.output(f"日均成交笔数:\t{daily_trade_count}")

        self.output(f"日均收益率:\t{daily_return:,.2f}%")
        self.output(f"收益标准差:\t{return_std:,.2f}%")
        self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
        self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")

    statistics = {
        "start_date": start_date,
        "end_date": end_date,
        "total_days": total_days,
        "profit_days": profit_days,
        "loss_days": loss_days,
        "capital": self.capital,
        "end_balance": end_balance,
        "max_drawdown": max_drawdown,
        "max_ddpercent": max_ddpercent,
        "total_net_pnl": total_net_pnl,
        "daily_net_pnl": daily_net_pnl,
        "total_commission": total_commission,
        "daily_commission": daily_commission,
        "total_slippage": total_slippage,
        "daily_slippage": daily_slippage,
        "total_turnover": total_turnover,
        "daily_turnover": daily_turnover,
        "total_trade_count": total_trade_count,
        "daily_trade_count": daily_trade_count,
        "total_return": total_return,
        "annual_return": annual_return,
        "daily_return": daily_return,
        "return_std": return_std,
        "sharpe_ratio": sharpe_ratio,
        "return_drawdown_ratio": return_drawdown_ratio,
    }

    return statistics

七、show_chart
最后是显示图表的函数。

def show_chart(self, df: DataFrame = None):
    """"""
    # Check DataFrame input exterior        
    if df is None:
        df = self.daily_df

    # Check for init DataFrame        
    if df is None:
        return

    plt.figure(figsize=(10, 16))

    balance_plot = plt.subplot(4, 1, 1)  #表示4行*1列,第一个图
    balance_plot.set_title("Balance")
    df["balance"].plot(legend=True)

    drawdown_plot = plt.subplot(4, 1, 2)
    drawdown_plot.set_title("Drawdown")
    drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)

    pnl_plot = plt.subplot(4, 1, 3)
    pnl_plot.set_title("Daily Pnl")
    df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])

    distribution_plot = plt.subplot(4, 1, 4)
    distribution_plot.set_title("Daily Pnl Distribution")
    df["net_pnl"].hist(bins=50)

    plt.show()      
本文暂无标签

发表评论

*

*