Recurrent Neural Networks for Churn Prediction

I just posted a simple implementation of WTTE-RNNs in Keras on GitHub: Keras Weibull Time-to-event Recurrent Neural Networks. I'll let you read up on the details in the linked information, but suffice it to say that this is a specific type of neural net that handles time-to-event prediction in a super intuitive way. If you're thinking of building a model to predict (rather than understand) churn, I'd definitely consider giving this a shot. And with Keras, implementing the model is pretty darn easy.

As proof of the model's effectiveness, here's my demo model (with absolutely no optimization) predicting the remaining useful life of jet engines. It's not perfect by any means, but it's definitely giving a pass to engines that are in the clear, and flagging ones that are more likely to fail (plus a few false positives). I have no idea how much better it could do with some tweaking:

Demo WTTE-RNN Performance

Also, if anybody's curious as to why I've been in a bit of a post desert lately, my wife and I recently had a baby and I haven't been giving as much thought to the blog. However, I have some ideas brewing!

22 Responses

  1. Jimmy March 24, 2017 / 4:32 am

    Dear blogger,
    What is the relation between test_x and test_y? I thought the data in the test_y were life, but obviously they were not.

    • daynebatten March 24, 2017 / 9:52 am

      In the demo code I posted, test_y is a (samples, 2) tensor containing time-to-event (y) and a 0/1 event indicator (u), specifically for the testing data (as opposed to the training data). The event indicator for test_y is always 1, because this is the true remaining lifetime for all these engines.

      test_x is simply a history of sensor readings for each engine, used to make predictions about its remaining useful life.

      If the model is working well, when it’s fed test_x and asked to make predictions, it should generate predictions roughly equivalent to the lifetimes in test_y.

      Hope that helps.

      • Jimmy March 28, 2017 / 3:09 am

        Thanks for your reply. It has been very helpful. I want to try to predict the real-time remaining life of the engines. Do you have any suggestion?

        • daynebatten March 28, 2017 / 7:40 am

          That’s actually exactly what the demo code is for… the alpha and beta parameters in the output describe a Weibull distribution showing the propensity for each engine to fail over future time.

          If you need a point estimate of life remaining (like I’ve used for the graph above), you can simply set the Weibull survivor function equal to .5 using the output alpha and beta and solve for time. The survivor function is exp(-(t/a)^b). Set it equal to .5 and solve for t and you get a(-ln(.5))^(1/b).

          Hope that helps.

  2. Jimmy March 30, 2017 / 2:53 am

    I calculated the remaining useful life and I found the values are quite far from the actual one. Take engine 1 as an example, the predicted result was [ 112. 1. 218.93415833 4.13228035]. The t I solved is about
    219*(-ln(.5))^(1/4)=196,
    but the actual value is 31. Why is it so different?
    Thanks!

    • daynebatten March 30, 2017 / 6:14 am

      If you look at the code, you’ll see that the output prints the test_y values next to the predictions, so that 112 in your output is the actual remaining life of that engine. Not sure where you got 31 from.

      A few other things to remember:

      1) This prediction is saying the engine has a 50% chance of surviving to day 196. We’d expect some engines to fail before then and some to fail after then.

      2) This is machine learning, not black magic. No predictions are going to be perfect.

      3) This is a demo and I made no attempts to optimize this model whatsoever. It’s entirely possible that making some tweaks to the model (additional layers of neurons, for example) could dramatically improve performance. Feel free to mess around.

      • Jimmy March 30, 2017 / 10:31 am

        Sorry about my stupid questions. XD
        I have been learning LSTM recently and I have not completed a single model of RNN (LSTM). Your example is wonderful and I am sure I will learn a lot from it. The value 31 was learned from the data test_x. According to the data, engine 1 ran 31 cycles and failed, so I think TTE value should be 31.(Am I right?) However, the value is 112 in test_y.
        Thanks!

        • daynebatten March 30, 2017 / 1:19 pm

          No problem at all. The best way to learn is by taking on new challenges and muddling through.

          Best of luck!

          (Don’t have time to go hunt down the 31 right now, but if the 31 came from test_x, it’s some kind of sensor reading from the engine.)

  3. Phillip April 1, 2017 / 4:28 am

    Mr. Batten,

    I am a doctoral student at NC state interested in discussing a section of your recent dissertation with you. If willing, please email me at your earliest convenience.

    • daynebatten April 3, 2017 / 7:43 am

      If you’re looking to talk about this model, I’d be happy to have a conversation with you, but I haven’t written a dissertation on the subject (or any subject, for that matter)… Perhaps you’re referring to Egil Martinsson’s thesis?

      I’ll send you an email offline, though I’m also mildly concerned this is trolling what with it being April 1 and all?

  4. Natalia Connolly April 12, 2017 / 1:17 pm

    Hi Dayne,

    Thank you for making your code available on github! This is really cool stuff.

    A quick question: where did you get the true time-to-event numbers (the ones in test_y)? Was it part of the NASA dataset?

    • daynebatten April 12, 2017 / 1:46 pm

      Hi Natalia,

      Yes, it was part of the NASA data set. The files starting with RUL indicate the remaining useful life for each engine in the test data files, which start with test_.

      The organization of the data is admittedly not the best…

      Thanks!

  5. Ashwin May 20, 2017 / 9:43 am

    Mr. Batten,

    Thanks for the simple write up and code! I am still learning the workings and have a basic question –
    The RUL = a(-ln(.5))^(1/b) that you mentioned to Jimmy’s post, is it absolute time or a percentage of the max time of the engine?

    Cheers!
    Ashwin

    • daynebatten May 22, 2017 / 3:22 pm

      Thanks for writing. That should be absolute time!

  6. Alberto Albiol May 22, 2017 / 5:22 am

    Excellent work, I have been digging into your code to get some light on recurrent models.
    I have two questions:
    1- Do you think that the result would be very different without masking?, you have already left-padded the sequences with zeros

    2- You are using a sliding window of 100 timesteps, but another option is using a stateful char-rnn, so the network can remember the whole history and reset the state for each engine?, have you tried something like this?. I think that the advantage of your approach is that you can shuffle the training samples….

    • daynebatten May 22, 2017 / 3:26 pm

      Good questions.

      1) I’m not an expert on NNs, so I’m not 100% sure, but I expect it wouldn’t make a super-big difference. I’d imagine the network would learn that having everything set to 0s is basically useless information and begin to ignore it. Not ideal to have some of your model’s knowledge dedicated to that, though, if it’s not necessary.

      2) Yes, you could use a stateful char-rnn, but there’d be two downsides. First, you could only back-propagate through time for as many time steps as you had in each batch. If you kept, say, 10 time-steps per batch, this might not matter… but I don’t know. I haven’t played with this dummy model much. Second, there’s just some Keras challenges involved in managing batches when you have time series of differing lengths. It’s not anything that can’t be accomplished, but it would be some extra coding.

      Thanks for writing!

  7. Jack May 24, 2017 / 1:46 am

    Thanks for posting this. On which hardware did you train the model and how much time did it take?

    • daynebatten May 24, 2017 / 8:09 am

      I trained the dummy model on my laptop (core i5). I don’t remember exactly how long it took, but definitely no more than an hour or so.

      I’ve scaled this up to a different data set with 52 observations of 67 features for ~100k individuals. As with the jet engine data, this expanded to a full historical look-back for each period in which the individual was observed. It was a pretty big model. Trains on a p2.8xlarge in ec2 (8 Tesla k80s) in a couple hours…

  8. DM June 26, 2017 / 12:44 pm

    Hi Dayne, thanks so much for taking the time to put this together. I’m trying to walk through the work you’ve done with the goal of spinning up my own churn prediction model. Downloading the data from the NASA site and running the code from your github page, I’ve plotted the predicted vs actual time to survival, similar to the graph you have near the bottom of your post.

    I see similar patterns and clusters between our two plots, except the vertical bar in my plot is centered at around 90 days whereas yours it at 60. Without asking you to debug my code, and being aware of the randomness inherent in the model, is it possible that your plot was generated from a different model than what you posted?

    Like I said, I pasted your code in its entirety so I want to verify the input to the chart before I begin diving into other possible discrepancies. Again, thanks a ton for the all the work you’ve done, it’s a great simplified version of Egil’s work and has been really helpful for me to wrap my head around this.

    • DM June 26, 2017 / 5:10 pm

      Nevermind! It was an issue in my code, somehow 🙂

  9. slaw June 29, 2017 / 2:44 pm

    Great post. I had a question about the normalization. If I train a model on a single feature that has raw values between 0-0.5 and the test data has values between 0-1.0 then wouldn’t the separate normalizations of the training and test sets be inconsistent? It would be even worse if the new incoming unlabeled data had values between 0-2.0 and so the normalization would be inconsistent across all three data sets. Is there a good way to address this?

    • daynebatten July 6, 2017 / 7:33 am

      This is just a bug… thanks for pointing it out. I actually noticed it in a private project a couple weeks ago and haven’t gotten around to fixing it in the public repo. I’ll knock that out now.

Leave a Reply

Your email address will not be published. Required fields are marked *