Home >  > 老男孩量化金融二(matplotlib及双均线)

老男孩量化金融二(matplotlib及双均线)

0

一、图表
1.图表类型

 
import matplotlib.pyplot as plt
 plt.hist()   #频数直方图
 plt.plot()   #线图,传入序列,元组、列表、numpy.ndarray
 plt.pie()
 plt.bar()
 plt.show()
 plt.scatter()

2.画图

fig = plt.figure()  创建一块画布
#将fig分成2*2,1表示是第一个图
ax1 = fig.add_subplot(2,2,1)

二、双均线策略

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

df = pd.read_csv('601318.csv', index_col='date', parse_dates=['date'])

df['ma5'] = np.nan
df['ma10'] = np.nan

# 第一步计算ma
# 循环计算,速度非常慢
#df.loc只能传行或列的名进去,loc左边是行,右边是列

# for i in range(4, len(df)):
#     df.loc[df.index[i], 'ma5'] = df['close'][i-4:i+1].mean()
# for i in range(9, len(df)):
#     df.loc[df.index[i], 'ma10'] = df['close'][i-9:i+1].mean()

# 方案2:cumsum

# close     =  [10, 11, 12, 13, 14, 15, 16]
# close.cumsum=[10, 21, 33, 46, 60, 75, 91]
#                                -   -   -
#               [nan,nan,nan,nan,0,  10, 21, 33, 46, 60, 75, 91]

# sr = df['close'].cumsum()
# df['ma5'] = (sr - sr.shift(1).fillna(0).shift(4))/5
# df['ma10'] = (sr - sr.shift(1).fillna(0).shift(9))/10

# 方案3:rolling

df['ma5'] = df['close'].rolling(5).mean()
df['ma10'] = df['close'].rolling(10).mean()

df = df.dropna()

df[['ma5', 'ma10']].plot()
plt.show()
# 第二部 判断金叉死叉
# 方案一
# 金叉 短期<=长期 短期>长期
# 死叉 短期>=长期 短期<长期
# sr = df['ma5'] <= df['ma10']
#
# golden_cross = []
# death_cross = []
# for i in range(1, len(sr)):
#     # if sr.iloc[i] == True and sr.iloc[i + 1] == False: 开始想的是加1,但是索引溢出
#     if sr.iloc[i - 1] == True and sr.iloc[i] == False:
#         golden_cross.append(sr.index[i])
#     if sr.iloc[i - 1] == False and sr.iloc[i] == True:
#         death_cross.append(sr.index[i])

# 方案2

golden_cross = df[(df['ma5'] <= df['ma10']) & (df['ma5'] > df['ma10']).shift(1)].index
death_cross = df[(df['ma5'] >= df['ma10']) & (df['ma5'] < df['ma10']).shift(1)].index

三、一个简单的回测框架
成果展示:

代码:

import pandas as pd
import matplotlib.pyplot as plt
import tushare
import datetime
import dateutil

'''
获取所有的股票交易日,交易日信息保存在csv文件
'''
try:
    trade_cal = pd.read_csv("trade_cal.csv")
except:
    trade_cal = tushare.trade_cal()
    trade_cal.to_csv("trade_cal.csv")

class Context:
    def __init__(self, cash, start_date, end_date):
        '''
        保存股票信息
        :param cash: 现金量
        :param start_date: 量化策略开始时间
        :param end_date: 量化策略结束时间
        :param positions: 持仓股票和对应的数量
        :param benchmark: 参考股票
        :param date_range: 开始-结束之间的所有交易日
        :param dt:  当前日期 (循环时当前日期会发生变化)
        '''
        self.cash = cash
        self.start_date = start_date
        self.end_date = end_date
        self.positions = {}  # 持仓信息
        self.benchmark = None
        self.date_range = trade_cal[(trade_cal['isOpen']==1)&
                                    (trade_cal['calendarDate']>=start_date)&
                                     (trade_cal['calendarDate']<=end_date)]['calendarDate'].values
        self.dt = None

class G:
    '''
    保存用户的全局参数
    '''
    pass

'''
默认的初始化信息
'''
g = G()
CASH = 100000
START_DATE = '2016-01-07'
END_DATE = '2017-01-31'
context = Context(CASH,START_DATE,END_DATE)


def attribute_history(security,
                      count,
                      field=('open','close','high','low','volume')):
    '''
    获取某股票count天的历史行情,每运行一次该函数,日期范围后移

    :param security: 股票代码
    :param count: 天数
    :param field: 字段
    :return:
    '''
    end_date = (context.dt - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
    start_date = trade_cal[(trade_cal['isOpen']==1)&
                           (trade_cal['calendarDate']<=end_date)][-count:]['calendarDate'].iloc[0]
    return attribute_daterange_history(security,start_date,end_date,field)

def attribute_daterange_history(security,
                                start_date,end_date,
                                field=('open','close','high','low','volume')):
    '''
    底层,获取某股票某一段时间的历史行情
    :param security:
    :param start_date:
    :param end_date:
    :param field:
    :return:
    '''
    df = tushare.get_k_data(security,start_date,end_date)
    df.index = df['date']
    return df[list(field)]


def get_today_data(security):
    '''
    获取context的"当天"的股票信息,停牌返回Null
    :param security:
    :return:
    '''
    try:
        today = context.dt.strftime('%Y-%m-%d')
        df = tushare.get_k_data(security,today,today)
        df.index = df['date']
        data = df.loc[today]
    except KeyError:  # 股票停牌
        data = pd.Series()
    return data


def _order(today_data, security, amount):
    '''
    底层买股票的函数
    :param today_data: "当天"的股票价格OCHL
    :param security: 股票代码
    :param amount: 交易股数,正数为买入,负数为卖出
    :return:
    '''
    p = today_data['open']
    # 找不到该股票默认为0股
    old_amount = context.positions.get(security, 0)

    if len(today_data) == 0:
        print("今日停牌")
        return
    if context.cash - amount * p < 0:
        amount = context.cash // p
        print('%s:现金不足,已调整为%d' %(today_data['date'],amount))
    if amount % 100 != 0:
        # 买或卖不是100的倍数就调整为100的倍数,卖光则不调整
        if amount != -old_amount:
            # 2345 => 2300
            amount = int(amount / 100) * 100
            print('%s:不是100的倍数,已调整为%d' %(today_data['date'],amount))
    if old_amount < -amount:
        amount = -old_amount
        print('%s:卖出股票不能超过持仓数,已调整为%d'%(today_data['date'],amount))

    # 更新持仓信息
    context.positions[security] = old_amount + amount
    # 更新钱
    context.cash -= amount*p
    # 持仓为0就删掉
    if context.positions[security] == 0:
        del context.positions[security]

def order(security, amount):
    # 买入股票。amount为正表示买入,负表示卖出
    today_data = get_today_data(security)
    _order(today_data, security, amount)
def order_target(security, amount):
    # 把股票交易到多少股,不能为负数,比原来小是卖出,比原来大是买入
    if amount < 0:
        print("数量不能为负,已调整为0")
        amount = 0
    today_data = get_today_data(security)
    hold_amount = context.positions.get(security, 0) # TODO: T + 1 closeable total
    delta_amount = amount - hold_amount
    _order(today_data,security,delta_amount)
def order_value(security, value):
    # 买多少钱的股票或者卖多少钱的股票
    today_data = get_today_data(security)
    amount = value / today_data['open']
    _order(today_data,security,amount)
def order_target_value(security, value):
    # 买到或者卖到多少钱
    if value < 0:
        print("价值不能为负,已调整为0")
        value = 0
    today_data = get_today_data(security)
    hold_value = context.positions.get(security,0) * today_data['open']
    dalta_value = value - hold_value
    order_value(security,dalta_value)

def run():
    plt_df = pd.DataFrame(index=pd.to_datetime(context.date_range),
                          columns=['value'])
    # 最初的钱,算收益率用
    init_value = context.cash
    # 保存停牌前一天的股票价格
    last_price = {}
    # 用户接口1
    initialize(context)
    for dt in context.date_range:
        context.dt = dateutil.parser.parse(dt)
        # 用户接口2
        handle_data(context)
        # 股票和现金的总价值
        value = context.cash
        for stock in context.positions:
            # 考虑停牌的情况
            today_data = get_today_data(stock)
            if len(today_data) == 0:
                p = last_price[stock]
            else:
                p = today_data['open']
                last_price[stock] = p
            value += p * context.positions[stock]
        plt_df.loc[dt, 'value'] = value
    plt_df['ratio'] = (plt_df['value']-init_value) / init_value

    bm_df = attribute_daterange_history(context.benchmark,
                                context.start_date,
                                context.end_date)
    bm_init = bm_df['open'][0]
    plt_df['benchmark_raito'] = (bm_df['open']-bm_init) / bm_init
    print(plt_df)
    plt_df[['ratio','benchmark_raito']].plot()
    plt.show()

'''
initialize和handle_data是用户的操作
'''
def initialize(context):
    context.benchmark = '601318'
    g.p1 = 5
    g.p2 = 60
    g.security = '601318'
def handle_data(context):
    hist = attribute_history(g.security, g.p2)
    ma5 = hist['close'][-g.p1:].mean()
    ma60 = hist['close'].mean()

    if ma5 > ma60 and g.security not in context.positions:
        order_value(g.security, context.cash)
    elif ma5 < ma60 and g.security in context.positions:
        order_target(g.security,0)

if __name__ == '__main__':
    run()

相关说明:
tushare.trade_cal() # 获取交易日信息,输出结果为:

calendarDate isOpen
0 1990/12/19 1
1 1990/12/20 1
2 1990/12/21 1
3 1990/12/22 0
4 1990/12/23 0
5 1990/12/24 1
6 1990/12/25 1
7 1990/12/26 1
8 1990/12/27 1
9 1990/12/28 1
10 1990/12/29 0
11 1990/12/30 0
12 1990/12/31 1
13 1991/1/1 0

本文暂无标签

发表评论

*

*