# read a neural net and make predictions 
# You have to read API of JooneTools
# This example is rewrite of the JAVA example
# which is located in samples/engine/helpers
# of the Joone package: http://www.jooneworld.com/
# Note: only joone-engine.jar  is used.
# S.Chekanov. JHepwWork
#
from  java.io import *
from  org.joone.engine import *
from  org.joone.helpers.factory import JooneTools
from  org.joone.io import *;
from  org.joone.net import *
from  org.joone.util import NormalizerPlugIn


# path to the files of this example
filePath =SystemDir+fSep+"macros"+fSep+"examples"+fSep+"neural_net"+fSep;

# input data for predictions 
fileName=filePath+"wine_forecast.txt"


# log file with predictions 
logName=filePath+"wine_forecast.log"

#  input file with trained NN
nnOutput=filePath+"nn_wine.snet"


# output stream to the log file
sout=PrintStream( FileOutputStream(logName) )


fileIn = FileInputSynapse();
fileIn.setInputFile(File(fileName));
fileIn.setAdvancedColumnSelector("1-14")



# Input data normalized between -1 and 1
normIn = NormalizerPlugIn()
normIn.setAdvancedSerieSelector("2-14")
normIn.setMin(-1)
normIn.setMax(1)
fileIn.addPlugIn(normIn)

# Target data normalized between 0 and 1
normOut = NormalizerPlugIn();
normOut.setAdvancedSerieSelector("1")
fileIn.addPlugIn(normOut)


# Extract the test data (24 raws)
inputTest = JooneTools.getDataFromStream(fileIn,1,24,2,14)
desiredTest = JooneTools.getDataFromStream(fileIn,1,24,1,1)


# Now assume that you have some data of the type you used to train
# now read the trained network from the file
mess="read the trained NN from the file nn_wine.snet "
print mess; sout.print( mess)

nnet=JooneTools.load(nnOutput)
sout.print( "\n\nNN from the file=\n"+nnOutput +"\n")


# output stream to the log file
sout=PrintStream( FileOutputStream(logName) )


# Gets and prints the final values
attrib = nnet.getDescriptor()

out_text1="\nLast training rmse="+str(attrib.getTrainingError())
out_text1=out_text1+" at epoch "+str(attrib.getLastEpoch())
print  out_text1
sout.print( out_text1+"\n")


# now compare outputs
out = JooneTools.compare(nnet,inputTest,desiredTest)
sout.print("\nPredictions  for "+str(out.__len__())+ " rows:\n")
cols = (out[0].__len__())/2

for i in range(out.__len__()):
        sout.print('\n NN Forecast: ')
        for x in range(cols):
             sout.print(out[i][x])



print "Log file is",logName
view.open( logName, 0  )

# jHepWork @S.Chekanov