Plot milepost along coastline in python - geopandas

I want to plot mileposts on a map every 100 miles along the coastline. An example is shown in the figure below:
It is easy to plot the coastlines using Cartopy, but how to determine the locations of these mileposts and show them on a map? It would be better if the coastlines between the points were colored.

You can use Shapely's interpolate method to figure out the location of the mileposts. However, the tricky part of this is getting a singlepart linestring for the coastline. I messed around with a few coastline shapefiles that I downloaded, but due to the complexity and distance, getting a nice singlepart linestring was not a simple task. Therefore, I chose to digitize my own for this example (digitize.shp).
import geopandas as gpd
import numpy as np
import matplotlib
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
# hand digitized simplified version of coastline
cl_gdf = gpd.read_file("digitize.shp")
# project data from geographic crs so we have meters for interpolation below, vs degrees
cl_gdf = cl_gdf.to_crs(9311)
# get shapely ls from gdf
coastline = cl_gdf.iloc[0].geometry
interval = 160934 # approx 100 miles in meters
interval_arr = np.arange(0, coastline.length, interval )
points = [coastline.interpolate(d) for d in interval_arr] + [coastline.boundary[1]]
# create gdf from our list of interpolated Shapely points
points_gdf = gpd.GeoDataFrame(geometry=points, crs=9311)
# transform crs to wgs84 for plotting
points_gdf = points_gdf.to_crs(4326)
# add Lat and Long cols from geometry for plotting
points_gdf['Lat'] = (points_gdf['geometry'].apply(lambda geom: np.max([coord[1] for coord in geom.coords])))
points_gdf['Long'] = (points_gdf['geometry'].apply(lambda geom: np.max([coord[0] for coord in geom.coords])))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
ax.set_extent([points_gdf['Long'].min() - 1,
points_gdf['Long'].max() + 1,
points_gdf['Lat'].min() - 1,
points_gdf['Lat'].max() + 1],
crs=ccrs.PlateCarree())
# add our interpolated points to the Cartopy coastline
ax.scatter(points_gdf['Long'], points_gdf['Lat'], color='red', s=10)
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(18.5, 10.5)
plt.show()
EDIT:
Here is a replacement for "digitize.shp" for a full working example. You can use gpd's to_file() if you want to save cl_gdf to a shapefile.
from shapely.geometry import LineString
ls_coords = [(-84.35698841221803, 29.95854398344339),
(-84.05368623055452, 30.09279249008134),
(-83.96252983715839, 30.0613020996354),
(-83.7586709937452, 29.908822314318225),
(-83.44376708928581, 29.68673219222582),
(-83.40854757365547, 29.657727885236138),
(-83.16946921461201, 29.28191493609843),
(-82.7551219719023, 28.99684403311415),
(-82.66893774541866, 28.698514018363166),
(-82.68219685718537, 28.439961338912305),
(-82.79489930720241, 28.158205213869703),
(-82.63910474394356, 27.94937420354401),
(-82.54629096157659, 27.641099854967987),
(-82.63578996600191, 27.392491509342157),
(-82.26122005859231, 26.72290636512327),
(-81.52533935553987, 25.88095276793714),
(-80.85243943337927, 25.168275510476434),
(-80.7596256510123, 25.14175728694301),
(-80.65355275687861, 25.150044231797207),
(-80.59222936495757, 25.176562455330625),
(-80.49278602670725, 25.206395456805726),
(-80.65355275687864, 24.90143588617138),
(-80.53587813994908, 25.03236961486765),
(-80.23920551416893, 25.340643963443675),
(-80.27069590461487, 25.353903075210386),
(-80.37345402080688, 25.312468350939415),
(-80.104957007531, 25.963822216479077),
(-80.07180922811422, 26.91847826368225),
(-80.66846925761622, 28.60238545805451),
(-80.72150570468307, 28.86756769338872),
(-81.19883372828465, 29.663114399391368),
(-81.43749774008545, 30.392365546560455),
(-81.47727507538558, 31.095098470196124),
(-81.26512928711821, 31.406687596713834),
(-81.09276083415097, 31.844238285015287),
(-80.7612830399832, 32.175716079183054),
(-80.48947124876561, 32.48067564981741),
(-80.07180922811422, 32.61326676748452),
(-79.32929896917841, 33.07070612343603),
(-79.20996696327803, 33.22981546463656),
(-78.9779325073606, 33.61432970587117),
(-78.55364093082585, 33.8331050500219),
(-77.93046267779044, 33.919289276505516),
(-77.71831688952308, 34.28391485009006),
(-77.45976421007221, 34.469542414824005),
(-75.76259790393324, 35.20542311787645),
(-75.76259790393324, 36.21974516802982),
(-75.84215257453349, 36.58437074161438),
(-76.17031559075959, 36.939051981373886),
(-76.29959193048501, 36.96722759387815),
(-76.2946197635725, 37.12136476816616),
(-76.41726654741457, 37.16942904832049),
(-76.50179338492735, 37.23738199612488),
(-76.3377118768143, 37.5108511763133),
(-76.72885567393226, 37.79923685723926),
(-76.76200345334904, 37.878791527839525),
(-76.31119365328087, 37.68653440722222),
(-76.23163898268061, 37.89205063960623),
(-76.35428576652268, 37.951716642556434),
(-76.38411876799778, 37.95668880946896),
(-76.50676555183985, 38.02298436830251),
(-76.52665421948991, 38.05613214771928),
(-76.6028941121485, 38.10253903890277),
(-76.60952366803185, 38.12905726243619),
(-76.85978940262852, 38.16717720876549),
(-76.90785368278284, 38.19203804332807),
(-76.94265885117046, 38.2102693220073),
(-77.01724135485821, 38.313027438199306),
(-77.01724135485821, 38.31799960511182),
(-77.26916447842571, 38.3428604396744),
(-77.30396964681333, 38.38263777497453),
(-77.29568270195914, 38.52517322646668),
(-77.23270192106726, 38.60804267500862),
(-77.21612803135888, 38.63953306545456),
(-77.17966547400042, 38.613014841921135),
(-76.96254751882053, 38.4108133874788),
(-76.81338251144503, 38.27987965878253),
(-76.41063699153119, 38.311370049228465),
(-76.53825594228582, 38.710800791200626),
(-76.53494116434413, 39.2047027045106),
(-76.0808165863343, 39.532865720736694),
(-76.02446536132577, 39.370441601594486),
(-76.06755747456758, 39.254424373635764),
(-76.14048258928449, 39.101944588318595),
(-76.1835747025263, 38.956094358884776),
(-76.196833814293, 38.88316924416787),
(-76.18688948046797, 38.76383723826747),
(-76.17031559075959, 38.64782001030875),
(-76.094075698101, 38.51854367058332),
(-75.98468802602564, 38.33291610584937),
(-75.89850379954201, 38.24341710142407),
(-75.84546735247517, 38.190380654357234),
(-75.77917179364162, 38.10419642787361),
(-75.69961712304135, 38.011382645506636),
(-75.65652500979955, 37.95171664255644),
(-75.66978412156625, 37.91525408519799),
(-75.92502202307544, 37.54068417778841),
(-75.76591268187491, 37.481018174838205),
(-75.67972845539128, 37.46112950718814),
(-75.62337723038277, 37.48433295277989),
(-75.58359989508264, 37.59703540279693),
(-75.2189743214981, 38.051159980806766),
(-75.05655020235588, 38.37600821909118),
(-75.03666153470581, 38.44893333380809),
(-75.39465755240701, 39.1450367015604),
(-75.50404522448237, 39.39033026924455),
(-75.54713733772417, 39.536180498678355),
(-75.5438225597825, 39.60413344648275),
(-75.50901739139488, 39.57595783397849),
(-75.51896172521991, 39.48977360749487),
(-75.50073044654069, 39.45331105013641),
(-75.45100877741551, 39.41684849277796),
(-75.3996297193195, 39.38701549130286),
(-75.33333416048596, 39.34889554497357),
(-75.2753255465066, 39.31409037658595),
(-75.17588220825627, 39.27431304128582),
(-75.07146670309342, 39.23453570598569),
(-74.95710686410554, 39.19144359274388),
(-74.8957834721845, 39.169897536122974),
(-74.53778745448334, 39.270998263344154),
(-74.29249388679919, 39.46325538396146),
(-74.10023676618188, 39.89417651637956),
(-74.02731165146497, 40.046656301696736),
(-74.13338454559866, 40.32509764879766),
(-74.23945743973235, 40.4974661017649),
(-74.07371854264846, 40.550502548831744),
(-73.80853630731424, 40.56376166059845),
(-73.60302007493023, 40.550502548831744),
(-73.4439107337297, 40.61016855178194),
(-73.39750384254621, 40.70298233414891),
(-73.41739251019628, 40.88198034299951),
(-73.42070728813798, 40.935016790066356),
(-73.62953829846366, 40.8985542327079),
(-73.77207374995581, 40.83225867387435),
(-73.7588146381891, 40.931702012124674),
(-73.53672451609668, 41.031145350375006),
(-73.15552505280375, 41.1504773562754),
(-72.96658271012812, 41.17368080186715),
(-71.70033753640726, 41.35930836660109),
(-71.04401150395508, 41.50515859603491),
(-70.70590415390394, 41.65100882546872),
(-70.68601548625388, 41.670897493118794),
(-70.67938593037053, 41.51178815191826),
(-70.61309037153697, 41.53830637545168),
(-70.5070174774033, 41.76371127548577),
(-70.54016525682006, 41.86978416961945),
(-70.6528677068371, 42.00237528728656),
(-70.77882926862085, 42.247668854970726),
(-70.93130905393804, 42.41340775205461),
(-70.91142038628796, 42.5924057609052),
(-70.86501349510448, 42.65870131973875),
(-70.54679481270341, 43.32165690807428),
(-70.19542835088558, 43.63987559047534),
(-69.74461855081742, 43.81887359932593),
(-69.53910231843341, 43.885169158159485),
(-68.96233095658148, 44.289572067044155),
(-68.81648072714766, 44.44868140824468),
(-68.41207781826299, 44.48845874354482),
(-67.91486112701133, 44.42216318471126),
(-67.61653111226035, 44.52823607884495),
(-67.09942575335863, 44.70723408769554),
(-67.02650063864172, 44.760270534762384)]
ls = LineString(ls_coords)
cl_gdf = gpd.GeoDataFrame(geometry=[ls], crs=4326)

Related

How to create multipolygon shapely geometry in one row of geopandas dataframe?

I have a multipolygon geometry which consists of a list of polygon geometries (3 in this example). I want to create a geodataframe and assign it to just one row. However, when I do the following it creates multiple rows.
Any help would be appreciated !
>>> from shapely.geometry import shape
>>> geo={'type': 'MultiPolygon', 'coordinates': [[[[-6.89361451404645, 42.3483702456613], [-6.88669573921854, 42.3512517221675], [-6.87580712586245, 42.3552054810059], [-6.86866686394711, 42.3573490202175], [-6.86634649451855, 42.3578029961007], [-6.86485929944272, 42.3578451095022], [-6.86269565884165, 42.3567199999001], [-6.84977005165036, 42.3574553272632], [-6.85736376939718, 42.3584219957405], [-6.86260312778948, 42.3586101962362], [-6.86540904591301, 42.3592523372132], [-6.86769848173221, 42.3602680395218], [-6.86944163436685, 42.3613226434774], [-6.875639521488, 42.3661014013158], [-6.87837881937712, 42.3659090495874], [-6.88128451646765, 42.3660684347702], [-6.89613000001251, 42.3681444544187], [-6.90629355446821, 42.3611397165696], [-6.91134851142469, 42.3579245448312], [-6.91590536445934, 42.3556212772542], [-6.92234623065452, 42.352847993855], [-6.92358612842103, 42.3477036195428], [-6.9224104785545, 42.3439383624458], [-6.91770201332808, 42.3426902193944], [-6.90512650312747, 42.343936876722], [-6.9031378495764, 42.3439369775367], [-6.90191893841204, 42.3435986892315], [-6.89791212110307, 42.3461065492655], [-6.89361451404645, 42.3483702456613]]], [[[-6.89906785460217, 42.4682161832092], [-6.8985645069621, 42.470555820794], [-6.89858965513706, 42.4739212739318], [-6.89896550934163, 42.4765974195716], [-6.89972491760736, 42.4787657510451], [-6.9009810281459, 42.4802088026149], [-6.90230301564329, 42.4804035982091], [-6.90426571956671, 42.4800515116974], [-6.90827881089143, 42.4788005356882], [-6.91660939687048, 42.4756837812393], [-6.92382490186326, 42.474250220658], [-6.9265398868427, 42.4731320663596], [-6.927085378973, 42.4726705844922], [-6.92751384817931, 42.4720387429714], [-6.92793713400628, 42.4707715563402], [-6.92809711825426, 42.469121060489], [-6.92777864404758, 42.4655184358035], [-6.9265916168156, 42.460554962602], [-6.92509229324126, 42.4566502755129], [-6.92312921129026, 42.4529560798592], [-6.91994977756089, 42.4483393479285], [-6.91714736214994, 42.445084760463], [-6.91523950238876, 42.4432753489897], [-6.91432193977259, 42.4426295412105], [-6.91346078340642, 42.4422568680072], [-6.91227481057592, 42.4422099855322], [-6.91121575180791, 42.4427776558529], [-6.91058022044174, 42.4434975209533], [-6.91000109532557, 42.4444905206296], [-6.90674756055746, 42.4519678712067], [-6.90334259105767, 42.4586766049039], [-6.90013178996497, 42.4652742523013], [-6.89906785460217, 42.4682161832092]]], [[[-6.81153692785486, 42.38608709317], [-6.81255364283188, 42.389076878809], [-6.81428465162502, 42.3928882452291], [-6.81468909293668, 42.3942970262505], [-6.81502856739456, 42.3965697891244], [-6.8150767586934, 42.3981913586016], [-6.81488185813205, 42.4007833041592], [-6.81391983219631, 42.405510635491], [-6.81354770683345, 42.4084498797579], [-6.81360710668831, 42.4124126081536], [-6.81406833183345, 42.4154174564801], [-6.81488670719631, 42.4184504208322], [-6.81700648277012, 42.4245123095467], [-6.81829511926674, 42.4294140920437], [-6.81898116245038, 42.4332326515733], [-6.8198464186291, 42.4413041712813], [-6.82043986166016, 42.4449014594344], [-6.82104066966606, 42.4459522550983], [-6.82188536934636, 42.4467141034876], [-6.82339631054119, 42.4475679287968], [-6.82520716870599, 42.4482527587704], [-6.82670080610246, 42.4485898913254], [-6.8278429373007, 42.4485693985685], [-6.83292590445095, 42.4467192778025], [-6.84029070701108, 42.4443130224279], [-6.84419676594343, 42.4433610436994], [-6.84792497293665, 42.442772924522], [-6.85233512738876, 42.4425494530441], [-6.85882703332382, 42.4427959126746], [-6.86207019904241, 42.4431158543985], [-6.86826707567392, 42.4431984565626], [-6.86971582822253, 42.4429462314883], [-6.87071979951853, 42.4423075911722], [-6.87138234339051, 42.4415737757853], [-6.87343079514828, 42.4381019289848], [-6.87540506959181, 42.4356805263046], [-6.87730718261172, 42.4339309548587], [-6.87994811762039, 42.4321507482297], [-6.88482749307854, 42.4293696834051], [-6.89288504801627, 42.4252654154878], [-6.90380506518841, 42.4199717157414], [-6.91512371615772, 42.414746875], [-6.90892681128518, 42.4134221875], [-6.90518184926205, 42.4129543241767], [-6.89684841190103, 42.4114517049355], [-6.89382995952878, 42.4112506910577], [-6.89044820655216, 42.4118417747366], [-6.88899405368913, 42.4113504325048], [-6.88535290642418, 42.4093367301526], [-6.87968569673017, 42.4057680005089], [-6.87414202824488, 42.4020377762816], [-6.87059213937194, 42.3993812288458], [-6.86765946068107, 42.3968980089979], [-6.86534399217227, 42.3945881167378], [-6.86414324104532, 42.3931444816688], [-6.86243928318094, 42.3904912241775], [-6.86143446636381, 42.3892932385407], [-6.86020235510347, 42.3881839241729], [-6.85874294939991, 42.387163281074], [-6.85612766376356, 42.385798575055], [-6.85410002195198, 42.3849996101285], [-6.84390772477599, 42.3822068637876], [-6.83932508071557, 42.3805370409279], [-6.83745999207373, 42.3794742338664], [-6.83322271926195, 42.3756476996947], [-6.83216467694021, 42.3748753255289], [-6.82967109419671, 42.373518847268], [-6.82611729148701, 42.3719774181725], [-6.8235117854645, 42.3715185151684], [-6.82153020724814, 42.3716568375518], [-6.81952510157509, 42.3722087310524], [-6.81681502907838, 42.3735879219019], [-6.81433436043164, 42.3754585172813], [-6.81253514636759, 42.377414128838], [-6.81141738688623, 42.3794547565721], [-6.81098108198756, 42.3815804004834], [-6.811094251368, 42.3837911369289], [-6.81153692785486, 42.38608709317]]]]}
>>> geo_shape=shape(geo_shape)
>>> import geopandas as gpd
>>> df = gpd.GeoDataFrame({'id':1,'geometry':geo_shape})
>>> df
id geometry
0 1 POLYGON ((-6.89361 42.34837, -6.88670 42.35125...
1 1 POLYGON ((-6.89907 42.46822, -6.89856 42.47056...
2 1 POLYGON ((-6.81154 42.38609, -6.81255 42.38908...
I just found the solution by doing the following. Hope it helps !
>>> df = gpd.GeoDataFrame({'id':[1],'geometry':[geo_shape]})
>>> df
id geometry
0 1 MULTIPOLYGON (((-6.89361 42.34837, -6.88670 42...

GEKKO: MHE load data of previous cycle

i am developing a model predictive controller (MPC) with a moving horizon estimation (MHE) Plugin for a dynamic simulation program.
My Problem is, that the simulation program executes the Python script in each timestep. So each timestep a new model in GEKKO is produced. Is there a possibility reload the model and the data files? So for example give the path of the data to GEKKO?
Best Regards,
Moritz
Try using a Pickle file to store the Gekko model. If the Gekko model archive exists then it is read back into Python.
from os.path import exists
import pickle
import numpy as np
from gekko import GEKKO
import matplotlib.pyplot as plt
if exists('m.pkl'):
# load model from subsequent call
m = pickle.load(open('m.pkl','rb'))
m.solve()
else:
# define model the first time
m = GEKKO()
m.time = np.linspace(0,20,41)
m.p = m.MV(value=0, lb=0, ub=1)
m.v = m.CV(value=0)
m.Equation(5*m.v.dt() == -m.v + 10*m.p)
m.options.IMODE = 6
m.p.STATUS = 1; m.p.DCOST = 1e-3
m.v.STATUS = 1; m.v.SP = 40; m.v.TAU = 5
m.options.CV_TYPE = 2
m.solve()
pickle.dump(m,open('m.pkl','wb'))
plt.figure()
plt.subplot(2,1,1)
plt.plot(m.time,m.p.value,'b-',lw=2)
plt.ylabel('gas')
plt.subplot(2,1,2)
plt.plot(m.time,m.v.value,'r--',lw=2)
plt.ylabel('velocity')
plt.xlabel('time')
plt.show()
Each cycle of the controller, the plot updates with the automatic time-shift of the initial condition.
This is similar to what happens in a loop with a combined MHE and MPC. As long as you include everything in the Pickle file, it should reload on the next cycle.
Here is the example code for MHE and MPC.

exclude one of the hue from seaborn catplot visualization

I want to visualize category count by seaborn catplot but one of the hue are not important and don't need to include the visualization.
How can I select specific Hues at catplot to visualize without changing or removing any value from the column ?
You could remove the rows with that value from the dataframe. If the column is Categorical you might also need to change the categories as the legend will still contain all the categories.
Here is an example:
import seaborn as sns
import pandas as pd
tips = sns.load_dataset('tips')
tips['day'].dtype # CategoricalDtype(categories=['Thur', 'Fri', 'Sat', 'Sun'], ordered=False)
# create a subset, a copy is needed to be able to change the categorical column
tips_weekend = tips[tips['day'].isin(['Sat', 'Sun'])].copy()
tips_weekend['day'].dtype # CategoricalDtype(categories=['Thur', 'Fri', 'Sat', 'Sun'], ordered=False)
tips_weekend['day'] = pd.Categorical(tips_weekend['day'], ['Sat', 'Sun'])
tips_weekend['day'].dtype # CategoricalDtype(categories=['Sat', 'Sun'], ordered=False)
sns.catplot(data=tips_weekend, x='smoker', y='tip', hue='day')
For the follow-up question, a histplot with multiple='fill' can show the percentage distribution:
import seaborn as sns
import pandas as pd
from matplotlib.ticker import PercentFormatter
tips = sns.load_dataset('tips')
tips_weekend = tips.copy()
tips_weekend['day'] = tips_weekend['day'].apply(lambda x: x if x in ['Sat', 'Sun'] else 'other')
# fix a new order
tips_weekend['day'] = pd.Categorical(tips_weekend['day'], ['other', 'Sat', 'Sun'])
ax = sns.histplot(data=tips_weekend, x='smoker', hue='day', stat='count', multiple='fill',
palette=['none', 'turquoise', 'crimson'])
# remove the first label ('other') in the legend
ax.legend(handles=ax.legend_.legendHandles[1:], labels=['Sat', 'Sun'], title='day')
ax.yaxis.set_major_formatter(PercentFormatter(1))
# add percentages
for bar_group in ax.containers[:-1]:
ax.bar_label(bar_group, label_type='center', labels=[f'{bar.get_height() * 100:.1f} %' for bar in bar_group])

Dimensionality problem in creation of patches

I have a binary image of 650x650 size. I want to create patches of 50x50. This means that I need 169 patches. I want to examine if there is in every patch is at least "ONE" element.
I need also the result to be pairs of every patch.
Here there is an example of 2d:
2d example
So far so good.When I implement the view_as_blocks function from skimage.util.shape it returns a list of (13,13,50,50).
The way i search for "ONES" will be numpy.where but i am lost in dimensions...
here is my code:
def distinguish_patches_for_label(image_path):
im1=cv2.imread(image_path)
im2 = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
im3=(im2>100).astype(np.uint8)
im4 = cv2.resize(im3, (650,650))
patches=view_as_blocks(im4, block_shape=(50,50))
for i in range(patches.shape[0]):
for j in range(patches.shape[0]):
indexes_normal=list(zip(*np.where(patches[i,j,:,:]== 0)))
for i in range(patches.shape[0]):
for j in range(patches.shape[0]):
indexes_XDs=list(zip(*np.where(patches[i,j,:,:]== 1)))
list1=[]
list2=[]
for i in range(len(indexes_normal)):
list1.append(indexes_normal[i][0])
list2.append(indexes_normal[i][1])
zipped_indexes_normal=list(zip(list1,list2))
list3=[]
list4=[]
for i in range(len(indexes_XDs)):
list3.append(indexes_XDs[i][0])
list4.append(indexes_XDs[i][1])
zipped_indexes_XDs=list(zip(list3,list4))
return zipped_indexes_normal,zipped_indexes_XDs
I'm not exactly sure what you are trying to get as your output, but the below will find the number of 1s in each patch:
import numpy as np
from skimage import io, util
def count_ones_in_patches(image_path):
image = io.imread(image_path)
patches = util.view_as_blocks(image, (50, 50))
ones_per_patch = np.sum(patches, axis=(2, 3))
return ones_per_patch
Using the numpy axis= keyword argument is a very good idea in general. =)

LSTM - LSTM - future value prediction error

After some research, I was able to predict the future value using the LSTM code below. I have also attached the Dmd1ahr.csv file in the github link that I am using.
https://github.com/ukeshchawal/hello-world/blob/master/Dmd1ahr.csv
As you all can see below, 90 data points are training sets and 91st to 100th are future value prediction.
However some of the questions that I still have are:
In order to predict these values I had to originally take more than hundred data sets (here, I have taken 500 data sets) which is not exactly what my primary goal is. Is there a way that given 500 data sets, it will predict the rest 10 or 20 out of sample data points? If yes, will you please write me a sample code where you can just take 500 data points from Dmd1ahr.csv file attached below and it will predict some future values (say 501 to 520) based on those 500 points?
The prediction are way off compared to the one who have in your blogs (definitely indicates for parameter tuning - I tried changing epochs, LSTM layers, Activation, optimizer). What other parameter tuning I can do to make it more robust?
Thank you'll in advance.
import numpy as np
import matplotlib.pyplot as plt
import pandas
# By twaking the architecture it could be made more robust
np.random.seed(7)
numOfSamples = 500
lengthTrain = 90
lengthValidation = 100
look_back = 1 # Can be set higher, in my experiments it made performance worse though
transientTime = 90 # Time to "burn in" time series
series = pandas.read_csv('Dmd1ahr.csv')
def generateTrainData(series, i, look_back):
return series[i:look_back+i+1]
trainX = np.stack([generateTrainData(series, i, look_back) for i in range(lengthTrain)])
testX = np.stack([generateTrainData(series, lengthTrain + i, look_back) for i in range(lengthValidation)])
trainX = trainX.reshape((lengthTrain,look_back+1,1))
testX = testX.reshape((lengthValidation, look_back + 1, 1))
trainY = trainX[:,1:,:]
trainX = trainX[:,:-1,:]
testY = testX[:,1:,:]
testX = testX[:,:-1,:]
############### Build Model ###############
import keras
from keras.models import Model
from keras import layers
from keras import regularizers
inputs = layers.Input(batch_shape=(1,look_back,1), name="main_input")
inputsAux = layers.Input(batch_shape=(1,look_back,1), name="aux_input")
# this layer makes the actual prediction, i.e. decides if and how much it goes up or down
x = layers.recurrent.LSTM(300,return_sequences=True, stateful=True)(inputs)
x = layers.recurrent.LSTM(200,return_sequences=True, stateful=True)(inputs)
x = layers.recurrent.LSTM(100,return_sequences=True, stateful=True)(inputs)
x = layers.recurrent.LSTM(50,return_sequences=True, stateful=True)(inputs)
x = layers.wrappers.TimeDistributed(layers.Dense(1, activation="linear",
kernel_regularizer=regularizers.l2(0.005),
activity_regularizer=regularizers.l1(0.005)))(x)
# auxillary input, the current input will be feed directly to the output
# this way the prediction from the step before will be used as a "base", and the Network just have to
# learn if it goes a little up or down
auxX = layers.wrappers.TimeDistributed(layers.Dense(1,
kernel_initializer=keras.initializers.Constant(value=1),
bias_initializer='zeros',
input_shape=(1,1), activation="linear", trainable=False
))(inputsAux)
outputs = layers.add([x, auxX], name="main_output")
model = Model(inputs=[inputs, inputsAux], outputs=outputs)
model.compile(optimizer='adam',
loss='mean_squared_error',
metrics=['mean_squared_error'])
#model.summary()
#model.fit({"main_input": trainX, "aux_input": trainX[look_back-1,look_back,:]},{"main_output": trainY}, epochs=4, batch_size=1, shuffle=False)
model.fit({"main_input": trainX, "aux_input": trainX[:,look_back-1,:].reshape(lengthTrain,1,1)},{"main_output": trainY}, epochs=100, batch_size=1, shuffle=False)
############### make predictions ###############
burnedInPredictions = np.zeros(transientTime)
testPredictions = np.zeros(len(testX))
# burn series in, here use first transitionTime number of samples from test data
for i in range(transientTime):
prediction = model.predict([np.array(testX[i, :, 0].reshape(1, look_back, 1)), np.array(testX[i, look_back - 1, 0].reshape(1, 1, 1))])
testPredictions[i] = prediction[0,0,0]
burnedInPredictions[:] = testPredictions[:transientTime]
# prediction, now dont use any previous data whatsoever anymore, network just has to run on its own output
for i in range(transientTime, len(testX)):
prediction = model.predict([prediction, prediction])
testPredictions[i] = prediction[0,0,0]
# for plotting reasons
testPredictions[:np.size(burnedInPredictions)-1] = np.nan
############### plot results ###############
#import matplotlib.pyplot as plt
plt.plot(testX[:, 0, 0])
plt.show()
plt.plot(burnedInPredictions, label = "training")
plt.plot(testPredictions, label = "prediction")
plt.legend()
plt.show()

Resources